diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index fc3cde9c..33ce2007 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -94,7 +94,7 @@ jobs: install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe rocm-version: '6.2' flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' - runner_dir: '' + runner_dir: 'rocm' runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} environment: release env: @@ -163,6 +163,7 @@ jobs: cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} -DOLLAMA_RUNNER_DIR="${{ matrix.runner_dir }}" cmake --build --parallel --preset "${{ matrix.preset }}" cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8 + Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue env: CMAKE_GENERATOR: Ninja - uses: actions/upload-artifact@v4 @@ -175,19 +176,19 @@ jobs: matrix: os: [windows] arch: [amd64, arm64] + include: + - os: windows + arch: amd64 + llvmarch: x86_64 + - os: windows + arch: arm64 + llvmarch: aarch64 runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} environment: release needs: [setup-environment] env: GOFLAGS: ${{ needs.setup-environment.outputs.GOFLAGS }} steps: - - name: Install AMD64 system dependencies - if: matrix.arch == 'amd64' - run: | - $ErrorActionPreference = "Stop" - Start-Process "C:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait - echo "C:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - name: Install ARM64 system dependencies if: matrix.arch == 'arm64' run: | @@ -199,15 +200,29 @@ jobs: choco install -y --no-progress git gzip echo "C:\Program Files\Git\cmd" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - - Invoke-WebRequest -Uri "https://github.com/mstorsjo/llvm-mingw/releases/download/20240619/llvm-mingw-20240619-ucrt-aarch64.zip" -OutFile "${{ runner.temp }}\llvm-mingw-ucrt-aarch64.zip" - Expand-Archive -Path ${{ runner.temp }}\llvm-mingw-ucrt-aarch64.zip -DestinationPath "C:\Program Files\" - $installPath=(Resolve-Path -Path "C:\Program Files\llvm-mingw-*-ucrt-aarch64").path - echo $installPath\bin | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + - name: Install clang and gcc-compat + run: | + $ErrorActionPreference = "Stop" + Set-ExecutionPolicy Bypass -Scope Process -Force + Invoke-WebRequest -Uri "https://github.com/mstorsjo/llvm-mingw/releases/download/20240619/llvm-mingw-20240619-ucrt-${{ matrix.llvmarch }}.zip" -OutFile "${{ runner.temp }}\llvm-mingw-ucrt.zip" + Expand-Archive -Path ${{ runner.temp }}\llvm-mingw-ucrt.zip -DestinationPath "C:\Program Files\" + $installPath=(Resolve-Path -Path "C:\Program Files\llvm-mingw-*-ucrt*").path + echo "$installPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version-file: go.mod + - name: Verify gcc is actually clang + run: | + $ErrorActionPreference='Continue' + $version=& gcc -v 2>&1 + $version=$version -join "`n" + echo "gcc is $version" + if ($version -notmatch 'clang') { + echo "ERROR: GCC must be clang for proper utf16 handling" + exit 1 + } + $ErrorActionPreference='Stop' - run: | go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ . - uses: actions/upload-artifact@v4 @@ -222,13 +237,13 @@ jobs: include: - os: linux arch: amd64 - target: archive + target: archive_novulkan - os: linux arch: amd64 target: rocm - os: linux arch: arm64 - target: archive + target: archive_novulkan runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} environment: release needs: setup-environment @@ -284,12 +299,14 @@ jobs: include: - os: linux arch: arm64 + target: novulkan build-args: | CGO_CFLAGS CGO_CXXFLAGS GOFLAGS - os: linux arch: amd64 + target: novulkan build-args: | CGO_CFLAGS CGO_CXXFLAGS @@ -302,6 +319,14 @@ jobs: CGO_CXXFLAGS GOFLAGS FLAVOR=rocm + - os: linux + arch: amd64 + suffix: '-vulkan' + target: default + build-args: | + CGO_CFLAGS + CGO_CXXFLAGS + GOFLAGS runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} environment: release needs: setup-environment @@ -319,6 +344,7 @@ jobs: with: context: . platforms: ${{ matrix.os }}/${{ matrix.arch }} + target: ${{ matrix.target }} build-args: ${{ matrix.build-args }} outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index e470540a..12ee7135 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -52,6 +52,12 @@ jobs: container: rocm/dev-ubuntu-22.04:6.1.2 extra-packages: rocm-libs flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm' + - preset: Vulkan + container: ubuntu:22.04 + extra-packages: > + mesa-vulkan-drivers vulkan-tools + libvulkan1 libvulkan-dev + vulkan-sdk cmake ccache g++ make runs-on: linux container: ${{ matrix.container }} steps: @@ -59,7 +65,19 @@ jobs: - run: | [ -n "${{ matrix.container }}" ] || sudo=sudo $sudo apt-get update + # Add LunarG Vulkan SDK apt repo for Ubuntu 22.04 + if [ "${{ matrix.preset }}" = "Vulkan" ]; then + $sudo apt-get install -y --no-install-recommends wget gnupg ca-certificates software-properties-common + wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | $sudo gpg --dearmor -o /usr/share/keyrings/lunarg-archive-keyring.gpg + # Use signed-by to bind the repo to the installed keyring to avoid NO_PUBKEY + echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan/1.4.313 jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-1.4.313-jammy.list > /dev/null + $sudo apt-get update + fi $sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }} + # Export VULKAN_SDK if provided by LunarG package (defensive) + if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then + echo "VULKAN_SDK=/usr" >> $GITHUB_ENV + fi env: DEBIAN_FRONTEND: noninteractive - uses: actions/cache@v4 @@ -92,18 +110,21 @@ jobs: - preset: ROCm install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' + - preset: Vulkan + install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe runs-on: windows steps: - run: | choco install -y --no-progress ccache ninja ccache -o cache_dir=${{ github.workspace }}\.ccache - - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' + - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' id: cache-install uses: actions/cache/restore@v4 with: path: | C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\AMD\ROCm + C:\VulkanSDK key: ${{ matrix.install }} - if: matrix.preset == 'CUDA' name: Install CUDA ${{ matrix.cuda-version }} @@ -133,6 +154,18 @@ jobs: echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append + - if: matrix.preset == 'Vulkan' + name: Install Vulkan ${{ matrix.rocm-version }} + run: | + $ErrorActionPreference = "Stop" + if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { + Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" + Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait + } + + $vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path + echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} uses: actions/cache/save@v4 with: diff --git a/CMakeLists.txt b/CMakeLists.txt index 59c1ceac..707bc603 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -109,6 +109,7 @@ if(CMAKE_HIP_COMPILER) endif() if(AMDGPU_TARGETS) + find_package(hip REQUIRED) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip) if (WIN32) @@ -117,7 +118,6 @@ if(CMAKE_HIP_COMPILER) target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM) - set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm) install(TARGETS ggml-hip RUNTIME_DEPENDENCY_SET rocm RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP @@ -128,15 +128,27 @@ if(CMAKE_HIP_COMPILER) PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf PRE_EXCLUDE_REGEXES ".*" POST_EXCLUDE_REGEXES "system32" - RUNTIME DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP - LIBRARY DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP + RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP + LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP ) foreach(HIP_LIB_BIN_INSTALL_DIR IN ITEMS ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}) if(EXISTS ${HIP_LIB_BIN_INSTALL_DIR}/rocblas) - install(DIRECTORY ${HIP_LIB_BIN_INSTALL_DIR}/rocblas DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP) + install(DIRECTORY ${HIP_LIB_BIN_INSTALL_DIR}/rocblas DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP) break() endif() endforeach() endif() endif() + +find_package(Vulkan) +if(Vulkan_FOUND) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan) + install(TARGETS ggml-vulkan + RUNTIME_DEPENDENCIES + PRE_INCLUDE_REGEXES vulkan + PRE_EXCLUDE_REGEXES ".*" + RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan + LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan + ) +endif() diff --git a/CMakePresets.json b/CMakePresets.json index bbeab76f..72417ade 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -30,7 +30,7 @@ "name": "CUDA 12", "inherits": [ "CUDA" ], "cacheVariables": { - "CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120", + "CMAKE_CUDA_ARCHITECTURES": "50;52;60;61;70;75;80;86;89;90;90a;120", "CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2" } }, @@ -38,7 +38,7 @@ "name": "CUDA 13", "inherits": [ "CUDA" ], "cacheVariables": { - "CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;110-virtual;120-virtual;121-virtual", + "CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual", "CMAKE_CUDA_FLAGS": "-t 2" } }, @@ -68,8 +68,12 @@ "inherits": [ "ROCm" ], "cacheVariables": { "CMAKE_HIP_FLAGS": "-parallel-jobs=4", - "AMDGPU_TARGETS": "gfx803;gfx902;gfx1030;gfx1031;gfx1032;gfx1034;gfx1035;gfx1036;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1200;gfx1201;gfx900:xnack-;gfx906:xnack-;gfx90c:xnack-;gfx1010:xnack-;gfx1011:xnack-;gfx1012:xnack-;" + "AMDGPU_TARGETS": "gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-" } + }, + { + "name": "Vulkan", + "inherits": [ "Default" ] } ], "buildPresets": [ @@ -122,6 +126,11 @@ "name": "ROCm 6", "inherits": [ "ROCm" ], "configurePreset": "ROCm 6" + }, + { + "name": "Vulkan", + "targets": [ "ggml-vulkan" ], + "configurePreset": "Vulkan" } ] } diff --git a/Dockerfile b/Dockerfile index ffaa31a5..dbc9207e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,6 +7,7 @@ ARG ROCMVERSION=6.3.3 ARG JETPACK5VERSION=r35.4.1 ARG JETPACK6VERSION=r36.4.0 ARG CMAKEVERSION=3.31.2 +ARG VULKANVERSION=1.4.321.1 # We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 @@ -17,6 +18,16 @@ RUN yum install -y yum-utils \ && dnf install -y ccache \ && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH +ARG VULKANVERSION +RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \ + && tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \ + && dnf -y install ninja-build \ + && ln -s /usr/bin/python3 /usr/bin/python \ + && /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \ + && /${VULKANVERSION}/vulkansdk -j 8 shaderc +RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \ + && cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib +ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH FROM --platform=linux/arm64 almalinux:8 AS base-arm64 # install epel-release for ccache @@ -77,9 +88,10 @@ FROM base AS rocm-6 ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'ROCm 6' \ + cmake --preset 'ROCm 6' -DOLLAMA_RUNNER_DIR="rocm" \ && cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \ && cmake --install build --component HIP --strip --parallel ${PARALLEL} +RUN rm -f dist/lib/ollama/rocm/rocblas/library/*gfx90[06]* FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5 ARG CMAKEVERSION @@ -89,7 +101,7 @@ COPY CMakeLists.txt CMakePresets.json . COPY ml/backend/ggml/ggml ml/backend/ggml/ggml ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'JetPack 5' \ + cmake --preset 'JetPack 5' -DOLLAMA_RUNNER_DIR="cuda_jetpack5" \ && cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \ && cmake --install build --component CUDA --strip --parallel ${PARALLEL} @@ -101,10 +113,17 @@ COPY CMakeLists.txt CMakePresets.json . COPY ml/backend/ggml/ggml ml/backend/ggml/ggml ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'JetPack 6' \ + cmake --preset 'JetPack 6' -DOLLAMA_RUNNER_DIR="cuda_jetpack6" \ && cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \ && cmake --install build --component CUDA --strip --parallel ${PARALLEL} +FROM base AS vulkan +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'Vulkan' -DOLLAMA_RUNNER_DIR="vulkan" \ + && cmake --build --parallel --preset 'Vulkan' \ + && cmake --install build --component Vulkan --strip --parallel 8 + + FROM base AS build WORKDIR /go/src/github.com/ollama/ollama COPY go.mod go.sum . @@ -122,27 +141,54 @@ RUN --mount=type=cache,target=/root/.cache/go-build \ FROM --platform=linux/amd64 scratch AS amd64 # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ -COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/ +COPY --from=cuda-13 dist/lib/ollama /lib/ollama/ +COPY --from=vulkan dist/lib/ollama /lib/ollama/ FROM --platform=linux/arm64 scratch AS arm64 # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/ -COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5 -COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6 +COPY --from=jetpack-5 dist/lib/ollama/ /lib/ollama/ +COPY --from=jetpack-6 dist/lib/ollama/ /lib/ollama/ FROM scratch AS rocm COPY --from=rocm-6 dist/lib/ollama /lib/ollama FROM ${FLAVOR} AS archive +ARG VULKANVERSION COPY --from=cpu dist/lib/ollama /lib/ollama COPY --from=build /bin/ollama /bin/ollama -FROM ubuntu:24.04 +# Temporary opt-out stages for Vulkan +FROM --platform=linux/amd64 scratch AS amd64_novulkan +# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ +COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ +COPY --from=cuda-13 dist/lib/ollama /lib/ollama/ +FROM arm64 AS arm64_novulkan +FROM ${FLAVOR}_novulkan AS archive_novulkan +COPY --from=cpu dist/lib/ollama /lib/ollama +COPY --from=build /bin/ollama /bin/ollama +FROM ubuntu:24.04 AS novulkan RUN apt-get update \ && apt-get install -y ca-certificates \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* +COPY --from=archive_novulkan /bin /usr/bin +ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +COPY --from=archive_novulkan /lib/ollama /usr/lib/ollama +ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 +ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility +ENV NVIDIA_VISIBLE_DEVICES=all +ENV OLLAMA_HOST=0.0.0.0:11434 +EXPOSE 11434 +ENTRYPOINT ["/bin/ollama"] +CMD ["serve"] + +FROM ubuntu:24.04 AS default +RUN apt-get update \ + && apt-get install -y ca-certificates libvulkan1 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* COPY --from=archive /bin /usr/bin ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin COPY --from=archive /lib/ollama /usr/lib/ollama diff --git a/Makefile.sync b/Makefile.sync index 0bfb70e0..2e99c7fb 100644 --- a/Makefile.sync +++ b/Makefile.sync @@ -1,6 +1,6 @@ UPSTREAM=https://github.com/ggml-org/llama.cpp.git WORKDIR=llama/vendor -FETCH_HEAD=e54d41befcc1575f4c898c5ff4ef43970cead75f +FETCH_HEAD=7049736b2dd9011bf819e298b844ebbc4b5afdc9 .PHONY: help help: diff --git a/README.md b/README.md index 0c79970a..e773236b 100644 --- a/README.md +++ b/README.md @@ -566,6 +566,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/)) - [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/)) - [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama) +- [achatbot-go](https://github.com/ai-bot-pro/achatbot-go) a multimodal(text/audio/image) chatbot. ### Mobile diff --git a/api/types.go b/api/types.go index 8cc7752c..1483c844 100644 --- a/api/types.go +++ b/api/types.go @@ -106,6 +106,14 @@ type GenerateRequest struct { // before this option was introduced) Think *ThinkValue `json:"think,omitempty"` + // Truncate is a boolean that, when set to true, truncates the chat history messages + // if the rendered prompt exceeds the context length limit. + Truncate *bool `json:"truncate,omitempty"` + + // Shift is a boolean that, when set to true, shifts the chat history + // when hitting the context length limit instead of erroring. + Shift *bool `json:"shift,omitempty"` + // DebugRenderOnly is a debug option that, when set to true, returns the rendered // template instead of calling the model. DebugRenderOnly bool `json:"_debug_render_only,omitempty"` @@ -140,6 +148,14 @@ type ChatRequest struct { // for supported models. Think *ThinkValue `json:"think,omitempty"` + // Truncate is a boolean that, when set to true, truncates the chat history messages + // if the rendered prompt exceeds the context length limit. + Truncate *bool `json:"truncate,omitempty"` + + // Shift is a boolean that, when set to true, shifts the chat history + // when hitting the context length limit instead of erroring. + Shift *bool `json:"shift,omitempty"` + // DebugRenderOnly is a debug option that, when set to true, returns the rendered // template instead of calling the model. DebugRenderOnly bool `json:"_debug_render_only,omitempty"` @@ -188,7 +204,7 @@ type ToolCall struct { } type ToolCallFunction struct { - Index int `json:"index,omitempty"` + Index int `json:"index"` Name string `json:"name"` Arguments ToolCallFunctionArguments `json:"arguments"` } @@ -250,9 +266,9 @@ func (pt PropertyType) String() string { type ToolProperty struct { AnyOf []ToolProperty `json:"anyOf,omitempty"` - Type PropertyType `json:"type"` + Type PropertyType `json:"type,omitempty"` Items any `json:"items,omitempty"` - Description string `json:"description"` + Description string `json:"description,omitempty"` Enum []any `json:"enum,omitempty"` } @@ -316,7 +332,7 @@ func (t *ToolFunctionParameters) String() string { type ToolFunction struct { Name string `json:"name"` - Description string `json:"description"` + Description string `json:"description,omitempty"` Parameters ToolFunctionParameters `json:"parameters"` } @@ -936,7 +952,7 @@ func (t *ThinkValue) UnmarshalJSON(data []byte) error { return nil } - return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\")") + return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\", true, or false)") } // MarshalJSON implements json.Marshaler diff --git a/api/types_test.go b/api/types_test.go index 5393b462..5053c162 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -298,6 +298,30 @@ func TestToolFunction_UnmarshalJSON(t *testing.T) { } } +func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) { + fn := ToolCallFunction{ + Name: "echo", + Arguments: ToolCallFunctionArguments{"message": "hi"}, + } + + data, err := json.Marshal(fn) + require.NoError(t, err) + + raw := map[string]any{} + require.NoError(t, json.Unmarshal(data, &raw)) + require.Contains(t, raw, "index") + assert.Equal(t, float64(0), raw["index"]) + + fn.Index = 3 + data, err = json.Marshal(fn) + require.NoError(t, err) + + raw = map[string]any{} + require.NoError(t, json.Unmarshal(data, &raw)) + require.Contains(t, raw, "index") + assert.Equal(t, float64(3), raw["index"]) +} + func TestPropertyType_UnmarshalJSON(t *testing.T) { tests := []struct { name string diff --git a/convert/convert_gptoss.go b/convert/convert_gptoss.go index 2048b18b..5338df21 100644 --- a/convert/convert_gptoss.go +++ b/convert/convert_gptoss.go @@ -85,6 +85,19 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor { case "scales": mxfp4s[name].scales = t } + } else if strings.HasSuffix(t.Name(), "gate_up_exps.bias") { + // gate_up_exps is interleaved, need to split into gate_exps and up_exps + // e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...] + out = append(out, slices.Collect(splitDim(t, 1, + split{ + Replacer: strings.NewReplacer("gate_up_exps", "gate_exps"), + slices: []tensor.Slice{nil, tensor.S(0, int(t.Shape()[1]), 2)}, + }, + split{ + Replacer: strings.NewReplacer("gate_up_exps", "up_exps"), + slices: []tensor.Slice{nil, tensor.S(1, int(t.Shape()[1]), 2)}, + }, + ))...) } else { out = append(out, &ggml.Tensor{ Name: t.Name(), @@ -97,17 +110,28 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor { for name, mxfp4 := range mxfp4s { dims := mxfp4.blocks.Shape() - - if !strings.HasSuffix(name, ".weight") { - name += ".weight" + if strings.Contains(name, "ffn_down_exps") { + out = append(out, &ggml.Tensor{ + Name: name + ".weight", + Kind: uint32(ggml.TensorTypeMXFP4), + Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2}, + WriterTo: mxfp4, + }) + } else if strings.Contains(name, "ffn_gate_up_exps") { + // gate_up_exps is interleaved, need to split into gate_exps and up_exps + // e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...] + out = append(out, &ggml.Tensor{ + Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight", + Kind: uint32(ggml.TensorTypeMXFP4), + Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2}, + WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2), + }, &ggml.Tensor{ + Name: strings.Replace(name, "gate_up", "up", 1) + ".weight", + Kind: uint32(ggml.TensorTypeMXFP4), + Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2}, + WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2), + }) } - - out = append(out, &ggml.Tensor{ - Name: name, - Kind: uint32(ggml.TensorTypeMXFP4), - Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2}, - WriterTo: mxfp4, - }) } return out @@ -158,9 +182,21 @@ func (m *gptossModel) Replacements() []string { } type mxfp4 struct { + slices []tensor.Slice + blocks, scales Tensor } +func (m *mxfp4) slice(dim, start, end, step int) *mxfp4 { + slice := slices.Repeat([]tensor.Slice{nil}, len(m.blocks.Shape())) + slice[dim] = tensor.S(start, end, step) + return &mxfp4{ + slices: slice, + blocks: m.blocks, + scales: m.scales, + } +} + func (m *mxfp4) WriteTo(w io.Writer) (int64, error) { var b bytes.Buffer if _, err := m.blocks.WriteTo(&b); err != nil { @@ -204,6 +240,13 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) { return 0, err } + if len(m.slices) > 0 { + out, err = out.Slice(m.slices...) + if err != nil { + return 0, err + } + } + out = tensor.Materialize(out) if err := out.Reshape(out.Shape().TotalSize()); err != nil { diff --git a/convert/tensor.go b/convert/tensor.go index c9565ed4..9b8517f1 100644 --- a/convert/tensor.go +++ b/convert/tensor.go @@ -16,7 +16,8 @@ import ( type split struct { *strings.Replacer - dim int + dim int + slices []tensor.Slice // fn is an optional function to apply to the tensor after slicing fn func(tensor.Tensor) (tensor.Tensor, error) @@ -32,9 +33,12 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] { shape := slices.Clone(t.Shape()) shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits))) - slice := slices.Repeat([]tensor.Slice{nil}, len(shape)) - slice[dim] = tensor.S(offset, offset+int(shape[dim])) - offset += int(shape[dim]) + slice := split.slices + if len(slice) == 0 { + slice = slices.Repeat([]tensor.Slice{nil}, len(shape)) + slice[dim] = tensor.S(offset, offset+int(shape[dim])) + offset += int(shape[dim]) + } t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) { dims := make([]int, len(shape)) diff --git a/discover/amd_common.go b/discover/amd_common.go deleted file mode 100644 index 08834b22..00000000 --- a/discover/amd_common.go +++ /dev/null @@ -1,83 +0,0 @@ -//go:build linux || windows - -package discover - -import ( - "errors" - "log/slog" - "os" - "path/filepath" - "runtime" - "strings" -) - -// Determine if the given ROCm lib directory is usable by checking for existence of some glob patterns -func rocmLibUsable(libDir string) bool { - slog.Debug("evaluating potential rocm lib dir " + libDir) - for _, g := range ROCmLibGlobs { - res, _ := filepath.Glob(filepath.Join(libDir, g)) - if len(res) == 0 { - return false - } - } - return true -} - -func GetSupportedGFX(libDir string) ([]string, error) { - var ret []string - files, err := filepath.Glob(filepath.Join(libDir, "rocblas", "library", "TensileLibrary_lazy_gfx*.dat")) - if err != nil { - return nil, err - } - for _, file := range files { - ret = append(ret, strings.TrimSuffix(strings.TrimPrefix(filepath.Base(file), "TensileLibrary_lazy_"), ".dat")) - } - return ret, nil -} - -func commonAMDValidateLibDir() (string, error) { - // Favor our bundled version - - // Installer payload location if we're running the installed binary - rocmTargetDir := filepath.Join(LibOllamaPath, "rocm") - if rocmLibUsable(rocmTargetDir) { - slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir) - return rocmTargetDir, nil - } - - // Prefer explicit HIP env var - hipPath := os.Getenv("HIP_PATH") - if hipPath != "" { - hipLibDir := filepath.Join(hipPath, "bin") - if rocmLibUsable(hipLibDir) { - slog.Debug("detected ROCM via HIP_PATH=" + hipPath) - return hipLibDir, nil - } - } - - // Scan the LD_LIBRARY_PATH or PATH - pathEnv := "LD_LIBRARY_PATH" - if runtime.GOOS == "windows" { - pathEnv = "PATH" - } - - paths := os.Getenv(pathEnv) - for _, path := range filepath.SplitList(paths) { - d, err := filepath.Abs(path) - if err != nil { - continue - } - if rocmLibUsable(d) { - return d, nil - } - } - - // Well known location(s) - for _, path := range RocmStandardLocations { - if rocmLibUsable(path) { - return path, nil - } - } - - return "", errors.New("no suitable rocm found, falling back to CPU") -} diff --git a/discover/amd_hip_windows.go b/discover/amd_hip_windows.go deleted file mode 100644 index bf19ef06..00000000 --- a/discover/amd_hip_windows.go +++ /dev/null @@ -1,147 +0,0 @@ -package discover - -import ( - "errors" - "fmt" - "log/slog" - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -const ( - hipSuccess = 0 - hipErrorNoDevice = 100 -) - -type hipDevicePropMinimal struct { - Name [256]byte - unused1 [140]byte - GcnArchName [256]byte // gfx#### - iGPU int // Doesn't seem to actually report correctly - unused2 [128]byte -} - -// Wrap the amdhip64.dll library for GPU discovery -type HipLib struct { - dll windows.Handle - hipGetDeviceCount uintptr - hipGetDeviceProperties uintptr - hipMemGetInfo uintptr - hipSetDevice uintptr - hipDriverGetVersion uintptr -} - -func NewHipLib() (*HipLib, error) { - // At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs - h, err := windows.LoadLibrary("amdhip64_6.dll") - if err != nil { - return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err) - } - hl := &HipLib{} - hl.dll = h - hl.hipGetDeviceCount, err = windows.GetProcAddress(hl.dll, "hipGetDeviceCount") - if err != nil { - return nil, err - } - hl.hipGetDeviceProperties, err = windows.GetProcAddress(hl.dll, "hipGetDeviceProperties") - if err != nil { - return nil, err - } - hl.hipMemGetInfo, err = windows.GetProcAddress(hl.dll, "hipMemGetInfo") - if err != nil { - return nil, err - } - hl.hipSetDevice, err = windows.GetProcAddress(hl.dll, "hipSetDevice") - if err != nil { - return nil, err - } - hl.hipDriverGetVersion, err = windows.GetProcAddress(hl.dll, "hipDriverGetVersion") - if err != nil { - return nil, err - } - return hl, nil -} - -// The hip library only evaluates the ROCR_VISIBLE_DEVICES variable at startup -// so we have to unload/reset the library after we do our initial discovery -// to make sure our updates to that variable are processed by llama.cpp -func (hl *HipLib) Release() { - err := windows.FreeLibrary(hl.dll) - if err != nil { - slog.Warn("failed to unload amdhip64.dll", "error", err) - } - hl.dll = 0 -} - -func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) { - if hl.dll == 0 { - return 0, 0, errors.New("dll has been unloaded") - } - var version int - status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version))) - if status != hipSuccess { - return 0, 0, fmt.Errorf("failed call to hipDriverGetVersion: %d %s", status, err) - } - - slog.Debug("hipDriverGetVersion", "version", version) - driverMajor = version / 10000000 - driverMinor = (version - (driverMajor * 10000000)) / 100000 - - return driverMajor, driverMinor, nil -} - -func (hl *HipLib) HipGetDeviceCount() int { - if hl.dll == 0 { - slog.Error("dll has been unloaded") - return 0 - } - var count int - status, _, err := syscall.SyscallN(hl.hipGetDeviceCount, uintptr(unsafe.Pointer(&count))) - if status == hipErrorNoDevice { - slog.Info("AMD ROCm reports no devices found") - return 0 - } - if status != hipSuccess { - slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err) - } - return count -} - -func (hl *HipLib) HipSetDevice(device int) error { - if hl.dll == 0 { - return errors.New("dll has been unloaded") - } - status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device)) - if status != hipSuccess { - return fmt.Errorf("failed call to hipSetDevice: %d %s", status, err) - } - return nil -} - -func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) { - if hl.dll == 0 { - return nil, errors.New("dll has been unloaded") - } - var props hipDevicePropMinimal - status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device)) - if status != hipSuccess { - return nil, fmt.Errorf("failed call to hipGetDeviceProperties: %d %s", status, err) - } - return &props, nil -} - -// free, total, err -func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) { - if hl.dll == 0 { - return 0, 0, errors.New("dll has been unloaded") - } - var totalMemory uint64 - var freeMemory uint64 - status, _, err := syscall.SyscallN(hl.hipMemGetInfo, uintptr(unsafe.Pointer(&freeMemory)), uintptr(unsafe.Pointer(&totalMemory))) - if status != hipSuccess { - return 0, 0, fmt.Errorf("failed call to hipMemGetInfo: %d %s", status, err) - } - return freeMemory, totalMemory, nil -} diff --git a/discover/amd_linux.go b/discover/amd_linux.go deleted file mode 100644 index 31ad4bb6..00000000 --- a/discover/amd_linux.go +++ /dev/null @@ -1,549 +0,0 @@ -package discover - -import ( - "bufio" - "errors" - "fmt" - "io" - "io/fs" - "log/slog" - "os" - "path/filepath" - "regexp" - "slices" - "sort" - "strconv" - "strings" - - "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/format" -) - -// Discovery logic for AMD/ROCm GPUs - -const ( - DriverVersionFile = "/sys/module/amdgpu/version" - AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/" - GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties" - - // Prefix with the node dir - GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line - - // Direct Rendering Manager sysfs location - DRMDeviceDirGlob = "/sys/class/drm/card*/device" - DRMTotalMemoryFile = "mem_info_vram_total" - DRMUsedMemoryFile = "mem_info_vram_used" - - // In hex; properties file is in decimal - DRMUniqueIDFile = "unique_id" - DRMVendorFile = "vendor" - DRMDeviceFile = "device" -) - -var ( - // Used to validate if the given ROCm lib is usable - ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here... - RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"} -) - -// Gather GPU information from the amdgpu driver if any supported GPUs are detected -// Only called once during bootstrap -func AMDGetGPUInfo() ([]RocmGPUInfo, error) { - resp := []RocmGPUInfo{} - if !AMDDetected() { - return resp, fmt.Errorf("AMD GPUs not detected") - } - - // Opportunistic logging of driver version to aid in troubleshooting - driverMajor, driverMinor, err := AMDDriverVersion() - if err != nil { - // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU - slog.Warn("ollama recommends running the https://www.amd.com/en/support/download/linux-drivers.html", "error", err) - } - - // Determine if the user has already pre-selected which GPUs to look at, then ignore the others - var visibleDevices []string - hipVD := envconfig.HipVisibleDevices() // zero based index only - rocrVD := envconfig.RocrVisibleDevices() // zero based index or UUID - gpuDO := envconfig.GpuDeviceOrdinal() // zero based index - switch { - case rocrVD != "": - visibleDevices = strings.Split(rocrVD, ",") - case hipVD != "": - visibleDevices = strings.Split(hipVD, ",") - case gpuDO != "": - visibleDevices = strings.Split(gpuDO, ",") - } - - gfxOverride := envconfig.HsaOverrideGfxVersion() - var supported []string - var libDir string - - // The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract - // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU) - matches, _ := filepath.Glob(GPUPropertiesFileGlob) - sort.Slice(matches, func(i, j int) bool { - // /sys/class/kfd/kfd/topology/nodes//properties - a, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[i])), 10, 64) - if err != nil { - slog.Debug("parse err", "error", err, "match", matches[i]) - return false - } - b, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[j])), 10, 64) - if err != nil { - slog.Debug("parse err", "error", err, "match", matches[i]) - return false - } - return a < b - }) - gpuCount := 0 - gpuOrdinalID := 0 - for _, match := range matches { - slog.Debug("evaluating amdgpu node " + match) - fp, err := os.Open(match) - if err != nil { - slog.Debug("failed to open sysfs node", "file", match, "error", err) - continue - } - defer fp.Close() - - scanner := bufio.NewScanner(fp) - isCPU := false - var major, minor, patch uint64 - var vendor, device, uniqueID uint64 - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - // Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs - if strings.HasPrefix(line, "gfx_target_version") { - ver := strings.Fields(line) - - // Detect CPUs - if len(ver) == 2 && ver[1] == "0" { - slog.Debug("detected CPU " + match) - isCPU = true - break - } - - if len(ver) != 2 || len(ver[1]) < 5 { - slog.Warn("malformed "+match, "gfx_target_version", line) - // If this winds up being a CPU, our offsets may be wrong - continue - } - l := len(ver[1]) - var err1, err2, err3 error - patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32) - minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32) - major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32) - if err1 != nil || err2 != nil || err3 != nil { - slog.Debug("malformed int " + line) - continue - } - } else if strings.HasPrefix(line, "vendor_id") { - ver := strings.Fields(line) - if len(ver) != 2 { - slog.Debug("malformed", "vendor_id", line) - continue - } - vendor, err = strconv.ParseUint(ver[1], 10, 64) - if err != nil { - slog.Debug("malformed", "vendor_id", line, "error", err) - } - } else if strings.HasPrefix(line, "device_id") { - ver := strings.Fields(line) - if len(ver) != 2 { - slog.Debug("malformed", "device_id", line) - continue - } - device, err = strconv.ParseUint(ver[1], 10, 64) - if err != nil { - slog.Debug("malformed", "device_id", line, "error", err) - } - } else if strings.HasPrefix(line, "unique_id") { - ver := strings.Fields(line) - if len(ver) != 2 { - slog.Debug("malformed", "unique_id", line) - continue - } - uniqueID, err = strconv.ParseUint(ver[1], 10, 64) - if err != nil { - slog.Debug("malformed", "unique_id", line, "error", err) - } - } - // TODO - any other properties we want to extract and record? - // vendor_id + device_id -> pci lookup for "Name" - // Other metrics that may help us understand relative performance between multiple GPUs - } - - // Note: while ./mem_banks/*/used_memory exists, it doesn't appear to take other VRAM consumers - // into consideration, so we instead map the device over to the DRM driver sysfs nodes which - // do reliably report VRAM usage. - - if isCPU { - continue - } - - // Skip over any GPUs that are masked - if major == 0 && minor == 0 && patch == 0 { - slog.Debug("skipping gpu with gfx000") - continue - } - - // Look up the memory for the current node - totalMemory := uint64(0) - usedMemory := uint64(0) - var usedFile string - mapping := []struct { - id uint64 - filename string - }{ - {vendor, DRMVendorFile}, - {device, DRMDeviceFile}, - {uniqueID, DRMUniqueIDFile}, // Not all devices will report this - } - slog.Debug("mapping amdgpu to drm sysfs nodes", "amdgpu", match, "vendor", vendor, "device", device, "unique_id", uniqueID) - // Map over to DRM location to find the total/free memory - drmMatches, _ := filepath.Glob(DRMDeviceDirGlob) - for _, devDir := range drmMatches { - matched := true - for _, m := range mapping { - if m.id == 0 { - // Null ID means it didn't populate, so we can't use it to match - continue - } - filename := filepath.Join(devDir, m.filename) - buf, err := os.ReadFile(filename) - if err != nil { - slog.Debug("failed to read sysfs node", "file", filename, "error", err) - matched = false - break - } - // values here are in hex, strip off the lead 0x and parse so we can compare the numeric (decimal) values in amdgpu - cmp, err := strconv.ParseUint(strings.TrimPrefix(strings.TrimSpace(string(buf)), "0x"), 16, 64) - if err != nil { - slog.Debug("failed to parse sysfs node", "file", filename, "error", err) - matched = false - break - } - if cmp != m.id { - matched = false - break - } - } - if !matched { - continue - } - - // Found the matching DRM directory - slog.Debug("matched", "amdgpu", match, "drm", devDir) - totalFile := filepath.Join(devDir, DRMTotalMemoryFile) - buf, err := os.ReadFile(totalFile) - if err != nil { - slog.Debug("failed to read sysfs node", "file", totalFile, "error", err) - break - } - totalMemory, err = strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64) - if err != nil { - slog.Debug("failed to parse sysfs node", "file", totalFile, "error", err) - break - } - - usedFile = filepath.Join(devDir, DRMUsedMemoryFile) - usedMemory, err = getFreeMemory(usedFile) - if err != nil { - slog.Debug("failed to update used memory", "error", err) - } - break - } - - var name string - // TODO - PCI ID lookup - if vendor > 0 && device > 0 { - name = fmt.Sprintf("%04x:%04x", vendor, device) - } - - // Favor UUIDs if available to reduce possibility of getting the numeric IDs wrong - var ID string - if uniqueID != 0 { - ID = fmt.Sprintf("GPU-%016x", uniqueID) - } else { - ID = strconv.Itoa(gpuOrdinalID) - } - - gpuInfo := RocmGPUInfo{ - GpuInfo: GpuInfo{ - Library: "rocm", - memInfo: memInfo{ - TotalMemory: totalMemory, - FreeMemory: (totalMemory - usedMemory), - }, - ID: ID, - filterID: gpuOrdinalID, - Name: name, - Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch), - MinimumMemory: rocmMinimumMemory, - DriverMajor: driverMajor, - DriverMinor: driverMinor, - }, - usedFilepath: usedFile, - index: gpuCount, - } - - // Keep track of numeric IDs based on valid GPUs - gpuCount += 1 - - // If the user wants to filter to a subset of devices, filter out if we aren't a match - if len(visibleDevices) > 0 { - include := false - for _, visible := range visibleDevices { - if (uniqueID != 0 && visible == gpuInfo.ID) || visible == strconv.Itoa(gpuInfo.index) { - include = true - break - } - } - if !include { - reason := "filtering out device per user request" - slog.Info(reason, "id", gpuInfo.ID, "index", gpuInfo.index, "visible_devices", visibleDevices) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - - continue - } - } - - // Ordinal IDs are based on the visible GPUs - gpuOrdinalID += 1 - - // iGPU detection, remove this check once we can support an iGPU variant of the rocm library - if totalMemory < IGPUMemLimit { - reason := "unsupported Radeon iGPU detected skipping" - slog.Info(reason, "id", gpuInfo.ID, "total", format.HumanBytes2(totalMemory)) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - continue - } - //minVer, err := strconv.Atoi(RocmComputeMajorMin) - //if err != nil { - // slog.Error("invalid RocmComputeMajorMin setting", "value", RocmComputeMajorMin, "error", err) - //} - // if int(major) < minVer { - // reason := fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch) - // slog.Warn(reason, "gpu", gpuID) - // unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - // GpuInfo: gpuInfo.GpuInfo, - // Reason: reason, - // }) - - // continue - //} - - slog.Debug("amdgpu memory", "gpu", gpuInfo.ID, "total", format.HumanBytes2(totalMemory)) - slog.Debug("amdgpu memory", "gpu", gpuInfo.ID, "available", format.HumanBytes2(totalMemory-usedMemory)) - - // Final validation is gfx compatibility - load the library if we haven't already loaded it - // even if the user overrides, we still need to validate the library - if libDir == "" { - libDir, err = AMDValidateLibDir() - if err != nil { - err = fmt.Errorf("unable to verify rocm library: %w", err) - slog.Warn(err.Error()) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: err.Error(), - }) - return nil, err - } - } - gpuInfo.DependencyPath = []string{libDir} - - if gfxOverride == "" { - // Only load supported list once - if len(supported) == 0 { - supported, err = GetSupportedGFX(libDir) - if err != nil { - err = fmt.Errorf("failed to lookup supported GFX types: %w", err) - slog.Warn(err.Error()) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: err.Error(), - }) - return nil, err - } - slog.Debug("rocm supported GPUs", "types", supported) - } - gfx := gpuInfo.Compute - if !slices.Contains[[]string, string](supported, gfx) { - reason := fmt.Sprintf("amdgpu is not supported (supported types:%s)", supported) - slog.Warn(reason, "gpu_type", gfx, "gpu", gpuInfo.ID, "library", libDir) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - - // TODO - consider discrete markdown just for ROCM troubleshooting? - slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage") - continue - } else { - slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx) - } - } else { - slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride) - } - - // Check for env var workarounds - if name == "1002:687f" { // Vega RX 56 - gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, "HSA_ENABLE_SDMA=0") - } - - // The GPU has passed all the verification steps and is supported - resp = append(resp, gpuInfo) - } - if len(resp) == 0 { - err := fmt.Errorf("no compatible amdgpu devices detected") - slog.Info(err.Error()) - return nil, err - } - if err := verifyKFDDriverAccess(); err != nil { - err = fmt.Errorf("amdgpu devices detected but permission problems block access: %w", err) - slog.Error(err.Error()) - return nil, err - } - return resp, nil -} - -// Quick check for AMD driver so we can skip amdgpu discovery if not present -func AMDDetected() bool { - // Some driver versions (older?) don't have a version file, so just lookup the parent dir - sysfsDir := filepath.Dir(DriverVersionFile) - _, err := os.Stat(sysfsDir) - if errors.Is(err, os.ErrNotExist) { - slog.Debug("amdgpu driver not detected " + sysfsDir) - return false - } else if err != nil { - slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err) - return false - } - return true -} - -// Prefer to use host installed ROCm, as long as it meets our minimum requirements -// failing that, tell the user how to download it on their own -func AMDValidateLibDir() (string, error) { - libDir, err := commonAMDValidateLibDir() - if err == nil { - return libDir, nil - } - - // Well known ollama installer path - installedRocmDir := "/usr/share/ollama/lib/rocm" - if rocmLibUsable(installedRocmDir) { - return installedRocmDir, nil - } - - // If we still haven't found a usable rocm, the user will have to install it on their own - slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install") - return "", errors.New("no suitable rocm found, falling back to CPU") -} - -func AMDDriverVersion() (driverMajor, driverMinor int, err error) { - _, err = os.Stat(DriverVersionFile) - if err != nil { - return 0, 0, fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err) - } - fp, err := os.Open(DriverVersionFile) - if err != nil { - return 0, 0, err - } - defer fp.Close() - verString, err := io.ReadAll(fp) - if err != nil { - return 0, 0, err - } - - pattern := `\A(\d+)\.(\d+).*` - regex := regexp.MustCompile(pattern) - match := regex.FindStringSubmatch(string(verString)) - if len(match) < 2 { - return 0, 0, fmt.Errorf("malformed version string %s", string(verString)) - } - driverMajor, err = strconv.Atoi(match[1]) - if err != nil { - return 0, 0, err - } - driverMinor, err = strconv.Atoi(match[2]) - if err != nil { - return 0, 0, err - } - return driverMajor, driverMinor, nil -} - -func (gpus RocmGPUInfoList) RefreshFreeMemory() error { - if len(gpus) == 0 { - return nil - } - for i := range gpus { - usedMemory, err := getFreeMemory(gpus[i].usedFilepath) - if err != nil { - return err - } - slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(gpus[i].TotalMemory-usedMemory)) - gpus[i].FreeMemory = gpus[i].TotalMemory - usedMemory - } - return nil -} - -func getFreeMemory(usedFile string) (uint64, error) { - buf, err := os.ReadFile(usedFile) - if err != nil { - return 0, fmt.Errorf("failed to read sysfs node %s %w", usedFile, err) - } - usedMemory, err := strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64) - if err != nil { - slog.Debug("failed to parse sysfs node", "file", usedFile, "error", err) - return 0, fmt.Errorf("failed to parse sysfs node %s %w", usedFile, err) - } - return usedMemory, nil -} - -func verifyKFDDriverAccess() error { - // Verify we have permissions - either running as root, or we have group access to the driver - fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0o666) - if err != nil { - if errors.Is(err, fs.ErrPermission) { - return fmt.Errorf("permissions not set up properly. Either run ollama as root, or add you user account to the render group. %w", err) - } else if errors.Is(err, fs.ErrNotExist) { - // Container runtime failure? - return fmt.Errorf("kfd driver not loaded. If running in a container, remember to include '--device /dev/kfd --device /dev/dri'") - } - return fmt.Errorf("failed to check permission on /dev/kfd: %w", err) - } - fd.Close() - return nil -} - -func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { - ids := []string{} - for _, info := range gpuInfo { - if info.Library != "rocm" { - continue - } - // If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number - if _, err := strconv.Atoi(info.ID); err == nil { - ids = append(ids, fmt.Sprintf("%d", info.filterID)) - } else { - ids = append(ids, info.ID) - } - } - if len(ids) == 0 { - return "" - } - - // There are 3 potential env vars to use to select GPUs. - // ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux - // GPU_DEVICE_ORDINAL supports numeric IDs only - // HIP_VISIBLE_DEVICES supports numeric IDs only - return "ROCR_VISIBLE_DEVICES=" + strings.Join(ids, ",") -} diff --git a/discover/amd_windows.go b/discover/amd_windows.go deleted file mode 100644 index ae28696a..00000000 --- a/discover/amd_windows.go +++ /dev/null @@ -1,226 +0,0 @@ -package discover - -import ( - "bytes" - "errors" - "fmt" - "log/slog" - "path/filepath" - "slices" - "strconv" - "strings" - - "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/format" -) - -const ( - - // TODO We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true - iGPUName = "AMD 2099 Graphics" -) - -var ( - // Used to validate if the given ROCm lib is usable - ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6 - RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob? -) - -// Only called once during bootstrap -func AMDGetGPUInfo() ([]RocmGPUInfo, error) { - resp := []RocmGPUInfo{} - hl, err := NewHipLib() - if err != nil { - slog.Debug(err.Error()) - return nil, err - } - defer hl.Release() - - driverMajor, driverMinor, err := hl.AMDDriverVersion() - if err != nil { - // For now this is benign, but we may eventually need to fail compatibility checks - slog.Debug("error looking up amd driver version", "error", err) - } - - // Note: the HIP library automatically handles subsetting to any *_VISIBLE_DEVICES the user specified - count := hl.HipGetDeviceCount() - if count == 0 { - err := fmt.Errorf("no compatible amdgpu devices detected") - slog.Info(err.Error()) - return nil, err - } - - libDir, err := AMDValidateLibDir() - if err != nil { - err = fmt.Errorf("unable to verify rocm library: %w", err) - slog.Warn(err.Error()) - return nil, err - } - - var supported []string - gfxOverride := envconfig.HsaOverrideGfxVersion() - if gfxOverride == "" { - supported, err = GetSupportedGFX(libDir) - if err != nil { - err = fmt.Errorf("failed to lookup supported GFX types: %w", err) - slog.Warn(err.Error()) - return nil, err - } - } else { - slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride) - } - - slog.Debug("detected hip devices", "count", count) - // TODO how to determine the underlying device ID when visible devices is causing this to subset? - for i := range count { - err = hl.HipSetDevice(i) - if err != nil { - slog.Warn("set device", "id", i, "error", err) - continue - } - - props, err := hl.HipGetDeviceProperties(i) - if err != nil { - slog.Warn("get properties", "id", i, "error", err) - continue - } - n := bytes.IndexByte(props.Name[:], 0) - name := string(props.Name[:n]) - // TODO is UUID actually populated on windows? - // Can luid be used on windows for setting visible devices (and is it actually set?) - n = bytes.IndexByte(props.GcnArchName[:], 0) - gfx := string(props.GcnArchName[:n]) - slog.Debug("hip device", "id", i, "name", name, "gfx", gfx) - // slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0 - // TODO Why isn't props.iGPU accurate!? - - freeMemory, totalMemory, err := hl.HipMemGetInfo() - if err != nil { - slog.Warn("get mem info", "id", i, "error", err) - continue - } - - gpuInfo := RocmGPUInfo{ - GpuInfo: GpuInfo{ - Library: "rocm", - memInfo: memInfo{ - TotalMemory: totalMemory, - FreeMemory: freeMemory, - }, - // Free memory reporting on Windows is not reliable until we bump to ROCm v6.2 - UnreliableFreeMemory: true, - - ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices - filterID: i, - DependencyPath: []string{libDir}, - MinimumMemory: rocmMinimumMemory, - Name: name, - Compute: gfx, - DriverMajor: driverMajor, - DriverMinor: driverMinor, - }, - index: i, - } - - // iGPU detection, remove this check once we can support an iGPU variant of the rocm library - if strings.EqualFold(name, iGPUName) || totalMemory < IGPUMemLimit { - reason := "unsupported Radeon iGPU detected skipping" - slog.Info(reason, "id", gpuInfo.ID, "total", format.HumanBytes2(totalMemory)) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - continue - } - - // Strip off Target Features when comparing - if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) { - reason := fmt.Sprintf("amdgpu is not supported (supported types:%s)", supported) - slog.Warn(reason, "gpu_type", gfx, "gpu", gpuInfo.ID, "library", libDir) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - // HSA_OVERRIDE_GFX_VERSION not supported on windows - continue - } else { - slog.Debug("amdgpu is supported", "gpu", i, "gpu_type", gfx) - } - - slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory)) - slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory)) - - resp = append(resp, gpuInfo) - } - - return resp, nil -} - -func AMDValidateLibDir() (string, error) { - libDir, err := commonAMDValidateLibDir() - if err == nil { - return libDir, nil - } - - // Installer payload (if we're running from some other location) - rocmTargetDir := filepath.Join(LibOllamaPath, "rocm") - if rocmLibUsable(rocmTargetDir) { - slog.Debug("detected ollama installed ROCm at " + rocmTargetDir) - return rocmTargetDir, nil - } - - // Should not happen on windows since we include it in the installer, but stand-alone binary might hit this - slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm") - return "", errors.New("no suitable rocm found, falling back to CPU") -} - -func (gpus RocmGPUInfoList) RefreshFreeMemory() error { - if len(gpus) == 0 { - return nil - } - hl, err := NewHipLib() - if err != nil { - slog.Debug(err.Error()) - return err - } - defer hl.Release() - - for i := range gpus { - err := hl.HipSetDevice(gpus[i].index) - if err != nil { - return err - } - freeMemory, _, err := hl.HipMemGetInfo() - if err != nil { - slog.Warn("get mem info", "id", i, "error", err) - continue - } - slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(freeMemory)) - gpus[i].FreeMemory = freeMemory - } - return nil -} - -func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { - ids := []string{} - for _, info := range gpuInfo { - if info.Library != "rocm" { - continue - } - // If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number - if _, err := strconv.Atoi(info.ID); err == nil { - ids = append(ids, fmt.Sprintf("%d", info.filterID)) - } else { - ids = append(ids, info.ID) - } - } - if len(ids) == 0 { - return "" - } - - // There are 3 potential env vars to use to select GPUs. - // ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows - // HIP_VISIBLE_DEVICES supports numeric IDs only - // GPU_DEVICE_ORDINAL supports numeric IDs only - return "HIP_VISIBLE_DEVICES=" + strings.Join(ids, ",") -} diff --git a/discover/cpu_common.go b/discover/cpu_common.go deleted file mode 100644 index 2b9f7292..00000000 --- a/discover/cpu_common.go +++ /dev/null @@ -1,24 +0,0 @@ -package discover - -import ( - "os" - "path/filepath" - "runtime" - "strings" -) - -func IsNUMA() bool { - if runtime.GOOS != "linux" { - // numa support in llama.cpp is linux only - return false - } - ids := map[string]any{} - packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id") - for _, packageId := range packageIds { - id, err := os.ReadFile(packageId) - if err == nil { - ids[strings.TrimSpace(string(id))] = struct{}{} - } - } - return len(ids) > 1 -} diff --git a/discover/gpu_linux.go b/discover/cpu_linux.go similarity index 75% rename from discover/gpu_linux.go rename to discover/cpu_linux.go index 44c53b44..c3a0ef7f 100644 --- a/discover/gpu_linux.go +++ b/discover/cpu_linux.go @@ -4,7 +4,9 @@ import ( "bufio" "fmt" "io" + "log/slog" "os" + "path/filepath" "reflect" "regexp" "sort" @@ -13,47 +15,6 @@ import ( "github.com/ollama/ollama/format" ) -var CudartGlobs = []string{ - "/usr/local/cuda/lib64/libcudart.so*", - "/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*", - "/usr/lib/x86_64-linux-gnu/libcudart.so*", - "/usr/lib/wsl/lib/libcudart.so*", - "/usr/lib/wsl/drivers/*/libcudart.so*", - "/opt/cuda/lib64/libcudart.so*", - "/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*", - "/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*", - "/usr/lib/aarch64-linux-gnu/libcudart.so*", - "/usr/local/cuda/lib*/libcudart.so*", - "/usr/lib*/libcudart.so*", - "/usr/local/lib*/libcudart.so*", -} - -var NvmlGlobs = []string{} - -var NvcudaGlobs = []string{ - "/usr/local/cuda*/targets/*/lib/libcuda.so*", - "/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*", - "/usr/lib/*-linux-gnu/libcuda.so*", - "/usr/lib/wsl/lib/libcuda.so*", - "/usr/lib/wsl/drivers/*/libcuda.so*", - "/opt/cuda/lib*/libcuda.so*", - "/usr/local/cuda/lib*/libcuda.so*", - "/usr/lib*/libcuda.so*", - "/usr/local/lib*/libcuda.so*", -} - -var OneapiGlobs = []string{ - "/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*", - "/usr/lib*/libze_intel_gpu.so*", -} - -var ( - CudartMgmtName = "libcudart.so*" - NvcudaMgmtName = "libcuda.so*" - NvmlMgmtName = "" // not currently wired on linux - OneapiMgmtName = "libze_intel_gpu.so*" -) - func GetCPUMem() (memInfo, error) { var mem memInfo var total, available, free, buffers, cached, freeSwap uint64 @@ -106,16 +67,17 @@ type linuxCpuInfo struct { CoreID string `cpuinfo:"core id"` } -func GetCPUDetails() ([]CPU, error) { +func GetCPUDetails() []CPU { file, err := os.Open(CpuInfoFilename) if err != nil { - return nil, err + slog.Warn("failed to get CPU details", "error", err) + return nil } defer file.Close() return linuxCPUDetails(file) } -func linuxCPUDetails(file io.Reader) ([]CPU, error) { +func linuxCPUDetails(file io.Reader) []CPU { reColumns := regexp.MustCompile("\t+: ") scanner := bufio.NewScanner(file) cpuInfos := []linuxCpuInfo{} @@ -194,5 +156,17 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) { for _, k := range keys { result = append(result, *socketByID[k]) } - return result, nil + return result +} + +func IsNUMA() bool { + ids := map[string]any{} + packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id") + for _, packageId := range packageIds { + id, err := os.ReadFile(packageId) + if err == nil { + ids[strings.TrimSpace(string(id))] = struct{}{} + } + } + return len(ids) > 1 } diff --git a/discover/gpu_linux_test.go b/discover/cpu_linux_test.go similarity index 99% rename from discover/gpu_linux_test.go rename to discover/cpu_linux_test.go index c4d64e38..3a514478 100644 --- a/discover/gpu_linux_test.go +++ b/discover/cpu_linux_test.go @@ -2062,10 +2062,7 @@ power management: for k, v := range testCases { t.Run(k, func(t *testing.T) { buf := bytes.NewBufferString(v.input) - cpus, err := linuxCPUDetails(buf) - if err != nil { - t.Fatal(err) - } + cpus := linuxCPUDetails(buf) slog.Info("example", "scenario", k, "cpus", cpus) si := SystemInfo{ diff --git a/discover/gpu_windows.go b/discover/cpu_windows.go similarity index 82% rename from discover/gpu_windows.go rename to discover/cpu_windows.go index 2dc2f074..5f516b5d 100644 --- a/discover/gpu_windows.go +++ b/discover/cpu_windows.go @@ -26,29 +26,6 @@ var ( GetLogicalProcessorInformationEx = k32.NewProc("GetLogicalProcessorInformationEx") ) -var CudartGlobs = []string{ - "c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll", -} - -var NvmlGlobs = []string{ - "c:\\Windows\\System32\\nvml.dll", -} - -var NvcudaGlobs = []string{ - "c:\\windows\\system*\\nvcuda.dll", -} - -var OneapiGlobs = []string{ - "c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll", -} - -var ( - CudartMgmtName = "cudart64_*.dll" - NvcudaMgmtName = "nvcuda.dll" - NvmlMgmtName = "nvml.dll" - OneapiMgmtName = "ze_intel_gpu64.dll" -) - func GetCPUMem() (memInfo, error) { memStatus := MEMORYSTATUSEX{length: sizeofMemoryStatusEx} r1, _, err := globalMemoryStatusExProc.Call(uintptr(unsafe.Pointer(&memStatus))) @@ -122,27 +99,22 @@ func (pkg *winPackage) IsMember(target *GROUP_AFFINITY) bool { } func getLogicalProcessorInformationEx() ([]byte, error) { - buf := make([]byte, 1) + buf := make([]byte, 1024) bufSize := len(buf) - ret, _, err := GetLogicalProcessorInformationEx.Call( - uintptr(RelationAll), - uintptr(unsafe.Pointer(&buf[0])), - uintptr(unsafe.Pointer(&bufSize)), - ) - if ret != 0 { - return nil, fmt.Errorf("failed to determine size info ret:%d %w", ret, err) + var err error + for range 3 { + var ret uintptr + ret, _, err = GetLogicalProcessorInformationEx.Call( + uintptr(RelationAll), + uintptr(unsafe.Pointer(&buf[0])), + uintptr(unsafe.Pointer(&bufSize)), + ) + if ret == 1 && bufSize <= len(buf) { + return buf, nil + } + buf = make([]byte, bufSize) } - - buf = make([]byte, bufSize) - ret, _, err = GetLogicalProcessorInformationEx.Call( - uintptr(RelationAll), - uintptr(unsafe.Pointer(&buf[0])), - uintptr(unsafe.Pointer(&bufSize)), - ) - if ret == 0 { - return nil, fmt.Errorf("failed to gather processor information ret:%d buflen:%d %w", ret, bufSize, err) - } - return buf, nil + return nil, fmt.Errorf("unable to determine CPU details: %w", err) } func processSystemLogicalProcessorInforationList(buf []byte) []*winPackage { @@ -217,10 +189,11 @@ func processSystemLogicalProcessorInforationList(buf []byte) []*winPackage { return packages } -func GetCPUDetails() ([]CPU, error) { +func GetCPUDetails() []CPU { buf, err := getLogicalProcessorInformationEx() if err != nil { - return nil, err + slog.Warn("failed to get CPU details", "error", err) + return nil } packages := processSystemLogicalProcessorInforationList(buf) cpus := make([]CPU, len(packages)) @@ -230,5 +203,10 @@ func GetCPUDetails() ([]CPU, error) { cpus[i].EfficiencyCoreCount = pkg.efficiencyCoreCount cpus[i].ThreadCount = pkg.threadCount } - return cpus, nil + return cpus +} + +func IsNUMA() bool { + // numa support in ggml is linux only + return false } diff --git a/discover/gpu_windows_test.go b/discover/cpu_windows_test.go similarity index 100% rename from discover/gpu_windows_test.go rename to discover/cpu_windows_test.go diff --git a/discover/cuda_common.go b/discover/cuda_common.go deleted file mode 100644 index a2c43420..00000000 --- a/discover/cuda_common.go +++ /dev/null @@ -1,64 +0,0 @@ -//go:build linux || windows - -package discover - -import ( - "fmt" - "log/slog" - "os" - "regexp" - "runtime" - "strconv" - "strings" -) - -// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed. -// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. -var CudaTegra string = os.Getenv("JETSON_JETPACK") - -func cudaVariant(gpuInfos []CudaGPUInfo) string { - if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" { - if CudaTegra != "" { - ver := strings.Split(CudaTegra, ".") - if len(ver) > 0 { - return "jetpack" + ver[0] - } - } else if data, err := os.ReadFile("/etc/nv_tegra_release"); err == nil { - r := regexp.MustCompile(` R(\d+) `) - m := r.FindSubmatch(data) - if len(m) != 2 { - slog.Info("Unexpected format for /etc/nv_tegra_release. Set JETSON_JETPACK to select version") - } else { - if l4t, err := strconv.Atoi(string(m[1])); err == nil { - // Note: mapping from L4t -> JP is inconsistent (can't just subtract 30) - // https://developer.nvidia.com/embedded/jetpack-archive - switch l4t { - case 35: - return "jetpack5" - case 36: - return "jetpack6" - default: - slog.Info("unsupported L4T version", "nv_tegra_release", string(data)) - } - } - } - } - } - - // Check GPU compute capability FIRST, lowest common denominator if multi-gpu - for _, gpuInfo := range gpuInfos { - if gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) { - // GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1) - return "v12" - } - } - - // GPU is Turing or newer (CC >= 7.5) - can use newer CUDA - if len(gpuInfos) > 0 && gpuInfos[0].DriverMajor < 13 { - // The detected driver is older than 580 (Aug 2025) - // Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance - slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfos[0].DriverMajor, gpuInfos[0].DriverMinor)) - return "v12" - } - return "v13" -} diff --git a/discover/gpu.go b/discover/gpu.go index 4bb0d94e..2f394fdf 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -1,730 +1,207 @@ -//go:build linux || windows - package discover -/* -#cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm -#cgo windows LDFLAGS: -lpthread - -#include "gpu_info.h" -*/ -import "C" - import ( - "fmt" + "context" "log/slog" "os" "path/filepath" + "regexp" "runtime" "strconv" "strings" - "sync" - "unsafe" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/ml" ) -type cudaHandles struct { - deviceCount int - cudart *C.cudart_handle_t - nvcuda *C.nvcuda_handle_t - nvml *C.nvml_handle_t +// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed. +// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. +var CudaTegra string = os.Getenv("JETSON_JETPACK") + +func GetCPUInfo() GpuInfo { + mem, err := GetCPUMem() + if err != nil { + slog.Warn("error looking up system memory", "error", err) + } + + return GpuInfo{ + memInfo: mem, + DeviceID: ml.DeviceID{ + Library: "cpu", + ID: "0", + }, + } } -type oneapiHandles struct { - oneapi *C.oneapi_handle_t - deviceCount int +func GetGPUInfo(ctx context.Context, runners []FilteredRunnerDiscovery) GpuInfoList { + devs := GPUDevices(ctx, runners) + return devInfoToInfoList(devs) } -const ( - cudaMinimumMemory = 457 * format.MebiByte - rocmMinimumMemory = 457 * format.MebiByte - // TODO OneAPI minimum memory -) - -var ( - gpuMutex sync.Mutex - bootstrapped bool - cpus []CPUInfo - cudaGPUs []CudaGPUInfo - nvcudaLibPath string - cudartLibPath string - oneapiLibPath string - nvmlLibPath string - rocmGPUs []RocmGPUInfo - oneapiGPUs []OneapiGPUInfo - - // If any discovered GPUs are incompatible, report why - unsupportedGPUs []UnsupportedGPUInfo - - // Keep track of errors during bootstrapping so that if GPUs are missing - // they expected to be present this may explain why - bootstrapErrors []error -) - -// With our current CUDA compile flags, older than 5.0 will not work properly -// (string values used to allow ldflags overrides at build time) -var ( - CudaComputeMajorMin = "5" - CudaComputeMinorMin = "0" -) -//change valute from 9 to 8 would release the gfx version limits ,refer to https://github.com/likelovewant/ollama-for-amd/issues/51 -var RocmComputeMajorMin = "8" - -// TODO find a better way to detect iGPU instead of minimum memory -const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU - -// Note: gpuMutex must already be held -func initCudaHandles() *cudaHandles { - // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing - - cHandles := &cudaHandles{} - // Short Circuit if we already know which library to use - // ignore bootstrap errors in this case since we already recorded them - if nvmlLibPath != "" { - cHandles.nvml, _, _ = loadNVMLMgmt([]string{nvmlLibPath}) - return cHandles - } - if nvcudaLibPath != "" { - cHandles.deviceCount, cHandles.nvcuda, _, _ = loadNVCUDAMgmt([]string{nvcudaLibPath}) - return cHandles - } - if cudartLibPath != "" { - cHandles.deviceCount, cHandles.cudart, _, _ = loadCUDARTMgmt([]string{cudartLibPath}) - return cHandles - } - - slog.Debug("searching for GPU discovery libraries for NVIDIA") - var cudartMgmtPatterns []string - - // Aligned with driver, we can't carry as payloads - nvcudaMgmtPatterns := NvcudaGlobs - cudartMgmtPatterns = append(cudartMgmtPatterns, filepath.Join(LibOllamaPath, "cuda_v*", CudartMgmtName)) - cudartMgmtPatterns = append(cudartMgmtPatterns, CudartGlobs...) - - if len(NvmlGlobs) > 0 { - nvmlLibPaths := FindGPULibs(NvmlMgmtName, NvmlGlobs) - if len(nvmlLibPaths) > 0 { - nvml, libPath, err := loadNVMLMgmt(nvmlLibPaths) - if nvml != nil { - slog.Debug("nvidia-ml loaded", "library", libPath) - cHandles.nvml = nvml - nvmlLibPath = libPath - } - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - } - } - - nvcudaLibPaths := FindGPULibs(NvcudaMgmtName, nvcudaMgmtPatterns) - if len(nvcudaLibPaths) > 0 { - deviceCount, nvcuda, libPath, err := loadNVCUDAMgmt(nvcudaLibPaths) - if nvcuda != nil { - slog.Debug("detected GPUs", "count", deviceCount, "library", libPath) - cHandles.nvcuda = nvcuda - cHandles.deviceCount = deviceCount - nvcudaLibPath = libPath - return cHandles - } - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - } - - cudartLibPaths := FindGPULibs(CudartMgmtName, cudartMgmtPatterns) - if len(cudartLibPaths) > 0 { - deviceCount, cudart, libPath, err := loadCUDARTMgmt(cudartLibPaths) - if cudart != nil { - slog.Debug("detected GPUs", "library", libPath, "count", deviceCount) - cHandles.cudart = cudart - cHandles.deviceCount = deviceCount - cudartLibPath = libPath - return cHandles - } - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - } - - return cHandles -} - -// Note: gpuMutex must already be held -func initOneAPIHandles() *oneapiHandles { - oHandles := &oneapiHandles{} - - // Short Circuit if we already know which library to use - // ignore bootstrap errors in this case since we already recorded them - if oneapiLibPath != "" { - oHandles.deviceCount, oHandles.oneapi, _, _ = loadOneapiMgmt([]string{oneapiLibPath}) - return oHandles - } - - oneapiLibPaths := FindGPULibs(OneapiMgmtName, OneapiGlobs) - if len(oneapiLibPaths) > 0 { - var err error - oHandles.deviceCount, oHandles.oneapi, oneapiLibPath, err = loadOneapiMgmt(oneapiLibPaths) - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - } - - return oHandles -} - -func GetCPUInfo() GpuInfoList { - gpuMutex.Lock() - if !bootstrapped { - gpuMutex.Unlock() - GetGPUInfo() - } else { - gpuMutex.Unlock() - } - return GpuInfoList{cpus[0].GpuInfo} -} - -func GetGPUInfo() GpuInfoList { - // TODO - consider exploring lspci (and equivalent on windows) to check for - // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries - gpuMutex.Lock() - defer gpuMutex.Unlock() - needRefresh := true - var cHandles *cudaHandles - var oHandles *oneapiHandles - defer func() { - if cHandles != nil { - if cHandles.cudart != nil { - C.cudart_release(*cHandles.cudart) - } - if cHandles.nvcuda != nil { - C.nvcuda_release(*cHandles.nvcuda) - } - if cHandles.nvml != nil { - C.nvml_release(*cHandles.nvml) - } - } - if oHandles != nil { - if oHandles.oneapi != nil { - // TODO - is this needed? - C.oneapi_release(*oHandles.oneapi) - } - } - }() - - if !bootstrapped { - slog.Info("looking for compatible GPUs") - cudaComputeMajorMin, err := strconv.Atoi(CudaComputeMajorMin) - if err != nil { - slog.Error("invalid CudaComputeMajorMin setting", "value", CudaComputeMajorMin, "error", err) - } - cudaComputeMinorMin, err := strconv.Atoi(CudaComputeMinorMin) - if err != nil { - slog.Error("invalid CudaComputeMinorMin setting", "value", CudaComputeMinorMin, "error", err) - } - bootstrapErrors = []error{} - needRefresh = false - var memInfo C.mem_info_t - - mem, err := GetCPUMem() - if err != nil { - slog.Warn("error looking up system memory", "error", err) - } - - details, err := GetCPUDetails() - if err != nil { - slog.Warn("failed to lookup CPU details", "error", err) - } - cpus = []CPUInfo{ - { - GpuInfo: GpuInfo{ - memInfo: mem, - Library: "cpu", - ID: "0", - }, - CPUs: details, - }, - } - - // Load ALL libraries - cHandles = initCudaHandles() - - // NVIDIA - for i := range cHandles.deviceCount { - if cHandles.cudart != nil || cHandles.nvcuda != nil { - gpuInfo := CudaGPUInfo{ - GpuInfo: GpuInfo{ - Library: "cuda", - }, - index: i, - } - var driverMajor int - var driverMinor int - if cHandles.cudart != nil { - C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo) - driverMajor = int(cHandles.cudart.driver_major) - driverMinor = int(cHandles.cudart.driver_minor) - } else { - C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo) - driverMajor = int(cHandles.nvcuda.driver_major) - driverMinor = int(cHandles.nvcuda.driver_minor) - } - if memInfo.err != nil { - slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) - C.free(unsafe.Pointer(memInfo.err)) - continue - } - gpuInfo.TotalMemory = uint64(memInfo.total) - gpuInfo.FreeMemory = uint64(memInfo.free) - gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) - gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor) - gpuInfo.computeMajor = int(memInfo.major) - gpuInfo.computeMinor = int(memInfo.minor) - gpuInfo.MinimumMemory = cudaMinimumMemory - gpuInfo.DriverMajor = driverMajor - gpuInfo.DriverMinor = driverMinor - - gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - - if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) { - unsupportedGPUs = append(unsupportedGPUs, - UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - }) - slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor)) - continue - } - - // query the management library as well so we can record any skew between the two - // which represents overhead on the GPU we must set aside on subsequent updates - if cHandles.nvml != nil { - uuid := C.CString(gpuInfo.ID) - defer C.free(unsafe.Pointer(uuid)) - C.nvml_get_free(*cHandles.nvml, uuid, &memInfo.free, &memInfo.total, &memInfo.used) - if memInfo.err != nil { - slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) - C.free(unsafe.Pointer(memInfo.err)) - } else { - if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory { - gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory - slog.Info("detected OS VRAM overhead", - "id", gpuInfo.ID, - "library", gpuInfo.Library, - "compute", gpuInfo.Compute, - "driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor), - "name", gpuInfo.Name, - "overhead", format.HumanBytes2(gpuInfo.OSOverhead), - ) - } - } - } - - // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... - cudaGPUs = append(cudaGPUs, gpuInfo) - } - // Second pass on NVIDIA GPUs to set lowest common denominator variant and DependencyPaths - variant := cudaVariant(cudaGPUs) - var variantPath string - // Start with our bundled libraries - if variant != "" { - variantPath = filepath.Join(LibOllamaPath, "cuda_"+variant) - if _, err := os.Stat(variantPath); err != nil { - variantPath = "" - } - } - - for i := range cudaGPUs { - cudaGPUs[i].Variant = variant - if variantPath != "" { - // Put the variant directory first in the search path to avoid runtime linking to the wrong library - cudaGPUs[i].DependencyPath = append([]string{variantPath}, cudaGPUs[i].DependencyPath...) - } - } - } - - // Intel - if envconfig.IntelGPU() { - oHandles = initOneAPIHandles() - if oHandles != nil && oHandles.oneapi != nil { - for d := range oHandles.oneapi.num_drivers { - if oHandles.oneapi == nil { - // shouldn't happen - slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers)) - continue - } - devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d)) - for i := range devCount { - gpuInfo := OneapiGPUInfo{ - GpuInfo: GpuInfo{ - Library: "oneapi", - }, - driverIndex: int(d), - gpuIndex: int(i), - } - // TODO - split bootstrapping from updating free memory - C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo) - // TODO - convert this to MinimumMemory based on testing... - var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. - memInfo.free = C.uint64_t(totalFreeMem) - gpuInfo.TotalMemory = uint64(memInfo.total) - gpuInfo.FreeMemory = uint64(memInfo.free) - gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) - gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - gpuInfo.DependencyPath = []string{LibOllamaPath} - oneapiGPUs = append(oneapiGPUs, gpuInfo) - } - } - } - } - - rocmGPUs, err = AMDGetGPUInfo() - - // The ID field is used in context of the filtered set of GPUS - // so we have to replace any of these numeric IDs with their - // placement in this set of GPUs - for i := range rocmGPUs { - if _, err := strconv.Atoi(rocmGPUs[i].ID); err == nil { - rocmGPUs[i].ID = strconv.Itoa(i) - } - } - if err != nil { - bootstrapErrors = append(bootstrapErrors, err) - } - bootstrapped = true - if len(cudaGPUs) == 0 && len(rocmGPUs) == 0 && len(oneapiGPUs) == 0 { - slog.Info("no compatible GPUs were discovered") - } - - // TODO verify we have runners for the discovered GPUs, filter out any that aren't supported with good error messages - } - - // For detected GPUs, load library if not loaded - - // Refresh free memory usage - if needRefresh { - mem, err := GetCPUMem() - if err != nil { - slog.Warn("error looking up system memory", "error", err) - } else { - slog.Debug("updating system memory data", - slog.Group( - "before", - "total", format.HumanBytes2(cpus[0].TotalMemory), - "free", format.HumanBytes2(cpus[0].FreeMemory), - "free_swap", format.HumanBytes2(cpus[0].FreeSwap), - ), - slog.Group( - "now", - "total", format.HumanBytes2(mem.TotalMemory), - "free", format.HumanBytes2(mem.FreeMemory), - "free_swap", format.HumanBytes2(mem.FreeSwap), - ), - ) - cpus[0].FreeMemory = mem.FreeMemory - cpus[0].FreeSwap = mem.FreeSwap - } - - var memInfo C.mem_info_t - if cHandles == nil && len(cudaGPUs) > 0 { - cHandles = initCudaHandles() - } - for i, gpu := range cudaGPUs { - if cHandles.nvml != nil { - uuid := C.CString(gpu.ID) - defer C.free(unsafe.Pointer(uuid)) - C.nvml_get_free(*cHandles.nvml, uuid, &memInfo.free, &memInfo.total, &memInfo.used) - } else if cHandles.cudart != nil { - C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo) - } else if cHandles.nvcuda != nil { - C.nvcuda_get_free(*cHandles.nvcuda, C.int(gpu.index), &memInfo.free, &memInfo.total) - memInfo.used = memInfo.total - memInfo.free - } else { - // shouldn't happen - slog.Warn("no valid cuda library loaded to refresh vram usage") - break - } - if memInfo.err != nil { - slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) - C.free(unsafe.Pointer(memInfo.err)) - continue - } - if memInfo.free == 0 { - slog.Warn("error looking up nvidia GPU memory") - continue - } - if cHandles.nvml != nil && gpu.OSOverhead > 0 { - // When using the management library update based on recorded overhead - memInfo.free -= C.uint64_t(gpu.OSOverhead) - } - slog.Debug("updating cuda memory data", - "gpu", gpu.ID, - "name", gpu.Name, - "overhead", format.HumanBytes2(gpu.OSOverhead), - slog.Group( - "before", - "total", format.HumanBytes2(gpu.TotalMemory), - "free", format.HumanBytes2(gpu.FreeMemory), - ), - slog.Group( - "now", - "total", format.HumanBytes2(uint64(memInfo.total)), - "free", format.HumanBytes2(uint64(memInfo.free)), - "used", format.HumanBytes2(uint64(memInfo.used)), - ), - ) - cudaGPUs[i].FreeMemory = uint64(memInfo.free) - } - - if oHandles == nil && len(oneapiGPUs) > 0 { - oHandles = initOneAPIHandles() - } - for i, gpu := range oneapiGPUs { - if oHandles.oneapi == nil { - // shouldn't happen - slog.Warn("nil oneapi handle with device count", "count", oHandles.deviceCount) - continue - } - C.oneapi_check_vram(*oHandles.oneapi, C.int(gpu.driverIndex), C.int(gpu.gpuIndex), &memInfo) - // TODO - convert this to MinimumMemory based on testing... - var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. - memInfo.free = C.uint64_t(totalFreeMem) - oneapiGPUs[i].FreeMemory = uint64(memInfo.free) - } - - err = RocmGPUInfoList(rocmGPUs).RefreshFreeMemory() - if err != nil { - slog.Debug("problem refreshing ROCm free memory", "error", err) - } - } - +func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList { resp := []GpuInfo{} - for _, gpu := range cudaGPUs { - resp = append(resp, gpu.GpuInfo) + // Our current packaging model places ggml-hip in the main directory + // but keeps rocm in an isolated directory. We have to add it to + // the [LD_LIBRARY_]PATH so ggml-hip will load properly + rocmDir := filepath.Join(LibOllamaPath, "rocm") + if _, err := os.Stat(rocmDir); err != nil { + rocmDir = "" } - for _, gpu := range rocmGPUs { - resp = append(resp, gpu.GpuInfo) - } - for _, gpu := range oneapiGPUs { - resp = append(resp, gpu.GpuInfo) + + for _, dev := range devs { + info := GpuInfo{ + DeviceID: dev.DeviceID, + filterID: dev.FilteredID, + Name: dev.Description, + memInfo: memInfo{ + TotalMemory: dev.TotalMemory, + FreeMemory: dev.FreeMemory, + }, + // TODO can we avoid variant + DependencyPath: dev.LibraryPath, + DriverMajor: dev.DriverMajor, + DriverMinor: dev.DriverMinor, + ComputeMajor: dev.ComputeMajor, + ComputeMinor: dev.ComputeMinor, + } + if dev.Library == "CUDA" || dev.Library == "ROCm" { + info.MinimumMemory = 457 * format.MebiByte + } + if dev.Library == "ROCm" && rocmDir != "" { + info.DependencyPath = append(info.DependencyPath, rocmDir) + } + // TODO any special processing of Vulkan devices? + resp = append(resp, info) } if len(resp) == 0 { - resp = append(resp, cpus[0].GpuInfo) + mem, err := GetCPUMem() + if err != nil { + slog.Warn("error looking up system memory", "error", err) + } + + resp = append(resp, GpuInfo{ + memInfo: mem, + DeviceID: ml.DeviceID{ + Library: "cpu", + ID: "0", + }, + }) } return resp } -func FindGPULibs(baseLibName string, defaultPatterns []string) []string { - // Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them - gpuLibPaths := []string{} - slog.Debug("Searching for GPU library", "name", baseLibName) - - // search our bundled libraries first - patterns := []string{filepath.Join(LibOllamaPath, baseLibName)} - - var ldPaths []string - switch runtime.GOOS { - case "windows": - ldPaths = strings.Split(os.Getenv("PATH"), string(os.PathListSeparator)) - case "linux": - ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), string(os.PathListSeparator)) - } - - // then search the system's LD_LIBRARY_PATH - for _, p := range ldPaths { - p, err := filepath.Abs(p) - if err != nil { - continue - } - patterns = append(patterns, filepath.Join(p, baseLibName)) - } - - // finally, search the default patterns provided by the caller - patterns = append(patterns, defaultPatterns...) - slog.Debug("gpu library search", "globs", patterns) - for _, pattern := range patterns { - // Nvidia PhysX known to return bogus results - if strings.Contains(pattern, "PhysX") { - slog.Debug("skipping PhysX cuda library path", "path", pattern) - continue - } - // Ignore glob discovery errors - matches, _ := filepath.Glob(pattern) - for _, match := range matches { - // Resolve any links so we don't try the same lib multiple times - // and weed out any dups across globs - libPath := match - tmp := match - var err error - for ; err == nil; tmp, err = os.Readlink(libPath) { - if !filepath.IsAbs(tmp) { - tmp = filepath.Join(filepath.Dir(libPath), tmp) - } - libPath = tmp - } - new := true - for _, cmp := range gpuLibPaths { - if cmp == libPath { - new = false - break - } - } - if new { - gpuLibPaths = append(gpuLibPaths, libPath) - } - } - } - slog.Debug("discovered GPU libraries", "paths", gpuLibPaths) - return gpuLibPaths -} - -// Bootstrap the runtime library -// Returns: num devices, handle, libPath, error -func loadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string, error) { - var resp C.cudart_init_resp_t - resp.ch.verbose = getVerboseState() - var err error - for _, libPath := range cudartLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.cudart_init(lib, &resp) - if resp.err != nil { - err = fmt.Errorf("Unable to load cudart library %s: %s", libPath, C.GoString(resp.err)) - slog.Debug(err.Error()) - C.free(unsafe.Pointer(resp.err)) - } else { - err = nil - return int(resp.num_devices), &resp.ch, libPath, err - } - } - return 0, nil, "", err -} - -// Bootstrap the driver library -// Returns: num devices, handle, libPath, error -func loadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string, error) { - var resp C.nvcuda_init_resp_t - resp.ch.verbose = getVerboseState() - var err error - for _, libPath := range nvcudaLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.nvcuda_init(lib, &resp) - if resp.err != nil { - // Decide what log level based on the type of error message to help users understand why - switch resp.cudaErr { - case C.CUDA_ERROR_INSUFFICIENT_DRIVER, C.CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: - err = fmt.Errorf("version mismatch between driver and cuda driver library - reboot or upgrade may be required: library %s", libPath) - slog.Warn(err.Error()) - case C.CUDA_ERROR_NO_DEVICE: - err = fmt.Errorf("no nvidia devices detected by library %s", libPath) - slog.Info(err.Error()) - case C.CUDA_ERROR_UNKNOWN: - err = fmt.Errorf("unknown error initializing cuda driver library %s: %s. see https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for more information", libPath, C.GoString(resp.err)) - slog.Warn(err.Error()) - default: - msg := C.GoString(resp.err) - if strings.Contains(msg, "wrong ELF class") { - slog.Debug("skipping 32bit library", "library", libPath) - } else { - err = fmt.Errorf("Unable to load cudart library %s: %s", libPath, C.GoString(resp.err)) - slog.Info(err.Error()) - } - } - C.free(unsafe.Pointer(resp.err)) - } else { - err = nil - return int(resp.num_devices), &resp.ch, libPath, err - } - } - return 0, nil, "", err -} - -// Bootstrap the management library -// Returns: handle, libPath, error -func loadNVMLMgmt(nvmlLibPaths []string) (*C.nvml_handle_t, string, error) { - var resp C.nvml_init_resp_t - resp.ch.verbose = getVerboseState() - var err error - for _, libPath := range nvmlLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.nvml_init(lib, &resp) - if resp.err != nil { - err = fmt.Errorf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)) - slog.Info(err.Error()) - C.free(unsafe.Pointer(resp.err)) - } else { - err = nil - return &resp.ch, libPath, err - } - } - return nil, "", err -} - -// bootstrap the Intel GPU library -// Returns: num devices, handle, libPath, error -func loadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string, error) { - var resp C.oneapi_init_resp_t - num_devices := 0 - resp.oh.verbose = getVerboseState() - var err error - for _, libPath := range oneapiLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.oneapi_init(lib, &resp) - if resp.err != nil { - err = fmt.Errorf("Unable to load oneAPI management library %s: %s", libPath, C.GoString(resp.err)) - slog.Debug(err.Error()) - C.free(unsafe.Pointer(resp.err)) - } else { - err = nil - for i := range resp.oh.num_drivers { - num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i))) - } - return num_devices, &resp.oh, libPath, err - } - } - return 0, nil, "", err -} - -func getVerboseState() C.uint16_t { - if envconfig.LogLevel() < slog.LevelInfo { - return C.uint16_t(1) - } - return C.uint16_t(0) -} - // Given the list of GPUs this instantiation is targeted for, // figure out the visible devices environment variable +// +// If different libraries are detected, the first one is what we use func (l GpuInfoList) GetVisibleDevicesEnv() []string { if len(l) == 0 { return nil } - vd := []string{} - // Only filter the AMD GPUs at this level, let all NVIDIA devices through - if tmp := rocmGetVisibleDevicesEnv(l); tmp != "" { - vd = append(vd, tmp) + res := []string{} + envVar := rocmGetVisibleDevicesEnv(l) + if envVar != "" { + res = append(res, envVar) } - return vd + envVar = vkGetVisibleDevicesEnv(l) + if envVar != "" { + res = append(res, envVar) + } + return res } -func GetSystemInfo() SystemInfo { - gpus := GetGPUInfo() - gpuMutex.Lock() - defer gpuMutex.Unlock() - discoveryErrors := []string{} - for _, err := range bootstrapErrors { - discoveryErrors = append(discoveryErrors, err.Error()) +func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { + ids := []string{} + for _, info := range gpuInfo { + if info.Library != "ROCm" { + continue + } + // If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number + if info.filterID != "" { + ids = append(ids, info.filterID) + } else { + ids = append(ids, info.ID) + } } + if len(ids) == 0 { + return "" + } + envVar := "ROCR_VISIBLE_DEVICES=" + if runtime.GOOS != "linux" { + envVar = "HIP_VISIBLE_DEVICES=" + } + // There are 3 potential env vars to use to select GPUs. + // ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows + // HIP_VISIBLE_DEVICES supports numeric IDs only + // GPU_DEVICE_ORDINAL supports numeric IDs only + return envVar + strings.Join(ids, ",") +} + +func vkGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { + ids := []string{} + for _, info := range gpuInfo { + if info.Library != "Vulkan" { + continue + } + if info.filterID != "" { + ids = append(ids, info.filterID) + } else { + ids = append(ids, info.ID) + } + } + if len(ids) == 0 { + return "" + } + envVar := "GGML_VK_VISIBLE_DEVICES=" + return envVar + strings.Join(ids, ",") +} + +// GetSystemInfo returns the last cached state of the GPUs on the system +func GetSystemInfo() SystemInfo { + deviceMu.Lock() + defer deviceMu.Unlock() + gpus := devInfoToInfoList(devices) if len(gpus) == 1 && gpus[0].Library == "cpu" { gpus = []GpuInfo{} } return SystemInfo{ - System: cpus[0], - GPUs: gpus, - UnsupportedGPUs: unsupportedGPUs, - DiscoveryErrors: discoveryErrors, + System: CPUInfo{ + CPUs: GetCPUDetails(), + GpuInfo: GetCPUInfo(), + }, + GPUs: gpus, } } + +func cudaJetpack() string { + if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" { + if CudaTegra != "" { + ver := strings.Split(CudaTegra, ".") + if len(ver) > 0 { + return "jetpack" + ver[0] + } + } else if data, err := os.ReadFile("/etc/nv_tegra_release"); err == nil { + r := regexp.MustCompile(` R(\d+) `) + m := r.FindSubmatch(data) + if len(m) != 2 { + slog.Info("Unexpected format for /etc/nv_tegra_release. Set JETSON_JETPACK to select version") + } else { + if l4t, err := strconv.Atoi(string(m[1])); err == nil { + // Note: mapping from L4t -> JP is inconsistent (can't just subtract 30) + // https://developer.nvidia.com/embedded/jetpack-archive + switch l4t { + case 35: + return "jetpack5" + case 36: + return "jetpack6" + default: + // Newer Jetson systems use the SBSU runtime + slog.Debug("unrecognized L4T version", "nv_tegra_release", string(data)) + } + } + } + } + } + return "" +} diff --git a/discover/gpu_darwin.go b/discover/gpu_darwin.go index 29b44ff5..6f55b4c5 100644 --- a/discover/gpu_darwin.go +++ b/discover/gpu_darwin.go @@ -1,5 +1,3 @@ -//go:build darwin - package discover /* @@ -11,7 +9,6 @@ import "C" import ( "log/slog" - "runtime" "syscall" "github.com/ollama/ollama/format" @@ -21,39 +18,6 @@ const ( metalMinimumMemory = 512 * format.MebiByte ) -func GetGPUInfo() GpuInfoList { - mem, _ := GetCPUMem() - if runtime.GOARCH == "amd64" { - return []GpuInfo{ - { - Library: "cpu", - memInfo: mem, - }, - } - } - info := GpuInfo{ - Library: "metal", - ID: "0", - } - info.TotalMemory = uint64(C.getRecommendedMaxVRAM()) - - // TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work) - info.FreeMemory = info.TotalMemory - - info.MinimumMemory = metalMinimumMemory - return []GpuInfo{info} -} - -func GetCPUInfo() GpuInfoList { - mem, _ := GetCPUMem() - return []GpuInfo{ - { - Library: "cpu", - memInfo: mem, - }, - } -} - func GetCPUMem() (memInfo, error) { return memInfo{ TotalMemory: uint64(C.getPhysicalMemory()), @@ -62,13 +26,7 @@ func GetCPUMem() (memInfo, error) { }, nil } -func (l GpuInfoList) GetVisibleDevicesEnv() []string { - // No-op on darwin - return nil -} - -func GetSystemInfo() SystemInfo { - mem, _ := GetCPUMem() +func GetCPUDetails() []CPU { query := "hw.perflevel0.physicalcpu" perfCores, err := syscall.SysctlUint32(query) if err != nil { @@ -81,19 +39,16 @@ func GetSystemInfo() SystemInfo { query = "hw.logicalcpu" logicalCores, _ := syscall.SysctlUint32(query) - return SystemInfo{ - System: CPUInfo{ - GpuInfo: GpuInfo{ - memInfo: mem, - }, - CPUs: []CPU{ - { - CoreCount: int(perfCores + efficiencyCores), - EfficiencyCoreCount: int(efficiencyCores), - ThreadCount: int(logicalCores), - }, - }, + return []CPU{ + { + CoreCount: int(perfCores + efficiencyCores), + EfficiencyCoreCount: int(efficiencyCores), + ThreadCount: int(logicalCores), }, - GPUs: GetGPUInfo(), } } + +func IsNUMA() bool { + // numa support in ggml is linux only + return false +} diff --git a/discover/gpu_info.h b/discover/gpu_info.h deleted file mode 100644 index ee7ff4c3..00000000 --- a/discover/gpu_info.h +++ /dev/null @@ -1,72 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_H__ -#define __GPU_INFO_H__ -#include -#include -#include - -#ifndef _WIN32 -#include -#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags) -#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym) -#define LOAD_ERR() strdup(dlerror()) -#define UNLOAD_LIBRARY(handle) dlclose(handle) -#else -#include -#define LOAD_LIBRARY(lib, flags) LoadLibrary(lib) -#define LOAD_SYMBOL(handle, sym) GetProcAddress(handle, sym) -#define UNLOAD_LIBRARY(handle) FreeLibrary(handle) -#define LOAD_ERR() ({\ - LPSTR messageBuffer = NULL; \ - size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, \ - NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, NULL); \ - char *resp = strdup(messageBuffer); \ - LocalFree(messageBuffer); \ - resp; \ -}) - -#endif - -#ifndef LOG -#define LOG(verbose, ...) \ - do { \ - if (verbose) { \ - fprintf(stderr, __VA_ARGS__); \ - } \ - } while (0) -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -#define GPU_ID_LEN 64 -#define GPU_NAME_LEN 96 - -typedef struct mem_info { - char *err; // If non-nill, caller responsible for freeing - char gpu_id[GPU_ID_LEN]; - char gpu_name[GPU_NAME_LEN]; - uint64_t total; - uint64_t free; - uint64_t used; - - // Compute Capability - int major; - int minor; - int patch; -} mem_info_t; - -void cpu_check_ram(mem_info_t *resp); - -#ifdef __cplusplus -} -#endif - -#include "gpu_info_cudart.h" -#include "gpu_info_nvcuda.h" -#include "gpu_info_nvml.h" -#include "gpu_info_oneapi.h" - -#endif // __GPU_INFO_H__ -#endif // __APPLE__ diff --git a/discover/gpu_info_cudart.c b/discover/gpu_info_cudart.c deleted file mode 100644 index 76c17b9d..00000000 --- a/discover/gpu_info_cudart.c +++ /dev/null @@ -1,181 +0,0 @@ -#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs? - -#include -#include -#include "gpu_info_cudart.h" - -void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { - cudartReturn_t ret; - resp->err = NULL; - resp->num_devices = 0; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - struct lookup { - char *s; - void **p; - } l[] = { - {"cudaSetDevice", (void *)&resp->ch.cudaSetDevice}, - {"cudaDeviceSynchronize", (void *)&resp->ch.cudaDeviceSynchronize}, - {"cudaDeviceReset", (void *)&resp->ch.cudaDeviceReset}, - {"cudaMemGetInfo", (void *)&resp->ch.cudaMemGetInfo}, - {"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount}, - {"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute}, - {"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion}, - {"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties}, - {NULL, NULL}, - }; - - resp->ch.handle = LOAD_LIBRARY(cudart_lib_path, RTLD_LAZY); - if (!resp->ch.handle) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "library %s load err: %s\n", cudart_lib_path, msg); - snprintf(buf, buflen, - "Unable to load %s library to query for Nvidia GPUs: %s", - cudart_lib_path, msg); - free(msg); - resp->err = strdup(buf); - return; - } - - for (i = 0; l[i].s != NULL; i++) { - *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); - if (!*(l[i].p)) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, - msg); - free(msg); - resp->err = strdup(buf); - return; - } - } - - ret = (*resp->ch.cudaSetDevice)(0); - if (ret != CUDART_SUCCESS) { - LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - if (ret == CUDART_ERROR_INSUFFICIENT_DRIVER) { - resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama"); - return; - } - snprintf(buf, buflen, "cudart init failure: %d", ret); - resp->err = strdup(buf); - return; - } - - int version = 0; - - // Report driver version if we're in verbose mode, ignore errors - ret = (*resp->ch.cudaDriverGetVersion)(&version); - if (ret != CUDART_SUCCESS) { - LOG(resp->ch.verbose, "cudaDriverGetVersion failed: %d\n", ret); - } else { - resp->ch.driver_major = version / 1000; - resp->ch.driver_minor = (version - (resp->ch.driver_major * 1000)) / 10; - LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", resp->ch.driver_major, resp->ch.driver_minor); - } - - ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices); - if (ret != CUDART_SUCCESS) { - LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "unable to get device count: %d", ret); - resp->err = strdup(buf); - return; - } -} - - -void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) { - resp->err = NULL; - cudartMemory_t memInfo = {0,0,0}; - cudartReturn_t ret; - const int buflen = 256; - char buf[buflen + 1]; - - if (h.handle == NULL) { - resp->err = strdup("cudart handle isn't initialized"); - return; - } - - ret = (*h.cudaSetDevice)(i); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "cudart device failed to initialize"); - resp->err = strdup(buf); - return; - } - - cudaDeviceProp_t props; - ret = (*h.cudaGetDeviceProperties)(&props, i); - if (ret != CUDART_SUCCESS) { - LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret); - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); - resp->major = 0; - resp->minor = 0; - } else { - int allNull = 1; - for (int j = 0; j < 16; j++) { - if (props.uuid.bytes[j] != 0) { - allNull = 0; - break; - } - } - if (allNull != 0) { - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); - } else { - // GPU-d110a105-ac29-1d54-7b49-9c90440f215b - snprintf(&resp->gpu_id[0], GPU_ID_LEN, - "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", - props.uuid.bytes[0], - props.uuid.bytes[1], - props.uuid.bytes[2], - props.uuid.bytes[3], - props.uuid.bytes[4], - props.uuid.bytes[5], - props.uuid.bytes[6], - props.uuid.bytes[7], - props.uuid.bytes[8], - props.uuid.bytes[9], - props.uuid.bytes[10], - props.uuid.bytes[11], - props.uuid.bytes[12], - props.uuid.bytes[13], - props.uuid.bytes[14], - props.uuid.bytes[15] - ); - } - resp->major = props.major; - resp->minor = props.minor; - - // TODO add other useful properties from props - } - ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret); - resp->err = strdup(buf); - return; - } - - resp->total = memInfo.total; - resp->free = memInfo.free; - resp->used = memInfo.used; - - LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "\n", resp->gpu_id, resp->total); - LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "\n", resp->gpu_id, resp->free); - LOG(h.verbose, "[%s] CUDA usedMem %" PRId64 "\n", resp->gpu_id, resp->used); - LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor); -} - -void cudart_release(cudart_handle_t h) { - LOG(h.verbose, "releasing cudart library\n"); - UNLOAD_LIBRARY(h.handle); - h.handle = NULL; -} - -#endif // __APPLE__ diff --git a/discover/gpu_info_cudart.h b/discover/gpu_info_cudart.h deleted file mode 100644 index 893f3f7b..00000000 --- a/discover/gpu_info_cudart.h +++ /dev/null @@ -1,145 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_CUDART_H__ -#define __GPU_INFO_CUDART_H__ -#include "gpu_info.h" - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum cudartReturn_enum { - CUDART_SUCCESS = 0, - CUDART_ERROR_INVALID_VALUE = 1, - CUDART_ERROR_MEMORY_ALLOCATION = 2, - CUDART_ERROR_INSUFFICIENT_DRIVER = 35, - // Other values omitted for now... -} cudartReturn_t; - -typedef enum cudartDeviceAttr_enum { - cudartDevAttrComputeCapabilityMajor = 75, - cudartDevAttrComputeCapabilityMinor = 76, - - // TODO - not yet wired up but may be useful for Jetson or other - // integrated GPU scenarios with shared memory - cudaDevAttrIntegrated = 18 - -} cudartDeviceAttr_t; - -typedef void *cudartDevice_t; // Opaque is sufficient -typedef struct cudartMemory_st { - size_t total; - size_t free; - size_t used; -} cudartMemory_t; - -typedef struct cudaUUID { - unsigned char bytes[16]; -} cudaUUID_t; -typedef struct cudaDeviceProp { - char name[256]; /**< ASCII string identifying device */ - cudaUUID_t uuid; /**< 16-byte unique identifier */ - char luid[8]; /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */ - unsigned int luidDeviceNodeMask; /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */ - size_t totalGlobalMem; /**< Global memory available on device in bytes */ - size_t sharedMemPerBlock; /**< Shared memory available per block in bytes */ - int regsPerBlock; /**< 32-bit registers available per block */ - int warpSize; /**< Warp size in threads */ - size_t memPitch; /**< Maximum pitch in bytes allowed by memory copies */ - int maxThreadsPerBlock; /**< Maximum number of threads per block */ - int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */ - int maxGridSize[3]; /**< Maximum size of each dimension of a grid */ - int clockRate; /**< Clock frequency in kilohertz */ - size_t totalConstMem; /**< Constant memory available on device in bytes */ - int major; /**< Major compute capability */ - int minor; /**< Minor compute capability */ - size_t textureAlignment; /**< Alignment requirement for textures */ - size_t texturePitchAlignment; /**< Pitch alignment requirement for texture references bound to pitched memory */ - int deviceOverlap; /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */ - int multiProcessorCount; /**< Number of multiprocessors on device */ - int kernelExecTimeoutEnabled; /**< Specified whether there is a run time limit on kernels */ - int integrated; /**< Device is integrated as opposed to discrete */ - int canMapHostMemory; /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */ - int computeMode; /**< Compute mode (See ::cudaComputeMode) */ - int maxTexture1D; /**< Maximum 1D texture size */ - int maxTexture1DMipmap; /**< Maximum 1D mipmapped texture size */ - int maxTexture1DLinear; /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */ - int maxTexture2D[2]; /**< Maximum 2D texture dimensions */ - int maxTexture2DMipmap[2]; /**< Maximum 2D mipmapped texture dimensions */ - int maxTexture2DLinear[3]; /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */ - int maxTexture2DGather[2]; /**< Maximum 2D texture dimensions if texture gather operations have to be performed */ - int maxTexture3D[3]; /**< Maximum 3D texture dimensions */ - int maxTexture3DAlt[3]; /**< Maximum alternate 3D texture dimensions */ - int maxTextureCubemap; /**< Maximum Cubemap texture dimensions */ - int maxTexture1DLayered[2]; /**< Maximum 1D layered texture dimensions */ - int maxTexture2DLayered[3]; /**< Maximum 2D layered texture dimensions */ - int maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */ - int maxSurface1D; /**< Maximum 1D surface size */ - int maxSurface2D[2]; /**< Maximum 2D surface dimensions */ - int maxSurface3D[3]; /**< Maximum 3D surface dimensions */ - int maxSurface1DLayered[2]; /**< Maximum 1D layered surface dimensions */ - int maxSurface2DLayered[3]; /**< Maximum 2D layered surface dimensions */ - int maxSurfaceCubemap; /**< Maximum Cubemap surface dimensions */ - int maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */ - size_t surfaceAlignment; /**< Alignment requirements for surfaces */ - int concurrentKernels; /**< Device can possibly execute multiple kernels concurrently */ - int ECCEnabled; /**< Device has ECC support enabled */ - int pciBusID; /**< PCI bus ID of the device */ - int pciDeviceID; /**< PCI device ID of the device */ - int pciDomainID; /**< PCI domain ID of the device */ - int tccDriver; /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */ - int asyncEngineCount; /**< Number of asynchronous engines */ - int unifiedAddressing; /**< Device shares a unified address space with the host */ - int memoryClockRate; /**< Peak memory clock frequency in kilohertz */ - int memoryBusWidth; /**< Global memory bus width in bits */ - int l2CacheSize; /**< Size of L2 cache in bytes */ - int persistingL2CacheMaxSize; /**< Device's maximum l2 persisting lines capacity setting in bytes */ - int maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */ - int streamPrioritiesSupported; /**< Device supports stream priorities */ - int globalL1CacheSupported; /**< Device supports caching globals in L1 */ - int localL1CacheSupported; /**< Device supports caching locals in L1 */ - size_t sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */ - int regsPerMultiprocessor; /**< 32-bit registers available per multiprocessor */ - int managedMemory; /**< Device supports allocating managed memory on this system */ - int isMultiGpuBoard; /**< Device is on a multi-GPU board */ - int multiGpuBoardGroupID; /**< Unique identifier for a group of devices on the same multi-GPU board */ - int hostNativeAtomicSupported; /**< Link between the device and the host supports native atomic operations */ - int singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */ - int pageableMemoryAccess; /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */ - int concurrentManagedAccess; /**< Device can coherently access managed memory concurrently with the CPU */ - int computePreemptionSupported; /**< Device supports Compute Preemption */ - int canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */ - int cooperativeLaunch; /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */ - int cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */ - size_t sharedMemPerBlockOptin; /**< Per device maximum shared memory per block usable by special opt in */ - int pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */ - int directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */ - int maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */ - int accessPolicyMaxWindowSize; /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */ - size_t reservedSharedMemPerBlock; /**< Shared memory reserved by CUDA driver per block in bytes */ - } cudaDeviceProp_t; - -typedef struct cudart_handle { - void *handle; - uint16_t verbose; - int driver_major; - int driver_minor; - cudartReturn_t (*cudaSetDevice)(int device); - cudartReturn_t (*cudaDeviceSynchronize)(void); - cudartReturn_t (*cudaDeviceReset)(void); - cudartReturn_t (*cudaMemGetInfo)(size_t *, size_t *); - cudartReturn_t (*cudaGetDeviceCount)(int *); - cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device); - cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion); - cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device); -} cudart_handle_t; - -typedef struct cudart_init_resp { - char *err; // If err is non-null handle is invalid - cudart_handle_t ch; - int num_devices; -} cudart_init_resp_t; - -void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp); -void cudart_bootstrap(cudart_handle_t ch, int device_id, mem_info_t *resp); -// TODO - if we keep this library longer term, add cudart_get_free -void cudart_release(cudart_handle_t ch); - -#endif // __GPU_INFO_CUDART_H__ -#endif // __APPLE__ diff --git a/discover/gpu_info_nvcuda.c b/discover/gpu_info_nvcuda.c deleted file mode 100644 index d2d0b683..00000000 --- a/discover/gpu_info_nvcuda.c +++ /dev/null @@ -1,251 +0,0 @@ -#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs? - -#include -#include -#include "gpu_info_nvcuda.h" - -void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) { - LOG(resp->ch.verbose, "initializing %s\n", nvcuda_lib_path); - CUresult ret; - resp->err = NULL; - resp->num_devices = 0; - resp->cudaErr = CUDA_SUCCESS; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - struct lookup { - char *s; - void **p; - } l[] = { - - {"cuInit", (void *)&resp->ch.cuInit}, - {"cuDriverGetVersion", (void *)&resp->ch.cuDriverGetVersion}, - {"cuDeviceGetCount", (void *)&resp->ch.cuDeviceGetCount}, - {"cuDeviceGet", (void *)&resp->ch.cuDeviceGet}, - {"cuDeviceGetAttribute", (void *)&resp->ch.cuDeviceGetAttribute}, - {"cuDeviceGetUuid", (void *)&resp->ch.cuDeviceGetUuid}, - {"cuDeviceGetName", (void *)&resp->ch.cuDeviceGetName}, - {"cuCtxCreate_v3", (void *)&resp->ch.cuCtxCreate_v3}, - {"cuMemGetInfo_v2", (void *)&resp->ch.cuMemGetInfo_v2}, - {"cuCtxDestroy", (void *)&resp->ch.cuCtxDestroy}, - {NULL, NULL}, - }; - - resp->ch.handle = LOAD_LIBRARY(nvcuda_lib_path, RTLD_LAZY); - if (!resp->ch.handle) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "library %s load err: %s\n", nvcuda_lib_path, msg); - snprintf(buf, buflen, - "Unable to load %s library to query for Nvidia GPUs: %s", - nvcuda_lib_path, msg); - free(msg); - resp->err = strdup(buf); - resp->cudaErr = -1; - return; - } - - for (i = 0; l[i].s != NULL; i++) { - *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); - if (!*(l[i].p)) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, - msg); - free(msg); - resp->err = strdup(buf); - resp->cudaErr = -1; - return; - } - LOG(resp->ch.verbose, "dlsym: %s - %p\n", l[i].s, *l[i].p); - } - - LOG(resp->ch.verbose, "calling cuInit\n"); - ret = (*resp->ch.cuInit)(0); - if (ret != CUDA_SUCCESS) { - LOG(resp->ch.verbose, "cuInit err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "cuda driver library init failure: %d", ret); - resp->err = strdup(buf); - resp->cudaErr = ret; - return; - } - - int version = 0; - resp->ch.driver_major = 0; - resp->ch.driver_minor = 0; - - // Report driver version if we're in verbose mode, ignore errors - LOG(resp->ch.verbose, "calling cuDriverGetVersion\n"); - ret = (*resp->ch.cuDriverGetVersion)(&version); - if (ret != CUDA_SUCCESS) { - LOG(resp->ch.verbose, "cuDriverGetVersion failed: %d\n", ret); - } else { - LOG(resp->ch.verbose, "raw version 0x%x\n", version); - resp->ch.driver_major = version / 1000; - resp->ch.driver_minor = (version - (resp->ch.driver_major * 1000)) / 10; - LOG(resp->ch.verbose, "CUDA driver version: %d.%d\n", resp->ch.driver_major, resp->ch.driver_minor); - } - - LOG(resp->ch.verbose, "calling cuDeviceGetCount\n"); - ret = (*resp->ch.cuDeviceGetCount)(&resp->num_devices); - if (ret != CUDA_SUCCESS) { - LOG(resp->ch.verbose, "cuDeviceGetCount err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "unable to get device count: %d", ret); - resp->err = strdup(buf); - resp->cudaErr = ret; - return; - } - LOG(resp->ch.verbose, "device count %d\n", resp->num_devices); -} - -const int buflen = 256; -void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) { - resp->err = NULL; - nvcudaMemory_t memInfo = {0,0}; - CUresult ret; - CUdevice device = -1; - CUcontext ctx = NULL; - char buf[buflen + 1]; - CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; - - if (h.handle == NULL) { - resp->err = strdup("cuda driver library handle isn't initialized"); - return; - } - - ret = (*h.cuDeviceGet)(&device, i); - if (ret != CUDA_SUCCESS) { - snprintf(buf, buflen, "cuda driver library device failed to initialize"); - resp->err = strdup(buf); - return; - } - - int major = 0; - int minor = 0; - ret = (*h.cuDeviceGetAttribute)(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device); - if (ret != CUDA_SUCCESS) { - LOG(h.verbose, "[%d] device major lookup failure: %d\n", i, ret); - } else { - ret = (*h.cuDeviceGetAttribute)(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); - if (ret != CUDA_SUCCESS) { - LOG(h.verbose, "[%d] device minor lookup failure: %d\n", i, ret); - } else { - resp->minor = minor; - resp->major = major; - } - } - - ret = (*h.cuDeviceGetUuid)(&uuid, device); - if (ret != CUDA_SUCCESS) { - LOG(h.verbose, "[%d] device uuid lookup failure: %d\n", i, ret); - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); - } else { - // GPU-d110a105-ac29-1d54-7b49-9c90440f215b - snprintf(&resp->gpu_id[0], GPU_ID_LEN, - "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", - uuid.bytes[0], - uuid.bytes[1], - uuid.bytes[2], - uuid.bytes[3], - uuid.bytes[4], - uuid.bytes[5], - uuid.bytes[6], - uuid.bytes[7], - uuid.bytes[8], - uuid.bytes[9], - uuid.bytes[10], - uuid.bytes[11], - uuid.bytes[12], - uuid.bytes[13], - uuid.bytes[14], - uuid.bytes[15] - ); - } - - ret = (*h.cuDeviceGetName)(&resp->gpu_name[0], GPU_NAME_LEN, device); - if (ret != CUDA_SUCCESS) { - LOG(h.verbose, "[%d] device name lookup failure: %d\n", i, ret); - resp->gpu_name[0] = '\0'; - } - - // To get memory we have to set (and release) a context - ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device); - if (ret != CUDA_SUCCESS) { - snprintf(buf, buflen, "cuda driver library failed to get device context %d", ret); - resp->err = strdup(buf); - return; - } - - ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total); - if (ret != CUDA_SUCCESS) { - snprintf(buf, buflen, "cuda driver library device memory info lookup failure %d", ret); - resp->err = strdup(buf); - // Best effort on failure... - (*h.cuCtxDestroy)(ctx); - return; - } - - resp->total = memInfo.total; - resp->free = memInfo.free; - - LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "mb\n", resp->gpu_id, resp->total / 1024 / 1024); - LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "mb\n", resp->gpu_id, resp->free / 1024 / 1024); - LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor); - - - - ret = (*h.cuCtxDestroy)(ctx); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library failed to release device context %d", ret); - } -} - -void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total) { - CUresult ret; - CUcontext ctx = NULL; - CUdevice device = -1; - *free = 0; - *total = 0; - - ret = (*h.cuDeviceGet)(&device, i); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library device failed to initialize"); - return; - } - - - // To get memory we have to set (and release) a context - ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library failed to get device context %d", ret); - return; - } - - ret = (*h.cuMemGetInfo_v2)(free, total); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library device memory info lookup failure %d", ret); - // Best effort on failure... - (*h.cuCtxDestroy)(ctx); - return; - } - - ret = (*h.cuCtxDestroy)(ctx); - if (ret != CUDA_SUCCESS) { - LOG(1, "cuda driver library failed to release device context %d", ret); - } -} - -void nvcuda_release(nvcuda_handle_t h) { - LOG(h.verbose, "releasing cuda driver library\n"); - UNLOAD_LIBRARY(h.handle); - // TODO and other context release logic? - h.handle = NULL; -} - -#endif // __APPLE__ diff --git a/discover/gpu_info_nvcuda.h b/discover/gpu_info_nvcuda.h deleted file mode 100644 index ef2fe8a3..00000000 --- a/discover/gpu_info_nvcuda.h +++ /dev/null @@ -1,79 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_NVCUDA_H__ -#define __GPU_INFO_NVCUDA_H__ -#include "gpu_info.h" - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum cudaError_enum { - CUDA_SUCCESS = 0, - CUDA_ERROR_INVALID_VALUE = 1, - CUDA_ERROR_OUT_OF_MEMORY = 2, - CUDA_ERROR_NOT_INITIALIZED = 3, - CUDA_ERROR_INSUFFICIENT_DRIVER = 35, - CUDA_ERROR_NO_DEVICE = 100, - CUDA_ERROR_SYSTEM_DRIVER_MISMATCH = 803, - CUDA_ERROR_UNKNOWN = 999, - // Other values omitted for now... -} CUresult; - -typedef enum CUdevice_attribute_enum { - CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75, - CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76, - - // TODO - not yet wired up but may be useful for Jetson or other - // integrated GPU scenarios with shared memory - CU_DEVICE_ATTRIBUTE_INTEGRATED = 18 - -} CUdevice_attribute; - -typedef void *nvcudaDevice_t; // Opaque is sufficient -typedef struct nvcudaMemory_st { - uint64_t total; - uint64_t free; -} nvcudaMemory_t; - -typedef struct nvcudaDriverVersion { - int major; - int minor; -} nvcudaDriverVersion_t; - -typedef struct CUuuid_st { - unsigned char bytes[16]; -} CUuuid; - -typedef int CUdevice; -typedef void* CUcontext; - -typedef struct nvcuda_handle { - void *handle; - uint16_t verbose; - int driver_major; - int driver_minor; - CUresult (*cuInit)(unsigned int Flags); - CUresult (*cuDriverGetVersion)(int *driverVersion); - CUresult (*cuDeviceGetCount)(int *); - CUresult (*cuDeviceGet)(CUdevice* device, int ordinal); - CUresult (*cuDeviceGetAttribute)(int* pi, CUdevice_attribute attrib, CUdevice dev); - CUresult (*cuDeviceGetUuid)(CUuuid* uuid, CUdevice dev); // signature compatible with cuDeviceGetUuid_v2 - CUresult (*cuDeviceGetName)(char *name, int len, CUdevice dev); - - // Context specific aspects - CUresult (*cuCtxCreate_v3)(CUcontext* pctx, void *params, int len, unsigned int flags, CUdevice dev); - CUresult (*cuMemGetInfo_v2)(uint64_t* free, uint64_t* total); - CUresult (*cuCtxDestroy)(CUcontext ctx); -} nvcuda_handle_t; - -typedef struct nvcuda_init_resp { - char *err; // If err is non-null handle is invalid - nvcuda_handle_t ch; - int num_devices; - CUresult cudaErr; -} nvcuda_init_resp_t; - -void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp); -void nvcuda_bootstrap(nvcuda_handle_t ch, int device_id, mem_info_t *resp); -void nvcuda_get_free(nvcuda_handle_t ch, int device_id, uint64_t *free, uint64_t *total); -void nvcuda_release(nvcuda_handle_t ch); - -#endif // __GPU_INFO_NVCUDA_H__ -#endif // __APPLE__ diff --git a/discover/gpu_info_nvml.c b/discover/gpu_info_nvml.c deleted file mode 100644 index 342a3aa4..00000000 --- a/discover/gpu_info_nvml.c +++ /dev/null @@ -1,104 +0,0 @@ -#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs? - -#include - -#include "gpu_info_nvml.h" - -void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) { - nvmlReturn_t ret; - resp->err = NULL; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - struct lookup { - char *s; - void **p; - } l[] = { - {"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2}, - {"nvmlShutdown", (void *)&resp->ch.nvmlShutdown}, - {"nvmlDeviceGetHandleByUUID", (void *)&resp->ch.nvmlDeviceGetHandleByUUID}, - {"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo}, - {NULL, NULL}, - }; - - resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY); - if (!resp->ch.handle) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg); - snprintf(buf, buflen, - "Unable to load %s library to query for Nvidia GPUs: %s", - nvml_lib_path, msg); - free(msg); - resp->err = strdup(buf); - return; - } - - // TODO once we've squashed the remaining corner cases remove this log - // LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path); - - for (i = 0; l[i].s != NULL; i++) { - // TODO once we've squashed the remaining corner cases remove this log - // LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s); - - *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); - if (!*(l[i].p)) { - resp->ch.handle = NULL; - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->ch.handle); - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, - msg); - free(msg); - resp->err = strdup(buf); - return; - } - } - - ret = (*resp->ch.nvmlInit_v2)(); - if (ret != NVML_SUCCESS) { - LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "nvml vram init failure: %d", ret); - resp->err = strdup(buf); - return; - } -} - - -void nvml_get_free(nvml_handle_t h, char *uuid, uint64_t *free, uint64_t *total, uint64_t *used) { - nvmlDevice_t device; - nvmlMemory_t memInfo = {0}; - nvmlReturn_t ret; - ret = (*h.nvmlDeviceGetHandleByUUID)((const char *)(uuid), &device); - if (ret != NVML_SUCCESS) { - LOG(1, "unable to get device handle %s: %d", uuid, ret); - *free = 0; - return; - } - - ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo); - if (ret != NVML_SUCCESS) { - LOG(1, "device memory info lookup failure %s: %d", uuid, ret); - *free = 0; - return; - } - *free = memInfo.free; - *total = memInfo.total; - *used = memInfo.used; -} - - -void nvml_release(nvml_handle_t h) { - LOG(h.verbose, "releasing nvml library\n"); - nvmlReturn_t ret; - ret = (*h.nvmlShutdown)(); - if (ret != NVML_SUCCESS) { - LOG(1, "error during nvmlShutdown %d", ret); - } - UNLOAD_LIBRARY(h.handle); - h.handle = NULL; -} - -#endif // __APPLE__ \ No newline at end of file diff --git a/discover/gpu_info_nvml.h b/discover/gpu_info_nvml.h deleted file mode 100644 index 90880233..00000000 --- a/discover/gpu_info_nvml.h +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_NVML_H__ -#define __GPU_INFO_NVML_H__ -#include "gpu_info.h" - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum nvmlReturn_enum { - NVML_SUCCESS = 0, - // Other values omitted for now... -} nvmlReturn_t; -typedef void *nvmlDevice_t; // Opaque is sufficient -typedef struct nvmlMemory_st { - unsigned long long total; - unsigned long long free; - unsigned long long used; -} nvmlMemory_t; - -typedef enum nvmlBrandType_enum -{ - NVML_BRAND_UNKNOWN = 0, -} nvmlBrandType_t; - -typedef struct nvml_handle { - void *handle; - uint16_t verbose; - nvmlReturn_t (*nvmlInit_v2)(void); - nvmlReturn_t (*nvmlShutdown)(void); - nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *); - nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *); -} nvml_handle_t; - -typedef struct nvml_init_resp { - char *err; // If err is non-null handle is invalid - nvml_handle_t ch; -} nvml_init_resp_t; - -typedef struct nvml_compute_capability { - char *err; - int major; - int minor; -} nvml_compute_capability_t; - -void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp); -void nvml_get_free(nvml_handle_t ch, char *uuid, uint64_t *free, uint64_t *total, uint64_t *used); -void nvml_release(nvml_handle_t ch); - -#endif // __GPU_INFO_NVML_H__ -#endif // __APPLE__ \ No newline at end of file diff --git a/discover/gpu_info_oneapi.c b/discover/gpu_info_oneapi.c deleted file mode 100644 index 3ff708ea..00000000 --- a/discover/gpu_info_oneapi.c +++ /dev/null @@ -1,259 +0,0 @@ -#ifndef __APPLE__ - -#include "gpu_info_oneapi.h" - -#include - -void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) { - ze_result_t ret; - resp->err = NULL; - resp->oh.devices = NULL; - resp->oh.num_devices = NULL; - resp->oh.drivers = NULL; - resp->oh.num_drivers = 0; - const int buflen = 256; - char buf[buflen + 1]; - int i, d; - struct lookup { - char *s; - void **p; - } l[] = { - {"zesInit", (void *)&resp->oh.zesInit}, - {"zesDriverGet", (void *)&resp->oh.zesDriverGet}, - {"zesDeviceGet", (void *)&resp->oh.zesDeviceGet}, - {"zesDeviceGetProperties", (void *)&resp->oh.zesDeviceGetProperties}, - {"zesDeviceEnumMemoryModules", - (void *)&resp->oh.zesDeviceEnumMemoryModules}, - {"zesMemoryGetProperties", (void *)&resp->oh.zesMemoryGetProperties}, - {"zesMemoryGetState", (void *)&resp->oh.zesMemoryGetState}, - {NULL, NULL}, - }; - - resp->oh.handle = LOAD_LIBRARY(oneapi_lib_path, RTLD_LAZY); - if (!resp->oh.handle) { - char *msg = LOAD_ERR(); - snprintf(buf, buflen, - "Unable to load %s library to query for Intel GPUs: %s\n", - oneapi_lib_path, msg); - free(msg); - resp->err = strdup(buf); - return; - } - - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->oh.verbose, - "wiring Level-Zero management library functions in %s\n", - oneapi_lib_path); - - for (i = 0; l[i].s != NULL; i++) { - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s); - - *l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s); - if (!*(l[i].p)) { - resp->oh.handle = NULL; - char *msg = LOAD_ERR(); - LOG(resp->oh.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->oh.handle); - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, msg); - free(msg); - resp->err = strdup(buf); - return; - } - } - - LOG(resp->oh.verbose, "calling zesInit\n"); - - ret = (*resp->oh.zesInit)(0); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesInit err: %x\n", ret); - snprintf(buf, buflen, "oneapi vram init failure: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - - LOG(resp->oh.verbose, "calling zesDriverGet\n"); - ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, NULL); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret); - snprintf(buf, buflen, "unable to get driver count: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - LOG(resp->oh.verbose, "oneapi driver count: %d\n", resp->oh.num_drivers); - resp->oh.drivers = malloc(resp->oh.num_drivers * sizeof(zes_driver_handle_t)); - resp->oh.num_devices = malloc(resp->oh.num_drivers * sizeof(uint32_t)); - memset(&resp->oh.num_devices[0], 0, resp->oh.num_drivers * sizeof(uint32_t)); - resp->oh.devices = - malloc(resp->oh.num_drivers * sizeof(zes_device_handle_t *)); - ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, &resp->oh.drivers[0]); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret); - snprintf(buf, buflen, "unable to get driver count: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - - for (d = 0; d < resp->oh.num_drivers; d++) { - LOG(resp->oh.verbose, "calling zesDeviceGet count %d: %p\n", d, resp->oh.drivers[d]); - ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d], - &resp->oh.num_devices[d], NULL); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret); - snprintf(buf, buflen, "unable to get device count: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - resp->oh.devices[d] = - malloc(resp->oh.num_devices[d] * sizeof(zes_device_handle_t)); - ret = (*resp->oh.zesDeviceGet)( - resp->oh.drivers[d], &resp->oh.num_devices[d], resp->oh.devices[d]); - if (ret != ZE_RESULT_SUCCESS) { - LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret); - snprintf(buf, buflen, "unable to get device count: %x", ret); - resp->err = strdup(buf); - oneapi_release(resp->oh); - return; - } - } - - return; -} - -void oneapi_check_vram(oneapi_handle_t h, int driver, int device, - mem_info_t *resp) { - ze_result_t ret; - resp->err = NULL; - uint64_t totalMem = 0; - uint64_t usedMem = 0; - const int buflen = 256; - char buf[buflen + 1]; - int i, d, m; - - if (h.handle == NULL) { - resp->err = strdup("Level-Zero handle not initialized"); - return; - } - - if (driver > h.num_drivers || device > h.num_devices[driver]) { - resp->err = strdup("driver of device index out of bounds"); - return; - } - - resp->total = 0; - resp->free = 0; - - zes_device_ext_properties_t ext_props; - ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES; - ext_props.pNext = NULL; - - zes_device_properties_t props; - props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES; - props.pNext = &ext_props; - - ret = (*h.zesDeviceGetProperties)(h.devices[driver][device], &props); - if (ret != ZE_RESULT_SUCCESS) { - snprintf(buf, buflen, "unable to get device properties: %d", ret); - resp->err = strdup(buf); - return; - } - - snprintf(&resp->gpu_name[0], GPU_NAME_LEN, "%s", props.modelName); - - // TODO this needs to map to ONEAPI_DEVICE_SELECTOR syntax - // (this is probably wrong...) - // TODO - the driver isn't included - what if there are multiple drivers? - snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", device); - - if (h.verbose) { - // When in verbose mode, report more information about - // the card we discover. - LOG(h.verbose, "[%d:%d] oneAPI device name: %s\n", driver, device, - props.modelName); - LOG(h.verbose, "[%d:%d] oneAPI brand: %s\n", driver, device, - props.brandName); - LOG(h.verbose, "[%d:%d] oneAPI vendor: %s\n", driver, device, - props.vendorName); - LOG(h.verbose, "[%d:%d] oneAPI S/N: %s\n", driver, device, - props.serialNumber); - LOG(h.verbose, "[%d:%d] oneAPI board number: %s\n", driver, device, - props.boardNumber); - } - - // TODO - // Compute Capability equivalent in resp->major, resp->minor, resp->patch - - uint32_t memCount = 0; - ret = (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, - NULL); - if (ret != ZE_RESULT_SUCCESS) { - snprintf(buf, buflen, "unable to enumerate Level-Zero memory modules: %x", - ret); - resp->err = strdup(buf); - return; - } - - LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount); - - zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t)); - (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, mems); - - for (m = 0; m < memCount; m++) { - zes_mem_state_t state; - state.stype = ZES_STRUCTURE_TYPE_MEM_STATE; - state.pNext = NULL; - ret = (*h.zesMemoryGetState)(mems[m], &state); - if (ret != ZE_RESULT_SUCCESS) { - snprintf(buf, buflen, "unable to get memory state: %x", ret); - resp->err = strdup(buf); - free(mems); - return; - } - - resp->total += state.size; - resp->free += state.free; - } - - free(mems); -} - -void oneapi_release(oneapi_handle_t h) { - int d; - LOG(h.verbose, "releasing oneapi library\n"); - for (d = 0; d < h.num_drivers; d++) { - if (h.devices != NULL && h.devices[d] != NULL) { - free(h.devices[d]); - } - } - if (h.devices != NULL) { - free(h.devices); - h.devices = NULL; - } - if (h.num_devices != NULL) { - free(h.num_devices); - h.num_devices = NULL; - } - if (h.drivers != NULL) { - free(h.drivers); - h.drivers = NULL; - } - h.num_drivers = 0; - UNLOAD_LIBRARY(h.handle); - h.handle = NULL; -} - -int oneapi_get_device_count(oneapi_handle_t h, int driver) { - if (h.handle == NULL || h.num_devices == NULL) { - return 0; - } - if (driver > h.num_drivers) { - return 0; - } - return (int)h.num_devices[driver]; -} - -#endif // __APPLE__ diff --git a/discover/gpu_info_oneapi.h b/discover/gpu_info_oneapi.h deleted file mode 100644 index 97fcecd9..00000000 --- a/discover/gpu_info_oneapi.h +++ /dev/null @@ -1,203 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_ONEAPI_H__ -#define __GPU_INFO_ONEAPI_H__ -#include "gpu_info.h" - -#define ZE_MAX_DEVICE_NAME 256 -#define ZE_MAX_DEVICE_UUID_SIZE 16 -#define ZES_STRING_PROPERTY_SIZE 64 -#define ZE_BIT(_i) (1 << _i) - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum ze_result_t { - ZE_RESULT_SUCCESS = 0, - // Other values omitted for now... -} ze_result_t; - -typedef uint8_t ze_bool_t; -typedef struct _zes_driver_handle_t *zes_driver_handle_t; -typedef struct _zes_device_handle_t *zes_device_handle_t; -typedef struct _zes_mem_handle_t *zes_mem_handle_t; - -typedef enum _ze_structure_type_t { - ZE_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff -} ze_structure_type_t; - -typedef enum _zes_structure_type_t { - ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES = 0x1, - ZES_STRUCTURE_TYPE_MEM_PROPERTIES = 0xb, - ZES_STRUCTURE_TYPE_MEM_STATE = 0x1e, - ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES = 0x2d, - ZES_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff -} zes_structure_type_t; - -typedef enum _zes_mem_type_t { - ZES_MEM_TYPE_FORCE_UINT32 = 0x7fffffff -} zes_mem_type_t; - -typedef enum _zes_mem_loc_t { - ZES_MEM_LOC_SYSTEM = 0, - ZES_MEM_LOC_DEVICE = 1, - ZES_MEM_LOC_FORCE_UINT32 = 0x7fffffff -} zes_mem_loc_t; - -typedef enum _zes_mem_health_t { - ZES_MEM_HEALTH_FORCE_UINT32 = 0x7fffffff -} zes_mem_health_t; - -typedef struct _ze_device_uuid_t { - uint8_t id[ZE_MAX_DEVICE_UUID_SIZE]; -} ze_device_uuid_t; - -typedef struct _zes_uuid_t { - uint8_t id[ZE_MAX_DEVICE_UUID_SIZE]; -} zes_uuid_t; - -typedef enum _ze_device_type_t { - ZE_DEVICE_TYPE_GPU = 1, - ZE_DEVICE_TYPE_CPU = 2, - ZE_DEVICE_TYPE_FPGA = 3, - ZE_DEVICE_TYPE_MCA = 4, - ZE_DEVICE_TYPE_VPU = 5, - ZE_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff -} ze_device_type_t; - -typedef enum _zes_device_type_t { - ZES_DEVICE_TYPE_GPU = 1, - ZES_DEVICE_TYPE_CPU = 2, - ZES_DEVICE_TYPE_FPGA = 3, - ZES_DEVICE_TYPE_MCA = 4, - ZES_DEVICE_TYPE_VPU = 5, - ZES_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff -} zes_device_type_t; - -typedef uint32_t ze_device_property_flags_t; -typedef enum _ze_device_property_flag_t { - ZE_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0), - ZE_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1), - ZE_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2), - ZE_DEVICE_PROPERTY_FLAG_ONDEMANDPAGING = ZE_BIT(3), - ZE_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff -} ze_device_property_flag_t; - -typedef uint32_t zes_device_property_flags_t; -typedef enum _zes_device_property_flag_t { - ZES_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0), - ZES_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1), - ZES_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2), - ZES_DEVICE_PROPERTY_FLAG_ONDEMANDPAGING = ZE_BIT(3), - ZES_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff -} zes_device_property_flag_t; - -typedef struct _ze_device_properties_t { - ze_structure_type_t stype; - void *pNext; - ze_device_type_t type; - uint32_t vendorId; - uint32_t deviceId; - ze_device_property_flags_t flags; - uint32_t subdeviceId; - uint32_t coreClockRate; - uint64_t maxMemAllocSize; - uint32_t maxHardwareContexts; - uint32_t maxCommandQueuePriority; - uint32_t numThreadsPerEU; - uint32_t physicalEUSimdWidth; - uint32_t numEUsPerSubslice; - uint32_t numSubslicesPerSlice; - uint32_t numSlices; - uint64_t timerResolution; - uint32_t timestampValidBits; - uint32_t kernelTimestampValidBits; - ze_device_uuid_t uuid; - char name[ZE_MAX_DEVICE_NAME]; -} ze_device_properties_t; - -typedef struct _zes_device_properties_t { - zes_structure_type_t stype; - void *pNext; - ze_device_properties_t core; - uint32_t numSubdevices; - char serialNumber[ZES_STRING_PROPERTY_SIZE]; - char boardNumber[ZES_STRING_PROPERTY_SIZE]; - char brandName[ZES_STRING_PROPERTY_SIZE]; - char modelName[ZES_STRING_PROPERTY_SIZE]; - char vendorName[ZES_STRING_PROPERTY_SIZE]; - char driverVersion[ZES_STRING_PROPERTY_SIZE]; -} zes_device_properties_t; - -typedef struct _zes_device_ext_properties_t { - zes_structure_type_t stype; - void *pNext; - zes_uuid_t uuid; - zes_device_type_t type; - zes_device_property_flags_t flags; -} zes_device_ext_properties_t; - -typedef struct _zes_mem_properties_t { - zes_structure_type_t stype; - void *pNext; - zes_mem_type_t type; - ze_bool_t onSubdevice; - uint32_t subdeviceId; - zes_mem_loc_t location; - uint64_t physicalSize; - int32_t busWidth; - int32_t numChannels; -} zes_mem_properties_t; - -typedef struct _zes_mem_state_t { - zes_structure_type_t stype; - const void *pNext; - zes_mem_health_t health; - uint64_t free; - uint64_t size; -} zes_mem_state_t; - -typedef struct oneapi_handle { - void *handle; - uint16_t verbose; - - uint32_t num_drivers; - zes_driver_handle_t *drivers; - uint32_t *num_devices; - zes_device_handle_t **devices; - - // TODO Driver major, minor information - // int driver_major; - // int driver_minor; - - ze_result_t (*zesInit)(int); - ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers); - ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount, - zes_device_handle_t *phDevices); - ze_result_t (*zesDeviceGetProperties)(zes_device_handle_t hDevice, - zes_device_properties_t *pProperties); - ze_result_t (*zesDeviceEnumMemoryModules)(zes_device_handle_t hDevice, - uint32_t *pCount, - zes_mem_handle_t *phMemory); - ze_result_t (*zesMemoryGetProperties)(zes_mem_handle_t hMemory, - zes_mem_properties_t *pProperties); - ze_result_t (*zesMemoryGetState)(zes_mem_handle_t hMemory, - zes_mem_state_t *pState); - -} oneapi_handle_t; - -typedef struct oneapi_init_resp { - char *err; // If err is non-null handle is invalid - oneapi_handle_t oh; -} oneapi_init_resp_t; - -typedef struct oneapi_version_resp { - ze_result_t status; - char *str; // Contains version or error string if status != 0 -} oneapi_version_resp_t; - -void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp); -void oneapi_check_vram(oneapi_handle_t h, int driver, int device, - mem_info_t *resp); -void oneapi_release(oneapi_handle_t h); -int oneapi_get_device_count(oneapi_handle_t h, int driver); - -#endif // __GPU_INFO_INTEL_H__ -#endif // __APPLE__ diff --git a/discover/gpu_test.go b/discover/gpu_test.go deleted file mode 100644 index 0c6ef7ba..00000000 --- a/discover/gpu_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package discover - -import ( - "runtime" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestBasicGetGPUInfo(t *testing.T) { - info := GetGPUInfo() - assert.NotEmpty(t, len(info)) - assert.Contains(t, "cuda rocm cpu metal", info[0].Library) - if info[0].Library != "cpu" { - assert.Greater(t, info[0].TotalMemory, uint64(0)) - assert.Greater(t, info[0].FreeMemory, uint64(0)) - } -} - -func TestCPUMemInfo(t *testing.T) { - info, err := GetCPUMem() - require.NoError(t, err) - switch runtime.GOOS { - case "darwin": - t.Skip("CPU memory not populated on darwin") - case "linux", "windows": - assert.Greater(t, info.TotalMemory, uint64(0)) - assert.Greater(t, info.FreeMemory, uint64(0)) - default: - return - } -} - -func TestByLibrary(t *testing.T) { - type testCase struct { - input []GpuInfo - expect int - } - - testCases := map[string]*testCase{ - "empty": {input: []GpuInfo{}, expect: 0}, - "cpu": {input: []GpuInfo{{Library: "cpu"}}, expect: 1}, - "cpu + GPU": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}}, expect: 2}, - "cpu + 2 GPU no variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}, {Library: "cuda"}}, expect: 2}, - "cpu + 2 GPU same variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v11"}}, expect: 2}, - "cpu + 2 GPU diff variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v12"}}, expect: 3}, - } - - for k, v := range testCases { - t.Run(k, func(t *testing.T) { - resp := (GpuInfoList)(v.input).ByLibrary() - if len(resp) != v.expect { - t.Fatalf("expected length %d, got %d => %+v", v.expect, len(resp), resp) - } - }) - } -} - -// TODO - add some logic to figure out card type through other means and actually verify we got back what we expected diff --git a/discover/runner.go b/discover/runner.go new file mode 100644 index 00000000..9da24675 --- /dev/null +++ b/discover/runner.go @@ -0,0 +1,600 @@ +package discover + +// Runner based GPU discovery + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "math/rand" + "net" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/format" + "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/ml" +) + +var ( + deviceMu sync.Mutex + devices []ml.DeviceInfo + libDirs map[string]struct{} + rocmDir string + exe string + bootstrapped bool +) + +func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.DeviceInfo { + deviceMu.Lock() + defer deviceMu.Unlock() + startDiscovery := time.Now() + msg := "overall device VRAM discovery took" + defer func() { + slog.Debug(msg, "duration", time.Since(startDiscovery)) + }() + + if !bootstrapped { + msg = "GPU bootstrap discovery took" + libDirs = make(map[string]struct{}) + var err error + exe, err = os.Executable() + if err != nil { + slog.Error("unable to lookup executable path", "error", err) + return nil + } + if eval, err := filepath.EvalSymlinks(exe); err == nil { + exe = eval + } + files, err := filepath.Glob(filepath.Join(LibOllamaPath, "*", "*ggml-*")) + if err != nil { + slog.Debug("unable to lookup runner library directories", "error", err) + } + for _, file := range files { + libDirs[filepath.Dir(file)] = struct{}{} + } + + // Our current packaging model places ggml-hip in the main directory + // but keeps rocm in an isolated directory. We have to add it to + // the [LD_LIBRARY_]PATH so ggml-hip will load properly + rocmDir = filepath.Join(LibOllamaPath, "rocm") + if _, err := os.Stat(rocmDir); err != nil { + rocmDir = "" + } + + if len(libDirs) == 0 { + libDirs[""] = struct{}{} + } + + slog.Info("discovering available GPUs...") + requested := envconfig.LLMLibrary() + jetpack := cudaJetpack() + + // For our initial discovery pass, we gather all the known GPUs through + // all the libraries that were detected. This pass may include GPUs that + // are enumerated, but not actually supported. + // We run this in serial to avoid potentially initializing a GPU multiple + // times concurrently leading to memory contention + // TODO refactor so we group the lib dirs and do serial per version, but parallel for different libs + for dir := range libDirs { + var dirs []string + if dir != "" { + if requested != "" && filepath.Base(dir) != requested { + slog.Debug("skipping available library at users request", "requested", requested, "libDir", dir) + continue + } else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack { + continue + } + } + if dir == "" { + dirs = []string{LibOllamaPath} + } else { + dirs = []string{LibOllamaPath, dir} + } + // Typically bootstrapping takes < 1s, but on some systems, with devices + // in low power/idle mode, initialization can take multiple seconds. We + // set a long timeout just for bootstrap discovery to reduce the chance + // of giving up too quickly + ctx1stPass, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + // For this pass, we retain duplicates in case any are incompatible with some libraries + devices = append(devices, bootstrapDevices(ctx1stPass, dirs, nil)...) + } + + // In the second pass, we more deeply initialize the GPUs to weed out devices that + // aren't supported by a given library. We run this phase in parallel to speed up discovery. + slog.Debug("filtering out unsupported or overlapping GPU library combinations", "count", len(devices)) + ctx2ndPass, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + var wg sync.WaitGroup + needsDelete := make([]bool, len(devices)) + supportedMu := sync.Mutex{} + supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index + for i := range devices { + libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1] + if devices[i].Library == "Metal" { + continue + } + slog.Debug("verifying GPU is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "pci_id", devices[i].PCIID) + wg.Add(1) + go func(i int) { + defer wg.Done() + var envVar string + id := devices[i].ID + if devices[i].Library == "ROCm" { + if runtime.GOOS != "linux" { + envVar = "HIP_VISIBLE_DEVICES" + } else { + envVar = "ROCR_VISIBLE_DEVICES" + } + } else if devices[i].Library == "CUDA" { + envVar = "CUDA_VISIBLE_DEVICES" + } else if devices[i].Library == "Vulkan" { + id = devices[i].FilteredID + envVar = "GGML_VK_VISIBLE_DEVICES" + } else { + slog.Error("Unknown Library:" + devices[i].Library) + } + + extraEnvs := []string{ + "GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs + envVar + "=" + id, // Filter to just this one GPU + } + if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 { + needsDelete[i] = true + } else { + supportedMu.Lock() + if _, ok := supported[devices[i].Library]; !ok { + supported[devices[i].Library] = make(map[string]map[string]int) + } + if _, ok := supported[devices[i].Library][libDir]; !ok { + supported[devices[i].Library][libDir] = make(map[string]int) + } + supported[devices[i].Library][libDir][devices[i].ID] = i + supportedMu.Unlock() + } + }(i) + } + wg.Wait() + logutil.Trace("supported GPU library combinations", "supported", supported) + + filterOutVulkanThatAreSupportedByOtherGPU(needsDelete) + + // Mark for deletion any overlaps - favoring the library version that can cover all GPUs if possible + filterOverlapByLibrary(supported, needsDelete) + + // TODO if we ever support multiple ROCm library versions this algorithm will need to be adjusted to keep the rocmID numeric value correct + rocmID := 0 + for i := 0; i < len(needsDelete); i++ { + if needsDelete[i] { + logutil.Trace("removing unsupported or overlapping GPU combination", "libDir", devices[i].LibraryPath[len(devices[i].LibraryPath)-1], "description", devices[i].Description, "compute", devices[i].Compute(), "pci_id", devices[i].PCIID) + devices = append(devices[:i], devices[i+1:]...) + needsDelete = append(needsDelete[:i], needsDelete[i+1:]...) + i-- + } else if devices[i].Library == "ROCm" { + if _, err := strconv.Atoi(devices[i].ID); err == nil { + // Replace the numeric ID with the post-filtered IDs + devices[i].FilteredID = devices[i].ID + devices[i].ID = strconv.Itoa(rocmID) + } + rocmID++ + } + } + + // Now filter out any overlap with different libraries (favor CUDA/HIP over others) + for i := 0; i < len(devices); i++ { + for j := i + 1; j < len(devices); j++ { + // For this pass, we only drop exact duplicates + switch devices[i].Compare(devices[j]) { + case ml.SameBackendDevice: + // Same library and device, skip it + devices = append(devices[:j], devices[j+1:]...) + j-- + continue + case ml.DuplicateDevice: + // Different library, choose based on priority + var droppedDevice ml.DeviceInfo + if devices[i].Library == "CUDA" || devices[i].Library == "ROCm" { + droppedDevice = devices[j] + } else { + droppedDevice = devices[i] + devices[i] = devices[j] + } + devices = append(devices[:j], devices[j+1:]...) + j-- + + typeStr := "discrete" + if droppedDevice.Integrated { + typeStr = "iGPU" + } + slog.Debug("dropping duplicate device", + "id", droppedDevice.ID, + "library", droppedDevice.Library, + "compute", droppedDevice.Compute(), + "name", droppedDevice.Name, + "description", droppedDevice.Description, + "libdirs", strings.Join(droppedDevice.LibraryPath, ","), + "driver", droppedDevice.Driver(), + "pci_id", droppedDevice.PCIID, + "type", typeStr, + "total", format.HumanBytes2(droppedDevice.TotalMemory), + "available", format.HumanBytes2(droppedDevice.FreeMemory), + ) + continue + } + } + } + + // Reset the libDirs to what we actually wind up using for future refreshes + libDirs = make(map[string]struct{}) + for _, dev := range devices { + dir := dev.LibraryPath[len(dev.LibraryPath)-1] + if dir != LibOllamaPath { + libDirs[dir] = struct{}{} + } + } + if len(libDirs) == 0 { + libDirs[""] = struct{}{} + } + + bootstrapped = true + } else { + if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { + // metal never updates free VRAM + return devices + } + + slog.Debug("refreshing free memory") + updated := make([]bool, len(devices)) + allDone := func() bool { + allDone := true + for _, done := range updated { + if !done { + allDone = false + break + } + } + return allDone + } + + // First try to use existing runners to refresh VRAM since they're already + // active on GPU(s) + for _, runner := range runners { + if runner == nil { + continue + } + deviceIDs := runner.GetActiveDeviceIDs() + if len(deviceIDs) == 0 { + // Skip this runner since it doesn't have active GPU devices + continue + } + + // Check to see if this runner is active on any devices that need a refresh + skip := true + devCheck: + for _, dev := range deviceIDs { + for i := range devices { + if dev == devices[i].DeviceID { + if !updated[i] { + skip = false + break devCheck + } + } + } + } + if skip { + continue + } + + // Typical refresh on existing runner is ~500ms but allow longer if the system + // is under stress before giving up and using stale data. + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + start := time.Now() + updatedDevices := runner.GetDeviceInfos(ctx) + slog.Debug("existing runner discovery took", "duration", time.Since(start)) + for _, u := range updatedDevices { + for i := range devices { + if u.DeviceID == devices[i].DeviceID { + updated[i] = true + devices[i].FreeMemory = u.FreeMemory + break + } + } + } + // Short circuit if we've updated all the devices + if allDone() { + break + } + } + if !allDone() { + slog.Debug("unable to refresh all GPUs with existing runners, performing bootstrap discovery") + + // Bootstrapping may take longer in some cases (AMD windows), but we + // would rather use stale free data to get the model running sooner + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + for dir := range libDirs { + updatedDevices := bootstrapDevices(ctx, []string{LibOllamaPath, dir}, nil) + for _, u := range updatedDevices { + for i := range devices { + if u.DeviceID == devices[i].DeviceID { + updated[i] = true + devices[i].FreeMemory = u.FreeMemory + break + } + } + // TODO - consider evaluating if new devices have appeared (e.g. hotplug) + } + if allDone() { + break + } + } + if !allDone() { + slog.Warn("unable to refresh free memory, using old values") + } + } + } + + return devices +} + +func filterOutVulkanThatAreSupportedByOtherGPU(needsDelete []bool) { + // Filter out Vulkan devices that share a PCI ID with a non-Vulkan device that is not marked for deletion + for i := range devices { + if devices[i].Library != "Vulkan" || needsDelete[i] { + continue + } + if devices[i].PCIID == "" { + continue + } + for j := range devices { + if i == j { + continue + } + if devices[j].PCIID == "" { + continue + } + if devices[j].PCIID == devices[i].PCIID && devices[j].Library != "Vulkan" && !needsDelete[j] { + needsDelete[i] = true + slog.Debug("dropping Vulkan duplicate by PCI ID", + "vulkan_id", devices[i].ID, + "vulkan_libdir", devices[i].LibraryPath[len(devices[i].LibraryPath)-1], + "pci_id", devices[i].PCIID, + "kept_library", devices[j].Library, + "kept_id", devices[j].ID, + ) + break + } + } + } +} + +func filterOverlapByLibrary(supported map[string]map[string]map[string]int, needsDelete []bool) { + // For multi-GPU systems, use the newest version that supports all the GPUs + for _, byLibDirs := range supported { + libDirs := make([]string, 0, len(byLibDirs)) + for libDir := range byLibDirs { + libDirs = append(libDirs, libDir) + } + sort.Sort(sort.Reverse(sort.StringSlice(libDirs))) + anyMissing := false + var newest string + for _, newest = range libDirs { + for _, libDir := range libDirs { + if libDir == newest { + continue + } + if len(byLibDirs[newest]) != len(byLibDirs[libDir]) { + anyMissing = true + break + } + for dev := range byLibDirs[newest] { + if _, found := byLibDirs[libDir][dev]; !found { + anyMissing = true + break + } + } + } + if !anyMissing { + break + } + } + // Now we can mark overlaps for deletion + for _, libDir := range libDirs { + if libDir == newest { + continue + } + for dev, i := range byLibDirs[libDir] { + if _, found := byLibDirs[newest][dev]; found { + needsDelete[i] = true + } + } + } + } +} + +type bootstrapRunner struct { + port int + cmd *exec.Cmd +} + +func (r *bootstrapRunner) GetPort() int { + return r.port +} + +func (r *bootstrapRunner) HasExited() bool { + if r.cmd != nil && r.cmd.ProcessState != nil { + return true + } + return false +} + +func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []string) []ml.DeviceInfo { + // TODO DRY out with llm/server.go + slog.Debug("spawning runner with", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs) + start := time.Now() + defer func() { + slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs) + }() + port := 0 + if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + var l *net.TCPListener + if l, err = net.ListenTCP("tcp", a); err == nil { + port = l.Addr().(*net.TCPAddr).Port + l.Close() + } + } + if port == 0 { + slog.Debug("ResolveTCPAddr failed, using random port") + port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range + } + params := []string{"runner", "--ollama-engine", "--port", strconv.Itoa(port)} + var pathEnv string + switch runtime.GOOS { + case "windows": + pathEnv = "PATH" + case "darwin": + pathEnv = "DYLD_LIBRARY_PATH" + default: + pathEnv = "LD_LIBRARY_PATH" + } + libraryPaths := append([]string{LibOllamaPath}, ollamaLibDirs...) + if rocmDir != "" { + libraryPaths = append(libraryPaths, rocmDir) + } + // Note: we always put our dependency paths first + // since these are the exact version we compiled/linked against + if libraryPath, ok := os.LookupEnv(pathEnv); ok { + libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...) + } + + cmd := exec.Command(exe, params...) + cmd.Env = os.Environ() + if envconfig.LogLevel() == logutil.LevelTrace { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + } + + // cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored + pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) + pathNeeded := true + ollamaPathNeeded := true + extraDone := make([]bool, len(extraEnvs)) + for i := range cmd.Env { + cmp := strings.SplitN(cmd.Env[i], "=", 2) + if strings.EqualFold(cmp[0], pathEnv) { + cmd.Env[i] = pathEnv + "=" + pathEnvVal + pathNeeded = false + } else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") { + cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(ollamaLibDirs, string(filepath.ListSeparator)) + ollamaPathNeeded = false + } else { + for j := range extraEnvs { + if extraDone[j] { + continue + } + extra := strings.SplitN(extraEnvs[j], "=", 2) + if cmp[0] == extra[0] { + cmd.Env[i] = extraEnvs[j] + extraDone[j] = true + } + } + } + } + if pathNeeded { + cmd.Env = append(cmd.Env, pathEnv+"="+pathEnvVal) + } + if ollamaPathNeeded { + cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator))) + } + for i := range extraDone { + if !extraDone[i] { + cmd.Env = append(cmd.Env, extraEnvs[i]) + } + } + logutil.Trace("starting runner for device discovery", "env", cmd.Env, "cmd", cmd) + if err := cmd.Start(); err != nil { + slog.Warn("unable to start discovery subprocess", "cmd", cmd, "error", err) + return nil + } + go func() { + cmd.Wait() // exit status ignored + }() + + defer cmd.Process.Kill() + devices, err := GetDevicesFromRunner(ctx, &bootstrapRunner{port: port, cmd: cmd}) + if err != nil { + if cmd.ProcessState != nil && cmd.ProcessState.ExitCode() >= 0 { + // Expected during bootstrapping while we filter out unsupported AMD GPUs + logutil.Trace("runner exited", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "code", cmd.ProcessState.ExitCode()) + } else { + slog.Info("failure during GPU discovery", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs, "error", err) + } + } + logutil.Trace("runner enumerated devices", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "devices", devices) + + return devices +} + +func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]ml.DeviceInfo, error) { + var moreDevices []ml.DeviceInfo + port := runner.GetPort() + tick := time.Tick(10 * time.Millisecond) + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("failed to finish discovery before timeout") + case <-tick: + r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + r.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(r) + if err != nil { + // slog.Warn("failed to send request", "error", err) + if runner.HasExited() { + return nil, fmt.Errorf("runner crashed") + } + continue + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + // old runner, fall back to bootstrapping model + return nil, fmt.Errorf("llamarunner free vram reporting not supported") + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + slog.Warn("failed to read response", "error", err) + continue + } + if resp.StatusCode != 200 { + logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body) + return nil, fmt.Errorf("runner error: %s", string(body)) + } + + if err := json.Unmarshal(body, &moreDevices); err != nil { + slog.Warn("unmarshal encode response", "error", err) + continue + } + return moreDevices, nil + } + } +} diff --git a/discover/runner_test.go b/discover/runner_test.go new file mode 100644 index 00000000..9ea19046 --- /dev/null +++ b/discover/runner_test.go @@ -0,0 +1,108 @@ +package discover + +import ( + "testing" + + "github.com/ollama/ollama/app/lifecycle" +) + +func init() { + lifecycle.InitLogging() +} + +func TestFilterOverlapByLibrary(t *testing.T) { + type testcase struct { + name string + inp map[string]map[string]map[string]int + exp []bool + } + for _, tc := range []testcase{ + { + name: "empty", + inp: map[string]map[string]map[string]int{}, + exp: []bool{}, // needs deletion + }, + { + name: "single no overlap", + inp: map[string]map[string]map[string]int{ + "CUDA": { + "cuda_v12": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0, + }, + }, + }, + exp: []bool{false}, + }, + { + name: "100% overlap pick 2nd", + inp: map[string]map[string]map[string]int{ + "CUDA": { + "cuda_v12": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0, + "GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 1, + }, + "cuda_v13": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 2, + "GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 3, + }, + }, + }, + exp: []bool{true, true, false, false}, + }, + { + name: "100% overlap pick 1st", + inp: map[string]map[string]map[string]int{ + "CUDA": { + "cuda_v13": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0, + "GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 1, + }, + "cuda_v12": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 2, + "GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 3, + }, + }, + }, + exp: []bool{false, false, true, true}, + }, + { + name: "partial overlap pick older", + inp: map[string]map[string]map[string]int{ + "CUDA": { + "cuda_v13": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0, + }, + "cuda_v12": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 1, + "GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 2, + }, + }, + }, + exp: []bool{true, false, false}, + }, + { + name: "no overlap", + inp: map[string]map[string]map[string]int{ + "CUDA": { + "cuda_v13": { + "GPU-d7b00605-c0c8-152d-529d-e03726d5dc52": 0, + }, + "cuda_v12": { + "GPU-cd6c3216-03d2-a8eb-8235-2ffbf571712e": 1, + }, + }, + }, + exp: []bool{false, false}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + needsDelete := make([]bool, len(tc.exp)) + filterOverlapByLibrary(tc.inp, needsDelete) + for i, exp := range tc.exp { + if needsDelete[i] != exp { + t.Fatalf("expected: %v\ngot: %v", tc.exp, needsDelete) + } + } + }) + } +} diff --git a/discover/types.go b/discover/types.go index 1027aaac..adb2f43a 100644 --- a/discover/types.go +++ b/discover/types.go @@ -1,10 +1,14 @@ package discover import ( - "fmt" + "context" "log/slog" + "path/filepath" + "runtime" + "strings" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/ml" ) type memInfo struct { @@ -15,8 +19,8 @@ type memInfo struct { // Beginning of an `ollama info` command type GpuInfo struct { // TODO better name maybe "InferenceProcessor"? + ml.DeviceID memInfo - Library string `json:"library,omitempty"` // Optional variant to select (e.g. versions, cpu feature flags) Variant string `json:"variant"` @@ -27,19 +31,16 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"? // Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly DependencyPath []string `json:"lib_path,omitempty"` - // Extra environment variables specific to the GPU as list of [key=value] - EnvWorkarounds []string `json:"envs,omitempty"` - // Set to true if we can NOT reliably discover FreeMemory. A value of true indicates // the FreeMemory is best effort, and may over or under report actual memory usage // False indicates FreeMemory can generally be trusted on this GPU UnreliableFreeMemory bool // GPU information - ID string `json:"gpu_id"` // string to use for selection of this specific GPU - filterID int //nolint:unused,nolintlint // AMD Workaround: The numeric ID of the device used to filter out other devices - Name string `json:"name"` // user friendly name if available - Compute string `json:"compute"` // Compute Capability or gfx + filterID string // AMD/Vulkan Workaround: The numeric ID of the device used to filter out other devices + Name string `json:"name"` // user friendly name if available + ComputeMajor int `json:"compute_major"` // Compute Capability or gfx + ComputeMinor int `json:"compute_minor"` // Driver Information - TODO no need to put this on each GPU DriverMajor int `json:"driver_major,omitempty"` @@ -70,37 +71,8 @@ type CPU struct { ThreadCount int } -type CudaGPUInfo struct { - GpuInfo - OSOverhead uint64 // Memory overhead between the driver library and management library - index int //nolint:unused,nolintlint - computeMajor int //nolint:unused,nolintlint - computeMinor int //nolint:unused,nolintlint -} -type CudaGPUInfoList []CudaGPUInfo - -type RocmGPUInfo struct { - GpuInfo - usedFilepath string //nolint:unused,nolintlint - index int //nolint:unused,nolintlint -} -type RocmGPUInfoList []RocmGPUInfo - -type OneapiGPUInfo struct { - GpuInfo - driverIndex int //nolint:unused,nolintlint - gpuIndex int //nolint:unused,nolintlint -} -type OneapiGPUInfoList []OneapiGPUInfo - type GpuInfoList []GpuInfo -type UnsupportedGPUInfo struct { - GpuInfo - Reason string `json:"reason"` -} - -// Split up the set of gpu info's by Library and variant func (l GpuInfoList) ByLibrary() []GpuInfoList { resp := []GpuInfoList{} libs := []string{} @@ -125,18 +97,47 @@ func (l GpuInfoList) ByLibrary() []GpuInfoList { return resp } -// Report the GPU information into the log an Info level -func (l GpuInfoList) LogDetails() { - for _, g := range l { +func LogDetails(devices []ml.DeviceInfo) { + for _, dev := range devices { + var libs []string + for _, dir := range dev.LibraryPath { + if strings.Contains(dir, filepath.Join("lib", "ollama")) { + libs = append(libs, filepath.Base(dir)) + } + } + typeStr := "discrete" + if dev.Integrated { + typeStr = "iGPU" + } slog.Info("inference compute", - "id", g.ID, - "library", g.Library, - "variant", g.Variant, - "compute", g.Compute, - "driver", fmt.Sprintf("%d.%d", g.DriverMajor, g.DriverMinor), - "name", g.Name, - "total", format.HumanBytes2(g.TotalMemory), - "available", format.HumanBytes2(g.FreeMemory), + "id", dev.ID, + "library", dev.Library, + "compute", dev.Compute(), + "name", dev.Name, + "description", dev.Description, + "libdirs", strings.Join(libs, ","), + "driver", dev.Driver(), + "pci_id", dev.PCIID, + "type", typeStr, + "total", format.HumanBytes2(dev.TotalMemory), + "available", format.HumanBytes2(dev.FreeMemory), + ) + } + // CPU inference + if len(devices) == 0 { + dev, _ := GetCPUMem() + slog.Info("inference compute", + "id", "cpu", + "library", "cpu", + "compute", "", + "name", "cpu", + "description", "cpu", + "libdirs", "ollama", + "driver", "", + "pci_id", "", + "type", "", + "total", format.HumanBytes2(dev.TotalMemory), + "available", format.HumanBytes2(dev.FreeMemory), ) } } @@ -149,16 +150,15 @@ func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory } type SystemInfo struct { - System CPUInfo `json:"system"` - GPUs []GpuInfo `json:"gpus"` - UnsupportedGPUs []UnsupportedGPUInfo `json:"unsupported_gpus"` - DiscoveryErrors []string `json:"discovery_errors"` + System CPUInfo `json:"system"` + GPUs []GpuInfo `json:"gpus"` } // Return the optimal number of threads to use for inference func (si SystemInfo) GetOptimalThreadCount() int { if len(si.System.CPUs) == 0 { - return 0 + // Fall back to Go's num CPU + return runtime.NumCPU() } coreCount := 0 @@ -173,9 +173,10 @@ func (si SystemInfo) GetOptimalThreadCount() int { func (l GpuInfoList) FlashAttentionSupported() bool { for _, gpu := range l { supportsFA := gpu.Library == "cpu" || - gpu.Library == "metal" || - (gpu.Library == "cuda" && gpu.DriverMajor >= 7) || - gpu.Library == "rocm" + gpu.Name == "Metal" || gpu.Library == "Metal" || + (gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) || // We don't have kernels for Jetson Xavier + gpu.Library == "ROCm" || + gpu.Library == "Vulkan" if !supportsFA { return false @@ -183,3 +184,31 @@ func (l GpuInfoList) FlashAttentionSupported() bool { } return true } + +type BaseRunner interface { + // GetPort returns the localhost port number the runner is running on + GetPort() int + + // HasExited indicates if the runner is no longer running. This can be used during + // bootstrap to detect if a given filtered device is incompatible and triggered an assert + HasExited() bool +} + +type RunnerDiscovery interface { + BaseRunner + + // GetDeviceInfos will perform a query of the underlying device libraries + // for device identification and free VRAM information + // During bootstrap scenarios, this routine may take seconds to complete + GetDeviceInfos(ctx context.Context) []ml.DeviceInfo +} + +type FilteredRunnerDiscovery interface { + RunnerDiscovery + + // GetActiveDeviceIDs returns the filtered set of devices actively in + // use by this runner for running models. If the runner is a bootstrap runner, no devices + // will be active yet so no device IDs are returned. + // This routine will not query the underlying device and will return immediately + GetActiveDeviceIDs() []ml.DeviceID +} diff --git a/docs/gpu.md b/docs/gpu.md index 464788cc..910f82d1 100644 --- a/docs/gpu.md +++ b/docs/gpu.md @@ -9,15 +9,20 @@ Check your compute compatibility to see if your card is supported: | ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- | | 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` | | | NVIDIA Professioal | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` | -| 9.0 | NVIDIA | `H200` `H100` | +| 11.0 | Jetson | `T4000` `T5000` (Requires driver 580 or newer) | +| 10.3 | NVIDIA Professioal | `B300` `GB300` (Requires driver 580 or newer) | +| 10.0 | NVIDIA Professioal | `B200` `GB200` (Requires driver 580 or newer) | +| 9.0 | NVIDIA | `H200` `H100` `GH200` | | 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` | | | NVIDIA Professional | `L4` `L40` `RTX 6000` | +| 8.7 | Jetson | `Orin Nano` `Orin NX` `AGX Orin` | | 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` `RTX 3050 Ti` `RTX 3050` | | | NVIDIA Professional | `A40` `RTX A6000` `RTX A5000` `RTX A4000` `RTX A3000` `RTX A2000` `A10` `A16` `A2` | | 8.0 | NVIDIA | `A100` `A30` | | 7.5 | GeForce GTX/RTX | `GTX 1650 Ti` `TITAN RTX` `RTX 2080 Ti` `RTX 2080` `RTX 2070` `RTX 2060` | | | NVIDIA Professional | `T4` `RTX 5000` `RTX 4000` `RTX 3000` `T2000` `T1200` `T1000` `T600` `T500` | | | Quadro | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000` | +| 7.2 | Jetson | `Xavier NX` `AGX Xavier` (Jetpack 5) | | 7.0 | NVIDIA | `TITAN V` `V100` `Quadro GV100` | | 6.1 | NVIDIA TITAN | `TITAN Xp` `TITAN X` | | | GeForce GTX | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050 Ti` `GTX 1050` | @@ -51,20 +56,23 @@ sudo modprobe nvidia_uvm` Ollama supports the following AMD GPUs: ### Linux Support -| Family | Cards and accelerators | -| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | -| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` | -| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` | -| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` | +| Family | Cards and accelerators | +| -------------- | -------------------------------------------------------------------------------------------------------------------- | +| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` | +| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` | +| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` | ### Windows Support -With ROCm v6.1, the following GPUs are supported on Windows. +With ROCm v6.2, the following GPUs are supported on Windows. | Family | Cards and accelerators | | -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | | AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` | | AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` | +### Known Workarounds + +- The RX Vega 56 requires `HSA_ENABLE_SDMA=0` to disable SDMA ### Overrides on Linux Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In @@ -85,8 +93,6 @@ At this time, the known supported GPU types on linux are the following LLVM Targ This table shows some example GPUs that map to these LLVM targets: | **LLVM Target** | **An Example GPU** | |-----------------|---------------------| -| gfx900 | Radeon RX Vega 56 | -| gfx906 | Radeon Instinct MI50 | | gfx908 | Radeon Instinct MI100 | | gfx90a | Radeon Instinct MI210 | | gfx940 | Radeon Instinct MI300 | diff --git a/docs/macos.md b/docs/macos.md index 9617bdc7..26fb23c7 100644 --- a/docs/macos.md +++ b/docs/macos.md @@ -2,7 +2,7 @@ ## System Requirements -* MacOS Monterey (v12) or newer +* MacOS Sonoma (v14) or newer * Apple M series (CPU and GPU support) or x86 (CPU only) diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 7647b12f..18c014d1 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -38,26 +38,14 @@ Join the [Discord](https://discord.gg/ollama) for help interpreting the logs. ## LLM libraries -Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` and the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library. - -In the server log, you will see a message that looks something like this (varies from release to release): - -``` -Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v12 rocm_v5] -``` +Ollama includes multiple LLM libraries compiled for different GPU libraries and versions. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. **Experimental LLM Library Override** -You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass autodetection, so for example, if you have a CUDA card, but want to force the CPU LLM library with AVX2 vector support, use: +You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to limit autodetection, so for example, if you have both CUDA and AMD GPUs, but want to force the CUDA v13 only, use: ```shell -OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve -``` - -You can see what features your CPU has with the following. - -```shell -cat /proc/cpuinfo| grep flags | head -1 +OLLAMA_LLM_LIBRARY="cuda_v13" ollama serve ``` ## Installing older or pre-release versions on Linux diff --git a/envconfig/config.go b/envconfig/config.go index 09243ab9..d155bd8f 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -24,6 +24,9 @@ func Host() *url.URL { switch { case !ok: scheme, hostport = "http", s + if s == "ollama.com" { + scheme, hostport = "https", "ollama.com:443" + } case scheme == "http": defaultPort = "80" case scheme == "https": @@ -145,8 +148,8 @@ func Remotes() []string { return r } -func Bool(k string) func() bool { - return func() bool { +func BoolWithDefault(k string) func(defaultValue bool) bool { + return func(defaultValue bool) bool { if s := Var(k); s != "" { b, err := strconv.ParseBool(s) if err != nil { @@ -156,7 +159,14 @@ func Bool(k string) func() bool { return b } - return false + return defaultValue + } +} + +func Bool(k string) func() bool { + withDefault := BoolWithDefault(k) + return func() bool { + return withDefault(false) } } @@ -177,7 +187,7 @@ func LogLevel() slog.Level { var ( // FlashAttention enables the experimental flash attention feature. - FlashAttention = Bool("OLLAMA_FLASH_ATTENTION") + FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION") // KvCacheType is the quantization type for the K/V cache. KvCacheType = String("OLLAMA_KV_CACHE_TYPE") // NoHistory disables readline history. @@ -210,6 +220,7 @@ var ( CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES") HipVisibleDevices = String("HIP_VISIBLE_DEVICES") RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES") + VkVisibleDevices = String("GGML_VK_VISIBLE_DEVICES") GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL") HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION") ) @@ -263,7 +274,7 @@ type EnvVar struct { func AsMap() map[string]EnvVar { ret := map[string]EnvVar{ "OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, - "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"}, + "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"}, "OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"}, "OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"}, "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, @@ -300,6 +311,7 @@ func AsMap() map[string]EnvVar { ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"} ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible by numeric ID"} ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible by UUID or numeric ID"} + ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which Vulkan devices are visible by numeric ID"} ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"} ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"} ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"} diff --git a/envconfig/config_test.go b/envconfig/config_test.go index f232f1cd..ddd86a11 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -37,6 +37,7 @@ func TestHost(t *testing.T) { "https": {"https://1.2.3.4", "https://1.2.3.4:443"}, "https port": {"https://1.2.3.4:4321", "https://1.2.3.4:4321"}, "proxy path": {"https://example.com/ollama", "https://example.com:443/ollama"}, + "ollama.com": {"ollama.com", "https://ollama.com:443"}, } for name, tt := range cases { diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 58803f58..fcb3d9fd 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -870,11 +870,6 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool { return true } - if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) { - // gpt-oss uses attention with sinks which does not support quantized cache types - slog.Warn("model only supports non-quantized cache types", "model", arch) - return false - } return slices.Contains([]string{"q8_0", "q4_0"}, cacheType) } @@ -898,7 +893,10 @@ func (f GGML) SupportsFlashAttention() bool { // FlashAttention checks if the model should enable flash attention func (f GGML) FlashAttention() bool { return slices.Contains([]string{ + "gemma3", "gptoss", "gpt-oss", + "qwen3", + "qwen3moe", }, f.KV().String("general.architecture")) } diff --git a/fs/ggml/type.go b/fs/ggml/type.go index 1a31a5fd..fb69352b 100644 --- a/fs/ggml/type.go +++ b/fs/ggml/type.go @@ -229,7 +229,7 @@ const ( TensorTypeMXFP4 ) -// ParseFileType parses the provided GGUF file type +// ParseTensorType parses the provided GGUF tensor type // Only Ollama supported types are considered valid func ParseTensorType(s string) (TensorType, error) { switch s { diff --git a/integration/basic_test.go b/integration/basic_test.go index 0a6b9253..41406147 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -17,16 +17,21 @@ func TestBlueSky(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() // Set up the test data - req := api.GenerateRequest{ - Model: smol, - Prompt: blueSkyPrompt, + req := api.ChatRequest{ + Model: smol, + Messages: []api.Message{ + { + Role: "user", + Content: blueSkyPrompt, + }, + }, Stream: &stream, Options: map[string]any{ "temperature": 0, "seed": 123, }, } - GenerateTestHelper(ctx, t, req, blueSkyExpected) + ChatTestHelper(ctx, t, req, blueSkyExpected) } func TestUnicode(t *testing.T) { @@ -34,10 +39,15 @@ func TestUnicode(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) defer cancel() // Set up the test data - req := api.GenerateRequest{ + req := api.ChatRequest{ // DeepSeek has a Unicode tokenizer regex, making it a unicode torture test - Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage? - Prompt: "天空为什么是蓝色的?", // Why is the sky blue? + Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage? + Messages: []api.Message{ + { + Role: "user", + Content: "天空为什么是蓝色的?", // Why is the sky blue? + }, + }, Stream: &stream, Options: map[string]any{ "temperature": 0, @@ -57,9 +67,14 @@ func TestUnicode(t *testing.T) { if err != nil { t.Fatalf("failed to load model %s: %s", req.Model, err) } + defer func() { + // best effort unload once we're done with the model + client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil }) + }() + skipIfNotGPULoaded(ctx, t, client, req.Model, 100) - DoGenerate(ctx, t, client, req, []string{ + DoChat(ctx, t, client, req, []string{ "散射", // scattering "频率", // frequency }, 120*time.Second, 120*time.Second) @@ -69,9 +84,14 @@ func TestExtendedUnicodeOutput(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() // Set up the test data - req := api.GenerateRequest{ - Model: "gemma2:2b", - Prompt: "Output some smily face emoji", + req := api.ChatRequest{ + Model: "gemma2:2b", + Messages: []api.Message{ + { + Role: "user", + Content: "Output some smily face emoji", + }, + }, Stream: &stream, Options: map[string]any{ "temperature": 0, @@ -83,7 +103,7 @@ func TestExtendedUnicodeOutput(t *testing.T) { if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatal(err) } - DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second) + DoChat(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second) } func TestUnicodeModelDir(t *testing.T) { @@ -108,14 +128,19 @@ func TestUnicodeModelDir(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() - req := api.GenerateRequest{ - Model: smol, - Prompt: blueSkyPrompt, + req := api.ChatRequest{ + Model: smol, + Messages: []api.Message{ + { + Role: "user", + Content: blueSkyPrompt, + }, + }, Stream: &stream, Options: map[string]any{ "temperature": 0, "seed": 123, }, } - GenerateTestHelper(ctx, t, req, blueSkyExpected) + ChatTestHelper(ctx, t, req, blueSkyExpected) } diff --git a/integration/concurrency_test.go b/integration/concurrency_test.go index 3104eacc..cb44e900 100644 --- a/integration/concurrency_test.go +++ b/integration/concurrency_test.go @@ -20,9 +20,9 @@ import ( ) // Send multiple requests in parallel (concurrently) to a single model and ensure responses are expected -func TestConcurrentGenerate(t *testing.T) { +func TestConcurrentChat(t *testing.T) { // Assumes all requests have the same model - req, resp := GenerateRequests() + req, resp := ChatRequests() numParallel := int(envconfig.NumParallel() + 1) iterLimit := 3 @@ -57,7 +57,7 @@ func TestConcurrentGenerate(t *testing.T) { slog.Info("Starting", "thread", i, "iter", j) // On slower GPUs it can take a while to process the concurrent requests // so we allow a much longer initial timeout - DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second) + DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second) } }(i) } @@ -109,6 +109,8 @@ func TestMultiModelStress(t *testing.T) { defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() + initialTimeout := 120 * time.Second + streamTimeout := 20 * time.Second // Make sure all the models are pulled before we get started for _, model := range chosenModels { @@ -147,6 +149,8 @@ chooseModels: for _, m := range models.Models { if m.SizeVRAM == 0 { slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount]) + initialTimeout = 240 * time.Second + streamTimeout = 30 * time.Second break chooseModels } } @@ -163,7 +167,7 @@ chooseModels: wg.Add(1) go func(i int) { defer wg.Done() - reqs, resps := GenerateRequests() + reqs, resps := ChatRequests() for j := 0; j < 3; j++ { if time.Now().Sub(started) > softTimeout { slog.Info("exceeded soft timeout, winding down test") @@ -171,11 +175,8 @@ chooseModels: } k := r.Int() % len(reqs) reqs[k].Model = chosenModels[i] - slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Prompt) - DoGenerate(ctx, t, client, reqs[k], resps[k], - 120*time.Second, // Be extra patient for the model to load initially - 10*time.Second, // Once results start streaming, fail if they stall - ) + slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Messages[0].Content) + DoChat(ctx, t, client, reqs[k], resps[k], initialTimeout, streamTimeout) } }(i) } diff --git a/integration/context_test.go b/integration/context_test.go index 9d13f7ac..1d2d6554 100644 --- a/integration/context_test.go +++ b/integration/context_test.go @@ -21,9 +21,14 @@ func TestLongInputContext(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up the test data - req := api.GenerateRequest{ - Model: smol, - Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?", + req := api.ChatRequest{ + Model: smol, + Messages: []api.Message{ + { + Role: "user", + Content: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?", + }, + }, Stream: &stream, Options: map[string]any{ "temperature": 0, @@ -36,7 +41,7 @@ func TestLongInputContext(t *testing.T) { if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("PullIfMissing failed: %v", err) } - DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second) + DoChat(ctx, t, client, req, []string{"russia", "german", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second) } func TestContextExhaustion(t *testing.T) { @@ -48,9 +53,14 @@ func TestContextExhaustion(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up the test data - req := api.GenerateRequest{ - Model: smol, - Prompt: "Write me a story in english with a lot of emojis", + req := api.ChatRequest{ + Model: smol, + Messages: []api.Message{ + { + Role: "user", + Content: "Write me a story in english with a lot of emojis", + }, + }, Stream: &stream, Options: map[string]any{ "temperature": 0, @@ -63,12 +73,12 @@ func TestContextExhaustion(t *testing.T) { if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("PullIfMissing failed: %v", err) } - DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second) + DoChat(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second) } // Send multiple generate requests with prior context and ensure the response is coherant and expected func TestParallelGenerateWithHistory(t *testing.T) { - modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model + modelName := "gpt-oss:20b" req, resp := GenerateRequests() numParallel := 2 iterLimit := 2 @@ -78,15 +88,23 @@ func TestParallelGenerateWithHistory(t *testing.T) { defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() + initialTimeout := 120 * time.Second + streamTimeout := 20 * time.Second // Get the server running (if applicable) warm the model up with a single initial request - slog.Info("loading", "model", modelOverride) + slog.Info("loading", "model", modelName) err := client.Generate(ctx, - &api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}}, + &api.GenerateRequest{Model: modelName, KeepAlive: &api.Duration{Duration: 10 * time.Second}}, func(response api.GenerateResponse) error { return nil }, ) if err != nil { - t.Fatalf("failed to load model %s: %s", modelOverride, err) + t.Fatalf("failed to load model %s: %s", modelName, err) + } + gpuPercent := getGPUPercent(ctx, t, client, modelName) + if gpuPercent < 80 { + slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent) + initialTimeout = 240 * time.Second + streamTimeout = 30 * time.Second } var wg sync.WaitGroup @@ -95,7 +113,7 @@ func TestParallelGenerateWithHistory(t *testing.T) { go func(i int) { defer wg.Done() k := i % len(req) - req[k].Model = modelOverride + req[k].Model = modelName for j := 0; j < iterLimit; j++ { if time.Now().Sub(started) > softTimeout { slog.Info("exceeded soft timeout, winding down test") @@ -104,7 +122,7 @@ func TestParallelGenerateWithHistory(t *testing.T) { slog.Info("Starting", "thread", i, "iter", j) // On slower GPUs it can take a while to process the concurrent requests // so we allow a much longer initial timeout - c := DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second) + c := DoGenerate(ctx, t, client, req[k], resp[k], initialTimeout, streamTimeout) req[k].Context = c req[k].Prompt = "tell me more!" } @@ -155,7 +173,7 @@ func TestGenerateWithHistory(t *testing.T) { // Send multiple chat requests with prior context and ensure the response is coherant and expected func TestParallelChatWithHistory(t *testing.T) { - modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model + modelName := "gpt-oss:20b" req, resp := ChatRequests() numParallel := 2 iterLimit := 2 @@ -165,15 +183,23 @@ func TestParallelChatWithHistory(t *testing.T) { defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() + initialTimeout := 120 * time.Second + streamTimeout := 20 * time.Second // Get the server running (if applicable) warm the model up with a single initial empty request - slog.Info("loading", "model", modelOverride) + slog.Info("loading", "model", modelName) err := client.Generate(ctx, - &api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}}, + &api.GenerateRequest{Model: modelName, KeepAlive: &api.Duration{Duration: 10 * time.Second}}, func(response api.GenerateResponse) error { return nil }, ) if err != nil { - t.Fatalf("failed to load model %s: %s", modelOverride, err) + t.Fatalf("failed to load model %s: %s", modelName, err) + } + gpuPercent := getGPUPercent(ctx, t, client, modelName) + if gpuPercent < 80 { + slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent) + initialTimeout = 240 * time.Second + streamTimeout = 30 * time.Second } var wg sync.WaitGroup @@ -182,7 +208,7 @@ func TestParallelChatWithHistory(t *testing.T) { go func(i int) { defer wg.Done() k := i % len(req) - req[k].Model = modelOverride + req[k].Model = modelName for j := 0; j < iterLimit; j++ { if time.Now().Sub(started) > softTimeout { slog.Info("exceeded soft timeout, winding down test") @@ -191,7 +217,7 @@ func TestParallelChatWithHistory(t *testing.T) { slog.Info("Starting", "thread", i, "iter", j) // On slower GPUs it can take a while to process the concurrent requests // so we allow a much longer initial timeout - assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second) + assistant := DoChat(ctx, t, client, req[k], resp[k], initialTimeout, streamTimeout) if assistant == nil { t.Fatalf("didn't get an assistant response for context") } diff --git a/integration/library_models_test.go b/integration/library_models_test.go index 49e1097b..89968848 100644 --- a/integration/library_models_test.go +++ b/integration/library_models_test.go @@ -15,7 +15,7 @@ import ( // First run of this scenario on a target system will take a long time to download // ~1.5TB of models. Set a sufficiently large -timeout for your network speed -func TestLibraryModelsGenerate(t *testing.T) { +func TestLibraryModelsChat(t *testing.T) { softTimeout, hardTimeout := getTimeouts(t) slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout) ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) @@ -43,9 +43,14 @@ func TestLibraryModelsGenerate(t *testing.T) { t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch)) } } - req := api.GenerateRequest{ - Model: model, - Prompt: blueSkyPrompt, + req := api.ChatRequest{ + Model: model, + Messages: []api.Message{ + { + Role: "user", + Content: blueSkyPrompt, + }, + }, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: map[string]interface{}{ "temperature": 0.1, @@ -58,13 +63,13 @@ func TestLibraryModelsGenerate(t *testing.T) { anyResp = []string{"select", "from"} } else if model == "granite3-guardian" || model == "shieldgemma" || model == "llama-guard3" || model == "bespoke-minicheck" { anyResp = []string{"yes", "no", "safe", "unsafe"} - } else if model == "openthinker" || model == "nexusraven" { + } else if model == "openthinker" { anyResp = []string{"plugin", "im_sep", "components", "function call"} } else if model == "starcoder" || model == "starcoder2" || model == "magicoder" || model == "deepseek-coder" { - req.Prompt = "def fibonacci():" + req.Messages[0].Content = "def fibonacci():" anyResp = []string{"f(n)", "sequence", "n-1", "main()", "__main__", "while"} } - DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second) + DoChat(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second) }) } } diff --git a/integration/llm_image_test.go b/integration/llm_image_test.go index 9bf11257..e3591565 100644 --- a/integration/llm_image_test.go +++ b/integration/llm_image_test.go @@ -34,17 +34,22 @@ func TestVisionModels(t *testing.T) { if err != nil { t.Fatal(err) } - req := api.GenerateRequest{ - Model: v.model, - Prompt: "what does the text in this image say?", + req := api.ChatRequest{ + Model: v.model, + Messages: []api.Message{ + { + Role: "user", + Content: "what does the text in this image say?", + Images: []api.ImageData{ + image, + }, + }, + }, Stream: &stream, Options: map[string]any{ "seed": 42, "temperature": 0.0, }, - Images: []api.ImageData{ - image, - }, } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() @@ -56,8 +61,15 @@ func TestVisionModels(t *testing.T) { if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatal(err) } + // Preload to skip if we're less than 80% on GPU to avoid extremely slow tests + err = client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil }) + if err != nil { + t.Fatalf("failed to load model %s: %s", req.Model, err) + } + skipIfNotGPULoaded(ctx, t, client, req.Model, 80) + // llava models on CPU can be quite slow to start - DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second) + DoChat(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second) }) } } diff --git a/integration/model_arch_test.go b/integration/model_arch_test.go index 721d95c5..b09b6773 100644 --- a/integration/model_arch_test.go +++ b/integration/model_arch_test.go @@ -19,7 +19,7 @@ import ( "github.com/ollama/ollama/format" ) -func TestModelsGenerate(t *testing.T) { +func TestModelsChat(t *testing.T) { softTimeout, hardTimeout := getTimeouts(t) slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout) ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) @@ -65,16 +65,41 @@ func TestModelsGenerate(t *testing.T) { } } } + initialTimeout := 120 * time.Second + streamTimeout := 30 * time.Second + slog.Info("loading", "model", model) + err := client.Generate(ctx, + &api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 10 * time.Second}}, + func(response api.GenerateResponse) error { return nil }, + ) + if err != nil { + t.Fatalf("failed to load model %s: %s", model, err) + } + gpuPercent := getGPUPercent(ctx, t, client, model) + if gpuPercent < 80 { + slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent) + initialTimeout = 240 * time.Second + streamTimeout = 40 * time.Second + } + // TODO - fiddle with context size - req := api.GenerateRequest{ - Model: model, - Prompt: blueSkyPrompt, + req := api.ChatRequest{ + Model: model, + Messages: []api.Message{ + { + Role: "user", + Content: blueSkyPrompt, + }, + }, + KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: map[string]interface{}{ "temperature": 0, "seed": 123, }, } - DoGenerate(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second) + DoChat(ctx, t, client, req, blueSkyExpected, initialTimeout, streamTimeout) + // best effort unload once we're done with the model + client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil }) }) } } @@ -128,8 +153,9 @@ func TestModelsEmbed(t *testing.T) { } } req := api.EmbeddingRequest{ - Model: model, - Prompt: "why is the sky blue?", + Model: model, + Prompt: "why is the sky blue?", + KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: map[string]interface{}{ "temperature": 0, "seed": 123, @@ -139,6 +165,10 @@ func TestModelsEmbed(t *testing.T) { if err != nil { t.Fatalf("embeddings call failed %s", err) } + defer func() { + // best effort unload once we're done with the model + client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil }) + }() if len(resp.Embedding) == 0 { t.Errorf("zero length embedding response") } diff --git a/integration/model_perf_test.go b/integration/model_perf_test.go index 3d6ba923..339dff17 100644 --- a/integration/model_perf_test.go +++ b/integration/model_perf_test.go @@ -173,9 +173,14 @@ func doModelPerfTest(t *testing.T, chatModels []string) { slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent) continue } - req := api.GenerateRequest{ - Model: model, - Prompt: tc.prompt, + req := api.ChatRequest{ + Model: model, + Messages: []api.Message{ + { + Role: "user", + Content: tc.prompt, + }, + }, KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns Options: map[string]interface{}{ "temperature": 0, @@ -184,7 +189,7 @@ func doModelPerfTest(t *testing.T, chatModels []string) { }, } atLeastOne := false - var resp api.GenerateResponse + var resp api.ChatResponse stream := false req.Stream = &stream @@ -198,7 +203,7 @@ func doModelPerfTest(t *testing.T, chatModels []string) { ) defer cancel() - err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error { + err = client.Chat(genCtx, &req, func(rsp api.ChatResponse) error { resp = rsp return nil }) @@ -214,13 +219,13 @@ func doModelPerfTest(t *testing.T, chatModels []string) { } loaded = true for _, expResp := range tc.anyResp { - if strings.Contains(strings.ToLower(resp.Response), expResp) { + if strings.Contains(strings.ToLower(resp.Message.Content), expResp) { atLeastOne = true break } } if !atLeastOne { - t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Response) + t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Message.Content) } models, err := client.ListRunning(ctx) if err != nil { diff --git a/integration/quantization_test.go b/integration/quantization_test.go index 30564749..be0d81e7 100644 --- a/integration/quantization_test.go +++ b/integration/quantization_test.go @@ -74,9 +74,14 @@ func TestQuantization(t *testing.T) { } stream := true - genReq := api.GenerateRequest{ - Model: newName, - Prompt: blueSkyPrompt, + chatReq := api.ChatRequest{ + Model: newName, + Messages: []api.Message{ + { + Role: "user", + Content: blueSkyPrompt, + }, + }, KeepAlive: &api.Duration{Duration: 3 * time.Second}, Options: map[string]any{ "seed": 42, @@ -91,8 +96,8 @@ func TestQuantization(t *testing.T) { reqCtx, reqCancel := context.WithCancel(ctx) atLeastOne := false var buf bytes.Buffer - genfn := func(response api.GenerateResponse) error { - buf.Write([]byte(response.Response)) + chatfn := func(response api.ChatResponse) error { + buf.Write([]byte(response.Message.Content)) fullResp := strings.ToLower(buf.String()) for _, resp := range blueSkyExpected { if strings.Contains(fullResp, resp) { @@ -108,14 +113,14 @@ func TestQuantization(t *testing.T) { done := make(chan int) var genErr error go func() { - genErr = client.Generate(reqCtx, &genReq, genfn) + genErr = client.Chat(reqCtx, &chatReq, chatfn) done <- 0 }() select { case <-done: if genErr != nil && !atLeastOne { - t.Fatalf("failed with %s request prompt %s ", genReq.Model, genReq.Prompt) + t.Fatalf("failed with %s request prompt %s ", chatReq.Model, chatReq.Messages[0].Content) } case <-ctx.Done(): t.Error("outer test context done while waiting for generate") diff --git a/integration/testdata/embed.json b/integration/testdata/embed.json index 73f47b51..c80fd1b0 100644 --- a/integration/testdata/embed.json +++ b/integration/testdata/embed.json @@ -7,5 +7,7 @@ "mxbai-embed-large": [-0.184430, -0.044229, 0.257643, -0.340901, -0.494335, 0.091339, -0.128759, 0.432637, 0.724959, 0.991389, -0.708882, -0.403970, 0.461882, 0.291021, -0.155411, -0.517737, -0.507892, 0.593476, -0.632355, 0.989990, -0.793276, 0.410782, -0.710714, -0.392212, 0.671396, 0.544336, 0.793681, 0.426529, 0.050131, 0.794476, -0.365169, 0.158634, -0.525134, -1.073285, 0.289087, 0.421985, -0.181116, -0.277893, 0.411325, -0.823697, 0.158263, -0.776151, 0.180766, -1.537495, -0.595382, -0.798639, -0.887673, -0.828362, -0.192054, 0.534375, 0.244046, 0.513111, 0.792217, -0.443562, -0.240577, 0.568054, -0.602470, -0.063402, -0.021810, -0.225963, -0.385676, -0.335298, -0.146143, -0.481427, 0.290886, 1.367729, -0.337673, -0.311578, -0.178011, 0.634559, 0.491958, 0.487708, 0.187159, -0.580782, 0.243836, 0.014322, -0.103781, 0.495297, -0.637951, 0.210895, 0.204865, -0.016029, -0.246061, 0.233346, -0.478157, -0.558347, -0.311118, 0.440266, -0.601550, 0.339132, 0.016395, 0.852818, -0.050276, -0.467989, 0.183518, -0.068032, -0.359197, -0.066681, 0.317749, 0.686872, 0.762136, 0.893017, -0.400377, 0.071887, -0.772756, 0.480981, -0.071267, 1.038188, -0.365921, -0.162451, 0.155094, 0.300965, 0.372498, 0.537590, -0.089499, 0.980337, -0.072252, 0.809650, -0.189996, -0.506872, -0.215468, 1.173936, -0.340243, -0.363687, 0.601494, -0.585390, -0.938341, 0.217940, -0.634498, 0.187639, 0.266674, -1.292217, -0.016839, 0.576413, -0.752896, -0.109006, 0.398096, -0.092125, -0.786432, -0.872161, 0.586165, 0.230709, -0.610144, 1.061853, -0.008028, 0.958815, 0.546893, 0.003219, -0.232854, -0.319982, -0.712559, -0.571618, 0.456460, 0.300718, -0.310005, 0.265730, -0.712883, -0.181774, 0.033142, -0.831168, -0.386635, 0.305330, 0.146375, -0.147643, -0.508379, 0.049415, -0.583293, -1.235440, -0.286194, -0.458141, 0.601308, 0.189742, -0.213759, 0.017279, 0.678105, 0.299103, 0.288363, -0.392151, -0.355677, 0.605839, 0.084698, -0.495191, 0.340367, 0.208737, -0.714317, 0.616687, -0.832112, 0.137910, -0.016345, 0.026855, 0.327972, 0.264948, -0.523514, 1.131886, -0.610398, 1.097030, 0.031263, 0.209969, -0.003950, -1.267444, 0.253044, -0.410123, -0.748738, 0.343910, -0.210202, -0.356721, 0.702087, 0.403481, 0.189649, 0.880100, 0.735530, -0.637482, 0.405021, -0.697255, 1.082935, -0.942586, -0.234460, 0.495627, -0.500333, -0.753946, 0.741444, 0.219077, 0.507746, 0.372319, -0.497832, 0.029056, -0.059532, 0.239361, 0.017333, 0.100514, 0.191668, -0.091566, 0.462802, 0.598301, -0.290409, -0.111971, 0.614868, 1.004610, -0.487192, -0.792706, -0.577011, 0.257568, 0.659747, -0.104938, -0.072255, -0.133366, 0.286197, -0.385580, -0.091709, 1.294983, 0.131856, 0.270540, 0.619632, 0.055126, -0.731469, 0.141043, 0.452648, -0.496773, -0.064298, -0.505497, -0.191061, 0.260015, 0.774404, -0.500005, 0.262618, 0.206279, -0.323111, 0.526082, 0.249043, -0.876898, -0.593859, -0.076233, -0.071404, 0.014999, -0.666098, 0.547257, 0.527260, -0.847083, 0.987440, -0.077034, 0.606362, 0.281647, -0.173747, -0.422577, -0.772435, 0.463466, -0.053370, 0.354388, -0.074551, 0.331071, -0.292800, -0.152313, -0.420847, -0.976438, 0.296503, -0.807020, 0.583320, 0.734794, -1.042924, -1.068834, 0.194230, 0.557421, -0.683329, -0.252993, -0.338887, 0.656994, -0.026126, 0.629963, 0.230679, 0.283844, -1.162887, 0.978764, 0.159263, 0.428557, -0.670335, -0.119131, 1.104897, -0.354662, -0.088493, -0.344159, -1.152556, -0.733712, -0.644582, 0.547945, -0.314649, 1.112930, -0.085785, -0.122471, 0.757697, -0.561914, -0.842027, 0.031444, 0.055126, 1.339829, -0.080786, -0.376824, 0.171051, -0.127106, 0.201467, 0.252141, 0.623374, 0.608344, 0.600774, 0.526192, -0.781339, 0.752040, -0.183211, -0.013682, -0.889619, -0.015756, -0.667047, -0.146755, -0.277918, -0.041418, -1.103346, 0.021202, -0.888572, -0.395570, -0.323405, -0.508976, -0.464945, 0.108451, -0.271122, 0.136861, -0.008576, -0.532034, -0.096445, -0.467949, -0.033445, 0.234179, 0.706482, -0.457873, -1.118891, 0.268313, 0.440962, 0.082270, 0.201520, 1.338068, 0.030613, -1.423740, -0.353576, -0.758182, 0.194265, 0.556422, 0.504734, -0.047644, 0.118117, 0.189478, -0.292784, -0.570022, 0.192147, -0.311218, 0.780930, -0.819373, -0.494229, -0.080453, 0.161174, -0.068893, -0.302501, -0.122194, -0.887550, 0.187576, -0.238606, 0.254111, -0.480286, -0.010553, -0.235043, -0.213408, -0.693196, -0.871761, -1.177420, 0.421552, 0.150834, 0.428866, 0.173524, 0.333427, -1.617997, 0.589071, 0.319783, 0.598216, -0.285953, 0.030320, -0.250552, 0.620356, -0.410909, -0.402419, -0.542502, -0.859507, 0.544731, 0.140452, -0.465390, 0.210549, -0.443584, 0.258778, -0.097309, 0.538018, 0.906112, -0.435498, 0.769745, -0.112336, -0.146874, 0.774157, 0.303549, 0.353710, -0.594408, -0.050291, 0.460403, 0.166917, -0.557056, -0.427867, -0.571792, -1.193821, 0.455408, -0.514805, -0.114097, 0.755632, -0.036822, 0.244069, 0.017853, 0.320361, 0.600486, 0.345410, 0.248439, -0.065360, -0.147230, 0.681437, -0.376803, 0.619161, 0.224959, 1.061212, 0.348496, 0.348574, -0.257807, -0.132066, -0.019669, -1.047518, 0.230262, -0.404979, -0.217094, 0.350816, -0.226377, 0.359037, -0.590859, -0.638165, -0.439882, -1.194314, 0.977341, 1.182708, 0.566933, 0.653037, 0.293562, 0.393298, -0.785287, -0.710926, 0.050302, -0.501394, -0.292375, -0.478182, 0.249010, -0.654171, 0.677037, 0.281663, -1.149832, -0.215366, 0.208423, -0.789884, -0.023901, 0.393357, -0.861098, -0.378299, 0.416720, 0.888443, -1.049108, 0.301513, 0.703898, 0.510410, 0.783207, -0.712019, -0.682901, 0.652063, 0.001561, -0.840966, -0.029026, 0.395154, 0.613149, 0.474359, -0.393713, -0.282076, -0.544940, 0.096447, -0.421520, -0.848015, 0.425615, 0.298550, 0.063038, 0.395801, 0.638161, 0.510125, -0.514485, 0.667536, 0.213428, 0.246997, -0.401899, -0.264696, -0.292083, -0.163617, -0.314678, 0.010437, -0.826144, -0.622993, -0.570733, 0.776689, 0.390569, 0.167244, -0.921968, -0.253110, -0.087688, -0.611201, 0.519791, -0.400586, -0.565133, -0.155832, -0.869806, -0.069586, -0.332179, -0.325593, -0.176851, 0.292933, 0.544992, 0.309212, 0.693073, -1.008915, 0.640825, 0.803863, 0.050733, -0.121608, -1.046594, 0.325992, -0.824169, 0.366661, 0.620026, -1.136593, -0.807474, 0.276858, 0.043236, -0.494065, 0.081646, 0.808731, 0.832728, -0.723262, -1.256246, -0.325050, 0.524266, 0.039813, -0.704321, -0.126269, 0.095609, 0.082202, 0.086101, -0.826447, -0.920888, 0.154007, -0.300407, -0.268665, 1.098288, 0.614392, 0.120873, -0.454987, 0.571017, 1.130542, -0.767184, 0.506830, -0.420690, -0.584786, -0.044088, -0.406440, -1.072179, -0.057334, 1.156670, -0.761770, -0.425323, 0.149341, 0.284262, -0.712262, -0.567695, -0.224796, -0.110017, 0.330071, -0.340900, 0.158696, -0.727833, -0.147613, 0.423538, 0.510003, 0.606281, -0.080528, 0.186704, 0.072766, -0.565217, 0.119597, 0.906172, -0.123952, -0.762860, -1.154191, -0.660310, -0.160749, 0.144111, 1.347811, -0.369728, -0.820122, 0.006073, 0.964751, 0.018731, 0.465869, -0.101001, -0.335978, -0.754300, -0.334722, -0.480149, 0.628781, 0.483267, 0.457232, 0.117753, -0.521813, 0.551801, 0.084965, -0.318940, -0.066070, -0.098802, -1.013111, 0.143144, -0.174489, -0.640646, -0.311500, 0.376085, 0.679693, 1.051810, 0.688921, 0.696959, 0.970073, 0.269663, -0.530280, 0.100102, -0.269014, -0.388141, 0.306651, -0.285064, -1.084083, -0.004578, -1.495795, -0.214574, -0.831116, 1.321367, -0.018022, -0.032926, 0.408718, -0.809160, -0.638017, 1.053982, -0.068500, 0.827540, 0.449735, -0.892859, -0.777452, -0.191327, -0.306583, 0.189726, -0.296111, 0.082857, -0.138756, -0.403071, -0.173177, 0.587296, -0.515637, 0.072134, -0.039649, 0.172985, 0.354365, 0.193783, 0.270315, 0.342777, -0.521108, 0.215977, -0.537544, -0.070306, -0.460761, 0.478134, -0.382760, 0.653429, 0.395981, -0.165739, 0.720424, 0.143897, -0.038856, 0.363653, -0.709262, 0.169733, -0.667103, -0.367752, -0.539746, 0.300732, 1.065408, -0.421486, 0.107752, -0.725340, -0.318851, 0.935453, -0.989472, 0.480412, 0.058036, 0.393100, -0.074872, 0.293552, -0.245239, 0.005302, -0.016539, -0.376620, 0.350751, 0.242006, -0.198443, 0.433967, -0.524284, 0.116899, 0.421698, -0.111016, -0.967211, -0.320426, -0.062399, 0.177091, 0.348322, -0.699875, 0.141990, 0.826276, -0.848617, 0.025896, -0.130040, -0.034494, -0.888120, 0.856335, 0.249695, 0.814619, -0.131306, 0.120094, 0.643379, -0.611613, -0.357224, 0.776030, 0.960886, 0.129156, 0.408306, 0.367817, 0.269316, 0.070774, -0.516437, -0.052576, 0.091347, 0.260328, 0.024363, 0.875942, -0.280627, -0.102265, -0.039784, -0.072848, -0.560500, 0.449904, 0.261367, -0.364486, 0.101212, 0.669331, 0.761311, 0.717449, -0.052709, 0.197153, -0.262678, -0.031469, -0.055003, 0.644974, 1.261077, 0.392719, 0.296242, -0.339485, -0.021194, 0.603353, 0.280036, 0.629497, 0.364465, 0.452728, 0.643538, 0.378894, 0.091038, 0.456975, -0.513915, -0.126489, -0.009895, 0.325235, -0.066843, -0.528270, -0.705070, 0.692832, -0.033567, -0.373095, 0.114325, -0.005283, -0.594922, -0.727208, -0.231234, -0.534589, -0.161451, 0.769014, -0.490739, -0.689905, 0.751271, 0.116013, 0.617775, 0.171751, 0.616911, -0.557983, 0.512084, 1.131404, 0.902862, 0.123900, -0.130886, -0.095275, -0.146384, 0.973198, -0.939116, 0.415987, 0.141768, -1.016851, 0.149851, 0.220852, -0.581875, 0.281200, 1.005536, 0.315062, -0.401580, 0.134673, -0.673693, 0.345668, 0.241293, -0.400795, -0.281749, 0.368495, 0.381264, -0.130412, 0.837386, -1.158837, -0.186571, 0.424287, 0.333846, -0.583837, 0.640093, -0.476582, -0.713870, 0.361890, 0.745934, -0.563484, 0.294523, -0.299165, -0.072987, 0.605953, 0.568644, -0.225395, 0.680188, 0.766460, -0.218583, 0.227105, 0.441971, 0.476706, 0.851809, -0.150688, -0.313922, 0.075485, -0.767643, 0.136117, 1.051392, 0.787572, -0.419714, 0.122675, -0.540518, 0.298909, 0.135754, 0.669249, -0.181985, -0.089787, 0.697183, -0.446291, -0.030833, -1.476093, 3.910861, -0.303977, 0.346165, 0.212608, -0.071260, 0.537566, 0.450875, -0.119921, 0.219233, -0.238928, 0.069097, 0.273725, 0.292697, 0.470801, -0.138657, 0.494341, -0.798581, 0.083814, -0.197650, 0.258464, -0.970952, 1.141625, 0.720694, 0.143707, -0.231931, 0.017647, -0.545387, -0.972465, 0.025321, 0.373469, 0.364320, -0.454340, 0.282703, -0.805100, 0.071744, 0.888706, 0.679937, 0.149148, -0.405140, 0.212629, -0.196323, -0.542577, -0.039924, -0.108804, -0.079743, 0.055312, -0.170608, -0.294922, 0.139467, -0.190695, 1.359862, -0.148537, -0.612390, -0.195168, -1.193715, -0.320659, 0.292875, 0.053914, -0.631212, -0.395676, 0.342212, -0.614968, -0.445423, 0.628682, 0.036913, -0.490761, 0.384823, 0.048369, -0.272953, 0.008680, 0.474012, 0.105767, 0.458249, -0.761752, 0.951160, 0.299774, -0.386913, 0.840293, 0.739001, 0.305865, -0.168413, -0.653828, -0.398312, -0.069632, 0.567475, -0.260468, -0.812760, -0.575514, -0.053884, 0.415675, 0.269411, 0.248425, -0.074951, -0.186381, -0.143542], "paraphrase-multilingual": [-0.019807, -0.124781, -0.010519, 0.035812, -0.103448, 0.051982, 0.035322, 0.030018, -0.179976, 0.194586, 0.129194, 0.157071, 0.083678, 0.074628, 0.093773, -0.367580, 0.002608, 0.086277, 0.050985, -0.005689, -0.038710, 0.071398, 0.010391, -0.059942, 0.007196, -0.066065, -0.010554, -0.011521, 0.145288, 0.120511, -0.139100, -0.096199, -0.045498, -0.109749, 0.046571, 0.023483, -0.086807, 0.150124, -0.067052, -0.100689, -0.004482, -0.014063, -0.062190, 0.071008, -0.107359, 0.012106, 0.026683, 0.107762, -0.002190, -0.121664, 0.057639, 0.175526, -0.129658, 0.061670, 0.274528, 0.052475, -0.124988, 0.189575, 0.027682, 0.105478, -0.010325, -0.008585, 0.156806, 0.021770, -0.119687, -0.030621, 0.061486, 0.089130, 0.080578, 0.004526, -0.163631, -0.035526, -0.044562, 0.036523, -0.202825, 0.050263, 0.022896, 0.042070, 0.126741, 0.073518, 0.199230, -0.121035, -0.013655, -0.071069, -0.065983, 0.313145, -0.021707, 0.124713, -0.039624, 0.225527, -0.015417, -0.164423, -0.142655, -0.059337, 0.030137, 0.127238, 0.127086, -0.082194, -0.081504, 0.325473, 0.274064, 0.185700, -0.021754, 0.175575, 0.002501, -0.045027, 0.057571, -0.260881, -0.035121, -0.142682, 0.209513, -0.166192, 0.007538, -0.121503, -0.079821, -0.121559, 0.157354, -0.130091, -0.088810, -0.004192, 0.023477, 0.050395, 0.015282, 0.022486, 0.027325, 0.041678, -0.146638, 0.171089, 0.150886, -0.087244, -0.011451, -0.035348, -0.045925, 0.063444, -0.065683, -0.126295, -0.046725, -0.017725, -0.119099, -0.096294, 0.124213, -0.001037, -0.077951, 0.116946, -0.128626, 0.076870, 0.015107, -0.013591, 0.030020, 0.049803, 0.057727, 0.192952, -0.265347, -0.031025, -0.077450, 0.015170, -0.168407, -0.094748, 0.057666, -0.069248, 0.034561, -0.111670, 0.047948, -0.082442, -0.038034, 0.005981, -0.336813, 0.151752, -0.080341, -0.163140, 0.234783, -0.070792, 0.098568, -0.062491, -0.038122, -0.056743, -0.216298, 0.015405, 0.036285, -0.018388, -0.129567, 0.114494, 0.100684, 0.136078, -0.278469, -0.029172, -0.025171, -0.035048, -0.017327, -0.020234, 0.006405, 0.059504, -0.055152, 0.047702, -0.109771, -0.095923, 0.154146, -0.082645, 0.002055, 0.063278, 0.045186, -0.016451, 0.120333, -0.030705, -0.125732, 0.082911, 0.183584, 0.005612, 0.086614, -0.122572, 0.187004, 0.008749, 0.122742, -0.099332, -0.099544, -0.030457, -0.014596, 0.159668, -0.182861, -0.038095, -0.018787, -0.129022, -0.070407, 0.040420, -0.078966, 0.110361, -0.051468, 0.023479, -0.055557, -0.074713, -0.025666, 0.041186, -0.000058, 0.008151, -0.078964, 0.127330, -0.045430, -0.043395, -0.025994, -0.305759, -0.000632, 0.091581, -0.041979, -0.096488, 0.007829, -0.035366, -0.129597, 0.031931, 0.011414, 0.026075, 0.070006, 0.143212, -0.131706, -0.065480, -0.091587, -0.089944, 0.304327, 0.096218, -0.155311, 0.154486, 0.056186, -0.002324, 0.134550, -0.185795, -0.054339, 0.010738, 0.268656, 0.230560, 0.050754, -0.097614, 0.096583, 0.082153, -0.127167, -0.107377, -0.047550, 0.109379, -0.032336, 0.005514, -0.189381, 0.015142, -0.220278, -0.155431, -0.080936, -0.017348, 0.057081, 0.040142, 0.024299, 0.038554, -0.014053, 0.088013, 0.058415, 0.047141, -0.052754, 0.062682, 0.094209, -0.061054, -0.029627, 0.057371, 0.032965, -0.137422, -0.197806, -0.105999, -0.003994, -0.005150, 0.015822, 0.145214, 0.171718, -0.092218, 0.165397, 0.172935, -0.016241, -0.069164, -0.034006, 0.263521, -0.112738, 0.144954, -0.008142, 0.109327, -0.000139, 0.203327, -0.000758, -0.102171, -0.004223, -0.122857, -0.078052, -0.005030, 0.179426, -0.008189, 0.172658, -0.182432, -0.028655, 0.246079, 0.040135, -0.001440, -0.101024, -0.116102, 0.035103, -0.111655, -0.171831, 0.053297, -0.021837, 0.020048, 0.071553, 0.017092, -0.495468, 0.006690, -0.174933, -0.039871, 0.017558, 0.093333, -0.067826, -0.026449, -0.034882, -0.078675, -0.026006, -0.127709, 0.073291, -0.096413, 0.173521, 0.141467, 0.049000, -0.128893, -0.095217, 0.197807, 0.064243, 0.147542, 0.107418, 0.088213, -0.047051, -0.014437, 0.377273, -0.041961, 0.123879, -0.009810, 0.105710, 0.168773, -0.020232, -0.108163, -0.050267, -0.069577, -0.031271, 0.047579, -0.278478, -0.072615, -0.059372, 0.114844, 0.055385, -0.052592, 0.140747, -0.053970, -0.049484, -0.056079, -0.052369, -0.061402, -0.010092, 0.040888, -0.010542, -0.008642, 0.127806, 0.142922, 0.061796, 0.215661, -0.121110, 0.177801, 0.082593, -0.098139, 0.160477, -0.112506, -0.128137, 0.010061, -0.246614, -0.134404, 0.134328, 0.037165, -0.056656, 0.085682, -0.002025, -0.048427, 0.047335, -0.152925, 0.076913, 0.144639, 0.002542, -0.008786, -0.207630, -0.092424, -0.056038, 0.039837, 0.130480, -0.019214, 0.085709, -0.068168, -0.057661, 0.256396, 0.000436, 0.002165, 0.008250, 0.435296, -0.023791, 0.112853, 0.118685, 0.015178, 0.142689, -0.139655, 0.084141, 0.053003, -0.127661, 0.121614, 0.090306, -0.053635, 0.143329, -0.020410, -0.130167, -0.062897, -0.043274, -0.012359, 0.014011, -0.309357, 0.110538, -0.099683, 0.018306, 0.439442, 0.034141, 0.002030, 0.026504, -0.224360, -0.192707, 0.154315, 0.020682, -0.212653, -0.198598, 0.103733, -0.084605, 0.123315, -0.190156, 0.051589, -0.114352, -0.215452, 0.227831, 0.089644, -0.156986, -0.110336, 0.023221, 0.186123, -0.009580, -0.108279, -0.008263, -0.079465, -0.019248, 0.037930, -0.005270, 0.017321, -0.003298, 0.294424, -0.011487, 0.139208, -0.054023, -0.135061, 0.010541, -0.181049, -0.041205, -0.110344, 0.128945, -0.090110, -0.092730, -0.029277, 0.101132, 0.017030, 0.041486, -0.143502, 0.224712, -0.052848, -0.128890, -0.150927, 0.027277, 0.097778, 0.225844, 0.132758, 0.049771, -0.195139, -0.030116, 0.007751, -0.079459, 0.195759, 0.028297, 0.147042, -0.010751, -0.044499, 0.024308, -0.101806, 0.131116, -0.123838, -0.073508, 0.129509, -0.011302, 0.326354, -0.237273, 0.024596, 0.004420, -0.039178, 0.025751, 0.013973, 0.154100, 0.041046, 0.024320, -0.092331, 0.075485, 0.194852, 0.043371, -0.251192, 0.134674, 0.052031, -0.132075, 0.094175, -0.014784, -0.095276, -0.167319, 0.093634, -0.053208, -0.299019, -0.019493, 0.110037, -0.111475, -0.098528, -0.045980, 0.011906, -0.084867, 0.071568, -0.053325, 0.037509, -0.058839, 0.001778, 0.058313, 0.127749, 0.036488, -0.065275, -0.057004, 0.002167, -0.194989, 0.068705, -0.069410, 0.112359, -0.152019, -0.107722, 0.070784, -0.017405, -0.203961, -0.063757, -0.000544, 0.104791, -0.084216, 0.204668, 0.103679, -0.267183, -0.073881, -0.051626, -0.263557, 0.077896, -0.046059, 0.181407, 0.004982, -0.028577, -0.070820, 0.120156, 0.068127, -0.016167, 0.168783, -0.009547, 0.057545, -0.206602, -0.138948, -0.287059, -0.089665, 0.193052, 0.181721, 0.076652, 0.230598, 0.038210, -0.065900, 0.351109, 0.163837, -0.106730, 0.004680, 0.054401, -0.162431, 0.109289, -0.027845, -0.077752, 0.074426, -0.206153, -0.205087, -0.047387, -0.115959, -0.012581, 0.006516, 0.137222, 0.024973, 0.067576, 0.079758, 0.005901, -0.085006, -0.211992, 0.079703, 0.164714, 0.012983, -0.047775, 0.009934, 0.166054, -0.117008, 0.112174, -0.081620, 0.252085, -0.095814, -0.160737, 0.098616, 0.049302, -0.169005, 0.056813, -0.110345, -0.072744, 0.016748, 0.018266, 0.276841, -0.109161, -0.030222, -0.091865, -0.098636, -0.029673, -0.037370, -0.277655, 0.068380, 0.040822, -0.014380, 0.363860, -0.091828, -0.034534, 0.108802, -0.056442, -0.141440, 0.096531, -0.126003, -0.072285, -0.014293, -0.315917, 0.013416, -0.057672, -0.064211, 0.077573, -0.015361, 0.105270, 0.046737, 0.073715, 0.133964, -0.039862, 0.192067, -0.038854, -0.035655, 0.101362, 0.148665, -0.078182, 0.041527, -0.077087, 0.026681, 0.089204, 0.506013, 0.121540, -0.163288, -0.046427, 0.129322, 0.186661, 0.032343, 0.020226, 0.031071, -0.050872, 0.091166, -0.050102, -0.042110, 0.055500, -0.027633, -0.272802, 0.198007, -0.049932, 0.015780, 0.053894, 0.063445, 0.013361, -0.017767, 0.103368, -0.049283, -0.161567, -0.018339, 0.159721, 0.019753, 0.256000, 0.122950, -0.067329, 0.049447, -0.039212, -0.101245, -0.019110, 0.068606, -0.009369, -0.081864, -0.116030, -0.107591, -0.032567, -0.213658, 0.024803, 0.012063, 0.073045, 0.151132, 0.040293, 0.111463, -0.057375, 0.336502, -0.153928, 0.049947, -0.022919, 0.136091, -0.179530, -0.101300, 0.034927, 0.026369, -0.290807, -0.027303, 0.077214, 0.085054, -0.088758], "snowflake-arctic-embed": [0.164476, -0.981777, -0.405218, 0.399810, 0.901198, 0.409591, -0.077627, -0.677190, 0.222725, -1.757181, 1.154365, 0.970361, 0.139148, 0.673119, -0.024305, 0.273795, 0.692573, -0.239678, -0.362082, -0.275700, -0.206364, -0.501303, 0.699528, 0.320007, -0.261514, -0.199023, 0.255197, 0.461451, 0.586028, 0.502643, 0.727292, -0.206270, 0.097371, -0.161835, 0.680590, 0.230389, 0.173242, -0.845818, -0.187537, -0.595398, 0.080072, -0.614428, 0.249609, 0.753781, -0.356874, -0.436827, 0.524961, -0.157355, 0.518234, -1.566906, 0.572488, -0.467955, 0.191558, 0.039816, -0.793020, -0.215021, 1.121415, 1.650410, 0.526585, -1.186473, -0.232328, -0.854596, -0.380662, 0.417444, -0.008091, 0.964398, -0.264849, -0.478139, 0.551200, 0.654829, -0.477421, -0.520961, -0.090849, -0.448812, 0.104905, -0.738188, 0.303336, 0.398035, 1.183559, 0.649098, 0.404940, -0.358590, -0.979204, -1.484936, 0.228276, 0.803336, 2.641596, -0.125927, 0.113146, -0.385871, 0.499152, 0.051917, 0.334905, 1.279890, -0.545813, 0.604924, -0.420765, 0.912452, 0.772270, -0.737417, 0.391128, -1.199134, 0.121847, 1.555495, 0.648331, 0.196339, -0.591679, 0.363930, -0.068456, -0.155599, 0.527852, -0.488703, -0.712850, -0.531144, 0.479999, -0.559684, 1.147967, -0.265582, 0.119726, 1.675230, 0.942336, 0.065473, 0.428287, -0.342958, -0.162591, 1.297977, -0.338609, -0.096736, -0.088885, 0.330610, -1.069823, -0.485881, 0.355422, 0.058099, -0.582748, -0.080651, -2.783385, -1.813708, 0.929544, 0.427284, 0.167461, -1.018789, -0.186063, 0.125848, 1.110493, -0.323993, 0.468688, -0.310807, 0.267761, -0.193082, -0.649354, 0.090465, 0.213910, -0.901647, 0.184187, -2.126019, 0.618628, -0.386999, 0.338013, -0.291322, 0.601014, -0.248482, -0.011206, -0.109841, -0.738318, -0.745902, 1.245500, 1.346687, -0.500503, 0.614734, -0.478978, 1.417879, 0.647242, 0.600458, 0.093502, -0.399006, 0.019264, -0.670275, 0.760402, -0.139396, -0.833422, 0.600008, 0.150120, 0.215607, -0.787541, 1.722837, 0.167324, -0.535421, 0.388938, -1.382614, -0.172650, -0.562894, 0.249094, 0.258224, -0.438660, -0.845207, 0.875777, 0.783044, -0.391563, 0.029483, -0.291418, 0.204866, 0.673864, 0.580254, 0.495731, 1.010963, -0.346271, -0.046538, 0.067540, -0.395137, -0.387492, 0.393826, 0.172326, 0.251920, -0.889290, 0.045292, -0.161041, 0.922710, -0.320204, 0.351821, -0.392186, 0.629528, -0.211839, 0.032394, 0.861603, -0.016760, -0.558076, 0.262017, -0.085449, -0.318123, -0.498436, -0.133505, 0.664525, -0.666853, 0.140894, 0.074495, -0.730992, 0.992944, 0.263796, -0.169161, 0.421966, -0.251835, 0.246833, 0.467468, 0.229798, 0.471774, 0.010803, 0.537420, 1.005422, -0.544047, -1.095315, 0.525546, -0.378814, 0.772719, -0.635745, -0.187179, 0.751029, -1.497753, 0.605500, 0.040281, -0.410345, 0.186229, -0.747669, 0.437304, 0.144941, -0.459204, -0.198767, 0.449451, 0.858884, -0.359434, 0.437780, -0.007321, 0.043643, -0.462933, 0.042202, 0.678946, -0.236253, 0.311505, 0.022989, -0.236843, -0.317470, -0.867559, -0.468267, -0.032692, -0.619554, 0.740736, -0.394311, 0.164591, 0.771053, 0.628858, -0.159988, -0.132335, -0.270476, -0.244661, -0.045490, -1.068461, 0.361834, 0.681828, -0.072550, 0.995301, -0.476299, -0.130403, -0.443094, -1.400598, -1.715192, -0.609242, 0.392083, -1.302736, -1.254964, 0.315025, 1.056481, -0.284517, 0.145024, -0.197186, -1.191084, -0.475434, -0.337662, 0.478131, -0.134051, -0.541338, 0.065506, 0.982383, -0.017134, 0.082724, 0.754355, -0.607289, -0.561618, -0.672752, 0.071788, 0.801023, 0.425337, 0.566067, 0.911838, -0.390513, -0.408622, -0.555813, -0.248295, -0.697827, 0.293418, 0.759617, 0.671161, -0.225396, 0.400199, -0.615734, -0.089381, 0.535295, 0.435778, 1.210370, -0.322341, 1.353689, -0.435054, 1.075088, 0.098468, 0.002876, 0.152754, -0.287845, 1.134301, -0.570370, -0.841934, 0.699961, 0.661102, -1.207625, -0.006113, -0.292638, -1.588339, 0.058787, -0.000471, -0.147217, 1.015599, 1.263788, 0.532431, 0.427873, -0.974371, 0.469206, -2.081517, -0.388685, -0.406079, 0.724670, -0.209924, 1.067930, -0.204612, 0.721137, 0.562673, -0.707059, -1.473862, 0.808875, 0.731063, -0.227392, -0.899608, -0.376536, 0.129298, -0.727531, -0.080100, -0.603700, -0.313701, -0.383994, -0.551388, -0.945553, -0.780997, -1.005637, 1.145955, 1.714179, -0.614432, -0.243966, -0.018248, -0.135699, -0.490564, -0.513082, -0.206577, -0.301805, 0.211078, 0.084133, 0.834956, 0.017944, -0.600084, -0.209110, -0.153675, -0.477361, 0.291817, 0.594291, -0.062999, 0.252667, 0.253603, -0.102821, -0.086670, -0.257069, 0.350693, 0.477220, -0.893321, 0.729113, 0.591172, -0.885421, -0.189146, -0.826989, 0.353572, -0.433705, 0.448902, 0.380207, 0.027303, -0.379125, 0.636641, 0.343187, -0.146485, -0.199786, 0.017955, -0.172499, -0.219285, -0.017873, -0.242158, 0.135226, 0.705949, 1.523271, -0.805220, -0.542296, -2.357319, -0.279120, -0.574042, 0.293555, -0.602039, -0.058580, -1.546088, -0.100937, 0.281502, -0.406502, -0.351410, -0.459636, 1.166577, -0.482386, 0.655048, -0.003645, -0.069005, 1.035454, 0.287968, -0.083988, -0.167982, 0.045657, 0.909229, 0.163557, 0.200277, 0.661766, -0.859409, -0.271063, 0.376515, -0.768590, -1.227650, 0.251031, -0.152929, -0.052574, 0.698705, -1.078161, -0.407318, 1.014991, 0.492574, -0.408082, -0.391707, -0.306255, -0.208311, -0.109343, 0.148285, 0.494520, 0.340868, -1.124670, -0.289814, -0.446556, 0.735685, 0.302573, -0.316504, 0.691067, -0.925421, -1.132813, 0.672036, -0.337327, -0.476443, 0.034346, -0.878503, -0.416229, 0.471957, 0.982983, -0.279663, 0.109959, 0.352966, -0.322536, -0.531440, -0.073239, -0.904272, -0.925687, 0.689411, -0.287299, 0.071599, 0.040955, 0.363247, -0.329313, -0.734854, 0.052484, 0.374941, -0.262850, 0.674385, -1.883402, -0.469937, -0.481781, -1.376337, 0.911426, -0.421970, 0.758645, 0.866087, 0.054294, -2.286275, -1.039941, 0.392261, 0.600235, 0.566748, -1.022966, -0.349735, -0.143037, -0.283288, -0.109223, 0.977295, 0.156285, -0.241752, -0.327293, 0.096074, 1.355821, 0.886651, 0.834343, 0.149908, -0.404873, -0.431915, 0.959970, 0.061419, -0.285103, 0.111263, 0.202683, 0.376271, -0.116067, 0.138792, 0.445979, 0.290989, -0.242874, -0.346761, 0.078914, 0.034136, -0.870164, -0.230687, -0.679115, -0.453920, -0.498161, -0.709209, -0.146147, -0.017312, -0.411055, -0.364331, 1.790577, -0.176075, -0.723053, 0.661552, 0.051030, 0.182555, 0.162354, -0.562494, 0.585837, 0.219536, 0.723892, 0.267529, -0.493820, -0.000944, 1.073650, -0.610506, 1.969175, -0.321756, 0.165205, 0.560576, -0.445916, 0.147305, 0.081469, 0.733030, -0.639512, -0.128523, 0.821294, -0.156377, 0.661212, -0.447247, 0.747086, 1.197289, 1.120626, 0.644541, -1.037759, -0.927256, 1.108188, 0.443766, -0.037838, 1.038198, -0.777633, -0.206474, -0.710854, -0.355862, -0.180103, -0.558386, -0.997519, -0.045973, 0.604964, 0.637836, 0.303163, -0.184672, -0.063542, -0.034336, 0.758078, -0.627832, -1.903602, -0.933951, 0.025073, -0.847317, 0.818236, -0.455283, -0.528798, -0.199017, 0.047115, 0.782431, -0.855221, 0.439269, -0.285996, 0.625571, 0.096328, 0.089619, -0.090534, 0.529576, 4.501905, 1.003226, -0.528412, -1.492817, 0.117703, 1.167368, 0.004324, 0.388524, -0.685737, 0.061340, 0.331465, -0.331502, 1.002149, 0.419753, 0.302038, 0.609501, -0.266256, -1.288667, 0.551195, -0.368805, -0.667456, 0.202925, -1.491046, -0.230838, -0.461002, -0.123896, -0.036718, -1.897493, -0.502035, 0.408753, 0.170793, 0.273150, -0.000815, 0.124118, 0.156100, 0.594571, 0.917966, 0.389394, 0.366612, 0.521227, 0.562131, -0.225260, -0.873686, -0.046855, -0.135722, -0.503301, 0.022896, 0.998964, -0.616207, -0.396044, 0.910686, -0.176650, -0.609326, -0.449190, 0.576747, -0.521846, 0.431436, -1.468931, -1.514708, -0.154545, -0.435662, -0.673081, 0.473723, -0.590864, 0.025936, 0.215279, -0.016715, 1.134850, 0.661093, -0.513720, -0.121693, -0.326415, -0.136230, -0.568424, 1.253833, 0.408191, 0.712330, 0.746802, 0.040251, -0.059956, 0.024462, 0.911898, 0.040565, 0.012627, -0.030156, -0.520014, 0.850057, -0.270045, -0.847478, -0.157882, 0.443254, 0.806046, -0.186409, -0.422478, -1.133786, 1.079063, -0.166138, -0.571667, -0.098509, 0.364370, 0.282150, 1.241138, -0.254863, 0.275489, -1.065522, -0.333116, 0.098885, -0.880379, -0.746474, 0.167898, 0.622477, 0.259032, 1.088299, -0.092418, 0.559675, 0.407819, -0.206228, -1.064644, 1.147439, -0.502136, 0.062495, -1.109216, 0.489472, -0.802628, 0.369118, -0.378995, -0.544698, 0.353553, -0.191862, -1.536034, -0.484893, 0.444807, 0.157272, 0.381181, 0.449889, -0.075108, -0.352240, 0.955138, -0.743568, -0.823916, 0.352233, -0.961116, -1.893218, 0.018346, -0.022620, 0.857100, -0.413084, 0.308824, -0.227558, 0.461872, -0.424433, -0.114400, -0.592457, 0.122129, 0.790703, 0.366691, 0.552071, -0.423679, -0.394332, -0.144287, -0.164119, -0.047966, 0.180729, -0.003952, -0.355909, -0.133109, 0.789147, 0.631095, -0.126298, 0.862269, 0.101255, 1.096273, 0.545555, -0.639577, -2.018642, -0.439312, 0.003968, 0.327283, -0.092810, -1.042530, 1.364452, 1.179238, -0.883948, -0.371361, -0.329252, -0.025818, 0.505112, 0.683491, -1.012862, 0.493483, 0.539515, 0.500419, 0.288749, -0.114866, 0.352669, -0.456255, -0.580752, 0.585715, 0.144576, -0.035104, 1.233467, -0.195198, 1.108539, 1.116311, -0.347662, 0.016133, -0.638461, -0.484053, -0.320613, -0.085975, -0.232120, 0.161251, -0.568741, 0.960369, 0.155364, 0.735644, 0.690310, -0.882789, -0.913967, 0.081527, -0.160942, 0.896419, 0.815935, 0.775975, -0.035229, 0.086067, 0.087852, -0.830845, -0.041790, -1.243464, 1.077813, -1.032597, -0.740396, 0.250708, 0.160607, -0.222441, 0.088197, -0.714924, 0.323647, -0.979571, 0.903878, -0.037072, 0.791384, -0.732045, -0.145367, 0.280290, -0.850576, 1.411916, 0.231850, 0.276533, -1.070890, -0.301702, 0.003253, 0.272342, 0.515268, -0.294163, 0.183392, 0.166822, 0.171130, 0.058901, 0.402616, 1.227121, 0.264214, 0.447733, -1.129231, -0.026984, 1.065866, -0.700682, 0.365534, 0.382542, -0.918139, -0.707215, -0.394204, -0.508742, 0.480149, -0.082983, 0.365794, 0.358934, 0.430533, 0.000919, 0.917651, 0.456906, -1.114516, 1.789799, 0.563742, -0.845717, -0.513698, 0.248768, -0.953782, 1.674861, 0.771962, -0.836244, -0.015618, 0.296709, 0.385047, -1.406090, 0.869633, 1.178896, -0.181703, -0.002030, 1.396937, 0.106670, 1.051165, 0.232139, -0.785353, 0.440807, 0.134374, 0.422115, 0.017052, -0.285855, 0.881638, 0.943586, -0.419304, 0.852863, 0.640232, -0.067155, -0.269846, -0.091488, 0.728749, 0.800561, -0.179447, 0.737550, -0.039372, -0.298867, 2.224916, -0.833340, 0.586230, 0.680057, 0.273743, -0.536826, -0.445305, 0.109167, 0.042689, 0.324641, -0.135530, 0.299774, 0.135228, -0.322364, -0.536501, 0.250821, -0.529266, -0.036560, -0.006006, 0.202638, -0.135642, -0.427857, 0.223352, 0.747627, 0.093975, 0.408209, -0.207240, -0.228368, 0.782170, -0.550407, -0.078093, 0.006059, 0.011183, -1.023877, -0.775297], - "snowflake-arctic-embed2": [-0.337318, 0.485787, -0.037816, -0.943875, -0.819299, -0.257385, -0.115470, 0.246724, 0.048614, 0.159151, -0.467606, -0.364392, 0.089869, -0.209655, 0.342226, -0.527060, 0.520997, 0.927532, -0.102562, 0.333813, -0.854380, -0.701242, -1.463815, 0.799778, 0.750539, 0.757705, -0.125063, 0.527705, -0.437741, 0.078491, 0.460214, 0.255947, -0.031090, -0.345135, 0.058851, -0.327729, -0.372813, 0.352275, -1.168406, 0.354936, 0.625492, 0.045635, -0.242759, 0.650628, 0.195748, -0.495107, -0.539670, -0.986722, -1.069306, -0.014932, -0.385889, 0.215507, 0.333816, -0.158572, 0.246042, -0.687132, 0.207916, -0.342494, -0.347905, -0.563665, 0.336679, -0.059624, -0.155887, -0.246520, 0.296986, 0.569967, 0.131530, -0.355191, -0.582369, 0.490316, -0.415379, -0.019140, -0.214617, 1.085840, 0.019224, -1.180745, -0.544194, -0.182204, -0.471391, 0.877849, 1.787677, 0.196131, 0.338737, 0.554189, 0.723178, -0.052438, -0.270815, 0.443365, 0.101404, -0.692780, 0.004322, -0.050623, -0.693687, -0.116200, 0.434660, 0.065080, -0.055940, 0.122773, -0.999912, -0.499409, -0.359269, 0.027620, -0.399372, -0.299647, -0.744792, 0.102263, 1.084825, 0.028898, 0.323312, 0.014242, 1.325412, 0.983624, -0.325036, 0.526028, -0.157539, -0.063860, 0.436522, 0.116374, 0.118433, 0.614439, -0.139657, 0.522618, 0.017510, -0.188138, -0.677374, -0.840603, 0.192689, -0.135996, -0.894670, -0.158343, -0.792459, -0.136472, -0.355442, -0.123314, -0.910940, 0.186382, 0.334950, 0.204000, 0.222174, 0.263186, 0.094970, 0.061765, 0.345430, -0.235054, 0.172441, 0.881599, 0.841009, -0.169058, -0.269911, -0.217716, 0.359628, 0.208658, 0.652820, 0.442545, -0.161419, 0.418893, 0.292317, 0.231373, -0.805518, -0.000739, 0.297150, 0.121066, -0.408190, 0.273577, 0.463801, 0.064206, 0.312172, 1.092058, -0.371314, -0.277224, -0.683628, -0.435973, 0.045403, -0.140166, 0.184559, 1.853358, 0.416705, -0.374452, 0.760777, 0.248660, -0.569295, -0.954281, -0.347827, 0.531861, -0.570648, 0.556323, 0.206901, 0.252571, -0.043244, -0.062258, 1.049511, -0.402070, -0.068134, -0.149358, -0.012464, 0.620048, 0.654902, -0.538302, -0.245287, -0.066978, -1.405453, -0.445957, 0.331479, -0.495953, 0.923955, -0.328841, -0.644721, -0.372834, 0.357546, 0.478619, -0.081360, 0.340657, -0.122412, -0.597997, 0.235506, 0.016301, -0.058082, 0.446411, -0.802173, 0.115207, -0.464422, -0.257083, -0.133011, -0.359320, 0.389579, 0.485856, -0.053931, 1.149238, -0.967310, 0.020607, -0.235731, -0.358982, -0.698047, 0.653281, 0.734305, -0.836348, 0.074222, -0.177832, -0.486657, -0.344304, -0.443823, -0.255469, -0.606071, 0.069794, 0.069820, 0.494822, -0.536611, -0.175762, -0.448531, -0.522376, -0.108621, -0.271191, -0.141843, 0.071029, 0.171164, -0.195819, -0.059490, 0.026950, -0.433273, 0.016244, 0.146567, -0.032891, -0.039686, 0.323199, -0.057771, -0.176835, 0.470351, 0.048500, 0.327727, -0.158381, 0.162835, -0.407448, -0.555830, -0.465591, 0.264512, 0.354612, 0.218764, 0.031698, -0.265124, 0.312480, 0.181667, -0.338958, 0.186351, 0.053644, 0.812065, -0.862652, -0.026800, 0.572852, 0.005986, 0.828237, 0.090118, 0.063922, 0.076976, 0.096964, 0.180304, 0.781934, -0.003830, -0.027061, 0.221362, 0.449681, 0.125572, -0.095162, 0.018868, -0.360262, -0.373733, -0.392008, -0.125284, -0.212061, 0.159567, -0.233902, -0.235149, -0.190911, -0.028427, 0.344431, -0.155667, 0.722263, -0.144527, -0.199895, 0.188895, 0.894280, 0.140612, 0.698334, -0.078967, -0.845755, 0.500688, -0.028362, -0.309373, -0.050033, 0.393043, -0.684940, -0.012917, 0.442933, -0.152553, -0.068629, -0.237759, -0.239215, 0.132807, 0.019395, 0.123185, 0.242981, 0.786300, 0.018274, 0.157075, 0.240215, 0.229825, -0.137675, -0.203565, -0.245311, -0.036812, 0.430710, -0.207664, -0.132277, 0.557027, 0.452612, -0.331802, 0.004795, 0.139062, 0.078491, -0.501776, 0.156317, -0.092398, 0.078616, -0.144665, -0.419595, 0.396099, 0.319320, -0.084284, -0.013825, 0.811091, -0.228181, -0.249798, 0.043037, -0.014254, 0.145196, -0.379182, -0.241216, 0.270687, 0.331287, 0.078576, 0.225569, 0.075139, 0.206449, -0.216213, -0.179613, -0.196487, -0.121997, 0.634396, 0.243545, -0.646855, -0.196892, 0.164843, -0.165656, -0.017864, -0.220435, -0.315971, -0.428766, -0.276434, 0.298087, 0.034363, 0.339730, -0.001861, -0.061919, -0.482472, -0.097411, -0.183378, -0.040443, -0.111079, 0.394592, 0.943151, -0.304478, 0.354390, 0.196057, 0.199277, 0.341486, 0.218786, 0.193412, 0.226260, -0.177706, -0.272467, 0.395993, -0.259079, 0.001724, 0.371750, 0.350838, 0.290101, -0.419872, -0.302239, 0.187943, -0.047100, 0.501532, 0.395721, 0.057455, 0.260134, -0.393160, 0.164219, 0.066535, 0.172231, -0.359559, -0.161729, 0.682735, -0.679863, 0.053116, -0.210306, 0.089449, 0.457067, -0.076446, 0.443101, 0.434519, 0.493740, -0.721550, -0.047476, -0.149920, 0.792890, -0.869984, 0.416676, -0.278901, -0.456933, 0.201800, 0.250265, 0.093752, 0.216085, -0.122870, 0.141153, 0.164069, -0.099821, -0.121633, 0.180234, -0.016088, -0.070337, -0.163921, -0.103767, 0.440052, 0.191798, 0.114916, 0.325931, -0.172159, 0.250953, 0.115396, -0.131392, 0.363941, -0.167835, -0.198244, -0.989684, -0.186654, 0.199121, 0.593739, -0.318832, -0.185066, -0.236732, -0.230723, 0.018697, -0.223611, -0.002621, 0.185624, -0.180204, 0.115503, 0.430932, -0.117918, 0.103355, 0.195856, -0.223646, 0.132063, 0.571766, -0.608208, -0.051812, 0.142387, -0.170185, -0.515449, 0.352781, 0.486267, -0.422757, 0.272677, -0.105689, 0.340707, -0.156664, -0.782644, 0.512138, -0.341311, -0.487717, 0.194345, -0.057030, -0.015855, 0.099853, 0.549729, -0.415887, 0.604569, 0.066785, -0.448733, -0.011270, -0.616035, -0.562425, -0.334210, -0.393114, -0.628784, -0.305269, 0.209872, -0.199347, 0.101649, 0.090523, 0.282902, -0.088015, -0.191279, -0.044561, -0.709134, 0.072914, -0.249584, 0.037448, 0.165476, 0.059152, 0.055725, -0.518436, -0.005831, -0.164648, -0.281878, 0.347298, 0.177980, -0.114527, 0.210128, 0.120374, -0.146421, 0.075994, -0.181335, 0.150211, -0.225272, -0.489089, -0.078891, -0.178676, -0.740558, 0.205851, 0.392087, -0.328261, -0.068016, -0.021789, -0.280372, 0.704844, -0.058202, 0.168101, 0.180238, 0.096060, -0.275457, 0.027325, 0.425901, -0.313618, 0.154550, 0.204825, -0.104279, 0.245843, -0.489933, -0.046835, -0.247613, 0.823351, 0.004220, 0.017303, -0.158378, 0.154119, -0.197591, -0.127734, 0.159808, -0.600171, -0.346363, 0.469721, -0.058461, -0.315804, -0.083556, 0.267933, -0.717538, -0.110205, -0.563653, 0.005439, 0.389236, 0.552098, 0.436608, -0.472080, 0.223911, -0.471215, -0.560872, -0.021037, 0.275148, 0.461694, -0.325049, 0.598732, 0.376293, -0.225930, -0.151626, 0.146455, 0.396804, 0.021290, 0.037224, 0.235271, 0.329889, 0.672245, -0.496795, -0.378117, -0.350688, 0.435732, 0.370599, 0.008810, 0.555823, 0.623420, 0.260685, -0.383603, -0.185294, 0.175743, 0.406610, -0.249284, 0.318281, 0.203903, 0.182324, -0.028281, -0.134342, 0.156111, -0.666054, -0.169002, 0.259389, -0.127781, -0.134607, 0.133519, -0.287695, -0.392834, 0.252281, -0.458701, 0.297617, 0.066121, 0.535986, -1.198022, -0.872793, -0.535140, 0.635081, -0.181788, 0.259800, 0.160934, 0.403854, -0.016975, 0.122155, 0.106455, 0.017354, 0.064465, -0.004753, 0.183455, 0.125073, 0.000588, -1.079189, -0.091745, 0.131509, -0.038783, 0.086098, -0.011477, 0.033550, -0.027044, -0.398735, -0.133224, -0.045345, -0.183940, 0.100738, 0.766663, 0.008661, -0.061123, 0.052512, 0.097162, 0.122948, -0.363722, -0.118078, -0.802726, -0.130973, -0.369868, 0.688861, 0.363402, -0.023863, 0.067200, -0.240462, 0.499130, -0.021514, -0.149011, -0.011722, -0.237259, 0.152696, 0.124860, 0.081450, 0.090567, 0.048832, 0.615275, 0.147335, -0.101912, -0.132456, 0.131634, -0.168211, 0.355089, 0.199154, -0.000686, -0.334698, 0.464978, 0.060418, 0.398211, 0.122107, 0.336332, -0.415999, 0.140270, 0.113768, -0.197597, -0.220913, -0.169208, 0.155395, 0.350888, -0.163269, -0.365437, 0.111591, 0.043267, 0.600786, -0.172549, -0.028790, 0.133079, 0.111489, -0.018018, 0.260471, -0.890617, 0.236967, 0.416443, 0.903602, -0.082193, -0.280290, 0.138442, 0.411884, -0.454041, 0.491140, -0.444857, -0.186720, -0.382473, -0.126291, 0.495247, -0.631967, -0.266918, -0.220935, 0.367287, 0.502838, 0.155025, -0.429546, -0.408211, 0.234250, -0.462584, -0.046278, -0.231486, 0.209515, 0.246387, -0.061538, 0.270009, -0.012469, -0.420804, 0.087525, -0.513991, 0.020571, 0.507510, -0.444389, -0.022836, -0.590260, 0.167235, -0.201333, 0.189617, 0.279683, -0.402719, 0.145037, 0.929912, 0.430638, -0.179808, 0.080103, 0.600420, -0.489557, 0.381116, -0.722508, -0.164676, -0.037822, -0.305011, -0.376997, 0.013216, -0.315066, 0.022070, 0.528256, 0.300673, 0.108121, 0.488978, -0.100333, -0.130812, 0.217841, -0.220755, -0.671549, -0.076320, 0.525022, 0.184758, -0.214599, 0.194860, 0.236146, -0.240089, -0.474762, -0.037878, 0.149301, -0.063512, 0.294585, 0.747633, -0.437204, 0.083148, 0.410454, 0.142592, -0.260462, 0.127561, -0.031248, 0.321641, 0.304835, -0.315456, 0.321474, -0.200811, -0.007041, -0.019529, 0.332829, 0.095737, 0.888721, -0.068599, 0.112251, 0.200350, 0.349384, 0.130674, -0.199802, 0.104813, -0.402484, 0.338873, 0.018662, -0.304823, 0.138016, 0.002506, -0.095239, -0.271009, -0.849811, -0.423410, -0.232685, -0.589317, 0.450318, -0.305014, 0.563061, -0.142598, 0.286005, 0.081525, 0.097474, 0.012287, 0.317698, -0.170248, -0.958868, 0.213176, 0.301248, 0.396288, -0.022001, 0.404562, -0.049691, -0.227430, -0.230833, 0.232825, 0.310583, 0.357731, 0.113404, 0.015757, 0.094021, 0.318617, 0.595829, -0.039896, 0.615338, -0.176179, -0.043411, 0.534391, -0.335011, 0.427954, -0.310139, -0.024028, -0.739826, -0.112875, -0.258219, 0.677319, -0.274854, -0.202554, -0.027695, 0.908598, -0.016939, 0.387993, 0.037429, -0.101158, 0.166008, 0.416612, 0.189825, -0.642134, -0.106222, 0.141566, -0.026880, 0.021668, 0.221566, 0.267000, 0.196498, -0.181309, -0.062393, 0.203500, 0.037145, -0.128068, -0.645994, 0.417619, 0.601422, 0.012565, 0.457200, -0.532447, 0.277037, -0.485728, -0.274002, 0.261037, -0.255880, -0.009387, 0.491182, 0.383511, 0.125899, -0.204434, 0.205015, 0.109285, -0.415707, 0.095736, 0.147818, 0.122518, 0.038847, 0.232760, 0.166897, 0.331865, -0.357069, 0.314145, -0.216854, -0.337515, 0.259433, 0.320100, -0.172233, -0.315187, 0.197327, 0.046211, -0.521370, 0.391666, 0.248245, -0.153588, -0.275701, -0.000683, -0.205512, 0.000457, -0.134299, 0.452796, -0.099954, 0.194279, -0.210376, -0.530722, -0.265526, -0.408304, 0.263296, 0.311573, 0.364050, 0.212423, 0.355866, -0.102873, -0.300132, -1.024923, 0.019980, 0.381418, 0.513570, -0.051673, 0.091931, 0.043775, 0.022401, 0.230052, 0.140274, -0.147261, 0.173270, 0.150905, -0.167662, 0.099411, -0.022456, -0.727629, -0.310803, -0.555541, -0.286311, -0.483686, -0.054392, 0.234199, -0.675458, -0.605178, -0.033194, 0.591152, -0.440875] + "snowflake-arctic-embed2": [-0.337318, 0.485787, -0.037816, -0.943875, -0.819299, -0.257385, -0.115470, 0.246724, 0.048614, 0.159151, -0.467606, -0.364392, 0.089869, -0.209655, 0.342226, -0.527060, 0.520997, 0.927532, -0.102562, 0.333813, -0.854380, -0.701242, -1.463815, 0.799778, 0.750539, 0.757705, -0.125063, 0.527705, -0.437741, 0.078491, 0.460214, 0.255947, -0.031090, -0.345135, 0.058851, -0.327729, -0.372813, 0.352275, -1.168406, 0.354936, 0.625492, 0.045635, -0.242759, 0.650628, 0.195748, -0.495107, -0.539670, -0.986722, -1.069306, -0.014932, -0.385889, 0.215507, 0.333816, -0.158572, 0.246042, -0.687132, 0.207916, -0.342494, -0.347905, -0.563665, 0.336679, -0.059624, -0.155887, -0.246520, 0.296986, 0.569967, 0.131530, -0.355191, -0.582369, 0.490316, -0.415379, -0.019140, -0.214617, 1.085840, 0.019224, -1.180745, -0.544194, -0.182204, -0.471391, 0.877849, 1.787677, 0.196131, 0.338737, 0.554189, 0.723178, -0.052438, -0.270815, 0.443365, 0.101404, -0.692780, 0.004322, -0.050623, -0.693687, -0.116200, 0.434660, 0.065080, -0.055940, 0.122773, -0.999912, -0.499409, -0.359269, 0.027620, -0.399372, -0.299647, -0.744792, 0.102263, 1.084825, 0.028898, 0.323312, 0.014242, 1.325412, 0.983624, -0.325036, 0.526028, -0.157539, -0.063860, 0.436522, 0.116374, 0.118433, 0.614439, -0.139657, 0.522618, 0.017510, -0.188138, -0.677374, -0.840603, 0.192689, -0.135996, -0.894670, -0.158343, -0.792459, -0.136472, -0.355442, -0.123314, -0.910940, 0.186382, 0.334950, 0.204000, 0.222174, 0.263186, 0.094970, 0.061765, 0.345430, -0.235054, 0.172441, 0.881599, 0.841009, -0.169058, -0.269911, -0.217716, 0.359628, 0.208658, 0.652820, 0.442545, -0.161419, 0.418893, 0.292317, 0.231373, -0.805518, -0.000739, 0.297150, 0.121066, -0.408190, 0.273577, 0.463801, 0.064206, 0.312172, 1.092058, -0.371314, -0.277224, -0.683628, -0.435973, 0.045403, -0.140166, 0.184559, 1.853358, 0.416705, -0.374452, 0.760777, 0.248660, -0.569295, -0.954281, -0.347827, 0.531861, -0.570648, 0.556323, 0.206901, 0.252571, -0.043244, -0.062258, 1.049511, -0.402070, -0.068134, -0.149358, -0.012464, 0.620048, 0.654902, -0.538302, -0.245287, -0.066978, -1.405453, -0.445957, 0.331479, -0.495953, 0.923955, -0.328841, -0.644721, -0.372834, 0.357546, 0.478619, -0.081360, 0.340657, -0.122412, -0.597997, 0.235506, 0.016301, -0.058082, 0.446411, -0.802173, 0.115207, -0.464422, -0.257083, -0.133011, -0.359320, 0.389579, 0.485856, -0.053931, 1.149238, -0.967310, 0.020607, -0.235731, -0.358982, -0.698047, 0.653281, 0.734305, -0.836348, 0.074222, -0.177832, -0.486657, -0.344304, -0.443823, -0.255469, -0.606071, 0.069794, 0.069820, 0.494822, -0.536611, -0.175762, -0.448531, -0.522376, -0.108621, -0.271191, -0.141843, 0.071029, 0.171164, -0.195819, -0.059490, 0.026950, -0.433273, 0.016244, 0.146567, -0.032891, -0.039686, 0.323199, -0.057771, -0.176835, 0.470351, 0.048500, 0.327727, -0.158381, 0.162835, -0.407448, -0.555830, -0.465591, 0.264512, 0.354612, 0.218764, 0.031698, -0.265124, 0.312480, 0.181667, -0.338958, 0.186351, 0.053644, 0.812065, -0.862652, -0.026800, 0.572852, 0.005986, 0.828237, 0.090118, 0.063922, 0.076976, 0.096964, 0.180304, 0.781934, -0.003830, -0.027061, 0.221362, 0.449681, 0.125572, -0.095162, 0.018868, -0.360262, -0.373733, -0.392008, -0.125284, -0.212061, 0.159567, -0.233902, -0.235149, -0.190911, -0.028427, 0.344431, -0.155667, 0.722263, -0.144527, -0.199895, 0.188895, 0.894280, 0.140612, 0.698334, -0.078967, -0.845755, 0.500688, -0.028362, -0.309373, -0.050033, 0.393043, -0.684940, -0.012917, 0.442933, -0.152553, -0.068629, -0.237759, -0.239215, 0.132807, 0.019395, 0.123185, 0.242981, 0.786300, 0.018274, 0.157075, 0.240215, 0.229825, -0.137675, -0.203565, -0.245311, -0.036812, 0.430710, -0.207664, -0.132277, 0.557027, 0.452612, -0.331802, 0.004795, 0.139062, 0.078491, -0.501776, 0.156317, -0.092398, 0.078616, -0.144665, -0.419595, 0.396099, 0.319320, -0.084284, -0.013825, 0.811091, -0.228181, -0.249798, 0.043037, -0.014254, 0.145196, -0.379182, -0.241216, 0.270687, 0.331287, 0.078576, 0.225569, 0.075139, 0.206449, -0.216213, -0.179613, -0.196487, -0.121997, 0.634396, 0.243545, -0.646855, -0.196892, 0.164843, -0.165656, -0.017864, -0.220435, -0.315971, -0.428766, -0.276434, 0.298087, 0.034363, 0.339730, -0.001861, -0.061919, -0.482472, -0.097411, -0.183378, -0.040443, -0.111079, 0.394592, 0.943151, -0.304478, 0.354390, 0.196057, 0.199277, 0.341486, 0.218786, 0.193412, 0.226260, -0.177706, -0.272467, 0.395993, -0.259079, 0.001724, 0.371750, 0.350838, 0.290101, -0.419872, -0.302239, 0.187943, -0.047100, 0.501532, 0.395721, 0.057455, 0.260134, -0.393160, 0.164219, 0.066535, 0.172231, -0.359559, -0.161729, 0.682735, -0.679863, 0.053116, -0.210306, 0.089449, 0.457067, -0.076446, 0.443101, 0.434519, 0.493740, -0.721550, -0.047476, -0.149920, 0.792890, -0.869984, 0.416676, -0.278901, -0.456933, 0.201800, 0.250265, 0.093752, 0.216085, -0.122870, 0.141153, 0.164069, -0.099821, -0.121633, 0.180234, -0.016088, -0.070337, -0.163921, -0.103767, 0.440052, 0.191798, 0.114916, 0.325931, -0.172159, 0.250953, 0.115396, -0.131392, 0.363941, -0.167835, -0.198244, -0.989684, -0.186654, 0.199121, 0.593739, -0.318832, -0.185066, -0.236732, -0.230723, 0.018697, -0.223611, -0.002621, 0.185624, -0.180204, 0.115503, 0.430932, -0.117918, 0.103355, 0.195856, -0.223646, 0.132063, 0.571766, -0.608208, -0.051812, 0.142387, -0.170185, -0.515449, 0.352781, 0.486267, -0.422757, 0.272677, -0.105689, 0.340707, -0.156664, -0.782644, 0.512138, -0.341311, -0.487717, 0.194345, -0.057030, -0.015855, 0.099853, 0.549729, -0.415887, 0.604569, 0.066785, -0.448733, -0.011270, -0.616035, -0.562425, -0.334210, -0.393114, -0.628784, -0.305269, 0.209872, -0.199347, 0.101649, 0.090523, 0.282902, -0.088015, -0.191279, -0.044561, -0.709134, 0.072914, -0.249584, 0.037448, 0.165476, 0.059152, 0.055725, -0.518436, -0.005831, -0.164648, -0.281878, 0.347298, 0.177980, -0.114527, 0.210128, 0.120374, -0.146421, 0.075994, -0.181335, 0.150211, -0.225272, -0.489089, -0.078891, -0.178676, -0.740558, 0.205851, 0.392087, -0.328261, -0.068016, -0.021789, -0.280372, 0.704844, -0.058202, 0.168101, 0.180238, 0.096060, -0.275457, 0.027325, 0.425901, -0.313618, 0.154550, 0.204825, -0.104279, 0.245843, -0.489933, -0.046835, -0.247613, 0.823351, 0.004220, 0.017303, -0.158378, 0.154119, -0.197591, -0.127734, 0.159808, -0.600171, -0.346363, 0.469721, -0.058461, -0.315804, -0.083556, 0.267933, -0.717538, -0.110205, -0.563653, 0.005439, 0.389236, 0.552098, 0.436608, -0.472080, 0.223911, -0.471215, -0.560872, -0.021037, 0.275148, 0.461694, -0.325049, 0.598732, 0.376293, -0.225930, -0.151626, 0.146455, 0.396804, 0.021290, 0.037224, 0.235271, 0.329889, 0.672245, -0.496795, -0.378117, -0.350688, 0.435732, 0.370599, 0.008810, 0.555823, 0.623420, 0.260685, -0.383603, -0.185294, 0.175743, 0.406610, -0.249284, 0.318281, 0.203903, 0.182324, -0.028281, -0.134342, 0.156111, -0.666054, -0.169002, 0.259389, -0.127781, -0.134607, 0.133519, -0.287695, -0.392834, 0.252281, -0.458701, 0.297617, 0.066121, 0.535986, -1.198022, -0.872793, -0.535140, 0.635081, -0.181788, 0.259800, 0.160934, 0.403854, -0.016975, 0.122155, 0.106455, 0.017354, 0.064465, -0.004753, 0.183455, 0.125073, 0.000588, -1.079189, -0.091745, 0.131509, -0.038783, 0.086098, -0.011477, 0.033550, -0.027044, -0.398735, -0.133224, -0.045345, -0.183940, 0.100738, 0.766663, 0.008661, -0.061123, 0.052512, 0.097162, 0.122948, -0.363722, -0.118078, -0.802726, -0.130973, -0.369868, 0.688861, 0.363402, -0.023863, 0.067200, -0.240462, 0.499130, -0.021514, -0.149011, -0.011722, -0.237259, 0.152696, 0.124860, 0.081450, 0.090567, 0.048832, 0.615275, 0.147335, -0.101912, -0.132456, 0.131634, -0.168211, 0.355089, 0.199154, -0.000686, -0.334698, 0.464978, 0.060418, 0.398211, 0.122107, 0.336332, -0.415999, 0.140270, 0.113768, -0.197597, -0.220913, -0.169208, 0.155395, 0.350888, -0.163269, -0.365437, 0.111591, 0.043267, 0.600786, -0.172549, -0.028790, 0.133079, 0.111489, -0.018018, 0.260471, -0.890617, 0.236967, 0.416443, 0.903602, -0.082193, -0.280290, 0.138442, 0.411884, -0.454041, 0.491140, -0.444857, -0.186720, -0.382473, -0.126291, 0.495247, -0.631967, -0.266918, -0.220935, 0.367287, 0.502838, 0.155025, -0.429546, -0.408211, 0.234250, -0.462584, -0.046278, -0.231486, 0.209515, 0.246387, -0.061538, 0.270009, -0.012469, -0.420804, 0.087525, -0.513991, 0.020571, 0.507510, -0.444389, -0.022836, -0.590260, 0.167235, -0.201333, 0.189617, 0.279683, -0.402719, 0.145037, 0.929912, 0.430638, -0.179808, 0.080103, 0.600420, -0.489557, 0.381116, -0.722508, -0.164676, -0.037822, -0.305011, -0.376997, 0.013216, -0.315066, 0.022070, 0.528256, 0.300673, 0.108121, 0.488978, -0.100333, -0.130812, 0.217841, -0.220755, -0.671549, -0.076320, 0.525022, 0.184758, -0.214599, 0.194860, 0.236146, -0.240089, -0.474762, -0.037878, 0.149301, -0.063512, 0.294585, 0.747633, -0.437204, 0.083148, 0.410454, 0.142592, -0.260462, 0.127561, -0.031248, 0.321641, 0.304835, -0.315456, 0.321474, -0.200811, -0.007041, -0.019529, 0.332829, 0.095737, 0.888721, -0.068599, 0.112251, 0.200350, 0.349384, 0.130674, -0.199802, 0.104813, -0.402484, 0.338873, 0.018662, -0.304823, 0.138016, 0.002506, -0.095239, -0.271009, -0.849811, -0.423410, -0.232685, -0.589317, 0.450318, -0.305014, 0.563061, -0.142598, 0.286005, 0.081525, 0.097474, 0.012287, 0.317698, -0.170248, -0.958868, 0.213176, 0.301248, 0.396288, -0.022001, 0.404562, -0.049691, -0.227430, -0.230833, 0.232825, 0.310583, 0.357731, 0.113404, 0.015757, 0.094021, 0.318617, 0.595829, -0.039896, 0.615338, -0.176179, -0.043411, 0.534391, -0.335011, 0.427954, -0.310139, -0.024028, -0.739826, -0.112875, -0.258219, 0.677319, -0.274854, -0.202554, -0.027695, 0.908598, -0.016939, 0.387993, 0.037429, -0.101158, 0.166008, 0.416612, 0.189825, -0.642134, -0.106222, 0.141566, -0.026880, 0.021668, 0.221566, 0.267000, 0.196498, -0.181309, -0.062393, 0.203500, 0.037145, -0.128068, -0.645994, 0.417619, 0.601422, 0.012565, 0.457200, -0.532447, 0.277037, -0.485728, -0.274002, 0.261037, -0.255880, -0.009387, 0.491182, 0.383511, 0.125899, -0.204434, 0.205015, 0.109285, -0.415707, 0.095736, 0.147818, 0.122518, 0.038847, 0.232760, 0.166897, 0.331865, -0.357069, 0.314145, -0.216854, -0.337515, 0.259433, 0.320100, -0.172233, -0.315187, 0.197327, 0.046211, -0.521370, 0.391666, 0.248245, -0.153588, -0.275701, -0.000683, -0.205512, 0.000457, -0.134299, 0.452796, -0.099954, 0.194279, -0.210376, -0.530722, -0.265526, -0.408304, 0.263296, 0.311573, 0.364050, 0.212423, 0.355866, -0.102873, -0.300132, -1.024923, 0.019980, 0.381418, 0.513570, -0.051673, 0.091931, 0.043775, 0.022401, 0.230052, 0.140274, -0.147261, 0.173270, 0.150905, -0.167662, 0.099411, -0.022456, -0.727629, -0.310803, -0.555541, -0.286311, -0.483686, -0.054392, 0.234199, -0.675458, -0.605178, -0.033194, 0.591152, -0.440875], + "embeddinggemma": [-0.180607, -0.005889, 0.056060, 0.003927, -0.000914, 0.039077, -0.014656, 0.043961, 0.019580, -0.035296, -0.007450, -0.005179, 0.019225, -0.019774, 0.091014, 0.019635, 0.013230, -0.058614, -0.087619, -0.020779, 0.013899, 0.020425, 0.020929, -0.013275, 0.006773, 0.035736, -0.000525, 0.041411, -0.040545, -0.020314, 0.050535, 0.011460, 0.006180, 0.004620, 0.009491, 0.063213, 0.026650, -0.078453, 0.014068, -0.007105, -0.067159, 0.070726, -0.005620, -0.001726, 0.039832, -0.008431, -0.042386, -0.074533, 0.017800, 0.028184, -0.036976, 0.010219, -0.076775, -0.004639, -0.032891, -0.027700, -0.001458, 0.025879, -0.049648, 0.054045, -0.085062, -0.058225, -0.033515, 0.008202, 0.031972, -0.016915, -0.008515, 0.009412, 0.030835, 0.290311, -0.003487, -0.034788, -0.038132, -0.043061, 0.256677, -0.011666, 0.014681, -0.028823, -0.018269, 0.039487, 0.031672, 0.019415, 0.005047, -0.028207, 0.068291, -0.014481, -0.038045, -0.002510, 0.018587, -0.037189, 0.005428, -0.023473, -0.014835, -0.017678, 0.018800, -0.063929, 0.020199, -0.006755, -0.017839, 0.010022, 0.013381, -0.012818, 0.049738, 0.103187, 0.042391, -0.026812, -0.025366, -0.020568, -0.027174, 0.016077, -0.039851, 0.020666, -0.000882, -0.005518, -0.019171, 0.019656, -0.034325, -0.011952, -0.007715, 0.016382, 0.023478, 0.008794, 0.008553, -0.029300, -0.008735, 0.039593, -0.000716, 0.020169, -0.052826, -0.029727, -0.012380, 0.010375, 0.018598, 0.022578, -0.026641, 0.029909, -0.029175, 0.023739, 0.107380, 0.020876, -0.023140, -0.044402, -0.022905, 0.007913, 0.022877, 0.023412, -0.039243, -0.043118, 0.028891, -0.026474, 0.058203, -0.009469, 0.043529, -0.001148, -0.039958, 0.014740, -0.004989, 0.036507, -0.051620, -0.031175, -0.044597, -0.012939, 0.042538, 0.040928, -0.017088, 0.067081, 0.034191, 0.039442, -0.032367, -0.025219, 0.016267, -0.075883, 0.012993, -0.009841, -0.031461, -0.003222, 0.025393, -0.014700, 0.137779, 0.059643, -0.022662, 0.071586, -0.021947, -0.016830, -0.051643, 0.079643, 0.031313, 0.028581, -0.046328, -0.030986, 0.032716, -0.026602, 0.006470, -0.039692, 0.014163, 0.026654, 0.093365, 0.032694, 0.015409, 0.056742, 0.010313, -0.042171, -0.028517, -0.032063, -0.032345, -0.007562, -0.016120, 0.027168, -0.013794, 0.007886, 0.012508, 0.041319, -0.045403, -0.017043, 0.020755, 0.023180, -0.032282, -0.001933, 0.004788, -0.008845, -0.008160, 0.060156, -0.013906, -0.023820, -0.079445, -0.008594, -0.053413, -0.022282, 0.047015, -0.032093, -0.024175, 0.047034, 0.008884, 0.004303, -0.028246, -0.051090, -0.052362, 0.023090, 0.039360, 0.022026, -0.030191, -0.013131, 0.070108, 0.026949, -0.021665, 0.004105, 0.020551, -0.040327, 0.019240, -0.016254, 0.044880, 0.002207, 0.014687, -0.025028, 0.019416, 0.024616, 0.041669, 0.006119, 0.037250, -0.028410, 0.009880, -0.040271, 0.002080, -0.025919, -0.004405, -0.000738, 0.056129, -0.012073, -0.076000, -0.029060, 0.023691, -0.032691, -0.007954, -0.007977, 0.004320, 0.034496, -0.004204, 0.010539, -0.016130, -0.017084, 0.008819, -0.045655, 0.039586, 0.072274, -0.001564, -0.018986, 0.008003, 0.014540, -0.020202, -0.047112, 0.031214, 0.016062, 0.018050, 0.010549, -0.011854, -0.018039, -0.069473, -0.009963, -0.051167, 0.009143, -0.153973, -0.004123, -0.005579, -0.040051, 0.035759, 0.006347, 0.101297, 0.023210, -0.027991, 0.018936, -0.015880, -0.005618, -0.015289, 0.004609, 0.021374, 0.039930, 0.006416, -0.023404, 0.011298, 0.049358, 0.030248, -0.003199, 0.037742, -0.005549, 0.002536, -0.009225, 0.075556, 0.025399, -0.105251, 0.051742, 0.024327, -0.009478, 0.085254, 0.022996, -0.011836, 0.017050, -0.033193, -0.024902, -0.010120, -0.008034, -0.030815, -0.036057, 0.012252, -0.011298, 0.039328, 0.059369, -0.006778, 0.022502, -0.035587, -0.017115, -0.009679, -0.038681, 0.008259, -0.032500, -0.032367, -0.006804, 0.017457, 0.009156, 0.081487, 0.022148, 0.067775, -0.093894, 0.031846, -0.020408, 0.030209, -0.025761, -0.002701, -0.010321, 0.027960, 0.033251, -0.036872, -0.050187, 0.000647, -0.006226, -0.023869, 0.016206, 0.037978, 0.010212, 0.015309, 0.010771, -0.021051, -0.025120, -0.015512, 0.014893, 0.025609, 0.020798, 0.017523, -0.003228, -0.039813, -0.000640, -0.072478, 0.023971, 0.017670, -0.000263, 0.027826, -0.016781, -0.062787, 0.033756, -0.054924, -0.016018, -0.014411, -0.023932, -0.019506, 0.053342, 0.010944, -0.042293, 0.008613, -0.034974, -0.021446, -0.009096, 0.009447, -0.047535, -0.026310, 0.052943, -0.018407, -0.067084, -0.015153, 0.022066, -0.094386, -0.020756, -0.018389, -0.007097, 0.023673, 0.029475, 0.016351, 0.056441, 0.041489, 0.022455, 0.045231, 0.016921, 0.038108, -0.050633, -0.011182, 0.043934, 0.004127, -0.002120, -0.043600, 0.029996, 0.030896, 0.053286, -0.040573, -0.056230, 0.005976, 0.023566, 0.002953, 0.024094, -0.013409, -0.004520, -0.008556, -0.013104, 0.021295, 0.055402, 0.010722, -0.001335, -0.017037, -0.005384, -0.001479, 0.011284, -0.036778, -0.025387, 0.011728, 0.031009, 0.013086, -0.033330, 0.001464, -0.042345, -0.004958, -0.047134, -0.046586, -0.009092, -0.035167, 0.028894, 0.016499, 0.027641, -0.018313, -0.030339, -0.007525, -0.002521, 0.016437, 0.038167, -0.039028, 0.018501, 0.011082, -0.038511, 0.052748, -0.026239, 0.013372, 0.048941, 0.024461, 0.028460, -0.032540, 0.030801, -0.004988, 0.019265, -0.000821, -0.003360, -0.008584, -0.010812, -0.038912, 0.021251, -0.007917, 0.009466, 0.044254, 0.025877, -0.041038, 0.021859, -0.028726, -0.025102, -0.027201, 0.021041, -0.024082, -0.019939, -0.028786, 0.035638, 0.012499, -0.056768, 0.017606, 0.013645, 0.063322, 0.019431, -0.012916, -0.014921, 0.021280, 0.038604, -0.051948, 0.014327, 0.000644, -0.037636, 0.044974, -0.034804, -0.018713, -0.015628, 0.004696, -0.059636, -0.028115, -0.026099, 0.039006, -0.008930, 0.003926, -0.002180, 0.026208, -0.013091, -0.008003, -0.021853, 0.025176, 0.002682, -0.005717, 0.015906, -0.016278, 0.059330, 0.037651, -0.019816, -0.002063, 0.078280, 0.038225, -0.007886, -0.019419, -0.013467, 0.036248, 0.022667, -0.008686, -0.018650, -0.003174, 0.005356, 0.015273, -0.042080, 0.017633, -0.010409, -0.034426, 0.015203, -0.015169, 0.030749, 0.026981, 0.009063, 0.007627, 0.091058, 0.009252, -0.058266, 0.020832, 0.024774, -0.028801, 0.026656, 0.022873, 0.016810, 0.016199, 0.007166, 0.024154, -0.028103, -0.006009, 0.022183, 0.005554, -0.028878, -0.010401, 0.013940, -0.019788, 0.003170, 0.049577, -0.023240, 0.052571, -0.011319, -0.015615, -0.054576, 0.005388, -0.000214, 0.044267, 0.015146, 0.001074, 0.064579, 0.004720, -0.008599, -0.013695, 0.012638, -0.040416, -0.000573, 0.011788, -0.003367, 0.008859, -0.018324, 0.011476, 0.018124, 0.010761, 0.046478, 0.010131, -0.050664, 0.021277, -0.018322, -0.003070, -0.019420, 0.029148, 0.012572, -0.004788, 0.040219, 0.039951, 0.043636, -0.005553, -0.006092, 0.066745, 0.027182, -0.029501, 0.014834, -0.021343, 0.023051, -0.000411, 0.026976, 0.030971, -0.004881, 0.003360, -0.006648, -0.008272, 0.041518, -0.027338, 0.001205, -0.006581, -0.024365, -0.033114, -0.024966, 0.010534, -0.012564, 0.045804, -0.004190, 0.057720, -0.022263, -0.003263, 0.040221, -0.028405, -0.004599, -0.023340, 0.005303, -0.001754, -0.057940, 0.006630, -0.015906, -0.024751, -0.005112, -0.024829, 0.034132, 0.027506, -0.011464, -0.000899, 0.065783, 0.021920, -0.007581, -0.001119, 0.025989, -0.010824, -0.017624, 0.003288, -0.036588, -0.003869, -0.009002, 0.033091, -0.091586, -0.008219, -0.033366, -0.006626, 0.005773, -0.006797, 0.028244, 0.020040, -0.012321, 0.039671, 0.017080, 0.055742, -0.003618, 0.025329, -0.000387, -0.003931, -0.011762, 0.029402, 0.019193, -0.035431, -0.012032, 0.005728, 0.017904, 0.013591, 0.010789, -0.027405, -0.014921, -0.042695, -0.036618, -0.013241, -0.041739, 0.027101, 0.030270, -0.025742, 0.034299, 0.014907, -0.028621, -0.017876, -0.010664, 0.016825, 0.028010, -0.032065, -0.031433, -0.011079, -0.015334, 0.020823, -0.020160, 0.015865, -0.003164, 0.008807, -0.026496, 0.028156, -0.040351, 0.024934, -0.084462, -0.008195, -0.022980, -0.019642, -0.000236, 0.005202, -0.016276, 0.005408, 0.056486, -0.031008, -0.032354, -0.018442, -0.006234, 0.005748, -0.000144, 0.004343, 0.045699, 0.014304, -0.011867, 0.004920, -0.057248, -0.006343, -0.005401], + "qwen3-embedding": [0.031792, 0.004926, -0.018041, -0.021473, 0.011381, -0.021165, -0.029530, -0.013660, 0.001806, 0.011227, -0.030067, 0.022565, 0.017513, -0.002146, 0.007644, 0.010086, -0.001986, -0.008106, 0.028812, 0.051614, 0.000999, 0.035312, -0.004000, 0.011976, 0.031688, 0.074720, -0.001988, -0.008885, -0.022300, 0.007263, -0.006750, -0.011187, 0.008181, 0.020549, -0.008814, 0.027741, -0.035995, -0.010632, 0.025195, 0.003310, -0.013303, 0.023147, -0.006527, 0.017887, 0.007376, -0.015949, -0.025742, 0.000312, -0.002480, 0.024913, 0.003318, -0.016577, -0.009663, -0.012305, -0.007548, -0.020042, -0.018475, 0.012711, 0.013015, 0.019632, -0.044327, 0.009160, 0.023335, -0.044934, 0.004110, -0.002380, 0.007062, -0.024249, -0.014356, -0.022298, 0.022666, -0.009599, 0.008619, 0.021663, 0.011470, 0.004400, -0.026681, -0.027311, 0.013865, 0.029518, 0.002130, -0.003086, -0.016096, -0.029406, -0.011760, -0.003154, -0.022900, -0.002128, -0.006287, -0.004423, -0.013240, 0.008787, 0.015061, 0.008619, -0.012334, -0.023233, -0.024260, -0.000038, -0.021759, -0.017202, 0.036565, 0.007832, 0.020661, -0.000709, -0.010937, -0.006529, -0.021067, -0.003493, -0.019981, -0.007152, -0.010431, 0.016528, 0.009478, -0.000387, 0.011030, -0.015702, 0.004910, 0.019820, -0.009501, 0.032242, 0.009791, -0.012693, -0.012758, 0.005951, 0.002285, -0.021048, -0.003414, 0.008490, -0.009785, 0.010615, 0.007026, -0.024024, 0.009654, -0.000423, -0.012367, -0.000683, -0.023359, 0.022525, -0.038280, 0.005439, -0.008667, 0.000479, 0.016903, 0.003972, -0.017659, 0.009390, 0.001721, 0.014011, -0.008501, -0.005130, 0.005864, -0.015588, -0.023901, -0.018212, -0.017600, -0.000808, 0.003798, -0.000374, 0.002171, -0.010798, -0.002167, -0.002688, 0.011768, 0.008376, 0.008252, -0.014139, 0.011628, -0.017354, 0.020029, -0.003110, -0.006605, 0.000088, -0.008673, 0.031119, 0.012816, -0.001554, -0.028450, -0.031961, 0.000749, 0.008979, -0.011676, 0.021030, 0.011047, 0.031832, 0.011692, 0.027716, -0.013094, -0.004017, 0.009027, -0.003426, 0.021272, -0.011953, -0.002096, 0.014148, 0.004425, 0.019694, -0.002320, 0.009787, -0.016366, -0.006968, 0.044094, -0.021135, 0.014664, 0.012024, -0.001312, -0.010924, -0.018518, 0.004067, 0.005731, -0.022986, 0.018130, 0.005514, -0.000535, 0.014868, 0.006152, -0.004057, 0.021108, 0.017294, 0.010202, -0.006595, 0.001040, -0.014503, 0.010996, -0.001199, -0.008933, -0.011412, -0.001148, -0.021028, -0.012644, -0.027791, 0.011597, -0.006358, -0.002310, -0.008983, -0.016520, -0.016006, 0.005623, -0.011567, -0.006791, 0.018220, 0.010030, -0.016808, 0.017599, -0.038462, 0.001185, 0.001272, 0.029970, -0.011697, 0.004016, -0.022518, -0.005410, -0.011339, 0.006955, -0.019577, -0.009260, -0.025437, -0.009223, 0.046833, -0.012395, 0.026334, -0.003020, -0.041108, -0.012053, 0.013952, 0.008421, 0.000389, -0.008318, -0.026117, -0.013364, 0.003995, 0.001253, -0.008553, -0.014819, 0.042401, 0.008028, 0.015358, -0.006760, -0.003456, -0.010274, -0.005063, 0.019056, -0.019282, 0.028529, -0.023047, 0.001330, -0.014517, 0.001787, 0.035838, -0.002197, 0.017273, 0.013223, 0.008261, 0.011069, 0.016115, 0.009038, 0.003824, 0.007600, -0.028759, 0.007420, -0.019551, 0.000208, -0.027524, 0.001970, 0.015536, -0.015439, -0.025726, -0.006733, 0.015017, -0.004176, -0.030268, -0.016025, -0.001417, 0.012878, -0.006500, -0.003119, -0.003401, -0.007219, -0.007370, 0.000644, 0.021524, 0.008409, 0.005200, 0.003197, 0.014018, 0.013874, 0.007392, 0.008548, -0.001694, -0.023479, 0.026971, -0.019892, -0.011915, -0.009339, 0.010411, 0.000962, 0.026920, 0.020772, -0.015630, 0.042286, 0.003036, 0.002418, -0.000559, 0.016372, -0.001013, 0.019164, 0.011091, 0.007332, -0.012521, -0.026355, -0.012465, 0.012425, -0.012866, -0.003681, 0.010110, -0.004440, -0.017833, -0.004337, 0.003432, 0.012279, -0.013508, 0.002860, 0.007560, -0.013746, -0.007328, 0.006398, 0.012368, -0.031189, 0.010435, -0.026745, -0.002065, -0.000018, 0.008437, -0.020951, -0.014613, 0.027587, 0.021053, 0.008047, -0.019996, -0.002226, -0.008008, 0.013445, -0.034107, -0.000744, -0.001821, 0.014077, 0.005022, -0.002037, -0.008170, 0.020361, -0.036807, -0.040290, -0.033997, 0.010617, 0.018125, 0.007784, 0.011251, -0.019881, -0.029746, -0.016549, 0.027708, 0.017331, 0.000739, 0.012864, 0.015012, -0.003049, 0.001765, -0.016737, -0.004086, -0.019370, 0.012912, -0.004322, 0.006763, 0.024780, -0.001206, 0.009158, 0.008418, -0.004266, 0.030131, -0.000400, -0.017726, -0.017540, 0.011032, -0.011073, -0.013453, -0.022519, -0.003229, 0.006512, -0.001383, 0.009883, 0.016059, -0.000605, -0.000309, -0.018137, 0.014978, 0.002921, -0.007613, 0.027025, -0.005092, 0.009398, 0.006602, 0.007236, 0.008204, 0.020230, 0.018121, 0.002815, 0.002668, -0.030765, -0.012687, -0.002267, 0.008824, -0.005739, 0.022980, -0.040300, -0.003685, -0.023128, -0.001817, -0.014380, -0.001766, -0.012538, 0.011029, -0.009562, -0.025560, 0.000254, 0.000899, 0.020819, 0.003545, 0.006402, -0.010389, -0.004296, -0.000360, -0.009568, 0.006516, 0.006023, 0.001291, -0.014954, 0.006374, -0.011088, -0.004016, 0.024020, 0.003261, 0.009086, -0.006118, 0.012321, 0.008060, -0.001949, 0.045665, 0.010600, -0.028078, 0.015062, 0.019561, -0.008793, 0.000384, 0.012627, -0.015218, 0.005784, -0.003044, -0.002830, -0.003675, 0.017365, -0.005626, 0.012490, 0.001138, -0.004063, -0.023211, -0.015569, -0.011751, 0.020837, -0.020561, -0.022642, -0.000206, 0.010976, 0.004722, 0.006458, -0.002802, -0.009693, -0.025396, 0.009108, 0.001791, 0.006541, 0.016408, 0.001736, -0.018632, 0.000523, -0.018195, -0.008380, 0.003091, 0.007251, -0.013442, 0.009905, -0.010768, -0.005161, 0.002064, 0.010608, 0.002720, -0.021422, -0.009019, 0.009357, -0.007045, 0.010005, 0.009786, -0.011280, 0.003003, 0.008567, -0.016222, -0.021154, 0.001371, 0.009106, 0.008682, 0.028164, 0.037620, -0.014166, 0.033103, -0.002531, -0.004949, -0.010924, -0.007954, -0.011785, -0.001748, -0.014597, 0.009884, 0.004108, 0.001241, -0.000416, 0.003360, -0.021418, -0.026198, 0.006894, 0.008989, -0.021985, 0.004533, 0.011405, -0.001827, 0.008044, 0.002529, -0.014493, 0.016014, -0.020658, 0.003807, 0.010540, -0.025505, 0.015002, 0.004699, 0.017521, 0.008660, 0.017759, -0.007729, 0.010906, -0.012483, 0.006340, -0.017246, -0.006083, 0.002357, 0.016951, -0.022541, 0.000364, -0.018440, 0.003730, -0.018185, -0.006742, 0.008023, 0.003459, -0.031610, 0.003049, -0.003019, -0.002934, 0.029219, -0.001473, -0.013225, 0.023437, 0.002153, 0.008362, 0.009142, -0.023763, -0.008043, 0.004517, 0.009636, 0.014824, -0.028260, 0.004312, 0.015419, -0.005401, -0.003108, -0.017145, 0.006375, 0.006473, 0.017673, 0.003004, -0.006814, -0.005512, -0.018296, -0.024305, 0.022902, -0.025757, -0.022487, -0.026135, 0.013664, 0.001370, 0.003182, -0.037260, 0.007060, 0.011588, 0.004182, 0.035425, 0.012125, -0.004238, 0.010359, 0.035212, 0.008152, 0.011075, 0.023878, 0.002958, 0.038817, 0.008300, 0.007776, 0.010572, -0.042451, -0.001251, 0.005246, 0.002344, -0.019186, -0.033779, -0.006243, 0.007207, 0.017790, -0.017984, -0.024683, -0.001003, 0.022494, 0.022498, -0.013629, -0.026255, -0.013596, 0.001076, 0.006961, 0.013133, -0.005664, -0.006499, -0.001609, 0.007189, -0.021156, -0.003479, -0.002400, -0.020974, 0.014524, 0.010587, -0.010552, -0.000728, 0.022545, 0.001695, 0.001498, -0.004404, -0.007288, 0.017903, 0.011703, 0.012844, 0.028733, -0.005856, -0.026446, 0.017745, 0.012850, 0.022067, 0.013617, -0.010212, -0.021234, -0.008570, -0.015652, -0.023508, 0.011418, -0.039396, 0.005391, 0.003879, 0.001210, -0.006911, 0.008865, -0.003326, -0.003076, 0.019264, 0.001549, 0.007484, -0.030370, 0.053156, 0.013863, -0.027415, -0.003470, -0.002664, 0.008749, 0.020691, 0.009630, 0.028416, -0.037981, -0.015957, -0.010788, -0.012660, -0.000779, -0.016764, 0.033517, -0.013758, -0.000528, -0.003093, -0.002753, -0.011892, -0.005444, -0.009057, 0.023202, -0.036589, 0.012229, -0.019088, 0.011596, 0.010203, -0.029219, 0.004284, 0.006076, -0.005539, 0.006054, 0.009512, 0.007094, -0.028645, -0.003598, 0.013799, -0.027507, -0.006348, 0.013886, 0.006111, -0.003856, -0.003430, -0.001100, 0.001812, 0.005712, 0.024730, -0.018796, 0.000108, -0.006207, 0.005937, 0.011734, -0.007228, -0.007973, -0.012129, 0.006572, 0.000141, 0.030832, 0.005892, 0.003501, 0.001516, 0.004694, -0.022240, 0.007386, -0.023270, 0.044361, -0.000140, 0.028047, -0.014853, -0.016221, 0.017074, -0.002851, 0.010071, -0.015005, 0.015156, 0.009846, 0.007697, 0.005352, -0.009038, 0.005556, -0.002746, 0.009233, 0.006823, -0.000160, -0.021344, -0.006151, -0.012515, 0.072906, -0.013540, -0.008361, 0.008153, -0.001799, 0.018483, 0.010785, -0.011283, -0.016609, 0.004088, 0.014252, -0.004421, -0.020900, 0.029211, 0.011621, 0.004254, -0.004932, 0.005741, 0.006653, 0.013325, -0.010694, 0.007876, -0.002466, -0.018666, -0.018410, -0.010627, -0.003349, 0.003484, -0.011489, 0.014391, 0.003229, 0.007021, 0.000133, -0.014888, -0.026584, 0.010275, 0.007855, 0.001890, 0.015709, 0.009294, 0.008799, 0.008655, -0.018378, 0.008336, 0.000331, -0.013533, 0.002439, 0.021340, -0.009806, 0.003492, -0.001372, -0.013885, -0.021650, 0.020662, -0.017006, 0.004307, 0.011045, 0.006932, 0.023574, -0.011301, -0.025923, -0.006913, 0.016671, -0.023506, -0.013017, -0.003302, -0.022934, 0.010941, 0.014406, 0.014412, -0.022246, -0.005683, 0.020179, -0.010220, 0.005060, 0.002080, 0.002767, 0.000137, -0.014047, 0.002494, -0.007142, 0.010849, -0.017285, -0.012656, -0.014468, -0.005269, 0.012875, -0.022823, 0.000353, -0.003313, -0.003364, -0.026604, 0.022048, -0.008279, -0.005318, 0.021979, 0.022471, -0.005116, -0.032529, -0.040768, -0.020011, 0.017497, -0.014278, 0.010649, 0.001603, -0.022499, 0.010507, 0.036945, -0.029510, -0.001214, -0.020313, -0.012374, 0.014613, 0.003072, 0.007042, 0.020545, -0.012042, -0.006691, 0.001797, 0.014907, -0.006485, 0.008387, -0.000148, 0.001464, 0.005346, -0.004612, -0.015713, 0.005616, -0.005177, 0.018050, -0.003931, -0.000001, -0.008391, 0.011832, -0.003223, -0.010326, -0.019975, -0.016370, 0.006447, -0.002780, -0.010225, -0.005227, -0.011900, -0.016903, 0.005408, -0.010780, -0.006199, -0.008412, -0.013894, -0.006245, -0.000687, 0.018553, 0.004978, -0.014254, -0.001509, -0.005372, 0.019807, -0.001753, 0.023208, -0.027532, 0.009226, 0.011976, -0.011693, 0.002449, -0.004840, -0.002368, -0.009941, 0.013470, 0.014675, -0.027788, -0.031636, 0.002538, 0.008076, -0.007696, -0.035465, 0.004288, 0.043763, -0.016293, -0.006316, -0.031554, 0.021771, -0.000773, 0.001582, 0.004359, 0.025712, 0.009557, -0.010360, 0.006726, -0.001341, 0.017298, -0.028810, 0.039897, 0.007975, -0.010776, -0.027927, 0.000299, -0.015461, 0.005385, -0.008718, -0.015047, 0.003341, -0.009360, 0.011141, 0.004886, 0.010366, -0.029717, -0.009179, -0.011421, 0.009537, -0.014428, -0.008828, -0.027039, -0.022218, 0.001996, -0.014867, 0.015115, 0.011442, 0.007213, -0.000849, -0.015230, 0.003767, 0.012754, 0.048488, 0.019210, 0.009058, -0.017050, 0.004827, 0.008308, 0.002984, 0.004916, 0.009659, -0.032912, -0.014253, -0.012794, -0.013872, 0.002399, 0.002207, -0.001631, -0.020052, -0.006069, -0.013495, 0.003094, 0.006192, -0.005703, 0.025562, -0.013196, -0.001365, 0.025535, -0.007324, -0.008711, -0.005778, 0.006962, -0.005504, -0.007406, -0.009825, -0.008078, 0.004535, -0.001390, -0.075147, -0.021516, 0.001119, -0.001852, 0.016273, 0.000722, -0.012491, -0.005250, 0.004081, 0.005299, 0.007435, -0.001721, -0.009242, -0.013120, -0.008459, 0.004338, -0.019973, -0.004505, 0.005025, -0.000954, -0.001090, 0.015541, -0.002063, -0.003308, 0.000588, 0.022199, -0.003141, -0.010038, -0.010964, 0.002221, 0.009131, 0.010295, 0.005334, -0.008926, 0.013424, 0.006229, -0.038063, -0.002829, 0.011684, 0.038563, 0.004651, -0.024207, -0.009563, -0.007640, -0.008146, 0.025759, 0.003030, 0.026461, 0.003975, 0.018165, 0.021594, 0.008707, 0.005547, -0.000296, 0.012815, -0.003540, -0.011585, 0.005085, -0.022372, 0.033902, -0.003263, -0.024705, -0.042109, 0.009599, 0.038620, -0.003434, 0.036212, 0.036306, -0.009963, 0.028066, 0.002020, 0.005640, -0.001960, -0.004135, -0.010166, -0.018146, 0.028150, -0.058238, 0.020195, -0.017224, -0.009405, 0.008595, -0.004518, 0.005570, 0.002576, -0.001744, -0.004112, -0.003760, 0.022951, 0.011364, -0.004922, 0.024788, 0.020602, -0.025195, 0.000685, 0.060091, 0.014272, 0.029017, 0.000804, -0.016958, 0.002612, 0.008865, -0.016550, -0.004889, -0.003536, -0.005749, -0.008546, -0.008727, -0.018071, -0.006674, 0.001521, -0.007114, 0.007189, 0.012673, 0.036561, -0.016118, 0.012443, -0.008857, -0.013068, 0.006335, -0.005951, 0.013068, -0.013533, -0.003326, 0.003575, 0.006155, -0.000159, -0.008904, 0.022211, 0.013479, 0.004807, -0.013215, 0.010392, 0.016700, 0.008846, 0.003426, -0.013167, 0.021007, -0.007006, -0.019989, -0.018935, -0.009326, -0.020080, -0.012506, -0.022386, 0.006021, -0.018697, -0.038166, -0.015170, 0.006244, -0.021284, -0.012843, 0.025328, -0.021694, -0.004130, -0.021319, 0.008720, 0.011652, -0.005452, -0.003179, 0.060741, -0.001221, 0.013110, 0.015061, 0.009986, 0.017344, 0.016174, 0.008966, 0.018068, 0.011810, 0.001908, 0.007490, -0.019028, -0.006078, -0.005390, 0.017064, -0.001505, -0.021410, 0.002280, -0.005085, -0.014639, 0.003743, -0.024898, -0.014706, 0.006912, -0.000702, 0.013239, -0.024904, 0.005201, -0.013320, 0.015443, 0.001929, -0.011712, -0.006642, -0.003515, 0.002906, 0.015133, 0.005487, -0.008241, -0.001367, 0.010693, 0.004662, -0.003180, -0.009749, 0.011601, -0.004775, 0.016301, 0.004968, 0.008186, -0.002459, -0.011336, -0.013957, 0.009293, 0.013549, -0.013098, 0.008125, 0.011986, -0.007959, 0.004094, 0.019528, -0.004365, 0.003876, -0.016059, -0.003758, 0.001789, -0.014303, -0.005796, 0.040344, -0.003670, -0.002143, 0.034586, 0.026734, -0.000698, -0.001604, 0.025056, -0.017334, 0.015160, -0.009957, 0.019010, 0.020996, 0.003243, 0.018213, 0.001208, -0.020305, 0.029187, -0.011641, -0.011657, -0.013246, -0.022376, 0.009607, 0.016580, 0.017885, 0.017881, 0.003957, -0.003618, 0.009981, 0.011575, -0.017757, 0.021916, 0.017002, 0.007537, 0.026295, 0.008167, 0.016813, -0.018921, -0.019569, 0.013323, -0.043220, -0.017883, 0.001302, -0.014914, 0.013275, 0.018250, 0.001672, 0.019226, 0.000573, 0.025074, -0.030046, 0.021336, -0.000482, 0.005240, 0.003336, -0.012777, -0.049509, 0.006522, -0.012608, -0.032835, 0.009286, 0.039234, 0.020901, -0.005669, -0.033840, -0.029497, 0.015623, 0.007640, 0.016760, 0.023920, 0.009403, 0.006438, 0.017191, 0.010603, 0.002468, 0.009694, 0.011633, 0.013090, -0.007263, -0.019116, -0.004664, 0.017793, -0.000072, 0.027779, -0.022312, 0.000518, 0.034992, -0.003028, 0.010675, 0.010700, 0.017552, 0.007521, 0.009735, -0.007287, -0.000703, 0.000518, -0.003049, -0.007723, 0.001756, 0.007471, -0.002170, 0.004427, 0.002457, -0.032252, 0.011692, -0.010095, 0.007934, -0.044980, 0.004782, 0.006804, -0.017519, -0.010744, 0.019077, 0.006702, -0.004973, -0.011182, 0.022724, -0.011923, -0.019493, -0.003578, 0.027181, -0.002335, -0.015360, 0.003101, 0.025972, 0.016819, -0.000467, -0.006585, 0.046542, 0.007835, 0.007061, -0.012299, -0.013939, -0.014781, 0.017238, 0.015341, 0.016411, -0.002476, -0.000096, 0.005822, -0.004916, 0.021976, 0.002665, -0.017337, -0.005119, -0.014974, -0.013682, -0.016552, 0.015179, 0.039968, -0.004849, -0.025470, -0.032710, -0.000492, 0.002529, -0.014715, -0.001878, 0.000461, -0.001201, -0.024754, -0.024242, 0.011683, 0.000981, -0.003004, 0.020895, -0.007614, -0.036281, -0.015148, -0.016843, -0.015740, 0.008739, 0.008076, -0.001046, 0.000070, -0.015041, -0.016683, 0.029050, -0.011142, 0.019542, 0.006395, 0.012477, 0.013558, 0.019579, 0.002287, 0.000364, -0.010508, 0.004982, 0.021573, 0.012588, 0.013008, -0.017638, -0.009140, 0.014363, -0.005592, 0.008980, -0.003731, 0.006250, -0.003962, 0.022276, 0.014456, 0.008127, -0.016733, -0.006529, 0.031901, -0.011931, 0.001025, 0.015881, -0.006654, -0.003508, 0.014181, 0.013726, 0.000726, -0.025121, -0.004030, 0.011454, 0.009010, 0.007005, -0.003615, 0.016721, -0.006014, -0.017870, 0.004426, 0.019319, 0.017792, 0.018068, 0.003201, -0.017899, -0.008119, -0.029222, -0.010760, -0.007811, 0.027540, 0.010730, 0.013286, -0.004819, -0.001144, -0.031621, 0.016147, 0.003982, -0.015991, -0.002804, 0.002829, 0.020540, -0.026591, 0.010923, 0.010854, -0.004079, 0.000688, -0.026974, -0.022422, 0.004963, 0.011102, 0.002007, 0.014523, -0.009108, 0.009736, 0.026097, -0.008907, -0.001358, -0.017569, 0.002859, 0.032701, -0.003909, -0.024595, -0.030636, 0.013448, 0.018823, 0.016989, 0.017027, -0.020685, 0.001997, 0.004204, 0.012288, 0.024839, -0.000365, 0.008064, -0.023655, -0.000214, -0.011491, -0.010244, 0.016034, 0.021091, 0.007094, -0.003997, -0.002698, 0.017125, -0.001379, 0.003748, -0.017694, 0.005844, -0.015642, 0.016894, -0.025795, -0.021323, -0.009326, 0.023170, -0.005597, -0.022418, 0.004685, -0.018928, -0.011539, -0.017003, -0.009692, 0.013745, -0.008849, -0.006986, 0.020762, 0.001902, -0.001151, 0.005636, 0.016969, -0.017418, -0.013400, -0.003218, 0.017248, -0.024777, 0.045254, -0.010008, -0.018173, 0.022667, 0.002803, 0.023494, 0.032903, -0.014638, 0.001442, 0.030762, 0.002861, 0.008452, -0.004196, -0.018943, 0.010758, 0.019940, -0.005218, -0.004911, 0.002166, 0.000062, 0.018939, -0.003565, -0.040918, -0.001216, 0.024237, -0.002687, -0.001567, -0.016337, 0.015927, 0.039783, 0.007293, -0.010945, -0.024960, -0.005193, 0.010885, 0.000692, 0.000611, -0.004277, 0.016390, 0.025758, -0.004503, 0.027956, -0.020453, -0.022293, 0.009417, 0.012242, -0.009043, 0.001688, -0.008467, 0.001545, 0.016667, 0.015859, 0.015847, -0.029128, -0.016145, -0.016548, -0.000915, -0.005255, 0.001502, 0.006229, -0.000733, -0.016100, -0.019398, 0.022031, 0.004469, 0.008908, -0.016122, 0.000040, -0.008888, 0.008074, -0.040070, -0.001359, 0.006614, 0.008660, 0.011839, -0.030364, 0.008786, -0.004480, -0.005094, 0.020516, -0.012271, 0.017133, -0.001555, 0.013039, -0.005642, 0.015864, -0.008735, 0.018597, -0.018773, 0.026437, -0.017914, 0.010521, -0.031799, 0.026542, -0.002553, -0.011440, 0.022807, -0.001484, -0.013086, -0.005393, -0.041449, 0.023232, -0.024994, -0.011003, 0.014226, -0.014660, -0.012297, 0.010081, 0.016016, 0.023430, 0.003944, -0.021434, 0.001499, 0.015885, -0.015178, 0.052111, 0.013777, 0.003943, -0.004159, -0.018207, 0.019766, -0.024061, -0.030762, -0.018855, 0.000095, -0.014928, -0.015209, 0.017462, 0.002385, 0.000187, -0.014586, -0.017039, 0.000806, 0.006072, 0.035368, -0.000529, 0.017466, 0.003279, 0.022002, 0.015390, -0.017172, -0.004862, -0.033992, -0.007625, -0.005381, 0.014700, 0.004541, 0.004763, -0.005538, -0.011130, -0.005827, 0.015927, -0.027840, 0.017271, 0.007639, 0.024618, 0.015270, 0.019440, 0.017037, 0.018614, 0.006260, 0.002318, 0.012834, -0.007415, -0.022029, -0.010439, 0.010957, 0.003316, 0.013317, -0.007644, -0.029696, -0.007906, 0.013180, -0.004121, -0.004793, -0.003164, 0.002117, 0.016471, -0.041044, 0.018250, -0.013665, -0.013043, -0.001555, -0.002679, -0.026402, -0.012257, 0.027655, -0.013110, -0.004093, -0.008459, -0.017506, 0.013686, -0.006936, 0.014952, -0.009130, -0.021595, 0.009934, -0.014897, 0.002735, 0.012240, 0.000107, 0.004421, -0.005080, -0.007473, -0.013056, 0.005994, 0.023927, 0.014086, -0.010669, 0.002883, 0.004913, 0.011917, 0.006067, 0.006099, 0.028509, 0.016327, 0.019547, 0.000761, -0.008872, -0.013328, 0.007887, 0.000593, 0.010895, -0.011474, -0.007090, 0.011083, -0.004068, -0.013910, 0.000002, -0.013572, 0.005778, -0.003331, -0.000280, -0.005848, -0.018626, -0.010224, -0.001178, 0.003822, 0.005855, 0.001250, 0.005114, -0.020260, 0.034792, 0.018608, -0.003275, 0.000991, 0.005417, -0.007322, -0.012350, -0.021752, 0.009537, 0.008009, -0.009680, -0.000582, -0.016834, -0.007484, 0.001159, -0.022297, 0.003660, -0.010565, -0.019750, -0.005773, 0.015054, 0.016563, 0.014081, 0.009023, 0.036565, -0.030304, 0.027378, -0.016617, 0.000151, -0.010887, -0.016542, 0.004438, 0.002592, 0.000920, 0.000102, -0.010652, -0.023235, 0.101960, -0.017048, 0.001956, 0.010342, -0.008566, -0.005541, -0.017047, -0.012462, 0.010076, 0.001959, 0.006444, -0.001760, 0.001517, -0.008421, -0.014456, -0.021142, -0.005687, 0.007755, -0.016494, 0.003861, -0.002703, -0.003307, -0.009360, 0.002867, 0.000226, -0.020640, 0.004909, -0.018447, 0.017833, 0.022051, 0.014006, -0.017507, -0.005500, -0.006043, -0.007814, 0.018392, -0.006371, 0.018850, 0.029652, -0.003573, -0.008146, 0.018313, -0.015838, -0.032720, -0.042324, -0.000093, -0.005138, 0.005588, 0.016665, -0.009604, -0.001978, -0.029234, -0.025235, 0.010030, -0.001410, 0.019863, -0.004580, -0.009004, -0.016924, 0.003690, -0.010201, 0.016367, 0.008306, -0.001806, 0.038056, 0.017252, 0.009558, -0.013220, 0.003652, 0.016436, 0.006446, -0.004599, 0.008749, -0.020319, 0.000831, -0.005372, 0.016846, -0.009377, -0.009748, -0.026560, 0.011980, 0.014937, 0.006341, 0.000422, 0.002159, -0.021079, 0.001828, 0.002897, 0.015790, 0.007269, -0.002133, 0.020799, 0.004535, -0.009252, 0.014515, -0.018034, 0.005088, 0.014639, -0.000818, -0.005400, -0.012085, 0.018262, 0.004450, 0.015766, 0.005318, 0.025644, -0.049883, 0.004744, 0.005378, 0.009072, 0.014824, 0.023132, 0.002685, -0.001183, -0.002213, 0.015892, 0.005347, -0.022873, 0.034731, -0.006599, -0.016648, 0.028667, 0.004957, -0.010771, 0.004812, -0.003598, -0.015015, -0.010878, 0.011263, -0.024440, -0.003584, 0.001943, -0.013649, -0.005871, -0.004335, -0.024247, 0.018355, 0.009756, 0.022101, 0.012232, 0.000029, 0.009751, -0.009421, -0.010585, 0.018912, 0.003387, 0.011882, -0.008308, -0.016522, -0.009758, -0.001156, 0.015289, 0.019122, 0.000015, 0.004118, 0.039255, -0.003367, -0.002975, -0.006581, -0.003712, 0.034320, -0.022950, -0.021703, 0.021714, 0.003876, -0.001524, 0.006148, 0.015376, -0.003583, 0.013684, 0.008504, 0.002071, -0.006866, -0.016622, 0.028972, -0.002585, -0.012830, 0.007892, 0.000639, -0.018131, -0.018077, -0.003100, -0.005005, -0.013567, 0.003568, 0.002382, -0.019491, 0.021040, 0.014864, 0.032373, -0.002519, -0.007588, 0.005639, 0.016072, -0.001837, 0.005916, 0.021606, -0.004785, 0.016915, -0.008056, -0.014667, 0.007789, -0.005898, -0.003012, -0.000263, -0.011757, 0.004057, -0.013413, -0.011619, 0.016374, -0.014115, -0.001854, 0.014490, 0.005928, 0.005582, 0.005524, 0.019696, 0.007976, -0.002337, 0.017389, 0.027090, -0.001294, -0.026454, -0.012785, -0.000151, 0.005695, 0.018820, -0.005554, -0.010554, -0.037088, -0.015285, 0.013529, -0.002270, 0.002447, -0.013967, -0.002778, 0.022457, 0.006619, -0.010586, -0.014883, -0.017480, -0.000678, 0.010898, -0.018060, 0.005616, 0.000099, 0.012023, -0.003565, -0.002615, -0.012217, -0.030788, 0.008546, 0.007993, 0.003866, -0.012082, -0.016117, 0.015401, -0.008101, 0.003709, 0.000091, 0.010832, 0.018474, 0.017259, -0.005851, -0.031973, 0.015791, 0.012643, 0.003244, 0.014998, 0.019063, -0.001472, -0.025990, 0.015169, -0.009884, 0.002544, 0.020427, -0.000470, 0.005721, 0.008325, 0.007087, -0.006803, 0.020466, -0.017335, 0.003829, -0.003448, 0.007477, -0.021420, -0.014349, 0.018811, -0.019868, -0.032201, -0.000424, 0.002255, 0.016701, -0.019483, 0.000868, -0.000312, -0.011390, 0.014417, -0.005372, -0.018477, -0.013866, -0.001135, -0.013151, -0.016340, 0.022038, 0.028332, 0.018423, -0.000885, -0.016016, -0.000506, 0.007382, 0.002883, -0.060843, -0.005289, -0.008497, 0.013998, 0.028891, 0.003624, 0.000382, 0.005699, -0.017407, 0.011960, -0.007124, -0.022642, 0.009878, 0.010962, -0.000292, -0.018771, 0.005196, 0.003887, 0.007990, 0.003359, -0.004517, 0.015622, -0.001508, -0.017210, -0.013518, -0.018791, 0.000493, 0.012015, -0.001230, -0.005306, -0.006177, -0.006319, 0.012276, 0.002216, -0.010670, -0.010702, -0.024221, -0.013020, -0.010832, -0.004789, -0.020057, 0.009258, -0.000225, -0.031841, 0.000593, -0.015819, -0.016449, -0.010948, -0.008769, -0.011786, 0.001202, 0.007093, 0.016759, -0.034051, -0.001936, -0.005886, 0.006068, 0.029942, -0.008858, -0.005155, 0.030633, 0.012225, 0.018284, 0.009850, -0.013926, 0.010475, -0.009784, 0.024091, -0.009334, 0.006966, -0.009209, -0.020398, -0.009779, -0.042508, -0.000022, -0.006571, 0.022690, -0.014969, -0.016340, -0.012293, -0.013041, 0.023558, -0.093774, -0.012834, 0.001748, -0.015414, -0.003389, 0.013077, -0.008845, 0.000491, 0.004804, 0.008608, -0.017550, 0.005078, -0.009128, 0.019334, 0.021869, -0.002114, 0.007375, -0.004183, -0.021008, -0.008093, -0.031298, 0.007937, -0.006032, -0.004382, -0.024452, 0.006999, -0.023552, 0.018541, -0.007993, 0.012050, -0.010784, -0.006336, 0.014794, -0.009498, -0.003932, -0.009129, -0.015602, -0.003046, 0.016863, -0.000131, -0.016725, -0.001143, 0.000428, 0.005294, -0.013933, -0.006820, -0.010661, 0.000711, -0.008837, -0.007209, -0.017701, 0.004061, 0.014703, 0.001840, 0.009885, -0.021338, 0.015789, -0.017123, -0.003605, 0.004870, -0.006961, -0.014616, 0.001046, -0.018480, -0.021327, 0.003602, -0.012612, -0.013908, -0.004274, -0.010954, 0.008409, -0.026547, 0.009493, 0.025824, -0.017831, -0.014135, -0.021283, 0.006764, 0.004242, -0.010836, -0.003256, -0.014130, -0.004335, -0.012133, 0.001019, 0.010675, 0.004686, 0.012308, -0.013302, -0.022552, -0.004927, -0.025277, -0.005335, 0.011381, 0.025372, 0.005224, 0.035451, 0.001242, -0.007717, -0.022296, -0.001692, 0.008255, 0.037308, 0.033247, 0.004620, -0.009462, 0.009099, -0.001126, 0.007144, 0.016147, -0.002579, -0.009192, 0.021881, -0.018137, -0.016120, 0.023954, 0.026693, 0.008363, -0.002968, -0.007836, -0.002710, 0.017337, -0.005227, 0.000648, -0.001666, 0.003518, -0.003102, 0.012371, -0.003384, 0.021508, 0.006837, 0.006620, -0.020704, -0.015379, 0.006124, -0.000815, -0.002005, 0.020740, -0.000341, 0.003189, -0.005670, 0.001827, -0.006109, 0.015929, -0.003354, 0.002359, -0.001033, 0.027072, -0.005968, 0.032495, 0.010132, -0.004043, 0.021537, 0.001836, 0.004311, 0.027786, -0.008344, -0.013441, -0.001841, 0.000048, 0.010439, 0.004804, -0.008400, -0.004262, -0.010300, -0.003906, -0.003394, -0.003375, -0.000944, -0.019556, 0.004310, 0.039962, 0.027445, -0.016895, -0.021711, -0.010279, 0.005174, -0.011083, -0.010566, 0.011690, -0.001378, 0.004322, -0.006402, 0.009215, 0.002703, -0.002887, 0.009664, 0.008476, 0.024891, -0.004103, 0.007138, -0.026648, 0.001226, 0.003641, -0.014135, -0.017579, 0.006830, -0.005211, -0.010327, 0.011596, -0.004247, 0.004704, 0.026018, -0.013983, -0.026755, 0.000465, -0.009860, 0.021312, 0.029785, -0.004714, 0.010407, -0.005684, 0.011187, 0.012351, 0.000256, 0.016947, 0.005220, -0.015514, -0.006897, 0.004795, 0.017589, -0.020394, 0.005371, -0.001247, -0.003036, -0.015652, -0.004967, -0.017463, 0.009540, -0.013814, -0.023903, -0.024999, 0.022597, 0.017431, 0.011907, 0.000741, 0.000395, 0.014146, -0.025263, 0.004919, -0.013391, -0.038441, -0.020687, 0.013218, -0.014276, 0.016215, -0.003140, 0.017917, -0.015071, 0.006642, -0.020299, 0.018196, 0.000165, -0.017600, -0.022149, -0.007831, -0.016498, -0.005095, -0.011056, 0.002061, 0.007088, 0.001588, 0.010239, -0.005201, -0.008742, -0.009925, 0.001828, -0.022987, 0.011902, 0.004328, 0.009978, -0.002052, 0.021294, 0.020016, 0.024059, -0.006338, 0.005658, -0.003601, 0.005725, -0.012885, -0.003567, -0.008227, 0.021375, 0.001915, 0.001100, 0.009236, 0.014495, 0.009794, 0.032031, -0.022061, 0.002740, -0.035699, 0.017734, -0.007188, -0.007257, 0.009609, -0.002559, 0.008155, -0.006893, -0.001755, 0.004039, -0.007587, 0.001309, -0.014647, 0.012599, -0.023093, 0.009309, -0.002869, -0.014727, -0.012037, -0.003156, -0.004096, 0.016965, 0.019384, 0.000002, 0.004630, 0.013427, -0.025145, -0.009610, -0.010228, -0.010881, -0.018150, 0.001892, 0.020061, -0.003203, -0.011193, 0.008656, -0.021301, 0.027418, -0.044232, 0.015832, 0.004668, 0.012090, 0.027355, 0.001649, -0.019923, 0.022830, -0.010257, 0.006284, 0.016087, -0.012424, -0.005967, 0.017051, -0.007909, 0.036533, 0.004785, 0.013536, 0.005001, -0.020018, -0.006978, 0.025895, 0.009551, 0.013015, 0.043835, -0.000942, 0.021205, -0.010978, -0.002646, 0.013192, -0.012961, -0.004818, -0.003524, -0.002655, -0.011494, -0.021520, 0.005068, -0.005938, -0.005708, -0.044458, 0.012405, 0.013281, 0.035163, -0.008479, 0.005200, 0.006298, -0.013604, 0.001231, 0.035308, -0.010445, -0.021337, -0.002967, -0.015408, -0.002867, -0.020344, 0.009895, 0.009303, -0.001393, -0.011840, 0.008937, -0.003710, 0.017104, -0.022924, 0.004528, -0.009281, 0.013223, 0.044217, -0.005570, -0.003403, -0.004183, -0.013926, -0.013133, 0.000399, -0.019108, 0.002664, -0.000750, -0.016516, -0.028828, 0.040569, 0.028743, -0.002339, -0.001530, 0.005358, 0.000831, 0.009249, -0.002734, 0.005316, -0.011899, 0.002943, -0.030287, 0.013423, 0.011075, -0.007595, 0.019828, 0.002101, -0.008739, -0.023729, -0.015603, 0.013206, 0.006018, 0.018934, 0.010430, 0.003697, 0.000756, -0.005511, -0.003178, -0.013694, 0.012796, 0.012422, -0.008806, -0.002501, 0.012712, 0.010711, -0.029260, 0.003970, -0.015845, 0.007692, 0.002803, -0.019664, -0.036282, -0.009719, -0.008491, 0.015094, -0.011584, 0.030892, 0.013343, -0.004266, 0.018714, -0.013045, 0.016724, -0.017056, -0.020940, 0.013137, -0.006146, -0.003954, -0.005386, -0.005451, -0.024084, -0.020846, -0.023871, 0.022407, 0.016726, 0.010994, -0.019705, -0.012944, -0.011736, -0.016800, 0.016154, 0.037629, -0.003424, -0.014073, 0.012585, -0.009890, -0.012940, -0.011151, -0.003971, -0.057809, -0.003960, 0.005835, 0.019932, -0.002561, -0.014869, -0.005486, -0.022046, 0.023892, -0.004431, 0.020073, 0.014859, 0.031115, -0.005770, 0.025947, -0.015123, 0.004376, 0.000216, 0.014397, -0.010126, 0.019828, -0.023535, -0.006688, -0.004238, 0.003923, -0.001913, 0.013896, -0.020261, -0.000211, -0.017331, -0.004555, 0.017812, 0.019288, -0.000217, 0.029406, 0.004155, 0.021335, -0.004290, -0.000817, -0.009499, -0.004489, 0.016444, -0.003451, 0.017868, 0.010835, -0.003867, 0.009663, -0.006793, -0.008544, 0.004528, -0.004728, -0.013121, 0.001463, -0.005069, -0.009838, -0.009709, -0.011030, 0.020544, 0.016891, -0.008413, 0.009241, -0.020013, 0.004533, -0.000644, -0.003671, 0.002443, 0.005531, 0.012014, -0.015547, 0.032867, -0.010345, -0.002676, 0.015909, -0.014906, -0.000076, 0.001226, 0.004505, -0.001700, 0.015689, 0.008147, -0.008414, 0.018174, -0.023157, -0.002639, 0.001052, -0.011672, 0.034863, 0.007373, 0.003746, 0.020635, -0.001293, -0.007285, -0.005697, 0.011585, 0.011641, 0.005903, -0.015321, -0.019069, 0.001474, -0.002128, -0.001428, -0.013760, 0.006829, 0.008988, -0.000711, 0.011416, 0.013476, 0.005795, 0.010170, 0.019851, -0.006154, -0.025921, 0.028450, 0.015028, -0.007077, -0.005037, 0.014004, 0.001087, 0.002007, 0.034379, 0.022603, 0.011928, 0.010591, 0.013465, -0.007109, 0.008027, 0.009830, 0.015632, -0.008051, 0.000997, 0.005566, -0.012794, -0.004673, 0.020111, 0.010783, 0.014006, -0.014628, -0.005732, 0.002201, -0.014574, 0.002813, -0.008896, 0.005612, -0.010093, 0.009163, 0.032559, -0.008701, 0.020292, -0.003993, -0.005449, 0.047966, -0.018202, -0.009047, 0.002837, 0.000065, 0.012356, -0.001373, -0.011784, 0.003937, 0.019357, 0.017805, 0.017216, 0.013104, 0.006715, -0.001720, 0.000405, 0.000307, 0.000554, 0.023089, -0.008537, 0.003480, 0.011162, 0.028809, -0.028040, 0.015474, -0.020833, -0.027994, -0.006712, 0.013394, 0.030260, 0.019467, -0.004163, 0.022857, 0.005267, -0.002363, 0.003092, 0.002819, 0.002405, 0.015851, 0.002322, -0.015501, 0.002296, -0.001550, -0.010819, -0.003535, -0.031813, -0.009316, 0.039037, 0.013429, -0.004979, 0.007779, -0.008625, 0.022815, -0.003904, -0.045269, -0.002160, 0.024076, 0.025454, 0.021919, 0.001096, 0.009633, 0.003316, -0.002369, 0.010195, -0.008795, -0.012336, -0.002107, -0.002632, 0.024025, -0.012306, -0.002433, -0.011226, 0.013108, 0.013018, -0.012350, -0.008192, 0.008925, -0.012370, -0.006352, 0.008407, 0.000015, -0.018896, -0.002139, 0.022095, -0.021460, -0.028576, 0.000539, -0.014611, 0.008801, 0.005960, -0.003063, -0.014593, -0.002300, 0.014949, 0.008666, 0.001127, 0.001032, 0.008120, 0.008157, 0.012913, -0.013515, -0.016257, -0.007906, -0.001303, 0.000938, -0.002996, 0.018525, 0.001288, 0.002526, -0.009045, -0.000344, 0.043224, 0.008138, -0.026972, -0.036458, -0.005729, -0.014122, -0.022735, -0.026821, 0.032595, 0.033623, -0.018001, 0.025555, 0.032314, 0.026786, 0.007475, -0.029497, -0.009322, -0.007920, 0.000674, -0.025910, -0.006128, 0.005208, 0.004453, -0.001005, 0.002987, 0.018911, 0.013004, 0.003351, -0.031948, -0.011428, -0.005920, 0.011624, -0.007659, -0.012838, 0.017949, 0.019967, 0.007770, 0.028680, 0.026473, 0.003621, 0.004719, -0.001039, -0.008586, -0.007310, 0.016093, 0.000288, -0.009725, 0.003192, -0.001937, 0.004935, 0.016688, 0.012862, 0.006541, 0.010781, 0.004346, -0.006409, -0.017303, -0.027135, -0.014983, -0.020721, -0.001405, 0.007737, 0.017811, 0.025829, -0.018499, 0.033446, -0.008173, 0.025652, -0.002779, -0.036197, 0.035352, -0.023758, -0.006748, -0.018221, 0.025066, 0.007794, 0.019503, 0.000028, -0.006134, -0.015939, 0.001823, -0.003823, 0.015975, 0.008037, -0.001529, 0.013992, 0.006494, 0.006369, -0.025116, -0.011833, -0.022861, 0.027657, -0.011983, 0.043682, 0.015773, 0.004072, 0.015403, -0.000675, -0.016528, 0.000549, -0.005763, -0.014040, 0.009783, 0.017142, -0.005261, 0.011948, -0.015775, 0.016749, 0.003973, -0.011806, -0.008391, 0.002960, 0.024964, -0.013062, -0.002914, 0.010767, -0.004168, -0.016841, -0.017812, -0.003733, -0.014694, -0.009252, 0.001167, -0.008773, -0.002801, -0.012193, -0.020946, -0.007018, 0.005875, 0.007153, 0.019214, -0.024508, 0.011197, -0.022353, -0.015757, -0.001978, -0.011591, 0.004546, 0.025104, -0.065326, -0.000773, -0.001352, 0.007841, -0.024962, 0.000702, -0.026619, -0.004280, -0.001047, 0.008003, -0.010382, -0.002991, 0.028899, 0.019583, 0.018180, 0.002436, 0.024709, -0.008453, 0.000247, 0.010228, -0.029503, 0.025050, 0.000883, -0.003178, -0.029948, -0.000435, -0.000472, 0.018712, -0.022477, 0.026339, -0.014619, 0.034320, -0.015215, -0.020927, -0.007180, -0.019584, -0.009599, 0.014168, -0.021145, -0.009943, 0.003515, 0.004487, 0.024646, 0.015464, 0.012207, 0.004644, -0.001743, 0.017812, -0.016363, 0.015604, 0.009810, 0.013646, -0.017065, -0.022975, 0.005958, 0.022131, 0.005337, 0.027475, -0.001793, -0.007428, 0.025136, 0.009438, -0.013006, 0.027094, 0.009010, -0.015209, -0.005203, -0.016260, 0.015376, -0.005864, -0.009721, -0.007073, -0.010791, -0.009775, -0.004087, 0.005616, 0.026532, -0.004916, -0.006672, -0.003260, -0.029856, 0.005316, 0.002688, 0.001484, -0.008383, -0.000859, 0.013535, -0.010640, 0.019049, -0.018435, -0.009252, -0.040983, 0.017240, -0.002047, -0.000440, -0.013463, 0.002644, -0.011559, -0.010040, -0.002651, 0.019927, -0.013669, 0.022540, -0.012926, 0.001236, 0.014573, 0.002343, -0.014226, 0.027290, -0.002921, 0.017484, -0.001449, 0.011943, 0.024709, 0.009541, -0.019108, 0.014974, 0.003024, 0.013132, -0.025489, 0.000146, -0.019084, -0.021294, 0.001326, 0.002435, -0.011663, -0.017603, 0.003963, -0.036307, 0.002392, 0.002865, -0.005224, 0.008638, 0.017922, 0.022342, 0.008152, 0.009601, -0.006566, 0.019362, 0.016708, 0.009775, 0.010324, -0.015168, -0.013288, 0.005146, 0.003812, 0.018278, 0.003083, 0.005744, 0.023287, 0.031175, -0.010414, 0.015014, -0.006324, -0.002789, 0.014685, -0.017423, 0.007766, -0.001774, -0.003090, 0.017457, -0.005279, -0.017417, 0.002423, 0.005983, -0.000100, -0.003857, -0.011357, 0.006177, -0.013336, 0.015153, -0.005920, 0.008527, 0.012864, -0.016132, 0.030787, 0.000947, 0.005183, -0.010514, -0.002621, 0.000255, -0.003783, -0.006133, 0.013980, -0.001777, 0.056108, 0.003873, 0.015884, 0.018199, -0.007962, -0.008492, 0.001115, 0.027627, -0.024726, -0.001219, 0.007042, -0.014110, 0.008502, 0.003677, -0.003958, -0.019703, 0.006495, -0.028127, 0.022678, 0.012794, -0.023379, -0.007287, 0.048726, 0.003756, -0.007684, 0.006204, -0.006883, -0.020040, 0.011968, 0.004371, 0.006712, 0.001578, 0.010776, 0.006446, 0.004054, -0.008785, -0.004622, 0.015532, 0.024050, -0.004696, 0.005004, 0.047457, 0.016234, -0.013847, -0.003231, 0.001491, -0.006422, -0.007664, 0.000159, 0.022566, 0.015440, 0.014559, 0.017279, -0.001534, 0.010613, 0.008598, 0.012202, 0.012841, -0.008078, -0.022893, -0.011717, 0.000308, 0.012272, 0.010973, -0.015097, 0.011852, -0.002605, -0.002470, -0.019322, 0.028891, 0.028117, -0.012933, -0.011691, -0.001118, 0.005864, 0.014090, -0.001960, -0.011202, 0.028238, 0.003765, 0.012505, 0.016722, -0.008177, -0.013111, -0.011508, -0.009641, 0.029831, -0.004879, -0.005077, -0.032950, -0.023307, -0.007336, 0.000371, 0.003904, -0.007409, -0.006993, 0.004508, -0.004245, 0.018904, -0.006684, 0.015313, 0.002921, -0.011040, -0.002182, 0.012927, 0.015248, 0.024934, -0.006216, 0.022404, 0.020427, -0.002124, -0.014849, -0.008371, -0.020697, 0.006873, -0.006633, -0.053067, 0.020209, -0.007383, -0.000910, -0.002166, -0.002127, 0.015922, -0.005251, -0.010564, -0.017466, 0.009684, 0.005759, 0.019960, 0.033270, 0.012287, -0.000100, -0.006858, 0.006734, 0.025588, 0.005950, -0.006609, 0.013516, 0.005205, 0.002285, -0.000606, 0.037977, -0.003159, -0.001248, -0.028815, 0.029236, -0.010141, -0.004224, -0.004079, 0.017874, -0.015790, 0.016889, 0.012398, 0.009094, 0.015814, 0.014839, -0.005277, -0.002516, -0.013254, 0.004981, -0.009693, -0.001848, 0.004601, 0.014224, 0.031232, -0.024458, 0.002672, -0.009351, 0.007968, 0.015294, 0.001275, 0.016084, 0.023957, -0.002103, 0.003466, 0.002882, -0.004706, -0.013566, 0.005811, 0.008504, -0.001594, 0.003923, 0.009754, 0.024187, -0.004900, 0.042902, 0.009232, -0.007876, -0.028139, 0.007731, -0.014280, -0.012051, 0.011560, 0.005460, -0.007022, 0.016041, 0.001226, 0.001036, -0.018123, 0.005008, -0.017481, -0.011131, -0.008724, 0.001174, 0.021142, 0.005457, 0.008019, 0.000416, 0.044647, -0.014745, 0.016356, -0.003489, -0.009914, -0.027851, 0.015946, 0.015048, 0.014239, -0.018476, 0.018326, 0.010209, 0.003926, -0.021824, 0.015518, 0.006414, 0.015442, -0.013744, 0.003908, -0.007205, 0.020117, 0.003079, 0.018629, -0.004714, -0.011078, 0.009998, -0.014231, 0.000272, 0.009207, -0.006673, -0.007812, -0.005450, 0.007986, -0.012454, 0.017887, -0.007598, 0.006841, 0.006608, -0.001630, -0.032937, 0.007242, -0.012077, 0.001551, -0.001544, 0.008404, -0.017311, 0.011169, -0.011868, -0.005969, 0.020850, 0.001749, -0.015342, 0.018045, -0.032569, 0.013476, 0.003299, 0.018967, -0.014875, 0.002329, -0.022181, -0.024022, -0.005146, 0.001831, 0.040513, -0.011674, -0.013500, -0.004823, 0.004701, 0.008984, 0.008307, -0.020678, -0.000369, 0.005629, 0.006012, 0.019657, 0.002078, 0.024057, -0.006275, -0.006897, -0.011263, -0.016342, 0.010246, 0.009262, -0.003426, -0.030797, -0.008025, -0.018327, -0.010448, 0.003397, 0.011466, 0.007359, 0.012101, 0.010924, -0.011682, 0.004035, -0.013045, 0.008647, 0.006065, -0.001575, -0.012142, -0.002077, -0.011405, -0.005507, 0.022302, -0.005649, -0.020798, -0.000002, 0.003886, 0.016590, -0.012743, -0.005405, 0.025257, -0.010380, 0.009266, 0.034503, 0.005400, -0.000897, -0.015368, 0.014782, -0.004936, -0.009334, -0.005339, 0.026014, -0.017054, 0.001600, 0.001218, -0.007068, -0.001957, -0.025544, -0.018504, -0.009509, -0.001614, 0.012602, 0.010126, 0.006851, 0.003219, -0.019446, -0.015196, -0.008934, -0.014628, -0.007057, -0.011605, 0.000150, -0.005216, 0.005098, 0.015251, -0.020344, 0.009604, 0.000319, -0.003701, -0.005448, 0.006905, -0.008316, 0.017366, -0.008151, -0.008407, -0.003523, -0.017969, -0.016836, -0.004794, -0.010536, 0.000728, 0.015541, 0.007297, 0.020033, 0.014789, 0.008347, -0.014012, -0.009191, -0.011634, 0.001047, -0.004078, -0.032395, -0.007056, -0.006072, 0.017237, -0.010169, -0.002082, -0.012120, -0.016639, 0.003092, 0.015185, 0.024824, 0.016023, -0.012708, -0.021427, 0.006886, -0.013677, -0.015014, -0.044346, -0.008770, -0.006189, -0.003424, 0.015988, -0.001982, -0.012779, -0.008387, 0.013574, -0.021576, 0.024866, -0.001767, 0.016229, 0.019212, 0.014702, 0.012229, -0.010593, 0.008917, 0.005296, -0.010568, 0.001254, -0.007352, 0.004297, 0.014843, -0.000107, 0.008463, -0.001418, -0.004644, -0.012905, -0.009090, 0.026492, 0.010213, 0.004233, 0.013997, 0.027515, -0.006066, -0.008489, -0.009575, 0.005357, -0.002819, 0.012562, -0.001755, -0.021962, 0.028935, 0.019488, -0.000949, 0.046604, -0.005296, -0.017023, 0.013773, 0.014090, 0.001501, -0.019164, -0.018293, -0.015383, -0.003715, -0.017315, -0.001150, 0.010238, 0.013849, 0.007795, -0.009511, -0.005155, -0.002302, -0.030114, -0.000130, -0.026576, -0.011283, -0.014823, 0.010253, -0.008892, 0.003799, 0.022423, 0.008893, 0.008317, 0.005401, -0.020923, -0.007466, 0.012005, 0.017001, -0.002136, -0.015182, -0.005745, 0.013719, 0.008295, 0.001176, 0.009355, 0.005882, -0.005260, -0.002908, 0.010056, -0.009662, 0.012986, 0.011305, 0.003803, -0.003433, 0.004537, 0.000067, 0.010801, 0.007846, 0.000285, 0.002437, 0.013732, -0.000136, 0.016116, -0.009388, -0.028867, 0.015769, -0.009709, 0.007497, 0.004109, 0.005993, -0.016153, 0.019901, -0.000855, -0.008279, 0.019159, -0.000092, -0.009310, -0.002351, -0.018273, 0.016507, -0.015380, 0.012349, 0.010352, -0.015944, -0.019335, -0.022712, 0.001215, 0.018114, 0.007969, 0.018223, 0.003234, 0.034172, 0.011018, -0.008155, -0.003510, -0.055344, -0.008135, 0.002642, -0.016795, 0.008877, -0.032482, 0.020419, -0.006696, -0.019295, 0.012625, -0.010185, 0.013550, 0.006444, 0.010389, -0.024405, 0.004697, -0.006485, -0.006578, 0.012772, -0.003774, 0.010779, -0.016861, 0.006165, 0.000787, 0.009266, -0.022250, 0.022082, 0.017113, 0.009719, 0.017075, 0.022148, -0.003797, 0.012725, 0.010799, -0.013215, 0.005114, -0.001655, 0.011833, 0.005005, -0.011579, 0.017267, 0.016029, -0.001224, -0.022074, -0.016803, -0.000424, -0.001202, -0.000810, 0.008266, -0.001548, -0.008027, -0.009839, -0.010535, 0.006481, 0.007259, 0.036436, -0.001130, -0.040727, 0.013779, 0.001592, -0.025086, 0.014909, 0.015090, 0.000607, -0.011528, -0.000361, -0.011167, 0.017288, 0.011351, 0.007367, 0.004935, 0.006236, 0.007734, 0.000988, 0.013061, -0.001021, -0.010461, -0.010896, -0.003681, 0.000912, 0.008134, 0.006001, -0.001354, -0.008708, 0.000808, 0.009372, 0.001674, 0.013925, -0.007881, -0.019357, 0.010177, 0.010353, 0.010437, 0.002275, 0.013243, 0.007057, -0.025082, -0.016305, -0.011859, 0.009043, -0.025206, -0.004534, 0.004452, -0.006219, -0.004143, 0.003171, 0.007724, 0.008987, 0.000169, 0.013302, 0.006299, 0.007297, 0.016530, -0.023153, 0.010610, -0.003937, 0.014075, 0.007115, -0.022596, 0.033165, -0.018001, 0.004484, -0.021658, -0.009335, 0.018308, 0.008660, -0.018094, -0.008676, 0.012363, 0.012372, -0.027894, 0.005208, -0.009708, -0.012411, 0.016168, -0.000753, 0.008323, -0.001104, -0.021953, 0.019505, 0.003769, -0.006093, -0.002687, -0.005854, -0.000456, 0.011746, 0.006733, -0.008888, 0.005915, 0.059429, -0.016033, 0.005489, 0.006918, 0.004521, -0.000074, -0.001815, -0.001053, -0.001441, 0.016997, -0.006283, -0.002506, -0.006677, -0.057522, -0.005242, 0.009781, 0.008025, 0.022570, -0.002895, 0.006240, -0.011600, 0.010416, 0.037326, 0.019760, -0.006887, -0.015520, -0.003295, 0.004990, -0.000308, 0.036522, 0.004534, 0.006574, 0.001989, 0.007110, 0.002790, -0.005056, -0.035095, 0.001534, -0.007598, 0.004227, 0.000755, 0.026881, 0.020648, 0.011147, -0.044771, -0.004757, -0.009377, 0.005339, -0.005997, -0.003715, 0.029131, 0.002203, 0.007175, -0.007573, -0.003617, -0.017187, -0.009481, 0.014506, 0.000882] } diff --git a/integration/utils_test.go b/integration/utils_test.go index f8ec13f3..c438aa93 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -15,6 +15,7 @@ import ( "net/http" "net/url" "os" + "os/exec" "path/filepath" "runtime" "strconv" @@ -24,7 +25,6 @@ import ( "time" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/app/lifecycle" "github.com/ollama/ollama/format" ) @@ -38,6 +38,7 @@ var ( // Note: add newer models at the top of the list to test them first ollamaEngineChatModels = []string{ + "qwen3-coder:30b", "gpt-oss:20b", "gemma3n:e2b", "mistral-small3.2:latest", @@ -46,6 +47,7 @@ var ( "qwen2.5-coder:latest", "qwen2.5vl:3b", "qwen3:0.6b", // dense + "qwen3:1.7b", // dense "qwen3:30b", // MOE "gemma3:1b", "llama3.1:latest", @@ -264,18 +266,17 @@ var ( rainbowFollowups = []string{ "Explain the physics involved in them. Be breif in your reply", "Explain the chemistry involved in them. Be breif in your reply", - "Explain the quantum mechanics involved in them. Be breif in your reply", "What are common myths related to them? Be brief in your reply", - "What are common fairytales related to them? Be brief in your reply", "Can they form if there is no rain? Be breif in your reply", "Can they form if there are no clouds? Be breif in your reply", "Do they happen on other planets? Be brief in your reply", } - rainbowExpected = []string{"water", "droplet", "mist", "glow", "refracted", "reflect", "color", "spectrum", "frequency", "end", "gold", "fortune", "blessing", "prosperity"} + rainbowExpected = []string{"water", "droplet", "mist", "glow", "refract", "reflect", "scatter", "particles", "wave", "color", "spectrum", "raindrop", "atmosphere", "frequency", "shower", "sky", "shimmer", "light", "storm", "sunny", "sunburst", "phenomenon", "mars", "venus", "jupiter"} ) func init() { - lifecycle.InitLogging() + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + slog.SetDefault(logger) custom := os.Getenv("OLLAMA_TEST_DEFAULT_MODEL") if custom != "" { slog.Info("setting default test model to " + custom) @@ -336,6 +337,7 @@ func GetTestEndpoint() (*api.Client, string) { var serverMutex sync.Mutex var serverReady bool +var serverLogFile string func startServer(t *testing.T, ctx context.Context, ollamaHost string) error { // Make sure the server has been built @@ -362,8 +364,9 @@ func startServer(t *testing.T, ctx context.Context, ollamaHost string) error { t.Setenv("OLLAMA_HOST", ollamaHost) } + logDir := t.TempDir() slog.Info("starting server", "url", ollamaHost) - done, err := lifecycle.SpawnServer(ctx, "../ollama") + done, err := SpawnServer(ctx, "../ollama", logDir) if err != nil { return fmt.Errorf("failed to start server: %w", err) } @@ -386,6 +389,36 @@ func startServer(t *testing.T, ctx context.Context, ollamaHost string) error { return nil } +func SpawnServer(ctx context.Context, command, logDir string) (chan int, error) { + done := make(chan int) + fp, err := os.CreateTemp(logDir, "ollama-server-*.log") + if err != nil { + return nil, fmt.Errorf("failed to create log file: %w", err) + } + serverLogFile = fp.Name() + + cmd := exec.CommandContext(ctx, command, "serve") + cmd.Stderr = fp + cmd.Stdout = fp + + go func() { + slog.Info("starting server...") + if err := cmd.Run(); err != nil { + // "signal: killed" expected + if !strings.Contains(err.Error(), "signal") { + slog.Info("failed to run server", "error", err) + } + } + var code int + if cmd.ProcessState != nil { + code = cmd.ProcessState.ExitCode() + } + slog.Info("server exited") + done <- code + }() + return done, nil +} + func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error { slog.Info("checking status of model", "model", modelName) showReq := &api.ShowRequest{Name: modelName} @@ -446,51 +479,59 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin client, testEndpoint := GetTestEndpoint() if os.Getenv("OLLAMA_TEST_EXISTING") == "" { serverProcMutex.Lock() - fp, err := os.CreateTemp("", "ollama-server-*.log") - if err != nil { - t.Fatalf("failed to generate log file: %s", err) - } - lifecycle.ServerLogFile = fp.Name() - fp.Close() if err := startServer(t, ctx, testEndpoint); err != nil { t.Fatal(err) } } + // Make sure server is online and healthy before returning + listCtx, cancel := context.WithDeadlineCause( + ctx, + time.Now().Add(120*time.Second), + fmt.Errorf("list models took too long"), + ) + defer cancel() + models, err := client.ListRunning(listCtx) + if err != nil { + t.Fatal(err) + } + if len(models.Models) > 0 { + names := make([]string, len(models.Models)) + for i, m := range models.Models { + names[i] = m.Name + } + slog.Info("currently loaded", "models", names) + } return client, testEndpoint, func() { if os.Getenv("OLLAMA_TEST_EXISTING") == "" { defer serverProcMutex.Unlock() if t.Failed() { - fp, err := os.Open(lifecycle.ServerLogFile) + fp, err := os.Open(serverLogFile) if err != nil { - slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err) + slog.Error("failed to open server log", "logfile", serverLogFile, "error", err) return } defer fp.Close() data, err := io.ReadAll(fp) if err != nil { - slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err) + slog.Error("failed to read server log", "logfile", serverLogFile, "error", err) return } slog.Warn("SERVER LOG FOLLOWS") os.Stderr.Write(data) slog.Warn("END OF SERVER") } - err := os.Remove(lifecycle.ServerLogFile) - if err != nil && !os.IsNotExist(err) { - slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err) - } } } } -func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) { +func ChatTestHelper(ctx context.Context, t *testing.T, req api.ChatRequest, anyResp []string) { client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() - if err := PullIfMissing(ctx, client, genReq.Model); err != nil { + if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatal(err) } - DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second) + DoChat(ctx, t, client, req, anyResp, 30*time.Second, 10*time.Second) } func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) []int { @@ -577,7 +618,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { KeepAlive: &api.Duration{Duration: 10 * time.Second}, }, { Model: smol, - Prompt: "how do rainbows form? Be brief but factual in your reply", + Prompt: rainbowPrompt, Stream: &stream, KeepAlive: &api.Duration{Duration: 10 * time.Second}, }, { @@ -595,7 +636,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { [][]string{ {"sunlight", "scatter", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorb", "wavelength", "water", "molecule"}, {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigment", "particle", "iron oxide", "rust", "air", "water", "wet", "mixture", "mixing", "mineral", "element", "decomposed", "matter", "wavelength"}, - {"water", "droplet", "refract", "reflect", "color", "spectrum", "raindrop"}, + rainbowExpected, {"fourth", "july", "declaration", "independence"}, {"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor", "fluid", "particles", "gas"}, } @@ -702,6 +743,13 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) { // Skip if the target model isn't X% GPU loaded to avoid excessive runtime func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) { + gpuPercent := getGPUPercent(ctx, t, client, model) + if gpuPercent < minPercent { + t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent)) + } +} + +func getGPUPercent(ctx context.Context, t *testing.T, client *api.Client, model string) int { models, err := client.ListRunning(ctx) if err != nil { t.Fatalf("failed to list running models: %s", err) @@ -709,8 +757,14 @@ func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, m loaded := []string{} for _, m := range models.Models { loaded = append(loaded, m.Name) - if m.Name != model { - continue + if strings.Contains(model, ":") { + if m.Name != model { + continue + } + } else if strings.Contains(m.Name, ":") { + if !strings.HasPrefix(m.Name, model+":") { + continue + } } gpuPercent := 0 switch { @@ -725,12 +779,10 @@ func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, m cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110) gpuPercent = int(100 - cpuPercent) } - if gpuPercent < minPercent { - t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent)) - } - return + return gpuPercent } - t.Skip(fmt.Sprintf("model %s not loaded - actually loaded: %v", model, loaded)) + t.Fatalf("model %s not loaded - actually loaded: %v", model, loaded) + return 0 } func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) { diff --git a/kvcache/causal.go b/kvcache/causal.go index 31f55233..543a65a6 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -160,7 +160,15 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity if c.swaMemorySize == 0 { c.swaMemorySize = c.swaWindowSize } - if int(c.swaMemorySize) > capacity { + // We will allocate space in the cache for the stop token, which won't be part of a follow on + // sequence, so allocate an extra token of storage to ensure that we can jump back without + // causing a cache break. As an optimization, only do this when we have parallel sequences + // because the extra token will live in the batch buffer and won't get overwritten if we + // only have a single sequence. + if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 { + c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1) + } + if int(c.swaMemorySize) >= capacity { c.swaMemorySize = math.MaxInt32 } @@ -214,7 +222,6 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e c.curLoc, err = c.findStartLoc() } if err != nil { - slog.Warn("unable to find a kv cache slot", "cache", c) return err } @@ -288,23 +295,44 @@ func (c *Causal) updateSlidingWindow() { return } + type lowestPosition struct { + pos int32 + curBatch bool + } + // create a map of unique sequences to the lowest position in that sequence - lowestPos := make(map[int]int32) + lowestPos := make(map[int]lowestPosition) for i := range c.curPositions { seq := c.curSequences[i] - pos, ok := lowestPos[seq] + lowest, ok := lowestPos[seq] if !ok { - pos = c.curPositions[i] - } else if c.curPositions[i] < pos { - pos = c.curPositions[i] + lowest = lowestPosition{pos: c.curPositions[i], curBatch: true} + } else if c.curPositions[i] < lowest.pos { + lowest.pos = c.curPositions[i] } - lowestPos[seq] = pos + lowestPos[seq] = lowest + } + + // for any sequences are not part of this batch, clean up any tokens + // that are no longer needed after the processing of the previous + // batch + for seq, seqRange := range c.cellRanges { + if _, ok := lowestPos[seq]; !ok { + var last int32 + for i := seqRange.min; i <= seqRange.max; i++ { + if slices.Contains(c.cells[i].sequences, seq) { + last = max(last, c.cells[i].pos) + } + } + + lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false} + } } // delete any entries that are beyond the window of the oldest position in the sequence - for seq, pos := range lowestPos { + for seq, lowest := range lowestPos { oldRange, ok := c.cellRanges[seq] if !ok { continue @@ -314,13 +342,13 @@ func (c *Causal) updateSlidingWindow() { for i := oldRange.min; i <= oldRange.max; i++ { if slices.Contains(c.cells[i].sequences, seq) { - if c.cells[i].pos < pos-c.swaMemorySize { + if c.cells[i].pos < lowest.pos-c.swaMemorySize { c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq }) } else { newRange.min = min(newRange.min, i) newRange.max = max(newRange.max, i) } - if c.cells[i].pos >= pos-c.swaWindowSize { + if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize { c.curCellRange.min = min(c.curCellRange.min, i) c.curCellRange.max = max(c.curCellRange.max, i) } @@ -657,9 +685,11 @@ func (c *Causal) CanResume(seq int, pos int32) bool { // for sliding window, check that the window of the new sequence is contained in // the window of what we are storing + var first int32 = math.MaxInt32 var last int32 = -1 for i := seqRange.min; i <= seqRange.max; i++ { if slices.Contains(c.cells[i].sequences, seq) { + first = min(first, c.cells[i].pos) last = max(last, c.cells[i].pos) } } @@ -668,10 +698,8 @@ func (c *Causal) CanResume(seq int, pos int32) bool { return false } - lastWindowStart := max(0, last-c.swaMemorySize) posWindowStart := max(0, pos-c.swaWindowSize) - - return posWindowStart >= lastWindowStart + return posWindowStart >= first && pos <= last+1 } func (c *Causal) shift(seq int, beginIndex, offset int32) error { diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 0d8cea79..7e4fc3b1 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -96,6 +96,86 @@ func TestSWA(t *testing.T) { testCache(t, backend, cache, tests) } +func TestSWASeparateBatches(t *testing.T) { + backend := &testBackend{} + cache := NewSWACache(1, nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 2, 16, 2) + + x := float32(math.Inf(-1)) + + tests := []testCase{ + { + name: "First seq 0", + in: []float32{1, 2}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{0, 1}, + expected: []float32{1, 2}, + expectedShape: []int{1, 1, 2}, + expectedMask: []float32{ + 0, x, + 0, 0, + }, + }, + { + name: "Second seq 0", + in: []float32{3, 4}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{2, 3}, + expected: []float32{2, 3, 4}, + expectedShape: []int{1, 1, 3}, + expectedMask: []float32{ + 0, 0, x, + x, 0, 0, + }, + }, + { + name: "First seq 1", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{1, 1}, + pos: []int32{0, 1}, + expected: []float32{5, 6}, + expectedShape: []int{1, 1, 2}, + expectedMask: []float32{ + 0, x, + 0, 0, + }, + }, + { + name: "Second seq 1", + in: []float32{7, 8}, + inShape: []int{1, 1, 2}, + seqs: []int{1, 1}, + pos: []int32{2, 3}, + expected: []float32{6, 3, 4, 7, 8}, + expectedShape: []int{1, 1, 5}, + expectedMask: []float32{ + 0, x, x, 0, x, + x, x, x, 0, 0, + }, + }, + { + name: "Third seq 0", + in: []float32{9, 10}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{4, 5}, + expected: []float32{9, 10, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{ + 0, x, x, 0, + 0, 0, x, x, + }, + }, + } + + testCache(t, backend, cache, tests) +} + func TestSWAMem(t *testing.T) { backend := &testBackend{} cache := NewSWAMemCache(1, 3, nil) @@ -431,15 +511,15 @@ func TestCanResume(t *testing.T) { defer context.Close() err := cache.StartForward(context, input.Batch{ - Positions: []int32{0, 1, 2, 3}, - Sequences: []int{0, 0, 0, 0}, + Positions: []int32{0, 1, 2, 3, 4}, + Sequences: []int{0, 0, 0, 0, 0}, }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } cache.SetLayer(0) - tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) + tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5}, 1, 1, 5) cache.Put(context, tensor, tensor) // with window size 4, nothing has slid out of the window yet @@ -455,18 +535,21 @@ func TestCanResume(t *testing.T) { if !cache.CanResume(0, 3) { t.Errorf("CanResume(0, 3) = false, want true (latest position)") } + if !cache.CanResume(0, 4) { + t.Errorf("CanResume(0, 4) = false, want true (latest position)") + } - // shift window by adding position 4 + // shift window by adding position 5 err = cache.StartForward(context, input.Batch{ - Positions: []int32{4, 5}, - Sequences: []int{0, 0}, + Positions: []int32{5}, + Sequences: []int{0}, }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } cache.SetLayer(0) - tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) + tensor = context.FromFloatSlice([]float32{6}, 1, 1, 1) cache.Put(context, tensor, tensor) // only the latest position has overlapping windows @@ -503,28 +586,28 @@ func TestCanResumeSWAMem(t *testing.T) { defer context.Close() err := cache.StartForward(context, input.Batch{ - Positions: []int32{0, 1, 2, 3, 4, 5}, - Sequences: []int{0, 0, 0, 0, 0, 0}, + Positions: []int32{0, 1, 2, 3, 4, 5, 6}, + Sequences: []int{0, 0, 0, 0, 0, 0, 0}, }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } cache.SetLayer(0) - tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6) + tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7) cache.Put(context, tensor, tensor) - // shift window by adding position 6 + // shift window by adding position 7 err = cache.StartForward(context, input.Batch{ - Positions: []int32{6, 7}, - Sequences: []int{0, 0}, + Positions: []int32{7}, + Sequences: []int{0}, }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } cache.SetLayer(0) - tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2) + tensor = context.FromFloatSlice([]float32{8}, 1, 1, 1) cache.Put(context, tensor, tensor) // only the latest position has overlapping windows diff --git a/llama/build-info.cpp b/llama/build-info.cpp index 18c44961..ea711c87 100644 --- a/llama/build-info.cpp +++ b/llama/build-info.cpp @@ -1,4 +1,4 @@ int LLAMA_BUILD_NUMBER = 0; -char const *LLAMA_COMMIT = "e54d41befcc1575f4c898c5ff4ef43970cead75f"; +char const *LLAMA_COMMIT = "7049736b2dd9011bf819e298b844ebbc4b5afdc9"; char const *LLAMA_COMPILER = ""; char const *LLAMA_BUILD_TARGET = ""; diff --git a/llama/llama.cpp/common/common.cpp b/llama/llama.cpp/common/common.cpp index c6962d1d..b0591e84 100644 --- a/llama/llama.cpp/common/common.cpp +++ b/llama/llama.cpp/common/common.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -41,6 +42,7 @@ #endif #include #include +#include #include #include #else @@ -49,6 +51,11 @@ #include #endif +#if defined(__linux__) +#include +#include +#endif + #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif @@ -557,13 +564,6 @@ std::string string_from(const struct llama_context * ctx, const std::vectorpw_dir)) { + throw std::runtime_error("Failed to find $HOME directory"); + } + + cache_directory = std::string(pw->pw_dir) + std::string("/.cache/"); +#else /* defined(__linux__) */ + throw std::runtime_error("Failed to find $HOME directory"); +#endif /* defined(__linux__) */ } #elif defined(__APPLE__) cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); @@ -914,7 +919,8 @@ struct common_init_result common_init_from_params(common_params & params) { llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); if (model == NULL) { - LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str()); + LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", + __func__, params.model.path.c_str()); return iparams; } @@ -924,7 +930,8 @@ struct common_init_result common_init_from_params(common_params & params) { llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { - LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); + LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", + __func__, params.model.path.c_str()); llama_model_free(model); return iparams; } @@ -971,15 +978,13 @@ struct common_init_result common_init_from_params(common_params & params) { bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL; + bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL; - if (!has_eos && !has_sep) { - LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__); + if (!has_eos && !has_sep && !has_rerank_prompt) { + LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__); ok = false; } else if (!has_eos) { LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__); - } else if (!has_sep) { - LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__); - ok = false; } if (!ok) { @@ -1001,7 +1006,12 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } + char buf[1024]; la.ptr = lora.get(); + llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf)); + la.task_name = buf; + llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); + la.prompt_prefix = buf; iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters } @@ -1123,6 +1133,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; mparams.use_extra_bufts = !params.no_extra_bufts; + mparams.no_host = params.no_host; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; @@ -1165,11 +1176,10 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.pooling_type = params.pooling_type; cparams.attention_type = params.attention_type; - cparams.defrag_thold = params.defrag_thold; + cparams.flash_attn_type = params.flash_attn_type; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; - cparams.flash_attn = params.flash_attn; cparams.no_perf = params.no_perf; cparams.op_offload = !params.no_op_offload; cparams.swa_full = params.swa_full; @@ -1565,3 +1575,56 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std return result; } + +ggml_opt_optimizer_params common_opt_lr_pars(void * userdata) { + ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr); + const lr_opt & d = *(lr_opt *) userdata; + result.adamw.alpha = result.sgd.alpha = d.get_lr(d.epoch); + result.sgd.wd = result.adamw.wd = d.wd; + return result; +} + +// TODO make all command line args case-insensitive +static inline bool eq_case_insensitive(char const* a, char const* b) { + return ! +#if defined(_MSC_VER) + _stricmp +#else + strcasecmp +#endif // defined(_MSC_VER) + (a, b); +} + +enum ggml_opt_optimizer_type common_opt_get_optimizer(const char * n) { + if (eq_case_insensitive("adamw", n)) { + return GGML_OPT_OPTIMIZER_TYPE_ADAMW; + } + if (eq_case_insensitive("sgd", n)) { + return GGML_OPT_OPTIMIZER_TYPE_SGD; + } + return GGML_OPT_OPTIMIZER_TYPE_COUNT; +} + +// TODO simplify to use just log and exp +static float const k_log_2 = std::log(2.f); + +void lr_opt::init() { + if (lr_min > 0 && lr_min < lr0) { + float nhalf = std::log(lr0 / lr_min) / k_log_2; + float e = epochs; + if (decay_epochs > 0 && decay_epochs < e) { + e = decay_epochs; + } else { + decay_epochs = e; + } + scale_epoch = nhalf / e; + } +} + +float lr_opt::get_lr(float epoch) const { + float r = lr_min <= 0 ? lr0 : + epoch >= decay_epochs ? lr_min : + lr0 * std::pow(0.5f, epoch * scale_epoch); + LOG_INF("epoch %.2g lr=%.2g\n", epoch, r); + return r; +} diff --git a/llama/llama.cpp/common/common.h b/llama/llama.cpp/common/common.h index 5eab199a..a8cb630e 100644 --- a/llama/llama.cpp/common/common.h +++ b/llama/llama.cpp/common/common.h @@ -2,14 +2,17 @@ #pragma once -#include "llama-cpp.h" - #include +#include #include #include #include #include #include +#include + +#include "ggml-opt.h" +#include "llama-cpp.h" #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' @@ -31,6 +34,9 @@ struct common_adapter_lora_info { std::string path; float scale; + std::string task_name; + std::string prompt_prefix; + struct llama_adapter_lora * ptr; }; @@ -82,6 +88,7 @@ enum llama_example { LLAMA_EXAMPLE_PARALLEL, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_DIFFUSION, + LLAMA_EXAMPLE_FINETUNE, LLAMA_EXAMPLE_COUNT, }; @@ -186,10 +193,11 @@ struct common_params_sampling { }; struct common_params_model { - std::string path = ""; // model local path // NOLINT - std::string url = ""; // model url to download // NOLINT - std::string hf_repo = ""; // HF repo // NOLINT - std::string hf_file = ""; // HF file // NOLINT + std::string path = ""; // model local path // NOLINT + std::string url = ""; // model url to download // NOLINT + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT + std::string docker_repo = ""; // Docker repo // NOLINT }; struct common_params_speculative { @@ -202,6 +210,7 @@ struct common_params_speculative { float p_split = 0.1f; // speculative decoding split probability float p_min = 0.75f; // minimum speculative decoding probability (greedy) std::vector> replacements; // main to speculative model replacements + std::vector tensor_buft_overrides; ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V @@ -234,14 +243,36 @@ struct common_params_diffusion { bool add_gumbel_noise = false; // add gumbel noise to the logits if temp > 0.0 }; +// reasoning API response format (not to be confused as chat template's reasoning format) enum common_reasoning_format { COMMON_REASONING_FORMAT_NONE, - COMMON_REASONING_FORMAT_AUTO, + COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content` COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in tags in stream mode COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. - COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. + // do not extend this enum unless you absolutely have to + // in most cases, use COMMON_REASONING_FORMAT_AUTO + // see: https://github.com/ggml-org/llama.cpp/pull/15408 }; + +struct lr_opt { + float lr0 = 1e-5; // learning rate at first epoch + float lr_min = -1; + float decay_epochs = -1; // if >0, the learning rate starts at lr0 and decays to lr_min after this many epochs + float scale_epoch = 0; + float wd = 0; + unsigned epochs = 2; + + unsigned epoch; // set by optimizer outer (epochs) loop + // learning rate decay - constant LR per epoch only for now + float get_lr(float e) const; + float get_lr() const { return get_lr(epoch); } + // must call after arg parse, before get_lr + void init(); +}; + +struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata); + struct common_params { int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 4096; // context size @@ -257,11 +288,10 @@ struct common_params { float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor - float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor - float yarn_beta_fast = 32.0f; // YaRN low correction dim - float yarn_beta_slow = 1.0f; // YaRN high correction dim + float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor + float yarn_beta_fast = -1.0f; // YaRN low correction dim + float yarn_beta_slow = -1.0f; // YaRN high correction dim int32_t yarn_orig_ctx = 0; // YaRN original context length - float defrag_thold = 0.1f; // KV cache defragmentation threshold // offload params std::vector devices; // devices to use for offloading @@ -283,6 +313,7 @@ struct common_params { enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings + enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention struct common_params_sampling sampling; struct common_params_speculative speculative; @@ -346,9 +377,8 @@ struct common_params { bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly - bool flash_attn = false; // flash attention bool no_perf = false; // disable performance metrics - bool ctx_shift = true; // context shift on inifinite text generation + bool ctx_shift = false; // context shift on infinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool kv_unified = false; // enable unified KV cache @@ -362,6 +392,7 @@ struct common_params { bool check_tensors = false; // validate tensor data bool no_op_offload = false; // globally disable offload host tensor operations to device bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking) + bool no_host = false; // bypass host buffer allowing extra buffers to be used bool single_turn = false; // single turn chat conversation @@ -376,6 +407,11 @@ struct common_params { bool no_mmproj = false; // explicitly disable multimodal model std::vector image; // path to image file(s) + // finetune + struct lr_opt lr; + enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW; + float val_split = 0.05f; // fraction of the data used for the validation set + // embedding bool embedding = false; // get only sentence embedding int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) @@ -384,11 +420,13 @@ struct common_params { std::string cls_sep = "\t"; // separator of classification sequences // server params - int32_t port = 8080; // server listens on this network port - int32_t timeout_read = 600; // http read timeout in seconds - int32_t timeout_write = timeout_read; // http write timeout in seconds - int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) - int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting + int32_t port = 8080; // server listens on this network port + int32_t timeout_read = 600; // http read timeout in seconds + int32_t timeout_write = timeout_read; // http write timeout in seconds + int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) + int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting + int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot + int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT @@ -396,7 +434,7 @@ struct common_params { std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; int reasoning_budget = -1; bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response @@ -409,7 +447,7 @@ struct common_params { // "advanced" endpoints are disabled by default for better security bool webui = true; - bool endpoint_slots = false; + bool endpoint_slots = true; bool endpoint_props = false; // only control POST requests, not GET bool endpoint_metrics = false; @@ -417,7 +455,7 @@ struct common_params { std::string slot_save_path; - float slot_prompt_similarity = 0.5f; + float slot_prompt_similarity = 0.1f; // batched-bench params bool is_pp_shared = false; @@ -698,8 +736,25 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; } +// +// MoE utils +// + +const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps"; + +static std::string llm_ffn_exps_block_regex(int idx) { + return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX); +} + +static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() { + return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() }; +} + // // training utils // ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector & tokens, int64_t stride); + +// "adamw" or "sgd" (case insensitive) +enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *); diff --git a/llama/llama.cpp/common/json-schema-to-grammar.cpp b/llama/llama.cpp/common/json-schema-to-grammar.cpp index 98b8280f..f4de7e34 100644 --- a/llama/llama.cpp/common/json-schema-to-grammar.cpp +++ b/llama/llama.cpp/common/json-schema-to-grammar.cpp @@ -257,12 +257,13 @@ std::unordered_map STRING_FORMAT_RULES = { }; static bool is_reserved_name(const std::string & name) { - static std::unordered_set RESERVED_NAMES; - if (RESERVED_NAMES.empty()) { - RESERVED_NAMES.insert("root"); - for (const auto &p : PRIMITIVE_RULES) RESERVED_NAMES.insert(p.first); - for (const auto &p : STRING_FORMAT_RULES) RESERVED_NAMES.insert(p.first); - } + static const std::unordered_set RESERVED_NAMES = [] { + std::unordered_set s; + s.insert("root"); + for (const auto & p : PRIMITIVE_RULES) s.insert(p.first); + for (const auto & p : STRING_FORMAT_RULES) s.insert(p.first); + return s; + }(); return RESERVED_NAMES.find(name) != RESERVED_NAMES.end(); } @@ -843,9 +844,10 @@ public: _build_object_rule( properties, required, name, schema.contains("additionalProperties") ? schema["additionalProperties"] : json())); - } else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) { + } else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) { std::unordered_set required; std::vector> properties; + std::map enum_values; std::string hybrid_name = name; std::function add_component = [&](const json & comp_schema, bool is_required) { if (comp_schema.contains("$ref")) { @@ -857,6 +859,14 @@ public: required.insert(prop.key()); } } + } else if (comp_schema.contains("enum")) { + for (const auto & v : comp_schema["enum"]) { + const auto rule = _generate_constant_rule(v); + if (enum_values.find(rule) == enum_values.end()) { + enum_values[rule] = 0; + } + enum_values[rule] += 1; + } } else { // todo warning } @@ -870,6 +880,17 @@ public: add_component(t, true); } } + if (!enum_values.empty()) { + std::vector enum_intersection; + for (const auto & p : enum_values) { + if (p.second == schema["allOf"].size()) { + enum_intersection.push_back(p.first); + } + } + if (!enum_intersection.empty()) { + return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space"); + } + } return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json())); } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) { json items = schema.contains("items") ? schema["items"] : schema["prefixItems"]; diff --git a/llama/llama.cpp/common/log.cpp b/llama/llama.cpp/common/log.cpp index 52b31470..4ccdbd17 100644 --- a/llama/llama.cpp/common/log.cpp +++ b/llama/llama.cpp/common/log.cpp @@ -4,17 +4,52 @@ #include #include #include +#include +#include #include #include #include #include +#if defined(_WIN32) +# include +# include +# define isatty _isatty +# define fileno _fileno +#else +# include +#endif // defined(_WIN32) + int common_log_verbosity_thold = LOG_DEFAULT_LLAMA; void common_log_set_verbosity_thold(int verbosity) { common_log_verbosity_thold = verbosity; } +// Auto-detect if colors should be enabled based on terminal and environment +static bool common_log_should_use_colors_auto() { + // Check NO_COLOR environment variable (https://no-color.org/) + if (const char * no_color = std::getenv("NO_COLOR")) { + if (no_color[0] != '\0') { + return false; + } + } + + // Check TERM environment variable + if (const char * term = std::getenv("TERM")) { + if (std::strcmp(term, "dumb") == 0) { + return false; + } + } + + // Check if stdout and stderr are connected to a terminal + // We check both because log messages can go to either + bool stdout_is_tty = isatty(fileno(stdout)); + bool stderr_is_tty = isatty(fileno(stderr)); + + return stdout_is_tty || stderr_is_tty; +} + static int64_t t_us() { return std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); } @@ -353,6 +388,11 @@ struct common_log * common_log_init() { struct common_log * common_log_main() { static struct common_log log; + static std::once_flag init_flag; + std::call_once(init_flag, [&]() { + // Set default to auto-detect colors + log.set_colors(common_log_should_use_colors_auto()); + }); return &log; } @@ -380,8 +420,19 @@ void common_log_set_file(struct common_log * log, const char * file) { log->set_file(file); } -void common_log_set_colors(struct common_log * log, bool colors) { - log->set_colors(colors); +void common_log_set_colors(struct common_log * log, log_colors colors) { + if (colors == LOG_COLORS_AUTO) { + log->set_colors(common_log_should_use_colors_auto()); + return; + } + + if (colors == LOG_COLORS_DISABLED) { + log->set_colors(false); + return; + } + + GGML_ASSERT(colors == LOG_COLORS_ENABLED); + log->set_colors(true); } void common_log_set_prefix(struct common_log * log, bool prefix) { diff --git a/llama/llama.cpp/common/log.h b/llama/llama.cpp/common/log.h index c56bb50d..f329b434 100644 --- a/llama/llama.cpp/common/log.h +++ b/llama/llama.cpp/common/log.h @@ -24,6 +24,12 @@ #define LOG_DEFAULT_DEBUG 1 #define LOG_DEFAULT_LLAMA 0 +enum log_colors { + LOG_COLORS_AUTO = -1, + LOG_COLORS_DISABLED = 0, + LOG_COLORS_ENABLED = 1, +}; + // needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower // set via common_log_set_verbosity() extern int common_log_verbosity_thold; @@ -65,10 +71,10 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch // D - debug (stderr, V = LOG_DEFAULT_DEBUG) // -void common_log_set_file (struct common_log * log, const char * file); // not thread-safe -void common_log_set_colors (struct common_log * log, bool colors); // not thread-safe -void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log -void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix +void common_log_set_file (struct common_log * log, const char * file); // not thread-safe +void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe +void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log +void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix // helper macros for logging // use these to avoid computing log arguments if the verbosity of the log is higher than the threshold diff --git a/llama/llama.cpp/common/sampling.cpp b/llama/llama.cpp/common/sampling.cpp index 9c04d35f..c69d525b 100644 --- a/llama/llama.cpp/common/sampling.cpp +++ b/llama/llama.cpp/common/sampling.cpp @@ -332,6 +332,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam } if (ctx) { llama_perf_context_print(ctx); + llama_memory_breakdown_print(ctx); } } @@ -426,8 +427,29 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { // helpers -llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) { - return &gsmpl->cur_p; +llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) { + auto * res = &gsmpl->cur_p; + + if (do_sort && !res->sorted) { + // remember the selected token before sorting + const llama_token id = res->data[res->selected].id; + + std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.p > b.p; + }); + + // restore the selected token after sorting + for (size_t i = 0; i < res->size; ++i) { + if (res->data[i].id == id) { + res->selected = i; + break; + } + } + + res->sorted = true; + } + + return res; } llama_token common_sampler_last(const struct common_sampler * gsmpl) { diff --git a/llama/llama.cpp/common/sampling.h b/llama/llama.cpp/common/sampling.h index 2064421d..e198eecd 100644 --- a/llama/llama.cpp/common/sampling.h +++ b/llama/llama.cpp/common/sampling.h @@ -86,7 +86,9 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); // helpers // access the internal list of current candidate tokens -llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl); +// if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability) +// the .sorted flag of the result indicates whether the returned candidates are sorted +llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort); // get the last accepted token llama_token common_sampler_last(const struct common_sampler * gsmpl); diff --git a/llama/llama.cpp/include/llama.h b/llama/llama.cpp/include/llama.h index 545e957e..a0a660bf 100644 --- a/llama/llama.cpp/include/llama.h +++ b/llama/llama.cpp/include/llama.h @@ -64,8 +64,6 @@ extern "C" { typedef struct llama_memory_i * llama_memory_t; - struct llama_kv_cache; // DEPRECATED (use llama_memory instead) - typedef int32_t llama_pos; typedef int32_t llama_token; typedef int32_t llama_seq_id; @@ -181,6 +179,14 @@ extern "C" { LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1, }; + enum llama_flash_attn_type { + LLAMA_FLASH_ATTN_TYPE_AUTO = -1, + LLAMA_FLASH_ATTN_TYPE_DISABLED = 0, + LLAMA_FLASH_ATTN_TYPE_ENABLED = 1, + }; + + LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); + enum llama_split_mode { LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs @@ -200,7 +206,7 @@ extern "C" { llama_token_data * data; size_t size; int64_t selected; // this is the index in the data array (i.e. not the token id) - bool sorted; + bool sorted; // note: do not assume the data is sorted - always check this flag } llama_token_data_array; typedef bool (*llama_progress_callback)(float progress, void * user_data); @@ -290,6 +296,7 @@ extern "C" { bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) + bool no_host; // bypass host buffer allowing extra buffers to be used }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations @@ -305,6 +312,7 @@ extern "C" { enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings + enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention // ref: https://github.com/ggml-org/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model @@ -314,7 +322,7 @@ extern "C" { float yarn_beta_fast; // YaRN low correction dim float yarn_beta_slow; // YaRN high correction dim uint32_t yarn_orig_ctx; // YaRN original context size - float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default) + float defrag_thold; // [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default) ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; @@ -331,7 +339,6 @@ extern "C" { // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU - bool flash_attn; // use flash attention [EXPERIMENTAL] bool no_perf; // measure performance timings bool op_offload; // offload host tensor operations to device bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) @@ -469,8 +476,6 @@ extern "C" { LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx); LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type - DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead"); - LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); @@ -539,6 +544,9 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is hybrid (like Jamba, Granite, etc.) + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); @@ -557,10 +565,32 @@ extern "C" { struct llama_model * model, const char * path_lora); + // Functions to access the adapter's GGUF metadata scalar values + // - The functions return the length of the string on success, or -1 on failure + // - The output string is always null-terminated and cleared on failure + // - When retrieving a string, an extra byte must be allocated to account for the null terminator + // - GGUF array values are not supported by these functions + + // Get metadata value as a string by key name + LLAMA_API int32_t llama_adapter_meta_val_str(const struct llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size); + + // Get the number of metadata key/value pairs + LLAMA_API int32_t llama_adapter_meta_count(const struct llama_adapter_lora * adapter); + + // Get metadata key name by index + LLAMA_API int32_t llama_adapter_meta_key_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); + + // Get metadata value as a string by index + LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); + // Manually free a LoRA adapter // Note: loaded adapters will be free when the associated model is deleted LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); + // Get the invocation tokens if the current lora is an alora + LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter); + LLAMA_API const llama_token * llama_adapter_get_alora_invocation_tokens (const struct llama_adapter_lora * adapter); + // The following functions operate on a llama_context, hence the naming: llama_verb_... // Add a loaded LoRA adapter to given context @@ -667,111 +697,6 @@ extern "C" { // Check if the memory supports shifting LLAMA_API bool llama_memory_can_shift(llama_memory_t mem); - // - // KV cache for self-attention (TODO: deprecate in favor of llama_memory) - // - - // Returns the number of tokens in the KV cache (slow, use only for debug) - // If a KV cell has multiple sequences assigned to it, it will be counted multiple times - DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx), - "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); - - // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) - DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx), - "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); - - // Clear the KV cache - both cell info is erased and KV data is zeroed - DEPRECATED(LLAMA_API void llama_kv_self_clear( - struct llama_context * ctx), - "Use llama_memory_clear() instead"); - - // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) - // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails - // seq_id < 0 : match any sequence - // p0 < 0 : [0, p1] - // p1 < 0 : [p0, inf) - DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1), - "Use llama_memory_seq_rm() instead"); - - // Copy all tokens that belong to the specified sequence to another sequence - // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence - // p0 < 0 : [0, p1] - // p1 < 0 : [p0, inf) - DEPRECATED(LLAMA_API void llama_kv_self_seq_cp( - struct llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1), - "Use llama_memory_seq_cp() instead"); - - // Removes all tokens that do not belong to the specified sequence - DEPRECATED(LLAMA_API void llama_kv_self_seq_keep( - struct llama_context * ctx, - llama_seq_id seq_id), - "Use llama_memory_seq_keep() instead"); - - // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) - // If the KV cache is RoPEd, the KV data is updated accordingly: - // - lazily on next llama_decode() - // p0 < 0 : [0, p1] - // p1 < 0 : [p0, inf) - DEPRECATED(LLAMA_API void llama_kv_self_seq_add( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta), - "Use llama_memory_seq_add() instead"); - - // Integer division of the positions by factor of `d > 1` - // If the KV cache is RoPEd, the KV data is updated accordingly: - // - lazily on next llama_decode() - // p0 < 0 : [0, p1] - // p1 < 0 : [p0, inf) - DEPRECATED(LLAMA_API void llama_kv_self_seq_div( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d), - "Use llama_memory_seq_div() instead"); - - // Returns the smallest position present in the KV cache for the specified sequence - // This is typically non-zero only for SWA caches - // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache - // Return -1 if the sequence is empty - DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min( - struct llama_context * ctx, - llama_seq_id seq_id), - "Use llama_memory_seq_pos_min() instead"); - - // Returns the largest position present in the KV cache for the specified sequence - // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache - // Return -1 if the sequence is empty - DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max( - struct llama_context * ctx, - llama_seq_id seq_id), - "Use llama_memory_seq_pos_max() instead"); - - // Defragment the KV cache - // This will be applied: - // - lazily on next llama_decode() - DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx), - "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); - - // Check if the context supports KV cache shifting - DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx), - "use llama_memory_can_shift() instead"); - - // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) - DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx), - "simply remove this call, updates are applied lazily on the next llama_decode()"); - // // State / sessions // @@ -870,6 +795,33 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); +// for backwards-compat +#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 + +// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) +#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 + + typedef uint32_t llama_state_seq_flags; + + LLAMA_API size_t llama_state_seq_get_size_ext( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags); + + LLAMA_API size_t llama_state_seq_get_data_ext( + struct llama_context * ctx, + uint8_t * dst, + size_t size, + llama_seq_id seq_id, + llama_state_seq_flags flags); + + LLAMA_API size_t llama_state_seq_set_data_ext( + struct llama_context * ctx, + const uint8_t * src, + size_t size, + llama_seq_id dest_seq_id, + llama_state_seq_flags flags); + // // Decoding // @@ -1216,11 +1168,6 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); - /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. - DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void), - "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)"); - /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// Setting k <= 0 makes this a noop LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); @@ -1390,24 +1337,25 @@ extern "C" { // // Performance utils // - // NOTE: Used by llama.cpp examples, avoid using in third-party apps. Instead, do your own performance measurements. + // NOTE: Used by llama.cpp examples/tools, avoid using in third-party apps. Instead, do your own performance measurements. // struct llama_perf_context_data { - double t_start_ms; - double t_load_ms; - double t_p_eval_ms; - double t_eval_ms; + // ms == milliseconds + double t_start_ms; // absolute start time + double t_load_ms; // time needed for loading the model + double t_p_eval_ms; // time needed for processing the prompt + double t_eval_ms; // time needed for generating tokens - int32_t n_p_eval; - int32_t n_eval; - int32_t n_reused; // number of times a ggml compute graph had been reused + int32_t n_p_eval; // number of prompt tokens + int32_t n_eval; // number of generated tokens + int32_t n_reused; // number of times a ggml compute graph had been reused }; struct llama_perf_sampler_data { - double t_sample_ms; + double t_sample_ms; // time needed for sampling in ms - int32_t n_sample; + int32_t n_sample; // number of sampled tokens }; LLAMA_API struct llama_perf_context_data llama_perf_context (const struct llama_context * ctx); @@ -1419,6 +1367,9 @@ extern "C" { LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); + // print a breakdown of per-device memory use via LLAMA_LOG: + LLAMA_API void llama_memory_breakdown_print(const struct llama_context * ctx); + // // training // @@ -1437,6 +1388,8 @@ extern "C" { ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters void * get_opt_pars_ud; // userdata for calculating optimizer parameters + + enum ggml_opt_optimizer_type optimizer_type; }; LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params); diff --git a/llama/llama.cpp/src/llama-adapter.cpp b/llama/llama.cpp/src/llama-adapter.cpp index 8d94034a..d8eef75a 100644 --- a/llama/llama.cpp/src/llama-adapter.cpp +++ b/llama/llama.cpp/src/llama-adapter.cpp @@ -6,6 +6,7 @@ #include #include +#include #include // vec @@ -163,13 +164,38 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ // check metadata { + const gguf_context * gguf_ctx = ctx_gguf.get(); + + LLAMA_LOG_INFO("%s: Dumping metadata keys/values.\n", __func__); + + // get metadata as string + for (int i = 0; i < gguf_get_n_kv(gguf_ctx); i++) { + gguf_type type = gguf_get_kv_type(gguf_ctx, i); + const std::string type_name = + type == GGUF_TYPE_ARRAY + ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(gguf_ctx, i)), gguf_get_arr_n(gguf_ctx, i)) + : gguf_type_name(type); + const char * name = gguf_get_key(gguf_ctx, i); + const std::string value = gguf_kv_to_str(gguf_ctx, i); + + if (type != GGUF_TYPE_ARRAY) { + adapter.gguf_kv.emplace(name, value); + } + + const size_t MAX_VALUE_LEN = 40; + std::string print_value = value.size() > MAX_VALUE_LEN ? format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()) : value; + replace_all(print_value, "\n", "\\n"); + + LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), print_value.c_str()); + } + auto get_kv_str = [&](const std::string & key) -> std::string { - int id = gguf_find_key(ctx_gguf.get(), key.c_str()); - return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id)); + int id = gguf_find_key(gguf_ctx, key.c_str()); + return id < 0 ? "" : std::string(gguf_get_val_str(gguf_ctx, id)); }; auto get_kv_f32 = [&](const std::string & key) -> float { - int id = gguf_find_key(ctx_gguf.get(), key.c_str()); - return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id); + int id = gguf_find_key(gguf_ctx, key.c_str()); + return id < 0 ? 0.0f : gguf_get_val_f32(gguf_ctx, id); }; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); @@ -190,6 +216,26 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ } adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA)); + + // parse alora invocation sequence vector + const auto & key = llm_kv(LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS); + const int kid = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (kid >= 0) { + if (gguf_get_kv_type(ctx_gguf.get(), kid) != GGUF_TYPE_ARRAY) { + throw std::runtime_error("invalid gguf type for " + key); + } + const auto arr_type = gguf_get_arr_type(ctx_gguf.get(), kid); + if (arr_type != GGUF_TYPE_UINT32) { + throw std::runtime_error("invalid gguf element type for " + key); + } + const size_t seq_len = gguf_get_arr_n(ctx_gguf.get(), kid); + const void * data = gguf_get_arr_data(ctx_gguf.get(), kid); + adapter.alora_invocation_tokens.resize(seq_len); + std::copy( + (const llama_token *)data, + (const llama_token *)data + seq_len, + adapter.alora_invocation_tokens.begin()); + } } int n_tensors = gguf_get_n_tensors(ctx_gguf.get()); @@ -383,6 +429,57 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p return nullptr; } +int32_t llama_adapter_meta_val_str(const llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size) { + const auto & it = adapter->gguf_kv.find(key); + if (it == adapter->gguf_kv.end()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + return snprintf(buf, buf_size, "%s", it->second.c_str()); +} + +int32_t llama_adapter_meta_count(const llama_adapter_lora * adapter) { + return (int)adapter->gguf_kv.size(); +} + +int32_t llama_adapter_meta_key_by_index(const llama_adapter_lora * adapter, int i, char * buf, size_t buf_size) { + if (i < 0 || i >= (int)adapter->gguf_kv.size()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + auto it = adapter->gguf_kv.begin(); + std::advance(it, i); + return snprintf(buf, buf_size, "%s", it->first.c_str()); +} + +int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size) { + if (i < 0 || i >= (int)adapter->gguf_kv.size()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + auto it = adapter->gguf_kv.begin(); + std::advance(it, i); + return snprintf(buf, buf_size, "%s", it->second.c_str()); +} + void llama_adapter_lora_free(llama_adapter_lora * adapter) { delete adapter; } + +uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) { + if (!adapter) { + return 0; + } + return adapter->alora_invocation_tokens.size(); +} + +const llama_token * llama_adapter_get_alora_invocation_tokens(const llama_adapter_lora * adapter) { + GGML_ASSERT(adapter); + return adapter->alora_invocation_tokens.data(); +} diff --git a/llama/llama.cpp/src/llama-adapter.h b/llama/llama.cpp/src/llama-adapter.h index 65824e97..4f65247c 100644 --- a/llama/llama.cpp/src/llama-adapter.h +++ b/llama/llama.cpp/src/llama-adapter.h @@ -67,6 +67,12 @@ struct llama_adapter_lora { float alpha; + // gguf metadata + std::unordered_map gguf_kv; + + // activated lora (aLoRA) + std::vector alora_invocation_tokens; + llama_adapter_lora() = default; ~llama_adapter_lora() = default; diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index 4b285646..9f6b6ad2 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -22,6 +22,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, { LLM_ARCH_NEO_BERT, "neo-bert" }, { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, + { LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" }, { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, @@ -44,6 +45,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, + { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA2, "mamba2" }, @@ -68,6 +70,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE4, "exaone4" }, { LLM_ARCH_RWKV6, "rwkv6" }, @@ -91,9 +94,14 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, + { LLM_ARCH_LFM2MOE, "lfm2moe" }, { LLM_ARCH_DREAM, "dream" }, { LLM_ARCH_SMALLTHINKER, "smallthinker" }, { LLM_ARCH_LLADA, "llada" }, + { LLM_ARCH_LLADA_MOE, "llada-moe" }, + { LLM_ARCH_SEED_OSS, "seed_oss" }, + { LLM_ARCH_GROVEMOE, "grovemoe" }, + { LLM_ARCH_APERTUS, "apertus" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -121,6 +129,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, + { LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, "%s.expert_chunk_feed_forward_length" }, { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, @@ -129,12 +138,16 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, + { LLM_KV_EXPERT_GROUP_SCALE, "%s.expert_group_scale" }, + { LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" }, { LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" }, { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, + { LLM_KV_DECODER_BLOCK_COUNT, "%s.decoder_block_count" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, + { LLM_KV_ROUTER_LOGIT_SOFTCAPPING, "%s.router_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, { LLM_KV_SWIN_NORM, "%s.swin_norm" }, { LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" }, @@ -165,20 +178,26 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, + { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, - { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, - { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, - { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, - { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, - { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, - { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, - { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, - { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, - { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, - { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, + { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, { LLM_KV_SPLIT_NO, "split.no" }, { LLM_KV_SPLIT_COUNT, "split.count" }, @@ -202,6 +221,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, + // sentence-transformers dense modules feature dims + { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, + { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, + { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, + { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, @@ -235,8 +259,16 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, - { LLM_KV_ADAPTER_TYPE, "adapter.type" }, - { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + { LLM_KV_ADAPTER_TYPE, "adapter.type" }, + { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + { LLM_KV_ADAPTER_LORA_TASK_NAME, "adapter.lora.task_name" }, + { LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" }, + { LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" }, + + { LLM_KV_XIELU_ALPHA_N, "xielu.alpha_n" }, + { LLM_KV_XIELU_ALPHA_P, "xielu.alpha_p" }, + { LLM_KV_XIELU_BETA, "xielu.beta" }, + { LLM_KV_XIELU_EPS, "xielu.eps" }, // deprecated { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, @@ -392,12 +424,16 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, }, @@ -576,6 +612,20 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_CLS, "cls" }, }, }, + { + LLM_ARCH_JINA_BERT_V3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + }, + }, { LLM_ARCH_BLOOM, { @@ -689,6 +739,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, @@ -1021,6 +1072,29 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" }, }, }, + { + LLM_ARCH_GEMMA_EMBEDDING, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_DENSE_2_OUT, "dense_2" }, + { LLM_TENSOR_DENSE_3_OUT, "dense_3" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_STARCODER2, { @@ -1534,6 +1608,31 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_NEMOTRON_H, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + // mamba(2) ssm layers + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + // attention layers + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + // dense FFN + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_EXAONE, { @@ -2030,6 +2129,33 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + } + }, + { + LLM_ARCH_LFM2MOE, + { + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" }, + { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, + { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, } }, { @@ -2053,6 +2179,25 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" } }, }, + { + LLM_ARCH_APERTUS, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_DREAM, { @@ -2087,6 +2232,66 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_LLADA_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_SEED_OSS, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GROVEMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_CHEXPS, "blk.%d.ffn_gate_chexps" }, + { LLM_TENSOR_FFN_DOWN_CHEXPS, "blk.%d.ffn_down_chexps" }, + { LLM_TENSOR_FFN_UP_CHEXPS, "blk.%d.ffn_up_chexps" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2103,6 +2308,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, @@ -2219,6 +2426,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_DOWN_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, // altup / laurel (gemma 3n) {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, @@ -2340,6 +2550,8 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_PLAMO2: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: + case LLM_ARCH_NEMOTRON_H: return true; default: return false; @@ -2350,6 +2562,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { switch (arch) { case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: + case LLM_ARCH_LLADA_MOE: return true; default: return false; diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index 3ea994c7..dc7a362a 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -26,6 +26,7 @@ enum llm_arch { LLM_ARCH_NOMIC_BERT_MOE, LLM_ARCH_NEO_BERT, LLM_ARCH_JINA_BERT_V2, + LLM_ARCH_JINA_BERT_V3, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, @@ -48,6 +49,7 @@ enum llm_arch { LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, + LLM_ARCH_GEMMA_EMBEDDING, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_MAMBA2, @@ -72,6 +74,7 @@ enum llm_arch { LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, LLM_ARCH_NEMOTRON, + LLM_ARCH_NEMOTRON_H, LLM_ARCH_EXAONE, LLM_ARCH_EXAONE4, LLM_ARCH_RWKV6, @@ -95,9 +98,14 @@ enum llm_arch { LLM_ARCH_SMOLLM3, LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, + LLM_ARCH_LFM2MOE, LLM_ARCH_DREAM, LLM_ARCH_SMALLTHINKER, LLM_ARCH_LLADA, + LLM_ARCH_LLADA_MOE, + LLM_ARCH_SEED_OSS, + LLM_ARCH_GROVEMOE, + LLM_ARCH_APERTUS, LLM_ARCH_UNKNOWN, }; @@ -125,6 +133,7 @@ enum llm_kv { LLM_KV_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, @@ -133,12 +142,16 @@ enum llm_kv { LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_EXPERT_WEIGHTS_NORM, LLM_KV_EXPERT_GATING_FUNC, + LLM_KV_EXPERT_GROUP_SCALE, + LLM_KV_EXPERTS_PER_GROUP, LLM_KV_MOE_EVERY_N_LAYERS, LLM_KV_NEXTN_PREDICT_LAYERS, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, + LLM_KV_DECODER_BLOCK_COUNT, LLM_KV_ATTN_LOGIT_SOFTCAPPING, + LLM_KV_ROUTER_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING, LLM_KV_SWIN_NORM, LLM_KV_RESCALE_EVERY_N_LAYERS, @@ -169,6 +182,8 @@ enum llm_kv { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SCALE, + LLM_KV_ATTENTION_OUTPUT_SCALE, + LLM_KV_ATTENTION_TEMPERATURE_LENGTH, LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, @@ -183,6 +198,10 @@ enum llm_kv { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, LLM_KV_ROPE_SCALING_YARN_LOG_MUL, + LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, + LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, + LLM_KV_ROPE_SCALING_YARN_BETA_FAST, + LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, LLM_KV_SPLIT_NO, LLM_KV_SPLIT_COUNT, @@ -231,6 +250,9 @@ enum llm_kv { LLM_KV_ADAPTER_TYPE, LLM_KV_ADAPTER_LORA_ALPHA, + LLM_KV_ADAPTER_LORA_TASK_NAME, + LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, + LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, LLM_KV_POSNET_EMBEDDING_LENGTH, LLM_KV_POSNET_BLOCK_COUNT, @@ -242,10 +264,21 @@ enum llm_kv { LLM_KV_SHORTCONV_L_CACHE, + LLM_KV_XIELU_ALPHA_N, + LLM_KV_XIELU_ALPHA_P, + LLM_KV_XIELU_BETA, + LLM_KV_XIELU_EPS, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, + + // sentence-transformers dense layers in and out features + LLM_KV_DENSE_2_FEAT_IN, + LLM_KV_DENSE_2_FEAT_OUT, + LLM_KV_DENSE_3_FEAT_IN, + LLM_KV_DENSE_3_FEAT_OUT, }; enum llm_tensor { @@ -253,6 +286,8 @@ enum llm_tensor { LLM_TENSOR_TOKEN_EMBD_NORM, LLM_TENSOR_TOKEN_TYPES, LLM_TENSOR_POS_EMBD, + LLM_TENSOR_DENSE_2_OUT, + LLM_TENSOR_DENSE_3_OUT, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_ROPE_FREQS, @@ -287,6 +322,9 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_DOWN_CHEXPS, + LLM_TENSOR_FFN_GATE_CHEXPS, + LLM_TENSOR_FFN_UP_CHEXPS, LLM_TENSOR_FFN_EXP_PROBS_B, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, diff --git a/llama/llama.cpp/src/llama-batch.cpp b/llama/llama.cpp/src/llama-batch.cpp index 8698d89a..55d89eca 100644 --- a/llama/llama.cpp/src/llama-batch.cpp +++ b/llama/llama.cpp/src/llama-batch.cpp @@ -477,7 +477,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) { if (sequential && has_cpl) { - LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__); + LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch (you may need to use the -kvu flag)\n", __func__); return {}; } diff --git a/llama/llama.cpp/src/llama-chat.cpp b/llama/llama.cpp/src/llama-chat.cpp index 0a96a9a5..956c4e08 100644 --- a/llama/llama.cpp/src/llama-chat.cpp +++ b/llama/llama.cpp/src/llama-chat.cpp @@ -16,10 +16,10 @@ static std::string trim(const std::string & str) { size_t start = 0; size_t end = str.size(); - while (start < end && isspace(str[start])) { + while (start < end && isspace(static_cast(str[start]))) { start += 1; } - while (end > start && isspace(str[end - 1])) { + while (end > start && isspace(static_cast(str[end - 1]))) { end -= 1; } return str.substr(start, end - start); @@ -69,6 +69,8 @@ static const std::map LLM_CHAT_TEMPLATES = { { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, + { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, + { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -201,6 +203,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { return LLM_CHAT_TEMPLATE_KIMI_K2; + } else if (tmpl_contains("")) { + return LLM_CHAT_TEMPLATE_SEED_OSS; + } else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) { + return LLM_CHAT_TEMPLATE_GROK_2; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -584,7 +590,7 @@ int32_t llm_chat_apply_template( ss << message->content << "<|end_of_text|>\n"; } if (add_ass) { - ss << "<|start_of_role|>assistant<|end_of_role|>\n"; + ss << "<|start_of_role|>assistant<|end_of_role|>"; } } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { // GigaChat template @@ -752,6 +758,28 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|im_assistant|>assistant<|im_middle|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_SEED_OSS) { + for (auto message: chat) { + std::string role(message->role); + ss << "" << role << "\n" << (role == "assistant" ? trim(message->content) : message->content) << ""; + } + if (add_ass) { + ss << "assistant\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GROK_2) { + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "System: " << trim(message->content) << "<|separator|>\n\n"; + } else if (role == "user") { + ss << "Human: " << trim(message->content) << "<|separator|>\n\n"; + } else if (role == "assistant") { + ss << "Assistant: " << message->content << "<|separator|>\n\n"; + } + } + if (add_ass) { + ss << "Assistant:"; + } } else { // template not supported return -1; diff --git a/llama/llama.cpp/src/llama-chat.h b/llama/llama.cpp/src/llama-chat.h index 35a94385..5a87d9ab 100644 --- a/llama/llama.cpp/src/llama-chat.h +++ b/llama/llama.cpp/src/llama-chat.h @@ -49,6 +49,8 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_OPENAI_MOE, LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, LLM_CHAT_TEMPLATE_KIMI_K2, + LLM_CHAT_TEMPLATE_SEED_OSS, + LLM_CHAT_TEMPLATE_GROK_2, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/llama/llama.cpp/src/llama-context.cpp b/llama/llama.cpp/src/llama-context.cpp index 6ece5263..53a5e3a9 100644 --- a/llama/llama.cpp/src/llama-context.cpp +++ b/llama/llama.cpp/src/llama-context.cpp @@ -35,14 +35,12 @@ llama_context::llama_context( cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; - cparams.yarn_ext_factor = params.yarn_ext_factor; - cparams.yarn_attn_factor = params.yarn_attn_factor; - cparams.yarn_beta_fast = params.yarn_beta_fast; - cparams.yarn_beta_slow = params.yarn_beta_slow; - cparams.defrag_thold = params.defrag_thold; + cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; + cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; + cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; + cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; - cparams.flash_attn = params.flash_attn; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; cparams.warmup = false; @@ -87,13 +85,15 @@ llama_context::llama_context( cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; } + cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; + // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) // ref: https://github.com/ggerganov/llama.cpp/pull/5021 - // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self + // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory if (cparams.n_batch < GGML_KQ_MASK_PAD) { LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); cparams.n_batch = GGML_KQ_MASK_PAD; @@ -103,16 +103,6 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; - { - const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); - supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows; - - if (!supports_set_rows && !cparams.kv_unified) { - LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__); - cparams.kv_unified = true; - } - } - { const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; @@ -130,7 +120,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); - LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type)); LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -145,11 +135,6 @@ llama_context::llama_context( __func__, n_ctx_per_seq, hparams.n_ctx_train); } - if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) { - LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n", - __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573"); - } - if (!hparams.vocab_only) { // GPU backends for (auto * dev : model.devices) { @@ -196,7 +181,7 @@ llama_context::llama_context( // graph outputs buffer { // resized during inference when a batch uses more outputs - if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) { + if (output_reserve(params.n_seq_max) < params.n_seq_max) { throw std::runtime_error("failed to reserve initial output buffer"); } @@ -285,28 +270,75 @@ llama_context::llama_context( } } - // reserve worst-case graph - if (!hparams.vocab_only && memory) { + if (!hparams.vocab_only) { + llama_memory_context_ptr mctx; + if (memory) { + LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); + mctx = memory->init_full(); + if (!mctx) { + throw std::runtime_error("failed to initialize memory module"); + } + } + + cross.v_embd.clear(); + const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + // avoid reserving graphs with zero outputs - assume one output per sequence + n_outputs = n_seqs; + LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + // resolve automatic Flash Attention use + if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to split graph for Flash Attention check"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; + bool fa_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_FLASH_ATTN_EXT) { + continue; + } + ggml_backend_dev_t device_fa = ggml_backend_get_device( + ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_fa != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); + // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways + fa_device_mismatch = true; + break; + } + } + if (fa_device_mismatch) { + cparams.flash_attn = false; + LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); + if (ggml_is_quantized(params.type_v)) { + throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); + } + } else { + cparams.flash_attn = true; + LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); + } + } + + // reserve worst-case graph int n_splits_pp = -1; int n_nodes_pp = -1; int n_splits_tg = -1; int n_nodes_tg = -1; - // simulate full KV cache - - const auto mctx = memory->init_full(); - if (!mctx) { - throw std::runtime_error("failed to initialize KV cache"); - } - - cross.v_embd.clear(); - // reserve pp (prompt processing) graph first so that buffers are only allocated once { auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); @@ -444,26 +476,12 @@ llama_memory_t llama_context::get_memory() const { return memory.get(); } -// deprecated -void llama_context::kv_self_defrag_sched() { - if (!memory) { - return; - } - - memory_force_optimize = true; -} - -// deprecated -bool llama_context::kv_self_update(bool optimize) { +bool llama_context::memory_update(bool optimize) { if (!memory) { return false; } { - // TODO: remove in the future - optimize |= memory_force_optimize; - memory_force_optimize = false; - const auto mctx = memory->init_update(this, optimize); switch (mctx->get_status()) { case LLAMA_MEMORY_STATUS_SUCCESS: @@ -908,12 +926,6 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - if (!supports_set_rows) { - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to - // overlap with device computation. - ggml_backend_sched_reset(sched.get()); - } - // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -996,8 +1008,8 @@ int llama_context::decode(const llama_batch & batch_inp) { bool did_optimize = false; - // handle any pending defrags/shifts - kv_self_update(false); + // handle any pending shifts/copies + memory_update(false); llama_memory_context_ptr mctx; @@ -1022,7 +1034,7 @@ int llama_context::decode(const llama_batch & batch_inp) { if (!did_optimize) { did_optimize = true; - if (kv_self_update(true)) { + if (memory_update(true)) { LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens()); continue; @@ -1075,7 +1087,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); if (!res) { - // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache + // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module llama_pos pos_min[LLAMA_MAX_SEQ]; for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { pos_min[s] = std::numeric_limits::max(); @@ -1092,7 +1104,7 @@ int llama_context::decode(const llama_batch & batch_inp) { continue; } - LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); + LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); memory->seq_rm(s, pos_min[s], -1); } @@ -1243,12 +1255,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); - if (!supports_set_rows) { - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to - // overlap with device computation. - ggml_backend_sched_reset(sched.get()); - } - return 0; } @@ -1362,8 +1368,9 @@ llm_graph_result * llama_context::get_gf_res_reserve() const { return static_cast(gf_res_reserve.get()); } -ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) { +ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) { LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); + GGML_ASSERT(n_outputs >= 1); if (n_tokens % n_seqs != 0) { n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs @@ -1397,7 +1404,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u this->n_outputs = save_n_outputs; // initialize scheduler with the specified graph - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + if (split_only) { + ggml_backend_sched_split_graph(sched.get(), gf); + } else if (!ggml_backend_sched_reserve(sched.get(), gf)) { LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); return nullptr; } @@ -1437,7 +1446,9 @@ ggml_status llama_context::graph_compute( if (backend_cpu != nullptr) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu)); auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); - set_threadpool_fn(backend_cpu, tp); + if (set_threadpool_fn) { + set_threadpool_fn(backend_cpu, tp); + } } // set the number of threads for all the backends @@ -1656,30 +1667,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) { } } -size_t llama_context::state_seq_get_size(llama_seq_id seq_id) { +size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) { llama_io_write_dummy io; try { - return state_seq_write_data(io, seq_id); + return state_seq_write_data(io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); return 0; } } -size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { +size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) { llama_io_write_buffer io(dst, size); try { - return state_seq_write_data(io, seq_id); + return state_seq_write_data(io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); return 0; } } -size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { +size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) { llama_io_read_buffer io(src, size); try { - return state_seq_read_data(io, seq_id); + return state_seq_read_data(io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); return 0; @@ -1777,7 +1788,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file { const size_t state_size = file.size() - file.tell(); llama_io_read_file io(&file); - const size_t nread = state_seq_read_data(io, seq_id); + const size_t nread = state_seq_read_data(io, seq_id, 0); if (!nread) { LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); return 0; @@ -1801,7 +1812,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file // save the context state using stream saving llama_io_write_file io(&file); - state_seq_write_data(io, seq_id); + state_seq_write_data(io, seq_id, 0); const size_t res = file.tell(); GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); @@ -1876,7 +1887,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } if (memory != nullptr) { - LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); + LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); memory->state_write(io); } @@ -1962,7 +1973,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { } if (memory) { - LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); + LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); memory->state_read(io); } @@ -1970,21 +1981,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { return io.n_bytes(); } -size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { +size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { GGML_UNUSED(seq_id); if (memory) { - memory->state_write(io, seq_id); + memory->state_write(io, seq_id, flags); } return io.n_bytes(); } -size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { +size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { GGML_UNUSED(seq_id); if (memory) { - memory->state_read(io, seq_id); + memory->state_read(io, seq_id, flags); } return io.n_bytes(); @@ -2015,6 +2026,21 @@ void llama_context::perf_reset() { n_reused = 0; } +std::map llama_context::memory_breakdown() const { + std::map ret; + for (const auto & buft_size : model.memory_breakdown()) { + ret[buft_size.first].model += buft_size.second; + } + for (const auto & buft_size : memory->memory_breakdown()) { + ret[buft_size.first].context += buft_size.second; + } + for (const auto & backend_ptr : backends) { + ggml_backend_t backend = backend_ptr.get(); + ret[ggml_backend_sched_get_buffer_type(sched.get(), backend)].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend); + } + return ret; +} + // // training // @@ -2047,7 +2073,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params opt_params.opt_period = n_batch / n_ubatch; opt_params.get_opt_pars = lopt_params.get_opt_pars; opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; - + opt_params.optimizer = lopt_params.optimizer_type; opt_ctx = ggml_opt_init(opt_params); llama_opt_param_filter param_filter = lopt_params.param_filter; @@ -2247,12 +2273,13 @@ llama_context_params llama_context_default_params() { /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, + /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, - /*.yarn_attn_factor =*/ 1.0f, - /*.yarn_beta_fast =*/ 32.0f, - /*.yarn_beta_slow =*/ 1.0f, + /*.yarn_attn_factor =*/ -1.0f, + /*.yarn_beta_fast =*/ -1.0f, + /*.yarn_beta_slow =*/ -1.0f, /*.yarn_orig_ctx =*/ 0, /*.defrag_thold =*/ -1.0f, /*.cb_eval =*/ nullptr, @@ -2263,7 +2290,6 @@ llama_context_params llama_context_default_params() { /*.abort_callback_data =*/ nullptr, /*.embeddings =*/ false, /*.offload_kqv =*/ true, - /*.flash_attn =*/ false, /*.no_perf =*/ true, /*.op_offload =*/ true, /*.swa_full =*/ true, @@ -2291,16 +2317,40 @@ llama_context * llama_init_from_model( return nullptr; } - if (params.flash_attn && model->arch == LLM_ARCH_GROK) { + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) { LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); - params.flash_attn = false; + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; } - if (ggml_is_quantized(params.type_v) && !params.flash_attn) { + if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { + const uint32_t blck_size = ggml_blck_size(params.type_k); + if (model->hparams.n_embd_head_k % blck_size != 0) { + LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", + __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k); + return nullptr; + } + } + + if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { + const uint32_t blck_size = ggml_blck_size(params.type_v); + if (model->hparams.n_embd_head_v % blck_size != 0) { + LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n", + __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v); + return nullptr; + } + } + + if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); return nullptr; } + if (params.pooling_type != model->hparams.pooling_type) { + //user-specified pooling-type is different from the model default + LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__, + model->hparams.pooling_type, params.pooling_type); + } + try { auto * ctx = new llama_context(*model, params); return ctx; @@ -2342,16 +2392,6 @@ const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } -// deprecated -llama_kv_cache * llama_get_kv_self(llama_context * ctx) { - return dynamic_cast(ctx->get_memory()); -} - -// deprecated -void llama_kv_self_update(llama_context * ctx) { - ctx->kv_self_update(false); -} - enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { return ctx->pooling_type(); } @@ -2569,168 +2609,6 @@ bool llama_memory_can_shift(llama_memory_t mem) { return mem->get_can_shift(); } -// -// kv cache -// - -// deprecated -int32_t llama_kv_self_n_tokens(const llama_context * ctx) { - const auto * kv = llama_get_memory(ctx); - if (!kv) { - return 0; - } - - int32_t res = 0; - - for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) { - const llama_pos p0 = kv->seq_pos_min(s); - const llama_pos p1 = kv->seq_pos_max(s); - - if (p0 >= 0) { - res += (p1 - p0) + 1; - } - } - - return res; -} - -// deprecated -// note: this is the same as above - will be removed anyway, so it's ok -int32_t llama_kv_self_used_cells(const llama_context * ctx) { - const auto * kv = llama_get_memory(ctx); - if (!kv) { - return 0; - } - - int32_t res = 0; - - for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) { - const llama_pos p0 = kv->seq_pos_min(s); - const llama_pos p1 = kv->seq_pos_max(s); - - if (p0 >= 0) { - res += (p1 - p0) + 1; - } - } - - return res; -} - -// deprecated -void llama_kv_self_clear(llama_context * ctx) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return; - } - - llama_memory_clear(kv, true); -} - -// deprecated -bool llama_kv_self_seq_rm( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return true; - } - - return llama_memory_seq_rm(kv, seq_id, p0, p1); -} - -// deprecated -void llama_kv_self_seq_cp( - llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return; - } - - llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1); -} - -// deprecated -void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return; - } - - llama_memory_seq_keep(kv, seq_id); -} - -// deprecated -void llama_kv_self_seq_add( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return; - } - - llama_memory_seq_add(kv, seq_id, p0, p1, delta); -} - -// deprecated -void llama_kv_self_seq_div( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return; - } - - llama_memory_seq_div(kv, seq_id, p0, p1, d); -} - -// deprecated -llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return -1; - } - - return llama_memory_seq_pos_min(kv, seq_id); -} - -// deprecated -llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return -1; - } - - return llama_memory_seq_pos_max(kv, seq_id); -} - -// deprecated -void llama_kv_self_defrag(llama_context * ctx) { - // force defrag - ctx->kv_self_defrag_sched(); -} - -// deprecated -bool llama_kv_self_can_shift(const llama_context * ctx) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return false; - } - - return llama_memory_can_shift(kv); -} - // llama state API // deprecated @@ -2800,19 +2678,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const } size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) { - return ctx->state_seq_get_size(seq_id); + return llama_state_seq_get_size_ext(ctx, seq_id, 0); } size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { - ctx->synchronize(); - - return ctx->state_seq_get_data(seq_id, dst, size); + return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0); } size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { + return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0); +} + +size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) { + return ctx->state_seq_get_size(seq_id, flags); +} + +size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { ctx->synchronize(); - return ctx->state_seq_set_data(seq_id, src, size); + return ctx->state_seq_get_data(seq_id, dst, size, flags); +} + +size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { + ctx->synchronize(); + + return ctx->state_seq_set_data(seq_id, src, size, flags); } size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { @@ -2895,6 +2785,142 @@ void llama_perf_context_reset(llama_context * ctx) { ctx->perf_reset(); } +void llama_memory_breakdown_print(const struct llama_context * ctx) { + const std::vector & devices = ctx->get_model().devices; + + std::map memory_breakdown = ctx->memory_breakdown(); + + std::vector> table_data; + table_data.reserve(devices.size()); + const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n"; + const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n"; + const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n"; + + table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"}); + + constexpr size_t MiB = 1024 * 1024; + const std::vector desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "}; + + // track seen buffer types to avoid double counting: + std::set seen_buffer_types; + + // accumulative memory breakdown for each device and for host: + std::vector mb_dev(devices.size()); + llama_memory_breakdown_data mb_host; + + for (const auto & buft_mb : memory_breakdown) { + ggml_backend_buffer_type_t buft = buft_mb.first; + const llama_memory_breakdown_data & mb = buft_mb.second; + if (ggml_backend_buft_is_host(buft)) { + mb_host.model += mb.model; + mb_host.context += mb.context; + mb_host.compute += mb.compute; + seen_buffer_types.insert(buft); + continue; + } + ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); + if (dev) { + int i_dev = -1; + for (size_t i = 0; i < devices.size(); i++) { + if (devices[i] == dev) { + i_dev = i; + break; + } + } + if (i_dev != -1) { + mb_dev[i_dev].model += mb.model; + mb_dev[i_dev].context += mb.context; + mb_dev[i_dev].compute += mb.compute; + seen_buffer_types.insert(buft); + continue; + } + } + } + + // print memory breakdown for each device: + for (size_t i = 0; i < devices.size(); i++) { + ggml_backend_dev_t dev = devices[i]; + llama_memory_breakdown_data mb = mb_dev[i]; + + const std::string name = ggml_backend_dev_name(dev); + std::string desc = ggml_backend_dev_description(dev); + for (const std::string & prefix : desc_prefixes_strip) { + if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) { + desc = desc.substr(prefix.length()); + } + } + + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + + const size_t self = mb.model + mb.context + mb.compute; + const size_t unaccounted = total - self - free; + + table_data.push_back({ + template_gpu, + " - " + name + " (" + desc + ")", + std::to_string(total / MiB), + std::to_string(free / MiB), + std::to_string(self / MiB), + std::to_string(mb.model / MiB), + std::to_string(mb.context / MiB), + std::to_string(mb.compute / MiB), + std::to_string(unaccounted / MiB)}); + } + + // print memory breakdown for host: + { + const size_t self = mb_host.model + mb_host.context + mb_host.compute; + table_data.push_back({ + template_other, + " - Host", + "", // total + "", // free + std::to_string(self / MiB), + std::to_string(mb_host.model / MiB), + std::to_string(mb_host.context / MiB), + std::to_string(mb_host.compute / MiB), + ""}); // unaccounted + } + + // print memory breakdown for all remaining buffer types: + for (const auto & buft_mb : memory_breakdown) { + ggml_backend_buffer_type_t buft = buft_mb.first; + const llama_memory_breakdown_data & mb = buft_mb.second; + if (seen_buffer_types.count(buft) == 1) { + continue; + } + const std::string name = ggml_backend_buft_name(buft); + const size_t self = mb.model + mb.context + mb.compute; + table_data.push_back({ + template_other, + " - " + name, + "", // total + "", // free + std::to_string(self / MiB), + std::to_string(mb.model / MiB), + std::to_string(mb.context / MiB), + std::to_string(mb.compute / MiB), + ""}); // unaccounted + seen_buffer_types.insert(buft); + } + + for (size_t j = 1; j < table_data[0].size(); j++) { + size_t max_len = 0; + for (const auto & td : table_data) { + max_len = std::max(max_len, td[j].length()); + } + for (auto & td : table_data) { + td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' '); + } + } + for (const auto & td : table_data) { + LLAMA_LOG_INFO(td[0].c_str(), + __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(), + td[6].c_str(), td[7].c_str(), td[8].c_str()); + } +} + // // training // diff --git a/llama/llama.cpp/src/llama-context.h b/llama/llama.cpp/src/llama-context.h index 25c143d5..ed6d82cb 100644 --- a/llama/llama.cpp/src/llama-context.h +++ b/llama/llama.cpp/src/llama-context.h @@ -17,9 +17,17 @@ class llama_batch_allocr; class llama_io_read_i; class llama_io_write_i; +// "memory" as in abstract memory for the context struct llama_memory_i; struct llama_memory_context_i; +// "memory" as in physical memory for a buffer type, in bytes +struct llama_memory_breakdown_data { + size_t model = 0; // memory allocated for the model + size_t context = 0; // memory allocated for the context + size_t compute = 0; // memory allocated for temporary compute buffers +}; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -46,10 +54,8 @@ struct llama_context { llama_memory_t get_memory() const; - // return true of the KV cache was updated - // TODO: remove - bool kv_self_update(bool optimize); - void kv_self_defrag_sched(); + // return true if the memory was updated + bool memory_update(bool optimize); enum llama_pooling_type pooling_type() const; @@ -111,9 +117,9 @@ struct llama_context { size_t state_get_data( uint8_t * dst, size_t size); size_t state_set_data(const uint8_t * src, size_t size); - size_t state_seq_get_size(llama_seq_id seq_id); - size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size); - size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size); + size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags); + size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags); + size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags); bool state_load_file( const char * filepath, @@ -146,12 +152,15 @@ struct llama_context { llama_perf_context_data perf_get_data() const; void perf_reset(); + std::map memory_breakdown() const; + // // training // void opt_init(struct llama_model * model, struct llama_opt_params lopt_params); + // TODO: more flexible combinations of logical/physical batch size and context size void opt_epoch( ggml_opt_dataset_t dataset, ggml_opt_result_t result_train, @@ -197,7 +206,7 @@ public: ggml_status graph_compute(ggml_cgraph * gf, bool batched); // reserve a graph with a dummy ubatch of the specified size - ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); + ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false); private: llm_graph_params graph_params( @@ -212,8 +221,8 @@ private: size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io); - size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id); - size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id); + size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); + size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); // // members @@ -229,9 +238,6 @@ private: std::unique_ptr memory; - // TODO: temporary, until the llama_kv_self_defrag() API is removed - bool memory_force_optimize = false; - // decode output (2-dimensional array: [n_outputs][n_vocab]) size_t logits_size = 0; // capacity (of floats) for logits float * logits = nullptr; @@ -287,10 +293,6 @@ private: bool has_evaluated_once = false; - // env: LLAMA_SET_ROWS (temporary) - // ref: https://github.com/ggml-org/llama.cpp/pull/14285 - bool supports_set_rows = true; - // env: LLAMA_GRAPH_REUSE_DISABLE bool graph_reuse_disable = false; diff --git a/llama/llama.cpp/src/llama-cparams.h b/llama/llama.cpp/src/llama-cparams.h index 38750aff..eae7b839 100644 --- a/llama/llama.cpp/src/llama-cparams.h +++ b/llama/llama.cpp/src/llama-cparams.h @@ -4,7 +4,7 @@ #include -#define LLAMA_MAX_SEQ 64 +#define LLAMA_MAX_SEQ 256 struct llama_cparams { uint32_t n_ctx; // context size used during inference @@ -24,7 +24,6 @@ struct llama_cparams { float yarn_attn_factor; float yarn_beta_fast; float yarn_beta_slow; - float defrag_thold; bool embeddings; bool causal_attn; diff --git a/llama/llama.cpp/src/llama-graph.cpp b/llama/llama.cpp/src/llama-graph.cpp index 053c72d6..a24853c6 100644 --- a/llama/llama.cpp/src/llama-graph.cpp +++ b/llama/llama.cpp/src/llama-graph.cpp @@ -4,8 +4,8 @@ #include "llama-batch.h" #include "llama-cparams.h" -#include "llama-kv-cache-unified.h" -#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache.h" +#include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" #include "llama-memory-recurrent.h" @@ -204,7 +204,10 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { std::vector target_pos(n_seqs_unq, -1); std::vector target_row(n_seqs_unq, -1); - bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST; + const bool last = ( + cparams.pooling_type == LLAMA_POOLING_TYPE_LAST || + (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token + ); for (int i = 0; i < n_tokens; ++i) { const llama_pos pos = ubatch->pos[i]; @@ -258,6 +261,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } } +static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { + LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); + const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" : + (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" : + (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" : + (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown"; + LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); + LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); + LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); + + LLAMA_LOG_DEBUG(" "); + for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) { + LLAMA_LOG_DEBUG("%2d", j); + } + LLAMA_LOG_DEBUG("\n"); + + for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) { + LLAMA_LOG_DEBUG(" %2d ", i); + for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) { + float val = data[i * n_kv + j]; + if (val == -INFINITY) { + LLAMA_LOG_DEBUG(" ∞"); + } else { + LLAMA_LOG_DEBUG(" 0"); + } + } + LLAMA_LOG_DEBUG("\n"); + } +} + void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_kv = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens; @@ -267,6 +300,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { float * data = (float *) kq_mask->data; + // [TAG_NO_CACHE_ISWA] + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement"); + for (int h = 0; h < 1; ++h) { for (int i1 = 0; i1 < n_tokens; ++i1) { const llama_seq_id s1 = ubatch->seq_id[i1][0]; @@ -277,32 +313,44 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) { const llama_seq_id s0 = ubatch->seq_id[i0][0]; + if (s0 != s1) { + continue; // skip different sequences + } + + if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) { + continue; // skip future tokens for causal attention + } + + // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA] + //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) { + // continue; // skip masked tokens for SWA + //} + // TODO: reimplement this like in llama_kv_cache_unified - if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) { - if (hparams.use_alibi) { - f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]); - } else { - f = 0.0f; - } - break; + if (hparams.use_alibi) { + f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]); + } else { + f = 0.0f; } } - data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f; } } } + if (debug) { + print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + } } -void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { +void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { mctx->set_input_k_idxs(self_k_idxs, ubatch); mctx->set_input_v_idxs(self_v_idxs, ubatch); mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } -bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) { - const auto * mctx = static_cast(params.mctx); +bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); this->mctx = mctx; @@ -314,12 +362,10 @@ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) res &= self_kq_mask->ne[0] == mctx->get_n_kv(); res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); - res &= mctx->get_supports_set_rows(); // TODO: tmp - return res; } -void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { +void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); @@ -331,8 +377,8 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } -bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) { - const auto * mctx = static_cast(params.mctx); +bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); this->mctx = mctx; @@ -350,8 +396,6 @@ bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & pa res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv(); res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); - res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp - return res; } @@ -879,15 +923,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn( selection_probs = logits; } + if (arch == LLM_ARCH_GROVEMOE) { + selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens] + cb(selection_probs, "ffn_moe_probs_biased", il); + } + // select experts ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); - ggml_tensor * weights = ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) { + // TODO: Use scalar div instead when/if implemented + ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32); + selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32); + probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens); + } else { + probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens); + } + + ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); + if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) { weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens] @@ -911,6 +969,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(weights, "ffn_moe_weights_scaled", il); } + //call early so that topk-moe can be used + ggml_build_forward_expand(gf, weights); + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); if (weight_before_ffn) { @@ -1136,7 +1197,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const { } ggml_tensor * llm_graph_context::build_inp_cls() const { - auto inp = std::make_unique(cparams); + auto inp = std::make_unique(cparams, arch); auto & cur = inp->cls; @@ -1186,7 +1247,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { } ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const { - const auto * mctx_cur = static_cast(mctx); + const auto * mctx_cur = static_cast(mctx); auto inp = std::make_unique(hparams, mctx_cur); @@ -1223,15 +1284,16 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * v, ggml_tensor * kq_b, ggml_tensor * kq_mask, - ggml_tensor * v_mla, ggml_tensor * sinks, - float kq_scale) const { + ggml_tensor * v_mla, + float kq_scale, + int il) const { const bool v_trans = v->nb[1] > v->nb[2]; // split the batch into streams if needed const auto n_stream = k->ne[3]; - q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream); + q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0); q = ggml_permute(ctx0, q, 0, 2, 1, 3); k = ggml_permute(ctx0, k, 0, 2, 1, 3); @@ -1260,6 +1322,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); + cb(cur, LLAMA_TENSOR_NAME_FATTN, il); ggml_flash_attn_ext_add_sinks(cur, sinks); ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32); @@ -1275,6 +1338,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( // The permutations are noops and only change how the tensor data is interpreted. cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); cur = ggml_mul_mat(ctx0, v_mla, cur); + cb(cur, "fattn_mla", il); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs. #endif @@ -1283,6 +1347,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); } else { ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); // note: this op tends to require high floating point range // while for some models F16 is enough, for others it is not, so we default to F32 here @@ -1290,38 +1355,48 @@ ggml_tensor * llm_graph_context::build_attn_mha( if (arch == LLM_ARCH_GROK) { // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 + // multiply by attn_output_multiplier // and then : // kq = 30 * tanh(kq / 30) // before the softmax below - kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f)); - kq = ggml_scale(ctx0, kq, 30); + kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping)); + cb(kq, "kq_tanh", il); + kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping); + cb(kq, "kq_scaled", il); } if (hparams.attn_soft_cap) { kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping); + cb(kq, "kq_scaled_1", il); kq = ggml_tanh (ctx0, kq); + cb(kq, "kq_tanh", il); kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping); + cb(kq, "kq_scaled_2", il); } if (kq_b) { kq = ggml_add(ctx0, kq, kq_b); + cb(kq, "kq_plus_kq_b", il); } kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); ggml_soft_max_add_sinks(kq, sinks); + cb(kq, "kq_soft_max", il); if (!v_trans) { // note: avoid this branch v = ggml_cont(ctx0, ggml_transpose(ctx0, v)); + cb(v, "v_cont", il); } ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cb(kqv, "kqv", il); // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA if (v_mla) { kqv = ggml_mul_mat(ctx0, v_mla, kqv); + cb(kqv, "kqv_mla", il); } cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); @@ -1360,6 +1435,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * sinks, ggml_tensor * v_mla, float kq_scale, int il) const { @@ -1375,13 +1451,14 @@ ggml_tensor * llm_graph_context::build_attn( // [TAG_NO_CACHE_PAD] // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams - assert(!ubatch.equal_seqs()); + // but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636 + //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq)); ggml_tensor * q = q_cur; ggml_tensor * k = k_cur; ggml_tensor * v = v_cur; - ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); if (wo) { @@ -1399,17 +1476,17 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -static std::unique_ptr build_attn_inp_kv_unified_impl( +static std::unique_ptr build_attn_inp_kv_impl( ggml_context * ctx0, const llama_ubatch & ubatch, const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified_context * mctx_cur) { + const llama_kv_cache_context * mctx_cur) { - auto inp = std::make_unique(hparams, cparams, mctx_cur); + auto inp = std::make_unique(hparams, cparams, mctx_cur); { - GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); const auto n_kv = mctx_cur->get_n_kv(); const auto n_tokens = ubatch.n_tokens; @@ -1427,22 +1504,23 @@ static std::unique_ptr build_attn_inp_kv_unifie return inp; } -llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { - const auto * mctx_cur = static_cast(mctx); +llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const { + const auto * mctx_cur = static_cast(mctx); - auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur); + auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur); - return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); + return (llm_graph_input_attn_kv *) res->add_input(std::move(inp)); } ggml_tensor * llm_graph_context::build_attn( - llm_graph_input_attn_kv_unified * inp, + llm_graph_input_attn_kv * inp, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * sinks, ggml_tensor * v_mla, float kq_scale, int il) const { @@ -1469,7 +1547,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = mctx_cur->get_k(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il); - ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); if (wo) { @@ -1488,40 +1566,15 @@ ggml_tensor * llm_graph_context::build_attn( } ggml_tensor * llm_graph_context::build_attn( - llm_graph_input_attn_kv_unified_iswa * inp, + llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, - ggml_tensor * v_mla, - float kq_scale, - int il) const { - return build_attn_with_sinks( - inp, - wo, - wo_b, - q_cur, - k_cur, - v_cur, - kq_b, - v_mla, - nullptr, - kq_scale, - il); -} - -ggml_tensor * llm_graph_context::build_attn_with_sinks( - llm_graph_input_attn_kv_unified_iswa * inp, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - ggml_tensor * kq_b, - ggml_tensor * v_mla, ggml_tensor * sinks, + ggml_tensor * v_mla, float kq_scale, int il) const { // these nodes are added to the graph together so that they are not reordered @@ -1561,7 +1614,7 @@ ggml_tensor * llm_graph_context::build_attn_with_sinks( ggml_tensor * k = mctx_cur->get_k(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il); - ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); if (wo) { @@ -1600,6 +1653,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * sinks, ggml_tensor * v_mla, float kq_scale, int il) const { @@ -1615,7 +1669,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = k_cur; ggml_tensor * v = v_cur; - ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); if (wo) { @@ -1636,10 +1690,10 @@ ggml_tensor * llm_graph_context::build_attn( // TODO: maybe separate the inner implementation into a separate function // like with the non-sliding window equivalent // once sliding-window hybrid caches are a thing. -llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { - const auto * mctx_cur = static_cast(mctx); +llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const { + const auto * mctx_cur = static_cast(mctx); - auto inp = std::make_unique(hparams, cparams, mctx_cur); + auto inp = std::make_unique(hparams, cparams, mctx_cur); const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; @@ -1656,7 +1710,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif } { - GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); + GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA"); const auto n_kv = mctx_cur->get_swa()->get_n_kv(); @@ -1669,7 +1723,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; } - return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp)); + return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); } ggml_tensor * llm_graph_context::build_rs( @@ -1792,13 +1846,30 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto * mctx_cur = static_cast(mctx); auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); - auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); + auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +void llm_graph_context::build_dense_out( + ggml_tensor * dense_2, + ggml_tensor * dense_3) const { + if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) { + return; + } + ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd; + GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd"); + + cur = ggml_mul_mat(ctx0, dense_2, cur); + cur = ggml_mul_mat(ctx0, dense_3, cur); + cb(cur, "result_embd_pooled", -1); + res->t_embd_pooled = cur; + ggml_build_forward_expand(gf, cur); +} + + void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, @@ -1843,34 +1914,32 @@ void llm_graph_context::build_pooling( case LLAMA_POOLING_TYPE_RANK: { ggml_tensor * inp_cls = build_inp_cls(); - inp = ggml_get_rows(ctx0, inp, inp_cls); + cur = ggml_get_rows(ctx0, inp, inp_cls); + // classification head + // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 if (cls) { - // classification head - // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 - cur = ggml_mul_mat(ctx0, cls, inp); + cur = ggml_mul_mat(ctx0, cls, cur); if (cls_b) { cur = ggml_add(ctx0, cur, cls_b); } cur = ggml_tanh(ctx0, cur); + } - // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en - // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 - if (cls_out) { - cur = ggml_mul_mat(ctx0, cls_out, cur); - if (cls_out_b) { - cur = ggml_add(ctx0, cur, cls_out_b); - } - } - } else if (cls_out) { - // Single layer classification head (direct projection) - // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476 - cur = ggml_mul_mat(ctx0, cls_out, inp); + // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 + // Single layer classification head (direct projection) + // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476 + if (cls_out) { + cur = ggml_mul_mat(ctx0, cls_out, cur); if (cls_out_b) { cur = ggml_add(ctx0, cur, cls_out_b); } - } else { - GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b"); + } + + // softmax for qwen3 reranker + if (arch == LLM_ARCH_QWEN3) { + cur = ggml_soft_max(ctx0, cur); } } break; default: diff --git a/llama/llama.cpp/src/llama-graph.h b/llama/llama.cpp/src/llama-graph.h index 6ff49de3..dc84b794 100644 --- a/llama/llama.cpp/src/llama-graph.h +++ b/llama/llama.cpp/src/llama-graph.h @@ -19,8 +19,8 @@ struct llama_cparams; struct llama_memory_context_i; -class llama_kv_cache_unified_context; -class llama_kv_cache_unified_iswa_context; +class llama_kv_cache_context; +class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; @@ -78,6 +78,11 @@ struct llm_graph_params; class llm_graph_input_i { public: + llm_graph_input_i() { + const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG"); + debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0; + } + virtual ~llm_graph_input_i() = default; virtual void set_input(const llama_ubatch * ubatch) = 0; @@ -90,6 +95,9 @@ public: GGML_UNUSED(params); return false; } +protected: + // env: LLAMA_GRAPH_INPUT_DEBUG + int debug = 0; }; using llm_graph_input_ptr = std::unique_ptr; @@ -152,7 +160,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { public: llm_graph_input_pos_bucket_kv( const llama_hparams & hparams, - const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {} + const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {} virtual ~llm_graph_input_pos_bucket_kv() = default; void set_input(const llama_ubatch * ubatch) override; @@ -161,7 +169,7 @@ public: const llama_hparams hparams; - const llama_kv_cache_unified_context * mctx; + const llama_kv_cache_context * mctx; }; class llm_graph_input_out_ids : public llm_graph_input_i { @@ -198,7 +206,7 @@ public: class llm_graph_input_cls : public llm_graph_input_i { public: - llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {} + llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {} virtual ~llm_graph_input_cls() = default; void set_input(const llama_ubatch * ubatch) override; @@ -206,6 +214,7 @@ public: ggml_tensor * cls; // I32 [n_batch] const llama_cparams cparams; + const llm_arch arch; }; class llm_graph_input_rs : public llm_graph_input_i { @@ -257,17 +266,17 @@ public: const llama_cparams cparams; }; -class llm_graph_input_attn_kv_unified : public llm_graph_input_i { +class llm_graph_input_attn_kv : public llm_graph_input_i { public: - llm_graph_input_attn_kv_unified( + llm_graph_input_attn_kv( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified_context * mctx) : + const llama_kv_cache_context * mctx) : hparams(hparams), cparams(cparams), mctx(mctx) { } - ~llm_graph_input_attn_kv_unified() = default; + ~llm_graph_input_attn_kv() = default; void set_input(const llama_ubatch * ubatch) override; @@ -290,20 +299,20 @@ public: const llama_hparams hparams; const llama_cparams cparams; - const llama_kv_cache_unified_context * mctx; + const llama_kv_cache_context * mctx; }; -class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { +class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { public: - llm_graph_input_attn_kv_unified_iswa( + llm_graph_input_attn_kv_iswa( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified_iswa_context * mctx) : + const llama_kv_cache_iswa_context * mctx) : hparams(hparams), cparams(cparams), mctx(mctx) { } - ~llm_graph_input_attn_kv_unified_iswa() = default; + ~llm_graph_input_attn_kv_iswa() = default; void set_input(const llama_ubatch * ubatch) override; @@ -330,7 +339,7 @@ public: const llama_hparams hparams; const llama_cparams cparams; - const llama_kv_cache_unified_iswa_context * mctx; + const llama_kv_cache_iswa_context * mctx; }; class llm_graph_input_attn_cross : public llm_graph_input_i { @@ -351,7 +360,7 @@ public: class llm_graph_input_mem_hybrid : public llm_graph_input_i { public: llm_graph_input_mem_hybrid( - std::unique_ptr inp_attn, + std::unique_ptr inp_attn, std::unique_ptr inp_rs, const llama_memory_hybrid_context * mctx) : inp_attn(std::move(inp_attn)), @@ -361,11 +370,11 @@ public: void set_input(const llama_ubatch * ubatch) override; - std::unique_ptr inp_attn; - std::unique_ptr inp_rs; + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; - llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); } - llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } const llama_memory_hybrid_context * mctx; }; @@ -680,14 +689,15 @@ struct llm_graph_context { // ggml_tensor * build_attn_mha( - ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens] - ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens] - ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false) - ggml_tensor * kq_b, - ggml_tensor * kq_mask, - ggml_tensor * sinks, - ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] - float kq_scale) const; + ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false) + ggml_tensor * kq_b, + ggml_tensor * kq_mask, + ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const; @@ -699,50 +709,39 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; - llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const; + llm_graph_input_attn_kv * build_attn_inp_kv() const; ggml_tensor * build_attn( - llm_graph_input_attn_kv_unified * inp, + llm_graph_input_attn_kv * inp, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; - llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const; + llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const; // note: if k_cur or v_cur are not provided, they will not be stored in the memory ggml_tensor * build_attn( - llm_graph_input_attn_kv_unified_iswa * inp, + llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional ggml_tensor * kq_b, - ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] - float kq_scale, - int il) const; - - // TODO: temporary to keep the diff small. after the code is public will refactor to simplify this - ggml_tensor * build_attn_with_sinks( - llm_graph_input_attn_kv_unified_iswa * inp, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] - ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional - ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional - ggml_tensor * kq_b, - ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -756,6 +755,7 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -765,7 +765,7 @@ struct llm_graph_context { // // TODO: move this implementation to llama_memory_recurrent. - // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v + // this is analogous to llama_kv_cache::cpy_k / cpy_v // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in // `llama_memory_recurrent` @@ -814,6 +814,14 @@ struct llm_graph_context { ggml_tensor * cls_b, ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + + // + // dense (out) + // + + void build_dense_out( + ggml_tensor * dense_2, + ggml_tensor * dense_3) const; }; // TODO: better name diff --git a/llama/llama.cpp/src/llama-hparams.cpp b/llama/llama.cpp/src/llama-hparams.cpp index 35fc054f..b6bf6bbf 100644 --- a/llama/llama.cpp/src/llama-hparams.cpp +++ b/llama/llama.cpp/src/llama-hparams.cpp @@ -1,6 +1,7 @@ #include "llama-hparams.h" #include "ggml.h" +#include void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { if (dense_first) { @@ -139,7 +140,11 @@ uint32_t llama_hparams::n_embd_s() const { } bool llama_hparams::is_recurrent(uint32_t il) const { - return recurrent_layer_arr[il]; + if (il < n_layer) { + return recurrent_layer_arr[il]; + } + + GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer); } uint32_t llama_hparams::n_pos_per_embd() const { @@ -161,3 +166,64 @@ bool llama_hparams::is_swa(uint32_t il) const { GGML_ABORT("fatal error"); } + +bool llama_hparams::has_kv(uint32_t il) const { + if (n_layer_kv_from_start >= 0) { + if (il < (uint32_t) n_layer_kv_from_start) { + return true; + } + + return false; + } + + // by default, all layers have kv + return true; +} + +uint32_t llama_hparams::n_layer_kv() const { + uint32_t res = 0; + + for (uint32_t il = 0; il < n_layer; ++il) { + if (has_kv(il)) { + res++; + } + } + + return res; +} + +bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + case LLAMA_SWA_TYPE_SYMMETRIC: + { + const int32_t half_n_swa = (int32_t) n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // Mask if outside the symmetric window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; + } + + return false; +} diff --git a/llama/llama.cpp/src/llama-hparams.h b/llama/llama.cpp/src/llama-hparams.h index 29bd9056..80582728 100644 --- a/llama/llama.cpp/src/llama-hparams.h +++ b/llama/llama.cpp/src/llama-hparams.h @@ -16,9 +16,10 @@ enum llama_expert_gating_func_type { }; enum llama_swa_type { - LLAMA_SWA_TYPE_NONE = 0, - LLAMA_SWA_TYPE_STANDARD = 1, - LLAMA_SWA_TYPE_CHUNKED = 2, + LLAMA_SWA_TYPE_NONE = 0, + LLAMA_SWA_TYPE_STANDARD = 1, + LLAMA_SWA_TYPE_CHUNKED = 2, + LLAMA_SWA_TYPE_SYMMETRIC = 3, }; struct llama_hparams_posnet { @@ -41,6 +42,7 @@ struct llama_hparams { uint32_t n_embd; uint32_t n_embd_features = 0; uint32_t n_layer; + int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache uint32_t n_rot; uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head @@ -69,10 +71,13 @@ struct llama_hparams { uint32_t n_lora_kv = 0; uint32_t n_ff_exp = 0; uint32_t n_ff_shexp = 0; + uint32_t n_ff_chexp = 0; uint32_t n_expert_shared = 0; uint32_t n_norm_groups = 0; + uint32_t n_group_experts = 0; - float expert_weights_scale = 0.0; + float expert_group_scale = 0.05f; + float expert_weights_scale = 0.0f; bool expert_weights_norm = false; uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; uint32_t moe_every_n_layers = 0; @@ -82,8 +87,9 @@ struct llama_hparams { float f_norm_rms_eps; float f_norm_group_eps; - float f_attn_logit_softcapping = 50.0f; - float f_final_logit_softcapping = 30.0f; + float f_attn_logit_softcapping = 50.0f; + float f_router_logit_softcapping = 30.0f; + float f_final_logit_softcapping = 30.0f; // for RWKV uint32_t rescale_every_n_layers = 0; @@ -104,6 +110,11 @@ struct llama_hparams { uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul = 0.0f; + float yarn_ext_factor = -1.0f; + float yarn_attn_factor = 1.0f; + float yarn_beta_fast = 32.0f; + float yarn_beta_slow = 1.0f; + std::array rope_sections; // Sliding Window Attention (SWA) @@ -136,10 +147,14 @@ struct llama_hparams { float f_embedding_scale = 0.0f; float f_attention_scale = 0.0f; + // grok-2 + float f_attn_out_scale = 0.0f; + uint32_t attn_temp_length = 0; + bool causal_attn = true; bool use_alibi = false; bool attn_soft_cap = false; - bool use_kq_norm = true; + bool use_kq_norm = false; // for Classifiers uint32_t n_cls_out = 1; @@ -156,9 +171,22 @@ struct llama_hparams { uint32_t laurel_rank = 64; uint32_t n_embd_altup = 256; + // needed for sentence-transformers dense layers + uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense + uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense + uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense + uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense + + // xIELU + std::array xielu_alpha_n; + std::array xielu_alpha_p; + std::array xielu_beta; + std::array xielu_eps; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; + uint32_t dec_n_layer = 0; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -226,6 +254,16 @@ struct llama_hparams { bool n_bskcn(uint32_t n, uint32_t il) const; bool is_swa(uint32_t il) const; + + bool has_kv(uint32_t il) const; + + // number of layers for which has_kv() returns true + uint32_t n_layer_kv() const; + + // note that this function uses different SWA parameters from those in the hparams + // TODO: think of a better place for this function + // TODO: pack the SWA params in a struct? + static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); }; static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); diff --git a/llama/llama.cpp/src/llama-impl.h b/llama/llama.cpp/src/llama-impl.h index 02b1d07f..c5163e92 100644 --- a/llama/llama.cpp/src/llama-impl.h +++ b/llama/llama.cpp/src/llama-impl.h @@ -59,3 +59,5 @@ std::string llama_format_tensor_shape(const std::vector & ne); std::string llama_format_tensor_shape(const struct ggml_tensor * t); std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); + +#define LLAMA_TENSOR_NAME_FATTN "__fattn__" diff --git a/llama/llama.cpp/src/llama-kv-cache-unified-iswa.cpp b/llama/llama.cpp/src/llama-kv-cache-iswa.cpp similarity index 56% rename from llama/llama.cpp/src/llama-kv-cache-unified-iswa.cpp rename to llama/llama.cpp/src/llama-kv-cache-iswa.cpp index 01d27fb4..facba1d0 100644 --- a/llama/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +++ b/llama/llama.cpp/src/llama-kv-cache-iswa.cpp @@ -1,4 +1,4 @@ -#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache-iswa.h" #include "llama-impl.h" #include "llama-batch.h" @@ -8,10 +8,10 @@ #include // -// llama_kv_cache_unified_iswa +// llama_kv_cache_iswa // -llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( +llama_kv_cache_iswa::llama_kv_cache_iswa( const llama_model & model, ggml_type type_k, ggml_type type_v, @@ -22,9 +22,26 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( uint32_t kv_size, uint32_t n_seq_max, uint32_t n_ubatch, - uint32_t n_pad) : hparams(model.hparams), unified(unified) { - llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; - llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; + uint32_t n_pad, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) { + + // chain filters + const layer_filter_cb filter_base = [&](int32_t il) { + if (filter && !filter(il)) { + return false; + } + + return !model.hparams.is_swa(il); + }; + + const layer_filter_cb filter_swa = [&](int32_t il) { + if (filter && !filter(il)) { + return false; + } + + return model.hparams.is_swa(il); + }; const uint32_t size_base = kv_size; @@ -40,25 +57,25 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); - kv_base = std::make_unique( - model, std::move(filter_base), type_k, type_v, + kv_base = std::make_unique( + model, type_k, type_v, v_trans, offload, unified, size_base, n_seq_max, n_pad, - 0, LLAMA_SWA_TYPE_NONE); + 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); - kv_swa = std::make_unique( - model, std::move(filter_swa), type_k, type_v, + kv_swa = std::make_unique( + model, type_k, type_v, v_trans, offload, unified, size_swa, n_seq_max, n_pad, - hparams.n_swa, hparams.swa_type); + hparams.n_swa, hparams.swa_type, filter_swa, reuse); } -void llama_kv_cache_unified_iswa::clear(bool data) { +void llama_kv_cache_iswa::clear(bool data) { kv_base->clear(data); kv_swa ->clear(data); } -bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { +bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { bool res = true; res = res & kv_base->seq_rm(seq_id, p0, p1); @@ -67,36 +84,44 @@ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llam return res; } -void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { +void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1); kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1); } -void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { +void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) { kv_base->seq_keep(seq_id); kv_swa ->seq_keep(seq_id); } -void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { +void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { kv_base->seq_add(seq_id, p0, p1, shift); kv_swa ->seq_add(seq_id, p0, p1, shift); } -void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { +void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { kv_base->seq_div(seq_id, p0, p1, d); kv_swa ->seq_div(seq_id, p0, p1, d); } -llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const { +llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const { // the base cache is a superset of the SWA cache, so we can just check the SWA cache return kv_swa->seq_pos_min(seq_id); } -llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { +llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const { return kv_swa->seq_pos_max(seq_id); } -llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { +std::map llama_kv_cache_iswa::memory_breakdown() const { + std::map mb = kv_base->memory_breakdown(); + for (const auto & buft_size : kv_swa->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; +} + +llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { GGML_UNUSED(embd_all); // first try simple split @@ -136,7 +161,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all assert(sinfos_base.size() == sinfos_swa.size()); - return std::make_unique( + return std::make_unique( this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); } while (false); @@ -172,61 +197,67 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all assert(sinfos_base.size() == sinfos_swa.size()); - return std::make_unique( + return std::make_unique( this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); } while (false); // TODO: if we fail again, we should attempt different splitting strategies // but to do that properly, we first have to refactor the batches to be more flexible - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } -llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() { - return std::make_unique(this); +llama_memory_context_ptr llama_kv_cache_iswa::init_full() { + return std::make_unique(this); } -llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { - return std::make_unique(this, lctx, optimize); +llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); } -bool llama_kv_cache_unified_iswa::get_can_shift() const { +bool llama_kv_cache_iswa::get_can_shift() const { return kv_base->get_size() == kv_swa->get_size(); } -void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - kv_base->state_write(io, seq_id); - kv_swa ->state_write(io, seq_id); +void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { + kv_base->state_write(io, seq_id, flags); + } + + kv_swa->state_write(io, seq_id, flags); } -void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - kv_base->state_read(io, seq_id); - kv_swa ->state_read(io, seq_id); +void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { + kv_base->state_read(io, seq_id, flags); + } + + kv_swa->state_read(io, seq_id, flags); } -llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const { +llama_kv_cache * llama_kv_cache_iswa::get_base() const { return kv_base.get(); } -llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { +llama_kv_cache * llama_kv_cache_iswa::get_swa() const { return kv_swa.get(); } // -// llama_kv_cache_unified_iswa_context +// llama_kv_cache_iswa_context // -llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {} +llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {} -llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( - llama_kv_cache_unified_iswa * kv) : +llama_kv_cache_iswa_context::llama_kv_cache_iswa_context( + llama_kv_cache_iswa * kv) : ctx_base(kv->get_base()->init_full()), ctx_swa (kv->get_swa ()->init_full()), status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { } -llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( - llama_kv_cache_unified_iswa * kv, +llama_kv_cache_iswa_context::llama_kv_cache_iswa_context( + llama_kv_cache_iswa * kv, llama_context * lctx, bool optimize) : ctx_base(kv->get_base()->init_update(lctx, optimize)), @@ -234,21 +265,21 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { } -llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( - llama_kv_cache_unified_iswa * kv, +llama_kv_cache_iswa_context::llama_kv_cache_iswa_context( + llama_kv_cache_iswa * kv, slot_info_vec_t sinfos_base, slot_info_vec_t sinfos_swa, std::vector ubatches) : ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal - ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)), - ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)), + ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)), + ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)), status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { } -llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default; +llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default; -bool llama_kv_cache_unified_iswa_context::next() { +bool llama_kv_cache_iswa_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); ctx_base->next(); @@ -261,7 +292,7 @@ bool llama_kv_cache_unified_iswa_context::next() { return true; } -bool llama_kv_cache_unified_iswa_context::apply() { +bool llama_kv_cache_iswa_context::apply() { assert(!llama_memory_status_is_fail(status)); bool res = true; @@ -272,24 +303,24 @@ bool llama_kv_cache_unified_iswa_context::apply() { return res; } -llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const { +llama_memory_status llama_kv_cache_iswa_context::get_status() const { return status; } -const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const { +const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return ubatches[i_next]; } -const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const { +const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return static_cast(ctx_base.get()); + return static_cast(ctx_base.get()); } -const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const { +const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return static_cast(ctx_swa.get()); + return static_cast(ctx_swa.get()); } diff --git a/llama/llama.cpp/src/llama-kv-cache-unified-iswa.h b/llama/llama.cpp/src/llama-kv-cache-iswa.h similarity index 66% rename from llama/llama.cpp/src/llama-kv-cache-unified-iswa.h rename to llama/llama.cpp/src/llama-kv-cache-iswa.h index d2650dad..70ab22f0 100644 --- a/llama/llama.cpp/src/llama-kv-cache-unified-iswa.h +++ b/llama/llama.cpp/src/llama-kv-cache-iswa.h @@ -1,19 +1,19 @@ #pragma once -#include "llama-kv-cache-unified.h" +#include "llama-kv-cache.h" #include // -// llama_kv_cache_unified_iswa +// llama_kv_cache_iswa // -// utilizes two instances of llama_kv_cache_unified +// utilizes two instances of llama_kv_cache // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers -class llama_kv_cache_unified_iswa : public llama_memory_i { +class llama_kv_cache_iswa : public llama_memory_i { public: - llama_kv_cache_unified_iswa( + llama_kv_cache_iswa( const llama_model & model, ggml_type type_k, ggml_type type_v, @@ -24,9 +24,11 @@ public: uint32_t kv_size, uint32_t n_seq_max, uint32_t n_ubatch, - uint32_t n_pad); + uint32_t n_pad, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse); - ~llama_kv_cache_unified_iswa() = default; + ~llama_kv_cache_iswa() = default; // // llama_memory_i @@ -54,52 +56,54 @@ public: llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override; + std::map memory_breakdown() const override; + // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; // - // llama_kv_cache_unified_iswa specific API + // llama_kv_cache_iswa specific API // - llama_kv_cache_unified * get_base() const; - llama_kv_cache_unified * get_swa () const; + llama_kv_cache * get_base() const; + llama_kv_cache * get_swa () const; private: const llama_hparams & hparams; const bool unified; - std::unique_ptr kv_base; - std::unique_ptr kv_swa; + std::unique_ptr kv_base; + std::unique_ptr kv_swa; }; -class llama_kv_cache_unified_iswa_context : public llama_memory_context_i { +class llama_kv_cache_iswa_context : public llama_memory_context_i { public: - using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; // used for errors - llama_kv_cache_unified_iswa_context(llama_memory_status status); + llama_kv_cache_iswa_context(llama_memory_status status); // used to create a full-cache context - llama_kv_cache_unified_iswa_context( - llama_kv_cache_unified_iswa * kv); + llama_kv_cache_iswa_context( + llama_kv_cache_iswa * kv); // used to create an update context - llama_kv_cache_unified_iswa_context( - llama_kv_cache_unified_iswa * kv, + llama_kv_cache_iswa_context( + llama_kv_cache_iswa * kv, llama_context * lctx, bool optimize); // used to create a batch processing context from a batch - llama_kv_cache_unified_iswa_context( - llama_kv_cache_unified_iswa * kv, + llama_kv_cache_iswa_context( + llama_kv_cache_iswa * kv, slot_info_vec_t sinfos_base, slot_info_vec_t sinfos_swa, std::vector ubatches); - virtual ~llama_kv_cache_unified_iswa_context(); + virtual ~llama_kv_cache_iswa_context(); // // llama_memory_context_i @@ -112,14 +116,14 @@ public: const llama_ubatch & get_ubatch() const override; // - // llama_kv_cache_unified_iswa_context specific API + // llama_kv_cache_iswa_context specific API // - const llama_kv_cache_unified_context * get_base() const; - const llama_kv_cache_unified_context * get_swa() const; + const llama_kv_cache_context * get_base() const; + const llama_kv_cache_context * get_swa() const; private: - //llama_kv_cache_unified_iswa * kv; + //llama_kv_cache_iswa * kv; // the index of the next ubatch to process size_t i_next = 0; diff --git a/llama/llama.cpp/src/llama-kv-cache-unified.cpp b/llama/llama.cpp/src/llama-kv-cache.cpp similarity index 66% rename from llama/llama.cpp/src/llama-kv-cache-unified.cpp rename to llama/llama.cpp/src/llama-kv-cache.cpp index e539142e..736693e1 100644 --- a/llama/llama.cpp/src/llama-kv-cache-unified.cpp +++ b/llama/llama.cpp/src/llama-kv-cache.cpp @@ -1,4 +1,4 @@ -#include "llama-kv-cache-unified.h" +#include "llama-kv-cache.h" #include "llama-impl.h" #include "llama-io.h" @@ -13,36 +13,29 @@ #include // -// llama_kv_cache_unified +// llama_kv_cache // -llama_kv_cache_unified::llama_kv_cache_unified( - const llama_model & model, - layer_filter_cb && filter, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - bool unified, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_pad, - uint32_t n_swa, - llama_swa_type swa_type) : +llama_kv_cache::llama_kv_cache( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse) : model(model), hparams(model.hparams), v_trans(v_trans), n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { GGML_ASSERT(kv_size % n_pad == 0); - // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE] - auto n_layer_cache = hparams.n_layer; - if (model.arch == LLM_ARCH_GEMMA3N) { - n_layer_cache = 20; - } - if (model.arch == LLM_ARCH_GLM4_MOE) { - // GLM-4.5: Only process up to last layer, skip final NextN layer - n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers; - } + const uint32_t n_layer_kv = hparams.n_layer_kv(); // create a context for each buffer type std::map ctx_map; @@ -50,7 +43,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -97,9 +90,14 @@ llama_kv_cache_unified::llama_kv_cache_unified( __func__, hparams.n_embd_v_gqa_max()); } - for (uint32_t il = 0; il < n_layer_cache; il++) { + for (uint32_t il = 0; il < hparams.n_layer; il++) { + if (!hparams.has_kv(il)) { + LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); + continue; + } + if (filter && !filter(il)) { - LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); + LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il); continue; } @@ -125,11 +123,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k; - ggml_tensor * v; - - k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); - v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); + ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); + ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); @@ -147,23 +142,27 @@ llama_kv_cache_unified::llama_kv_cache_unified( layers.push_back({ il, k, v, k_stream, v_stream, }); } - // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE] - if (model.arch == LLM_ARCH_GEMMA3N) { - LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1); + if (reuse) { + LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__); - for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) { - if (filter && !filter(il)) { - LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); + for (uint32_t il = 0; il < hparams.n_layer; il++) { + const int32_t il_reuse = reuse(il); + + if (il_reuse < 0) { + LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il); continue; } - const bool is_swa = hparams.is_swa(il); - const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1); + if (filter && !filter(il)) { + LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il); + continue; + } GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end()); + map_layer_ids[il] = map_layer_ids[il_reuse]; - LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa); + LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il)); } } @@ -195,21 +194,9 @@ llama_kv_cache_unified::llama_kv_cache_unified( const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; - - const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); - supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : supports_set_rows; - - if (!supports_set_rows) { - // ref: https://github.com/ggml-org/llama.cpp/pull/14363 - GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support"); - } - - if (!supports_set_rows) { - LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__); - } } -void llama_kv_cache_unified::clear(bool data) { +void llama_kv_cache::clear(bool data) { for (uint32_t s = 0; s < n_stream; ++s) { v_cells[s].reset(); v_heads[s] = 0; @@ -222,13 +209,8 @@ void llama_kv_cache_unified::clear(bool data) { } } -bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); - - auto & cells = v_cells[seq_to_stream[seq_id]]; - auto & head = v_heads[seq_to_stream[seq_id]]; - - uint32_t new_head = cells.size(); +bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); if (p0 < 0) { p0 = 0; @@ -239,6 +221,11 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } if (seq_id >= 0) { + auto & cells = v_cells[seq_to_stream[seq_id]]; + auto & head = v_heads[seq_to_stream[seq_id]]; + + uint32_t new_head = cells.size(); + for (uint32_t i = 0; i < cells.size(); ++i) { if (!cells.pos_in(i, p0, p1)) { continue; @@ -250,30 +237,42 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } } } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cells.size() && new_head < head) { + head = new_head; + } } else { // match any sequence - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; + for (uint32_t s = 0; s < n_stream; ++s) { + auto & cells = v_cells[s]; + auto & head = v_heads[s]; + + uint32_t new_head = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + cells.rm(i); + + if (new_head == cells.size()) { + new_head = i; + } } - cells.rm(i); - - if (new_head == cells.size()) { - new_head = i; + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cells.size() && new_head < head) { + head = new_head; } } } - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cells.size() && new_head < head) { - head = new_head; - } - return true; } -void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { +void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size()); GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size()); @@ -356,7 +355,7 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id //} } -void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { +void llama_kv_cache::seq_keep(llama_seq_id seq_id) { GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -378,7 +377,7 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { } } -void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { +void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -422,7 +421,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po head = new_head != cells.size() ? new_head : 0; } -void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { +void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -455,7 +454,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po } } -llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { +llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const { GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); const auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -463,7 +462,7 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { return cells.seq_pos_min(seq_id); } -llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { +llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const { GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); const auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -471,7 +470,15 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { return cells.seq_pos_max(seq_id); } -llama_memory_context_ptr llama_kv_cache_unified::init_batch( +std::map llama_kv_cache::memory_breakdown() const { + std::map ret; + for (const ggml_backend_buffer_ptr & buf_ptr : bufs) { + ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); + } + return ret; +} + +llama_memory_context_ptr llama_kv_cache::init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { @@ -501,62 +508,34 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( break; } - return std::make_unique( + return std::make_unique( this, std::move(sinfos), std::move(ubatches)); } while (false); - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } -llama_memory_context_ptr llama_kv_cache_unified::init_full() { - return std::make_unique(this); +llama_memory_context_ptr llama_kv_cache::init_full() { + return std::make_unique(this); } -llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) { +llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) { + GGML_UNUSED(optimize); + bool do_shift = get_has_shift(); - defrag_info dinfo; - - // see if we need to defrag - if (n_stream == 1) { - // note : for now do not consider defrag for n_stream > 1 - const auto & cells = v_cells[seq_to_stream[0]]; - - bool do_defrag = optimize; - - const auto thold = lctx->get_cparams().defrag_thold; - - if (!do_defrag && thold > 0.0f) { - const auto n_kv = cells.used_max_p1(); - - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; - - if (fragmentation > thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - do_defrag = true; - } - } - - if (do_defrag) { - dinfo = defrag_prepare(lctx->graph_max_nodes()); - } - } - - return std::make_unique(this, lctx, do_shift, std::move(dinfo), std::move(sc_info)); + return std::make_unique(this, lctx, do_shift, std::move(sc_info)); } -llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector & ubatches) { - llama_kv_cache_unified::slot_info_vec_t res; +llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector & ubatches) { + llama_kv_cache::slot_info_vec_t res; struct state_t { slot_info sinfo; // slot info for the ubatch std::vector v_heads_old; // old positions of the heads, before placing the ubatch - std::vector v_cells; // copy of the old cells, before placing the ubatch + std::vector v_cells; // copy of the old cells, before placing the ubatch }; // remember the old state of the cells so we can restore it in the end @@ -565,11 +544,8 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st bool success = true; for (const auto & ubatch : ubatches) { - // non-continuous slots require support for ggml_set_rows() - const bool cont = supports_set_rows ? false : true; - // only find a suitable slot for the ubatch. don't modify the cells yet - const auto sinfo_new = find_slot(ubatch, cont); + const auto sinfo_new = find_slot(ubatch, false); if (sinfo_new.empty()) { success = false; break; @@ -617,7 +593,7 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st return res; } -bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) { +bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) { bool updated = false; auto * sched = lctx->get_sched(); @@ -687,117 +663,74 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d } } - if (!dinfo.empty()) { - LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - - // note: for now do not consider defrag for n_stream > 1 - auto & cells = v_cells[seq_to_stream[0]]; - auto & head = v_heads[seq_to_stream[0]]; - - // apply moves: - { - const auto n_kv = dinfo.ids.size(); - - for (uint32_t i = 0; i < n_kv; ++i) { - assert(dinfo.ids[i] <= n_kv); - - if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) { - continue; - } - - cells.mv(i, dinfo.ids[i]); - } - - // reset the head so we can find the first free slot during the next ubatch - head = 0; - } - - ggml_backend_sched_reset(sched); - - auto * res = lctx->get_gf_res_reserve(); - - res->reset(); - - auto * gf = build_graph_defrag(res, lctx, dinfo); - if (!ggml_backend_sched_alloc_graph(sched, gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); - return updated; - } - - res->set_inputs(nullptr); - - if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { - LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); - return updated; - } - - updated = true; - } - return updated; } -llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { +llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const { + if (debug > 0) { - const auto & cells = v_cells[seq_to_stream[1]]; + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const auto seq_id = ubatch.seq_id_unq[s]; + const auto stream_id = seq_to_stream[seq_id]; + const auto & cells = v_cells[stream_id]; + const uint32_t head_cur = v_heads[stream_id]; - const uint32_t head_cur = v_heads[1]; + LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", + __func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa); - LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", - __func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa); - - if ((debug == 2 && n_swa > 0) || debug > 2) { - std::string ss; - for (uint32_t i = 0; i < cells.size(); ++i) { - if (cells.is_empty(i)) { - ss += '.'; - } else { - assert(cells.seq_count(i) >= 1); - - if (cells.seq_count(i) == 1) { - ss += std::to_string(cells.seq_get(i)); + if ((debug == 2 && n_swa > 0) || debug > 2) { + std::string ss; + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.is_empty(i)) { + ss += '.'; } else { - ss += 'M'; + assert(cells.seq_count(i) >= 1); + + if (cells.seq_count(i) == 1) { + ss += std::to_string(cells.seq_get(i)); + } else { + ss += 'M'; + } + } + if (i%256 == 255) { + ss += " *"; + ss += '\n'; } } - if (i%256 == 255) { - ss += " *"; - ss += '\n'; - } - } - LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); - } - - if ((debug == 2 && n_swa > 0) || debug > 2) { - std::string ss; - for (uint32_t i = 0; i < cells.size(); ++i) { - std::string cur; - if (cells.is_empty(i)) { - cur = '.'; - } else { - cur = std::to_string(cells.pos_get(i)); - } - const int n = cur.size(); - for (int j = 0; j < 5 - n; ++j) { - cur += ' '; - } - ss += cur; - if (i%256 == 255) { - ss += " *"; - } - if (i%64 == 63) { - ss += '\n'; - } - } - LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); - } - - for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { - if (cells.seq_pos_min(s) < 0) { - continue; + LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); } - LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s)); + if ((debug == 2 && n_swa > 0) || debug > 2) { + std::string ss; + for (uint32_t i = 0; i < cells.size(); ++i) { + std::string cur; + if (cells.is_empty(i)) { + cur = '.'; + } else { + cur = std::to_string(cells.pos_get(i)); + } + const int n = cur.size(); + for (int j = 0; j < 5 - n; ++j) { + cur += ' '; + } + ss += cur; + if (i%256 == 255) { + ss += " *"; + } + if (i%64 == 63) { + ss += '\n'; + } + } + LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); + } + + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (cells.seq_pos_min(s) < 0) { + continue; + } + + LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s)); + } } } @@ -828,8 +761,8 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id); } - res.s0 = std::min(res.s0, seq_to_stream[seq_id]); - res.s1 = std::max(res.s1, seq_to_stream[seq_id]); + res.s0 = std::min(res.s0, seq_to_stream[seq_id]); + res.s1 = std::max(res.s1, seq_to_stream[seq_id]); res.strm[s] = seq_to_stream[seq_id]; res.idxs[s].reserve(n_tokens); @@ -932,7 +865,7 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ return res; } -void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { +void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; @@ -997,21 +930,21 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u } } -bool llama_kv_cache_unified::get_can_shift() const { +bool llama_kv_cache::get_can_shift() const { return true; } -uint32_t llama_kv_cache_unified::get_size() const { +uint32_t llama_kv_cache::get_size() const { const auto & cells = v_cells[seq_to_stream[0]]; return cells.size(); } -uint32_t llama_kv_cache_unified::get_n_stream() const { +uint32_t llama_kv_cache::get_n_stream() const { return n_stream; } -bool llama_kv_cache_unified::get_has_shift() const { +bool llama_kv_cache::get_has_shift() const { bool result = false; for (uint32_t s = 0; s < n_stream; ++s) { @@ -1021,11 +954,11 @@ bool llama_kv_cache_unified::get_has_shift() const { return result; } -uint32_t llama_kv_cache_unified::get_n_kv() const { +uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { uint32_t result = 0; - for (uint32_t s = 0; s < n_stream; ++s) { - const auto & cells = v_cells[s]; + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const auto & cells = v_cells[sinfo.strm[s]]; result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result); } @@ -1033,11 +966,7 @@ uint32_t llama_kv_cache_unified::get_n_kv() const { return result; } -bool llama_kv_cache_unified::get_supports_set_rows() const { - return supports_set_rows; -} - -ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { +ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * k = layers[ikv].k; @@ -1057,7 +986,7 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0); } -ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { +ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; @@ -1074,106 +1003,113 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint // note: v->nb[1] <= v->nb[2] return ggml_view_4d(ctx, v, hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] - ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] + ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0); } // note: v->nb[1] > v->nb[2] return ggml_view_4d(ctx, v, n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns, - ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, kv_size), // v->nb[2] - ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] + ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, kv_size), // v->nb[2] + ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0); } -ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { +ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { + GGML_UNUSED(sinfo); + const int32_t ikv = map_layer_ids.at(il); - auto * k = layers[ikv].k; + ggml_tensor * k = layers[ikv].k; - const int64_t n_embd_k_gqa = k->ne[0]; - const int64_t n_tokens = k_cur->ne[2]; + const int64_t n_embd_head = k_cur->ne[0]; + const int64_t n_head = k_cur->ne[1]; + const int64_t n_tokens = k_cur->ne[2]; - k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); + const int64_t n_embd_gqa = n_embd_head*n_head; - if (k_idxs && supports_set_rows) { - if (k->ne[2] > 1) { - k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); - } + // we can merge dims 0 and 1 + // TODO: add ggml helper function for this? + GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]); - return ggml_set_rows(ctx, k, k_cur, k_idxs); + k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0); + + const int64_t n_stream = k->ne[2]; + + if (n_stream > 1) { + const int64_t kv_size = get_size(); + + assert(n_embd_gqa == k->ne[0]); + assert(kv_size == k->ne[1]); + + // merge the buffer across all streams because the idxs are global + k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream); } - // TODO: fallback to old ggml_cpy() method for backwards compatibility - // will be removed when ggml_set_rows() is adopted by all backends - - GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); - - ggml_tensor * k_view = ggml_view_1d(ctx, k, - n_tokens*n_embd_k_gqa, - ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head()); - - return ggml_cpy(ctx, k_cur, k_view); + // store the current K values into the cache + return ggml_set_rows(ctx, k, k_cur, k_idxs); } -ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const { +ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const { + GGML_UNUSED(sinfo); + const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; - const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1]; - const int64_t n_tokens = v_cur->ne[2]; + const int64_t n_embd_head = v_cur->ne[0]; + const int64_t n_head = v_cur->ne[1]; + const int64_t n_tokens = v_cur->ne[2]; - v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); + const int64_t n_embd_gqa = n_embd_head*n_head; - if (v_idxs && supports_set_rows) { - if (!v_trans) { - if (v->ne[2] > 1) { - v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]); - } + // we can merge dims 0 and 1 + GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]); - return ggml_set_rows(ctx, v, v_cur, v_idxs); - } - - // [TAG_V_CACHE_VARIABLE] - if (n_embd_v_gqa < v->ne[0]) { - v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0); - } - - // the row becomes a single element - ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]); - - v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]); - - return ggml_set_rows(ctx, v_view, v_cur, v_idxs); - } - - // TODO: fallback to old ggml_cpy() method for backwards compatibility - // will be removed when ggml_set_rows() is adopted by all backends - - GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); - - ggml_tensor * v_view = nullptr; + const int64_t n_stream = v->ne[2]; + // take this branch when FA is enabled (the V cache is not transposed) if (!v_trans) { - v_view = ggml_view_1d(ctx, v, - n_tokens*n_embd_v_gqa, - ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head()); - } else { - v_cur = ggml_transpose(ctx, v_cur); + v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0); - v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa, - (v->ne[1] )*ggml_element_size(v), - (sinfo.head())*ggml_element_size(v)); + if (n_stream > 1) { + const int64_t kv_size = get_size(); + + assert(n_embd_gqa == v->ne[0]); + assert(kv_size == v->ne[1]); + + // merge the buffer across all streams because the idxs are global + v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream); + } + + return ggml_set_rows(ctx, v, v_cur, v_idxs); } - return ggml_cpy(ctx, v_cur, v_view); + if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) { + // we can merge dims 0, 1 and 2 + v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens); + } else { + // otherwise -> make a copy to get contiguous data + v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens); + } + + // [TAG_V_CACHE_VARIABLE] + if (n_embd_gqa < v->ne[0]) { + v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0); + } + + // in this branch the v_idxs are constructed in such a way that each row is a single head element + ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v)); + + v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur)); + + return ggml_set_rows(ctx, v_view, v_cur, v_idxs); } -ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { +ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); @@ -1183,7 +1119,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con return k_idxs; } -ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { +ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; ggml_tensor * v_idxs; @@ -1199,11 +1135,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, con return v_idxs; } -void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { - if (!supports_set_rows) { - return; - } - +void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); @@ -1219,11 +1151,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba } } -void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { - if (!supports_set_rows) { - return; - } - +void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); @@ -1256,7 +1184,7 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba } } -void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { +void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const { GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); int32_t * data = (int32_t *) dst->data; @@ -1270,7 +1198,7 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { } } -void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { +void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); @@ -1342,7 +1270,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub } } -void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { +void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { const int64_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams"); @@ -1367,7 +1295,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama } } -size_t llama_kv_cache_unified::total_size() const { +size_t llama_kv_cache::total_size() const { size_t size = 0; for (const auto & buf : bufs) { @@ -1377,7 +1305,7 @@ size_t llama_kv_cache_unified::total_size() const { return size; } -size_t llama_kv_cache_unified::size_k_bytes() const { +size_t llama_kv_cache::size_k_bytes() const { size_t size_k_bytes = 0; for (const auto & layer : layers) { @@ -1387,7 +1315,7 @@ size_t llama_kv_cache_unified::size_k_bytes() const { return size_k_bytes; } -size_t llama_kv_cache_unified::size_v_bytes() const { +size_t llama_kv_cache::size_v_bytes() const { size_t size_v_bytes = 0; for (const auto & layer : layers) { @@ -1397,7 +1325,7 @@ size_t llama_kv_cache_unified::size_v_bytes() const { return size_v_bytes; } -ggml_tensor * llama_kv_cache_unified::build_rope_shift( +ggml_tensor * llama_kv_cache::build_rope_shift( const llama_cparams & cparams, ggml_context * ctx, ggml_tensor * cur, @@ -1449,14 +1377,14 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift( class llm_graph_input_k_shift : public llm_graph_input_i { public: - llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {} virtual ~llm_graph_input_k_shift() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * k_shift; // I32 [kv_size*n_stream] - const llama_kv_cache_unified * kv_self; + const llama_kv_cache * kv_self; }; void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { @@ -1467,7 +1395,7 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { } } -ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { +ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { auto * ctx = res->get_ctx(); auto * gf = res->get_gf(); @@ -1509,310 +1437,13 @@ ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, return gf; } -ggml_cgraph * llama_kv_cache_unified::build_graph_defrag( - llm_graph_result * res, - llama_context * lctx, - const defrag_info & dinfo) const { - auto * ctx = res->get_ctx(); - auto * gf = res->get_gf(); - - GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag"); - - const auto & cells = v_cells[0]; - - const auto & ids = dinfo.ids; - - const auto & cparams = lctx->get_cparams(); - -#if 0 - // CPU defrag - // - // TODO: optimizations are possible: - // - multiple threads - // - avoid copying to the host memory when already there - // - // likely not worth the effort, as we have ggml_graph based defrag - // - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - - const uint32_t kv_size = size; - - std::vector buf_k; - std::vector buf_v; - - for (uint32_t il = 0; il < n_layer; ++il) { - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); - - const size_t v_size_el = ggml_type_size(v_l[il]->type); - const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); - - buf_k.resize(k_size); - buf_v.resize(v_size); - - ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); - - // batch move [i, i+nm) to [id, id+nm) - // note: cells can move only to a lower index - for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == n_kv) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < n_kv && ids[i + nm] == id + nm) { - nm++; - } - - // move keys - { - const int64_t os = i*k_size_row; - const int64_t od = id*k_size_row; - - memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); - } - - // move values (note: they are transposed) - { - const int64_t os = i; - const int64_t od = id; - - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); - } - } - - i += nm - 1; - } - - ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); - } -#else - for (uint32_t i = 0; i < ids.size(); ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == ids.size()) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < ids.size() && ids[i + nm] == id + nm) { - nm++; - } - - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - - ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k, - n_embd_k_gqa, nm, - ggml_row_size(layer.k->type, n_embd_k_gqa), - ggml_row_size(layer.k->type, n_embd_k_gqa*i)); - - ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k, - n_embd_k_gqa, nm, - ggml_row_size(layer.k->type, n_embd_k_gqa), - ggml_row_size(layer.k->type, n_embd_k_gqa*id)); - - ggml_tensor * view_v_src; - ggml_tensor * view_v_dst; - - if (cparams.flash_attn) { - // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx, layer.v, - n_embd_v_gqa, nm, - ggml_row_size(layer.v->type, n_embd_v_gqa), - ggml_row_size(layer.v->type, n_embd_v_gqa*i)); - - view_v_dst = ggml_view_2d(ctx, layer.v, - n_embd_v_gqa, nm, - ggml_row_size(layer.v->type, n_embd_v_gqa), - ggml_row_size(layer.v->type, n_embd_v_gqa*id)); - } else { - view_v_src = ggml_view_2d(ctx, layer.v, - nm, n_embd_v_gqa, - ggml_row_size(layer.v->type, cells.size()), - ggml_row_size(layer.v->type, i)); - - view_v_dst = ggml_view_2d(ctx, layer.v, - nm, n_embd_v_gqa, - ggml_row_size(layer.v->type, cells.size()), - ggml_row_size(layer.v->type, id)); - } - - ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); - ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); - } - - i += nm - 1; - } - - //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); -#endif - - return gf; +bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const { + return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1); } -llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const { - GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag"); +void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + GGML_UNUSED(flags); - const auto & cells = v_cells[0]; - - const uint32_t n_layer = layers.size(); - - const uint32_t n_kv = cells.used_max_p1(); - const uint32_t n_used = cells.get_used(); - - assert(n_used <= n_kv); - - //const int64_t t_start = ggml_time_us(); - - // number of cells moved - uint32_t n_moves = 0; - - // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) - // - source view, destination view, copy operation - // - x2 for keys and values - //const uint32_t max_moves = max_nodes()/(6*n_layer); - // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 - const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); - - // determine which KV cells to move where - defrag_info res; - auto & ids = res.ids; - - ids.resize(n_kv, n_kv); - - for (uint32_t i0 = 0; i0 < n_used; ++i0) { - if (!cells.is_empty(i0)) { - ids[i0] = i0; - - continue; - } - - // found a hole - fill it with data from the end of the cache - - uint32_t nh = 1; - - // determine the size of the hole - while (i0 + nh < n_used && cells.is_empty(i0 + nh)) { - nh++; - } - - uint32_t nf = 0; - uint32_t is = n_kv - 1; - - // starting from the end, find nh non-empty cells - for (; is > i0; --is) { - if (cells.is_empty(is) || ids[is] != n_kv) { - continue; - } - - // non-empty cell which is not yet moved - nf++; - - if (nf == nh) { - break; - } - } - - // this can only happen if `n_used` is not accurate, which would be a bug - GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh"); - - nf = 0; - - uint32_t i1 = is; - - // are we moving a continuous block of memory? - bool cont = false; - - // should we stop searching for the next move? - bool stop = false; - - // go back and move the nf cells to the hole - for (; i1 < n_kv; ++i1) { - if (cells.is_empty(i1) || ids[i1] != n_kv) { - if (n_moves == max_moves) { - stop = true; - break; - } - - cont = false; - continue; - } - - // this cell goes to (i0 + nf) - ids[i1] = i0 + nf; - - if (!cont) { - n_moves++; - cont = true; - } - - nf++; - - if (nf == nh) { - break; - } - } - - if (stop || n_moves == max_moves) { - break; - } - - //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); - - i0 += nh - 1; - } - - if (n_moves == 0) { - return {}; - } - - LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); - - LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); - - return res; -} - -bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { - assert(p0 >= 0 && p1 >= 0); - - switch (swa_type) { - case LLAMA_SWA_TYPE_NONE: - { - } break; - case LLAMA_SWA_TYPE_STANDARD: - { - if (p1 - p0 >= (int32_t) n_swa) { - return true; - } - } break; - case LLAMA_SWA_TYPE_CHUNKED: - { - const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; - - if (p0 < pos_chunk_start) { - return true; - } - } break; - } - - return false; -} - -void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { io.write(&n_stream, sizeof(n_stream)); for (uint32_t s = 0; s < n_stream; ++s) { @@ -1863,7 +1494,9 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq } } -void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { +void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + GGML_UNUSED(flags); + GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); uint32_t n_stream_cur; @@ -1897,7 +1530,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i } } -void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const { +void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const { const auto & cells = v_cells[cr.strm]; for (const auto & range : cr.data) { @@ -1925,7 +1558,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ } } -void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const { +void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const { const auto & cells = v_cells[cr.strm]; const uint32_t v_trans = this->v_trans ? 1 : 0; @@ -2020,7 +1653,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ } } -bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) { +bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) { auto & cells = v_cells[strm]; auto & head = v_heads[strm]; @@ -2117,7 +1750,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm return true; } -bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) { +bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) { auto & cells = v_cells[strm]; auto & head = v_heads[strm]; @@ -2254,13 +1887,13 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm } // -// llama_kv_cache_unified_context +// llama_kv_cache_context // -llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {} +llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {} -llama_kv_cache_unified_context::llama_kv_cache_unified_context( - llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { +llama_kv_cache_context::llama_kv_cache_context( + llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { n_kv = kv->get_size(); const uint32_t n_stream = kv->get_n_stream(); @@ -2276,26 +1909,25 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( } } -llama_kv_cache_unified_context::llama_kv_cache_unified_context( - llama_kv_cache_unified * kv, +llama_kv_cache_context::llama_kv_cache_context( + llama_kv_cache * kv, llama_context * lctx, bool do_shift, - defrag_info dinfo, - stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) { - if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) { + stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), sc_info(std::move(sc_info)) { + if (!do_shift && this->sc_info.empty()) { status = LLAMA_MEMORY_STATUS_NO_UPDATE; } } -llama_kv_cache_unified_context::llama_kv_cache_unified_context( - llama_kv_cache_unified * kv, - llama_kv_cache_unified::slot_info_vec_t sinfos, +llama_kv_cache_context::llama_kv_cache_context( + llama_kv_cache * kv, + llama_kv_cache::slot_info_vec_t sinfos, std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) { } -llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; +llama_kv_cache_context::~llama_kv_cache_context() = default; -bool llama_kv_cache_unified_context::next() { +bool llama_kv_cache_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); if (++i_cur >= ubatches.size()) { @@ -2305,86 +1937,81 @@ bool llama_kv_cache_unified_context::next() { return true; } -bool llama_kv_cache_unified_context::apply() { +bool llama_kv_cache_context::apply() { assert(!llama_memory_status_is_fail(status)); // no ubatches -> this is a KV cache update if (ubatches.empty()) { - kv->update(lctx, do_shift, dinfo, sc_info); + kv->update(lctx, do_shift, sc_info); return true; } kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]); - - n_kv = kv->get_n_kv(); + n_kv = kv->get_n_kv(sinfos[i_cur]); return true; } -llama_memory_status llama_kv_cache_unified_context::get_status() const { +llama_memory_status llama_kv_cache_context::get_status() const { return status; } -const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const { +const llama_ubatch & llama_kv_cache_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return ubatches[i_cur]; } -uint32_t llama_kv_cache_unified_context::get_n_kv() const { +uint32_t llama_kv_cache_context::get_n_kv() const { return n_kv; } -bool llama_kv_cache_unified_context::get_supports_set_rows() const { - return kv->get_supports_set_rows(); -} - -ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const { +ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const { return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } -ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const { +ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const { return kv->get_v(ctx, il, n_kv, sinfos[i_cur]); } -ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const { +ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const { return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]); } -ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const { +ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const { return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]); } -ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { +ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { return kv->build_input_k_idxs(ctx, ubatch); } -ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { +ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { return kv->build_input_v_idxs(ctx, ubatch); } -void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const { +void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const { kv->set_input_k_shift(dst); } -void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { +void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]); } -void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { +void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]); } -void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { +void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { kv->set_input_kq_mask(dst, ubatch, causal_attn); } -void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { +void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_pos_bucket(dst, ubatch); } -uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { +uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) { // the FA kernels require padding to avoid extra runtime boundary checks return cparams.flash_attn ? 256u : 32u; } diff --git a/llama/llama.cpp/src/llama-kv-cache-unified.h b/llama/llama.cpp/src/llama-kv-cache.h similarity index 78% rename from llama/llama.cpp/src/llama-kv-cache-unified.h rename to llama/llama.cpp/src/llama-kv-cache.h index 342a6759..85f0663d 100644 --- a/llama/llama.cpp/src/llama-kv-cache-unified.h +++ b/llama/llama.cpp/src/llama-kv-cache.h @@ -14,27 +14,13 @@ struct llama_model; struct llama_context; // -// llama_kv_cache_unified +// llama_kv_cache // -class llama_kv_cache_unified : public llama_memory_i { +class llama_kv_cache : public llama_memory_i { public: static uint32_t get_padding(const llama_cparams & cparams); - // this callback is used to filter out layers that should not be included in the cache - using layer_filter_cb = std::function; - - struct defrag_info { - bool empty() const { - return ids.empty(); - } - - // contains information about which cell moves where: - // - cell i moves to ids[i] - // - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved - std::vector ids; - }; - struct stream_copy_info { bool empty() const { assert(ssrc.size() == sdst.size()); @@ -52,8 +38,8 @@ public: using idx_vec_t = std::vector; // number of streams: ns = s1 - s0 + 1 - llama_seq_id s0; - llama_seq_id s1; + uint32_t s0; + uint32_t s1; std::vector strm; // [ns] std::vector idxs; // [ns] @@ -92,21 +78,22 @@ public: using slot_info_vec_t = std::vector; - llama_kv_cache_unified( - const llama_model & model, - layer_filter_cb && filter, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - bool unified, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_pad, - uint32_t n_swa, - llama_swa_type swa_type); + llama_kv_cache( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse); - ~llama_kv_cache_unified() = default; + ~llama_kv_cache() = default; // // llama_memory_i @@ -134,13 +121,15 @@ public: llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override; + std::map memory_breakdown() const override; + // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; // - // llama_kv_cache_unified specific API + // llama_kv_cache specific API // uint32_t get_size() const; @@ -152,10 +141,7 @@ public: // graph_build API // - uint32_t get_n_kv() const; - - // TODO: temporary - bool get_supports_set_rows() const; + uint32_t get_n_kv(const slot_info & sinfo) const; // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; @@ -173,7 +159,7 @@ public: // return empty vector on failure slot_info_vec_t prepare(const std::vector & ubatches); - bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info); + bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info); // find a slot of kv cells that can hold the ubatch // if cont == true, then the slot must be continuous @@ -228,10 +214,7 @@ private: // env: LLAMA_KV_CACHE_DEBUG int debug = 0; - // env: LLAMA_SET_ROWS (temporary) - // ref: https://github.com/ggml-org/llama.cpp/pull/14285 - bool supports_set_rows = true; - + // this is the SWA type of the cache - not to be confused with the model SWA type const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; std::vector ctxs; @@ -241,7 +224,7 @@ private: // note: this is not part of the KV state and it's only used to speed-up the find_slot() method std::vector v_heads; - std::vector v_cells; + std::vector v_cells; // maps from a sequence id to a stream id std::vector seq_to_stream; @@ -254,9 +237,6 @@ private: // model layer id -> KV cache layer id std::unordered_map map_layer_ids; - // return non-empty vector if cells have been moved - defrag_info defrag_prepare(int32_t n_max_nodes) const; - size_t total_size() const; size_t size_k_bytes() const; @@ -277,11 +257,6 @@ private: llm_graph_result * res, llama_context * lctx) const; - ggml_cgraph * build_graph_defrag( - llm_graph_result * res, - llama_context * lctx, - const defrag_info & dinfo) const; - struct cell_ranges_t { uint32_t strm; @@ -295,35 +270,33 @@ private: bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); }; -class llama_kv_cache_unified_context : public llama_memory_context_i { +class llama_kv_cache_context : public llama_memory_context_i { public: // some shorthands - using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; - using defrag_info = llama_kv_cache_unified::defrag_info; - using stream_copy_info = llama_kv_cache_unified::stream_copy_info; + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + using stream_copy_info = llama_kv_cache::stream_copy_info; // used for errors - llama_kv_cache_unified_context(llama_memory_status status); + llama_kv_cache_context(llama_memory_status status); // used to create a full-cache context - llama_kv_cache_unified_context( - llama_kv_cache_unified * kv); + llama_kv_cache_context( + llama_kv_cache * kv); // used to create an update context - llama_kv_cache_unified_context( - llama_kv_cache_unified * kv, + llama_kv_cache_context( + llama_kv_cache * kv, llama_context * lctx, bool do_shift, - defrag_info dinfo, stream_copy_info sc_info); // used to create a batch procesing context from a batch - llama_kv_cache_unified_context( - llama_kv_cache_unified * kv, + llama_kv_cache_context( + llama_kv_cache * kv, slot_info_vec_t sinfos, std::vector ubatches); - virtual ~llama_kv_cache_unified_context(); + virtual ~llama_kv_cache_context(); // // llama_memory_context_i @@ -336,22 +309,27 @@ public: const llama_ubatch & get_ubatch() const override; // - // llama_kv_cache_unified_context specific API + // llama_kv_cache_context specific API // uint32_t get_n_kv() const; - // TODO: temporary - bool get_supports_set_rows() const; - // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location + // note: the heads in k_cur and v_cur should be layed out contiguously in memory + // - k_cur [n_embd_head_k, n_head_k, n_tokens] + // - k_idxs [n_tokens] + // - v_cur [n_embd_head_v, n_head_v, n_tokens] + // - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; + // create destination indices for each head of the current batch for where it would be written in the KV cache + // the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but + // helps understand the implementation logic of cpy_k and cpy_v ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; @@ -365,7 +343,7 @@ public: private: llama_memory_status status; - llama_kv_cache_unified * kv; + llama_kv_cache * kv; llama_context * lctx; // @@ -374,8 +352,6 @@ private: bool do_shift = false; - defrag_info dinfo; - stream_copy_info sc_info; // diff --git a/llama/llama.cpp/src/llama-kv-cells.h b/llama/llama.cpp/src/llama-kv-cells.h index 0d0dd316..8f6bf014 100644 --- a/llama/llama.cpp/src/llama-kv-cells.h +++ b/llama/llama.cpp/src/llama-kv-cells.h @@ -11,7 +11,7 @@ // meta information about KV cells that can be part of multiple sequences at the same time // TODO: add unit tests -class llama_kv_cells_unified { +class llama_kv_cells { public: void reset() { for (uint32_t i = 0; i < pos.size(); ++i) { @@ -77,30 +77,30 @@ public: } // move cell isrc to idst (used during defrag) - void mv(uint32_t isrc, uint32_t idst) { - assert(isrc < pos.size()); - assert(idst < pos.size()); + //void mv(uint32_t isrc, uint32_t idst) { + // assert(isrc < pos.size()); + // assert(idst < pos.size()); - assert(pos[idst] == -1); - assert(pos[isrc] != -1); + // assert(pos[idst] == -1); + // assert(pos[isrc] != -1); - pos [idst] = pos [isrc]; - shift[idst] = shift[isrc]; - seq [idst] = seq [isrc]; + // pos [idst] = pos [isrc]; + // shift[idst] = shift[isrc]; + // seq [idst] = seq [isrc]; - pos [isrc] = -1; - shift[isrc] = 0; - seq [isrc].reset(); + // pos [isrc] = -1; + // shift[isrc] = 0; + // seq [isrc].reset(); - used.erase (isrc); - used.insert(idst); - } + // used.erase (isrc); + // used.insert(idst); + //} // copy the state of cells [i, i + n) (used for save/restore the state of the cells) - llama_kv_cells_unified cp(uint32_t i, uint32_t n) const { + llama_kv_cells cp(uint32_t i, uint32_t n) const { assert(i + n <= pos.size()); - llama_kv_cells_unified res; + llama_kv_cells res; res.resize(n); @@ -117,8 +117,8 @@ public: } // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1]) - llama_kv_cells_unified cp(const std::vector & idxs) const { - llama_kv_cells_unified res; + llama_kv_cells cp(const std::vector & idxs) const { + llama_kv_cells res; res.resize(idxs.size()); @@ -135,7 +135,7 @@ public: } // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells) - void set(uint32_t i, const llama_kv_cells_unified & other) { + void set(uint32_t i, const llama_kv_cells & other) { assert(i + other.pos.size() <= pos.size()); for (uint32_t j = 0; j < other.pos.size(); ++j) { @@ -165,7 +165,7 @@ public: } // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1]) - void set(const std::vector & idxs, const llama_kv_cells_unified & other) { + void set(const std::vector & idxs, const llama_kv_cells & other) { assert(idxs.size() == other.pos.size()); for (uint32_t j = 0; j < other.pos.size(); ++j) { diff --git a/llama/llama.cpp/src/llama-memory-hybrid.cpp b/llama/llama.cpp/src/llama-memory-hybrid.cpp index e98b4e35..dfb8439e 100644 --- a/llama/llama.cpp/src/llama-memory-hybrid.cpp +++ b/llama/llama.cpp/src/llama-memory-hybrid.cpp @@ -9,32 +9,29 @@ // llama_memory_hybrid::llama_memory_hybrid( - const llama_model & model, - /* attn */ - ggml_type type_k, - ggml_type type_v, - bool v_trans, - uint32_t kv_size, - uint32_t n_pad, - uint32_t n_swa, - llama_swa_type swa_type, - /* recurrent */ - ggml_type type_r, - ggml_type type_s, - uint32_t rs_size, - /* common */ - uint32_t n_seq_max, - bool offload, - bool unified, - /* layer filters */ - layer_filter_cb && filter_attn, - layer_filter_cb && filter_recr) : + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + uint32_t kv_size, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn, + const layer_filter_cb & filter_recr) : hparams(model.hparams), - mem_attn(new llama_kv_cache_unified( + mem_attn(new llama_kv_cache( model, - filter_attn == nullptr ? - [&](int32_t il) { return !hparams.is_recurrent(il); } - : filter_attn, type_k, type_v, v_trans, @@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid( n_seq_max, n_pad, n_swa, - swa_type + swa_type, + filter_attn == nullptr ? + [&](int32_t il) { return !hparams.is_recurrent(il); } + : filter_attn, + nullptr )), mem_recr(new llama_memory_recurrent( model, - filter_recr == nullptr ? - [&](int32_t il) { return hparams.is_recurrent(il); } - : filter_recr, type_r, type_s, offload, rs_size, - n_seq_max + n_seq_max, + filter_recr == nullptr ? + [&](int32_t il) { return hparams.is_recurrent(il); } + : filter_recr )) {} llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { @@ -72,7 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch, false); + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); } if (ubatch.n_tokens == 0) { @@ -165,17 +168,29 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const { return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id)); } -void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - mem_attn->state_write(io, seq_id); - mem_recr->state_write(io, seq_id); +std::map llama_memory_hybrid::memory_breakdown() const { + std::map mb = mem_attn->memory_breakdown(); + for (const auto & buft_size : mem_recr->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; } -void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - mem_attn->state_read(io, seq_id); - mem_recr->state_read(io, seq_id); +void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { + mem_attn->state_write(io, seq_id, flags); + } + mem_recr->state_write(io, seq_id, flags); } -llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const { +void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { + mem_attn->state_read(io, seq_id, flags); + } + mem_recr->state_read(io, seq_id, flags); +} + +llama_kv_cache * llama_memory_hybrid::get_mem_attn() const { return mem_attn.get(); } @@ -206,7 +221,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( std::vector ubatches) : ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal - ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), + ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } @@ -244,8 +259,8 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const { return ubatches[i_next]; } -const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const { - return static_cast(ctx_attn.get()); +const llama_kv_cache_context * llama_memory_hybrid_context::get_attn() const { + return static_cast(ctx_attn.get()); } const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const { diff --git a/llama/llama.cpp/src/llama-memory-hybrid.h b/llama/llama.cpp/src/llama-memory-hybrid.h index c2d56cd5..558cafdf 100644 --- a/llama/llama.cpp/src/llama-memory-hybrid.h +++ b/llama/llama.cpp/src/llama-memory-hybrid.h @@ -2,7 +2,7 @@ #include "llama-batch.h" #include "llama-graph.h" -#include "llama-kv-cache-unified.h" +#include "llama-kv-cache.h" #include "llama-memory.h" #include "llama-memory-recurrent.h" @@ -13,36 +13,32 @@ // llama_memory_hybrid // -// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to +// utilizes instances of llama_memory_recurrent and llama_kv_cache to // support models where each layer may be either attention-based or recurrent class llama_memory_hybrid : public llama_memory_i { public: - - // this callback is used to filter out layers that should not be included in the cache - using layer_filter_cb = std::function; - llama_memory_hybrid( const llama_model & model, /* attn */ - ggml_type type_k, - ggml_type type_v, - bool v_trans, - uint32_t kv_size, - uint32_t n_pad, - uint32_t n_swa, - llama_swa_type swa_type, - /* recurrent */ - ggml_type type_r, - ggml_type type_s, - uint32_t rs_size, - /* common */ - uint32_t n_seq_max, - bool offload, - bool unified, - /* layer filters */ - layer_filter_cb && filter_attn = nullptr, - layer_filter_cb && filter_recr = nullptr); + ggml_type type_k, + ggml_type type_v, + bool v_trans, + uint32_t kv_size, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn = nullptr, + const layer_filter_cb & filter_recr = nullptr); ~llama_memory_hybrid() = default; @@ -72,28 +68,30 @@ public: llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override; + std::map memory_breakdown() const override; + // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; // // llama_memory_hybrid specific API // - llama_kv_cache_unified * get_mem_attn() const; + llama_kv_cache * get_mem_attn() const; llama_memory_recurrent * get_mem_recr() const; private: const llama_hparams & hparams; - const std::unique_ptr mem_attn; + const std::unique_ptr mem_attn; const std::unique_ptr mem_recr; }; class llama_memory_hybrid_context : public llama_memory_context_i { public: - using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; // init failure explicit llama_memory_hybrid_context(llama_memory_status status); @@ -125,7 +123,7 @@ public: // llama_memory_hybrid_context // - const llama_kv_cache_unified_context * get_attn() const; + const llama_kv_cache_context * get_attn() const; const llama_memory_recurrent_context * get_recr() const; private: diff --git a/llama/llama.cpp/src/llama-memory-recurrent.cpp b/llama/llama.cpp/src/llama-memory-recurrent.cpp index c0c2ec08..d67f5a5f 100644 --- a/llama/llama.cpp/src/llama-memory-recurrent.cpp +++ b/llama/llama.cpp/src/llama-memory-recurrent.cpp @@ -16,13 +16,13 @@ // llama_memory_recurrent::llama_memory_recurrent( - const llama_model & model, - layer_filter_cb && filter, - ggml_type type_r, - ggml_type type_s, - bool offload, - uint32_t mem_size, - uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { + const llama_model & model, + ggml_type type_r, + ggml_type type_s, + bool offload, + uint32_t mem_size, + uint32_t n_seq_max, + const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; head = 0; @@ -136,6 +136,7 @@ void llama_memory_recurrent::clear(bool data) { } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -156,7 +157,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (tail_id >= 0) { const auto & cell = cells[tail_id]; // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) { + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n"); return false; } // invalidate tails which will be cleared @@ -167,6 +169,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } else { // seq_id is negative, then the range should include everything or nothing if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n"); return false; } } @@ -359,6 +362,14 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } +std::map llama_memory_recurrent::memory_breakdown() const { + std::map ret; + for (const ggml_backend_buffer_ptr & buf_ptr : bufs) { + ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); + } + return ret; +} + llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { do { balloc.split_reset(); @@ -371,7 +382,9 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch, false); + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); } if (ubatch.n_tokens == 0) { @@ -680,7 +693,9 @@ size_t llama_memory_recurrent::size_s_bytes() const { return size_s_bytes; } -void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { +void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + GGML_UNUSED(flags); + std::vector> cell_ranges; // ranges, from inclusive, to exclusive uint32_t cell_count = 0; @@ -718,7 +733,9 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq state_write_data(io, cell_ranges); } -void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { +void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + GGML_UNUSED(flags); + uint32_t cell_count; io.read_to(&cell_count, sizeof(cell_count)); @@ -844,9 +861,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { if (dest_seq_id != -1) { // single sequence - seq_rm(dest_seq_id, -1, -1); + if (cell_count == 0) { + return true; + } + llama_batch_allocr balloc(hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); diff --git a/llama/llama.cpp/src/llama-memory-recurrent.h b/llama/llama.cpp/src/llama-memory-recurrent.h index 4d094f9a..077c6e3c 100644 --- a/llama/llama.cpp/src/llama-memory-recurrent.h +++ b/llama/llama.cpp/src/llama-memory-recurrent.h @@ -4,6 +4,7 @@ #include "llama-graph.h" #include "llama-memory.h" +#include #include #include @@ -12,21 +13,17 @@ // // TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i -// see the implementation of llama_kv_cache_unified_context_i for an example how to do it +// see the implementation of llama_kv_cache_context_i for an example how to do it class llama_memory_recurrent : public llama_memory_i { public: - - // this callback is used to filter out layers that should not be included in the cache - using layer_filter_cb = std::function; - llama_memory_recurrent( - const llama_model & model, - layer_filter_cb && filter, - ggml_type type_r, - ggml_type type_s, - bool offload, - uint32_t mem_size, - uint32_t n_seq_max); + const llama_model & model, + ggml_type type_r, + ggml_type type_s, + bool offload, + uint32_t mem_size, + uint32_t n_seq_max, + const layer_filter_cb & filter); ~llama_memory_recurrent() = default; @@ -54,6 +51,8 @@ public: llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override; + std::map memory_breakdown() const override; + bool prepare(const std::vector & ubatches); // find a contiguous slot of memory cells and emplace the ubatch there @@ -63,8 +62,8 @@ public: // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) uint32_t size = 0; // total number of cells, shared across all sequences diff --git a/llama/llama.cpp/src/llama-memory.h b/llama/llama.cpp/src/llama-memory.h index e8ba336e..4a157b91 100644 --- a/llama/llama.cpp/src/llama-memory.h +++ b/llama/llama.cpp/src/llama-memory.h @@ -2,7 +2,9 @@ #include "llama.h" +#include #include +#include struct llama_ubatch; @@ -36,8 +38,8 @@ bool llama_memory_status_is_fail(llama_memory_status status); // the interface for managing the memory context during batch processing // this interface is implemented per memory type. see: -// - llama_kv_cache_unified_context -// - llama_kv_cache_unified_iswa_context +// - llama_kv_cache_context +// - llama_kv_cache_iswa_context // ... // // the only method that should mutate the memory and the memory context is llama_memory_i::apply() @@ -64,6 +66,13 @@ using llama_memory_context_ptr = std::unique_ptr; // general concept of LLM memory // the KV cache is a type of LLM memory, but there can be other types struct llama_memory_i { + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + + // this callback is used to specify which layers should reuse memory from other layers + // return negative value to indicate that the layer il should not reuse memory + using layer_reuse_cb = std::function; + virtual ~llama_memory_i() = default; // split the input batch into a set of ubatches and verify that they can fit into the cache @@ -77,7 +86,7 @@ struct llama_memory_i { // simulate full cache, used for allocating worst-case compute buffers virtual llama_memory_context_ptr init_full() = 0; - // prepare for any pending memory updates, such as shifts, defrags, etc. + // prepare for any pending memory updates, such as shifts, copies, etc. // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0; @@ -100,17 +109,14 @@ struct llama_memory_i { virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; + virtual std::map memory_breakdown() const = 0; + // // state write/read // - virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; - virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; + virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0; + virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0; }; using llama_memory_ptr = std::unique_ptr; - -// TODO: temporary until the llama_kv_cache is removed from the public API -struct llama_kv_cache : public llama_memory_i { - virtual ~llama_kv_cache() = default; -}; diff --git a/llama/llama.cpp/src/llama-model-loader.cpp b/llama/llama.cpp/src/llama-model-loader.cpp index 7eab9b68..ee303bd5 100644 --- a/llama/llama.cpp/src/llama-model-loader.cpp +++ b/llama/llama.cpp/src/llama-model-loader.cpp @@ -465,6 +465,7 @@ namespace GGUFMeta { // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required); llama_model_loader::llama_model_loader( @@ -789,6 +790,7 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri } struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags) { + LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, name.c_str()); const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); if (cur == NULL) { diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index 280129e1..74e1d162 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -6,8 +6,8 @@ #include "llama-cparams.h" #include "llama-model-loader.h" -#include "llama-kv-cache-unified.h" -#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache.h" +#include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" #include "llama-memory-recurrent.h" @@ -36,6 +36,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_80M: return "80M"; case LLM_TYPE_109M: return "109M"; case LLM_TYPE_137M: return "137M"; + case LLM_TYPE_140M: return "140M"; case LLM_TYPE_160M: return "160M"; case LLM_TYPE_190M: return "190M"; case LLM_TYPE_220M: return "220M"; @@ -44,12 +45,15 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_270M: return "270M"; case LLM_TYPE_335M: return "335M"; case LLM_TYPE_350M: return "350M"; + case LLM_TYPE_360M: return "360M"; case LLM_TYPE_410M: return "410M"; case LLM_TYPE_450M: return "450M"; case LLM_TYPE_475M: return "475M"; + case LLM_TYPE_558M: return "558M"; case LLM_TYPE_700M: return "700M"; case LLM_TYPE_770M: return "770M"; case LLM_TYPE_780M: return "780M"; + case LLM_TYPE_950M: return "950M"; case LLM_TYPE_0_3B: return "0.3B"; case LLM_TYPE_0_5B: return "0.5B"; case LLM_TYPE_0_6B: return "0.6B"; @@ -62,6 +66,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_1_7B: return "1.7B"; case LLM_TYPE_1_8B: return "1.8B"; case LLM_TYPE_2B: return "2B"; + case LLM_TYPE_2_6B: return "2.6B"; case LLM_TYPE_2_8B: return "2.8B"; case LLM_TYPE_2_9B: return "2.9B"; case LLM_TYPE_3B: return "3B"; @@ -83,9 +88,11 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_32B: return "32B"; case LLM_TYPE_34B: return "34B"; case LLM_TYPE_35B: return "35B"; + case LLM_TYPE_36B: return "36B"; case LLM_TYPE_40B: return "40B"; case LLM_TYPE_65B: return "65B"; case LLM_TYPE_70B: return "70B"; + case LLM_TYPE_120B: return "120B"; case LLM_TYPE_142B: return "142B"; case LLM_TYPE_236B: return "236B"; case LLM_TYPE_290B: return "290B"; @@ -107,6 +114,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; case LLM_TYPE_A13B: return "A13B"; + case LLM_TYPE_8B_A1B: return "8B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; @@ -303,7 +311,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara } // CPU: ACCEL -> GPU host -> CPU extra -> CPU -static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts) { +static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; // add ACCEL buffer types @@ -324,11 +332,13 @@ static buft_list_t make_cpu_buft_list(const std::vector & de // generally, this will be done using the first device in the list // a better approach would be to handle this on a weight-by-weight basis using the offload_op // function of the device to determine if it would benefit from being stored in a host buffer - for (auto * dev : devices) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); - if (buft) { - buft_list.emplace_back(dev, buft); - break; + if (!no_host) { + for (auto * dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + break; + } } } @@ -505,9 +515,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); - std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0.0f); + std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); + std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); + std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -619,19 +633,32 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; - hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick - hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa == 0) { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope + } else { + hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; + hparams.n_swa = 8192; + hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full + } switch (hparams.n_expert) { + case 0: { + // MobileLLM (no MoE) + switch (hparams.n_embd) { + case 2048: type = LLM_TYPE_140M; break; + case 4096: type = LLM_TYPE_360M; break; + case 6144: type = LLM_TYPE_950M; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case 16: type = LLM_TYPE_17B_16E; break; case 128: type = LLM_TYPE_17B_128E; break; default: type = LLM_TYPE_UNKNOWN; } - if (type == LLM_TYPE_17B_128E) { - hparams.use_kq_norm = false; - } + hparams.use_kq_norm = type != LLM_TYPE_17B_128E; } break; case LLM_ARCH_ARCEE: { @@ -655,10 +682,17 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_MINICPM: { + // Backward-compatible defaults for older MiniCPM GGUFs + hparams.f_embedding_scale = 12.0f; + hparams.f_residual_scale = 1.4f / sqrtf(float(hparams.n_layer)); + hparams.f_logit_scale = hparams.n_embd ? (256.0f / float(hparams.n_embd)) : 1.0f; + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + + // Optional KV reads, override defaults if present in newer GGUF exports + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /*required=*/false); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /*required=*/false); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /*required=*/false); // MiniCPM uses rope by default, unlike Granite which uses it as a switch hparams.rope_finetuned = true; @@ -682,7 +716,30 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GROK: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // defaults for old GGUFs + hparams.yarn_beta_fast = 8.0f; + hparams.f_logit_scale = 0.5773502691896257f; + hparams.f_embedding_scale = 78.38367176906169f; + hparams.f_attn_out_scale = 0.08838834764831845f; + hparams.f_attn_logit_softcapping = 30.0f; + hparams.f_router_logit_softcapping = 30.0f; + // no final_logit_softcapping in grok-1 + hparams.f_final_logit_softcapping = 0.0f; + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); + ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); switch (hparams.n_layer) { case 64: type = LLM_TYPE_314B; break; @@ -770,6 +827,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_JINA_BERT_V3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + + switch (hparams.n_layer) { + case 24: + type = LLM_TYPE_558M; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: { @@ -898,6 +967,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.causal_attn = false; } break; + case LLM_ARCH_LLADA_MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // diffusion language model uses non-causal attention + hparams.causal_attn = false; + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_A1_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_QWEN2MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); @@ -1010,7 +1091,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; default: type = LLM_TYPE_UNKNOWN; - } + } + + // Load attention parameters + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); } break; case LLM_ARCH_GPT2: { @@ -1095,6 +1180,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { + case 18: type = LLM_TYPE_270M; break; case 26: type = LLM_TYPE_1B; break; case 34: type = LLM_TYPE_4B; break; case 48: type = LLM_TYPE_12B; break; @@ -1112,6 +1198,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(5); + hparams.n_layer_kv_from_start = 20; hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; hparams.f_attention_scale = 1.0f; @@ -1125,6 +1212,35 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GEMMA_EMBEDDING: + { + hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; + hparams.set_swa_pattern(6); + + hparams.causal_attn = false; // embeddings do not use causal attention + hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_scale_train_swa = 1.0f; + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + //applied only if model converted with --sentence-transformers-dense-modules + ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); + ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); + + GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); + GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_0_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + + } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1278,6 +1394,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(4); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + switch (hparams.n_layer) { case 16: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_7B; break; @@ -1286,6 +1410,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_SEED_OSS: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 64: type = LLM_TYPE_36B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_OLMOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1463,12 +1595,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { // Expert gating function (GLM-4.5 uses sigmoid) ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; } // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + switch (hparams.n_layer) { case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) @@ -1494,6 +1629,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.dec_start_token_id = dec_start_token_id; } + hparams.dec_n_layer = hparams.n_layer; + ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false); + switch (hparams.n_layer) { case 6: type = LLM_TYPE_60M; break; // t5-small case 8: type = LLM_TYPE_80M; break; // flan-t5-small @@ -1542,6 +1680,27 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_NEMOTRON_H: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // A layer is recurrent IFF the n_head_kv value is set to 0 and + // the n_ff value is set to 0 + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 56: type = LLM_TYPE_9B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_EXAONE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1848,7 +2007,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(2); - // TODO: switch (hparams.n_layer) + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_20B; break; + case 36: type = LLM_TYPE_120B; break; + default: type = LLM_TYPE_UNKNOWN; + } } break; case LLM_ARCH_LFM2: { @@ -1857,13 +2020,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { for (uint32_t il = 0; il < hparams.n_layer; ++il) { hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; } - switch (hparams.n_embd) { - case 1024: type = LLM_TYPE_350M; break; - case 1536: type = LLM_TYPE_700M; break; - case 2048: type = LLM_TYPE_1_2B; break; - default: type = LLM_TYPE_UNKNOWN; + hparams.n_layer_dense_lead = hparams.n_layer; + switch (hparams.n_ff()) { + case 4608: type = LLM_TYPE_350M; break; + case 6912: type = LLM_TYPE_700M; break; + case 8192: type = LLM_TYPE_1_2B; break; + case 10752: type = LLM_TYPE_2_6B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_LFM2MOE: + { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + } + + type = LLM_TYPE_8B_A1B; + } break; case LLM_ARCH_SMALLTHINKER: { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); @@ -1887,6 +2066,32 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GROVEMOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, hparams.n_ff_chexp); + ml.get_key(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); + ml.get_key(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_APERTUS: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -1920,7 +2125,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); // build a list of buffer types for the CPU and GPU devices - pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts); + pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); for (auto * dev : devices) { buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); // add CPU buffer types as a fallback @@ -2303,6 +2508,40 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; + case LLM_ARCH_LLADA_MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for llada-moe"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for llada-moe"); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } + } break; case LLM_ARCH_LLAMA4: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -2316,9 +2555,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0"); for (int i = 0; i < n_layer; ++i) { - bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0; + bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0; auto & layer = layers[i]; @@ -2479,6 +2717,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -2493,12 +2732,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + if (!layer.ffn_post_norm) { + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } } } break; case LLM_ARCH_DBRX: @@ -2627,6 +2873,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_JINA_BERT_V3: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); @@ -2662,24 +2909,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); } else { - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - if (arch == LLM_ARCH_BERT || arch == LLM_ARCH_NOMIC_BERT_MOE) { - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - } else { + if (arch == LLM_ARCH_NOMIC_BERT) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); } } @@ -3005,6 +3250,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + // output rerank head + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -3207,17 +3455,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_PLAMO2: { + // mamba parameters const uint32_t d_conv = hparams.ssm_d_conv; const uint32_t d_state = hparams.ssm_d_state; const uint32_t num_heads = hparams.ssm_dt_rank; const uint32_t intermediate_size = hparams.ssm_d_inner; - const uint32_t head_dim = intermediate_size / num_heads; - const uint32_t qk_dim = head_dim; - const uint32_t v_dim = head_dim; - const int64_t num_attention_heads = hparams.n_head(); - const int64_t q_num_heads = num_attention_heads; const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + // attention parameters + const uint32_t qk_dim = hparams.n_embd_head_k; + const uint32_t v_dim = hparams.n_embd_head_v; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -3251,6 +3499,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); } else { + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t q_num_heads = num_attention_heads; const int64_t num_key_value_heads = hparams.n_head_kv(i); const int64_t k_num_heads = num_key_value_heads; const int64_t v_num_heads = num_key_value_heads; @@ -3259,8 +3509,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t v_proj_dim = v_num_heads * v_dim; layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); } @@ -3447,6 +3697,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_GEMMA3: + case LLM_ARCH_GEMMA_EMBEDDING: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3459,6 +3710,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + // Dense linear weights + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -3976,6 +4232,43 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_SEED_OSS: + { + const uint32_t head_dim = hparams.n_embd_head_k; + const int64_t n_qo_dim = n_head * head_dim; + const int64_t n_kv_dim = n_head_kv * head_dim; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0); + + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + } + } break; + case LLM_ARCH_OLMOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4319,6 +4612,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + // n_layer: number of encoder_layers + // dec_n_layer: number of decoder_layers + const int dec_n_layer = hparams.dec_n_layer; + if (dec_n_layer > n_layer) { + layers.resize(dec_n_layer); + } + + // load encoder layers for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4334,6 +4635,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // load decoder layers + for (int i = 0; i < dec_n_layer; ++i) { + auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); @@ -4589,11 +4895,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); } } } @@ -4635,6 +4943,75 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_NEMOTRON_H: + { + // mamba2 Mixer SSM params + // NOTE: int64_t for tensor dimensions + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_ssm_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // all blocks use the attn norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.is_recurrent(i)) { + // ssm layers + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else if (hparams.n_ff(i) == 0) { + // attention layers (with optional bias) + const int64_t n_head_i = hparams.n_head(i); + const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + } else { + // mlp layers + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_EXAONE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5510,17 +5887,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - // ffn is same for transformer and conv layers + + const bool is_moe_layer = i >= static_cast(hparams.n_layer_dense_lead); + + // ffn/moe is same for transformer and conv layers layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + if (is_moe_layer) { + GGML_ASSERT(n_expert && n_expert_used); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } else { // dense + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -5578,6 +5973,95 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); } } break; + case LLM_ARCH_GROVEMOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for GROVEMOE"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for GROVEMOE"); + GGML_ASSERT(hparams.n_group_experts > 0 && "n_group_experts must be > 0 for GROVEMOE"); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_chexp = hparams.n_ff_chexp ? hparams.n_ff_chexp : n_embd_head_k; + const int64_t n_chunk_expert = n_expert / hparams.n_group_experts; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_gate_chexps = create_tensor(tn(LLM_TENSOR_FFN_GATE_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); + layer.ffn_down_chexps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_CHEXPS, "weight", i), {n_ff_chexp, n_embd, n_chunk_expert}, 0); + layer.ffn_up_chexps = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); + } + } break; + case LLM_ARCH_APERTUS: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + + // Q and K layernorms for Apertus + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -5746,6 +6230,14 @@ size_t llama_model::n_devices() const { return devices.size(); } +std::map llama_model::memory_breakdown() const { + std::map ret; + for (const ggml_backend_buffer_ptr & buf_ptr : pimpl->bufs) { + ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); + } + return ret; +} + uint64_t llama_model::n_elements() const { return pimpl->n_elements; } @@ -5832,7 +6324,8 @@ void llama_model::print_info() const { arch == LLM_ARCH_JAMBA || arch == LLM_ARCH_FALCON_H1 || arch == LLM_ARCH_PLAMO2 || - arch == LLM_ARCH_GRANITE_HYBRID) { + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); @@ -5903,11 +6396,18 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } - if (arch == LLM_ARCH_SMALLTHINKER) { + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } + if (arch == LLM_ARCH_GROVEMOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); + LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); + LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + } + vocab.print_info(); } @@ -6023,7 +6523,7 @@ struct llm_build_llama : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -6085,9 +6585,17 @@ struct llm_build_llama : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); + if (hparams.use_kq_norm) { + // Llama4TextL2Norm + Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); + Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + } + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -6183,7 +6691,7 @@ struct llm_build_llama_iswa : public llm_graph_context { ggml_tensor * inp_attn_scale = nullptr; inp_attn_scale = build_inp_attn_scale(); - auto * inp_attn = build_attn_inp_kv_unified_iswa(); + auto * inp_attn = build_attn_inp_kv_iswa(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -6192,7 +6700,8 @@ struct llm_build_llama_iswa : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0; + const bool use_rope = hparams.n_no_rope_layer_step > 0 && + (il + 1) % hparams.n_no_rope_layer_step != 0; // norm cur = build_norm(inpL, @@ -6261,7 +6770,7 @@ struct llm_build_llama_iswa : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -6362,7 +6871,7 @@ struct llm_build_deci : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -6438,7 +6947,7 @@ struct llm_build_deci : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -6518,7 +7027,7 @@ struct llm_build_baichuan : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr; - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -6570,7 +7079,7 @@ struct llm_build_baichuan : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -6640,7 +7149,7 @@ struct llm_build_xverse : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -6685,7 +7194,7 @@ struct llm_build_xverse : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -6754,7 +7263,7 @@ struct llm_build_falcon : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -6785,9 +7294,7 @@ struct llm_build_falcon : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -6808,7 +7315,7 @@ struct llm_build_falcon : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -6872,13 +7379,10 @@ struct llm_build_grok : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); - // multiply by embedding_multiplier_scale of 78.38367176906169 - inpL = ggml_scale(ctx0, inpL, 78.38367176906169f); - // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -6938,7 +7442,7 @@ struct llm_build_grok : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -6946,26 +7450,22 @@ struct llm_build_grok : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // Grok - // if attn_out_norm is present then apply it before adding the input - if (model.layers[il].attn_out_norm) { - cur = build_norm(cur, - model.layers[il].attn_out_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_out_norm", il); - } + cur = build_norm(cur, + model.layers[il].attn_out_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_out_norm", il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); // feed-forward network - // MoE branch cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); - cur = build_moe_ffn(cur, + // MoE branch + ggml_tensor * moe_out = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -6976,18 +7476,28 @@ struct llm_build_grok : public llm_graph_context { false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); - cb(cur, "ffn_moe_out", il); + cb(moe_out, "ffn_moe_out", il); - // Grok - // if layer_out_norm is present then apply it before adding the input - // Idea: maybe ffn_out_norm is a better name - if (model.layers[il].layer_out_norm) { - cur = build_norm(cur, - model.layers[il].layer_out_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "layer_out_norm", il); + if (model.layers[il].ffn_up) { + ggml_tensor * ffn_out = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(ffn_out, "ffn_out", il); + + cur = ggml_scale(ctx0, ggml_add(ctx0, ffn_out, moe_out), std::sqrt(2) / 2); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; } + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_post_norm", il); + cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -7010,10 +7520,14 @@ struct llm_build_grok : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // Grok - // multiply logits by output_multiplier_scale of 0.5773502691896257 + cur = ggml_scale(ctx0, cur, hparams.f_logit_scale); - cur = ggml_scale(ctx0, cur, 0.5773502691896257f); + // final logit soft-capping + if (hparams.f_final_logit_softcapping) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + } cb(cur, "result_output", -1); res->t_logits = cur; @@ -7038,7 +7552,7 @@ struct llm_build_dbrx : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -7065,9 +7579,7 @@ struct llm_build_dbrx : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -7087,7 +7599,7 @@ struct llm_build_dbrx : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -7162,7 +7674,7 @@ struct llm_build_starcoder : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); cb(pos, "pos_embd", -1); @@ -7187,13 +7699,9 @@ struct llm_build_starcoder : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -7201,7 +7709,7 @@ struct llm_build_starcoder : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -7267,7 +7775,7 @@ struct llm_build_refact : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -7300,7 +7808,7 @@ struct llm_build_refact : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -7409,35 +7917,43 @@ struct llm_build_bert : public llm_graph_context { cb(cur, "bqkv", il); } - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } if (model.layers[il].attn_q_norm) { + Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); } if (model.layers[il].attn_k_norm) { + Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - // RoPE - if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { + if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) { Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -7457,7 +7973,7 @@ struct llm_build_bert : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } @@ -7496,7 +8012,7 @@ struct llm_build_bert : public llm_graph_context { 0.0f, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); - } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { + } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) { cur = build_ffn(cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, @@ -7579,9 +8095,7 @@ struct llm_build_neo_bert : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); // RoPE Qcur = ggml_rope_ext( @@ -7602,7 +8116,7 @@ struct llm_build_neo_bert : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, nullptr, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } @@ -7663,7 +8177,7 @@ struct llm_build_bloom : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); inpL = build_norm(inpL, model.tok_norm, @@ -7688,13 +8202,9 @@ struct llm_build_bloom : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -7702,7 +8212,7 @@ struct llm_build_bloom : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -7770,7 +8280,7 @@ struct llm_build_mpt : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); if (model.pos_embd) { // inp_pos - contains the positions @@ -7810,46 +8320,36 @@ struct llm_build_mpt : public llm_graph_context { cb(cur, "wqkv_clamped", il); } - ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); // Q/K Layernorm if (model.layers[il].attn_q_norm) { + Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens); + Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); - cb(Qcur, "Qcur", il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); - cb(Kcur, "Kcur", il); - } else { - Qcur = ggml_cont(ctx0, Qcur); - cb(Qcur, "Qcur", il); - Kcur = ggml_cont(ctx0, Kcur); - cb(Kcur, "Kcur", il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -7919,7 +8419,7 @@ struct llm_build_stablelm : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -7995,7 +8495,7 @@ struct llm_build_stablelm : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8071,7 +8571,7 @@ struct llm_build_qwen : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -8091,11 +8591,9 @@ struct llm_build_qwen : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd))); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 2*sizeof(float)*(n_embd)); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -8116,7 +8614,7 @@ struct llm_build_qwen : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8186,7 +8684,7 @@ struct llm_build_qwen2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -8236,7 +8734,7 @@ struct llm_build_qwen2 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8350,8 +8848,9 @@ struct llm_build_dream : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, - nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8450,8 +8949,9 @@ struct llm_build_llada : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, - 1.0f / sqrtf(float(n_embd_head)), il); + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8511,7 +9011,7 @@ struct llm_build_qwen2vl : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); @@ -8564,7 +9064,7 @@ struct llm_build_qwen2vl : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8632,7 +9132,7 @@ struct llm_build_qwen2moe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -8691,7 +9191,7 @@ struct llm_build_qwen2moe : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8791,7 +9291,7 @@ struct llm_build_qwen3 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -8844,7 +9344,7 @@ struct llm_build_qwen3 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8912,7 +9412,7 @@ struct llm_build_qwen3moe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -8965,7 +9465,7 @@ struct llm_build_qwen3moe : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9042,7 +9542,7 @@ struct llm_build_phi2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -9068,21 +9568,17 @@ struct llm_build_phi2 : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9105,7 +9601,7 @@ struct llm_build_phi2 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9171,13 +9667,13 @@ struct llm_build_phi3 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - using inp_attn_type = std::conditional_t; + using inp_attn_type = std::conditional_t; inp_attn_type * inp_attn = nullptr; if constexpr (iswa) { - inp_attn = build_attn_inp_kv_unified_iswa(); + inp_attn = build_attn_inp_kv_iswa(); } else { - inp_attn = build_attn_inp_kv_unified(); + inp_attn = build_attn_inp_kv(); } ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -9206,21 +9702,17 @@ struct llm_build_phi3 : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9242,7 +9734,7 @@ struct llm_build_phi3 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9329,7 +9821,7 @@ struct llm_build_plamo : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -9376,7 +9868,7 @@ struct llm_build_plamo : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9445,7 +9937,7 @@ struct llm_build_gpt2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); cb(pos, "pos_embd", -1); @@ -9470,21 +9962,17 @@ struct llm_build_gpt2 : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9555,7 +10043,7 @@ struct llm_build_codeshell : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -9576,9 +10064,7 @@ struct llm_build_codeshell : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -9598,7 +10084,7 @@ struct llm_build_codeshell : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9668,7 +10154,7 @@ struct llm_build_orion : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -9727,7 +10213,7 @@ struct llm_build_orion : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9795,7 +10281,7 @@ struct llm_build_internlm2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -9854,7 +10340,7 @@ struct llm_build_internlm2 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9931,7 +10417,7 @@ struct llm_build_minicpm3 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -10042,7 +10528,7 @@ struct llm_build_minicpm3 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); + q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -10126,7 +10612,7 @@ struct llm_build_gemma : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -10172,7 +10658,7 @@ struct llm_build_gemma : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -10242,7 +10728,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified_iswa(); + auto * inp_attn = build_attn_inp_kv_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -10287,7 +10773,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -10376,7 +10862,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context { ggml_tensor * inp_pos = build_inp_pos(); // TODO: is causal == true correct? might need some changes - auto * inp_attn = build_attn_inp_kv_unified_iswa(); + auto * inp_attn = build_attn_inp_kv_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -10429,7 +10915,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -10501,7 +10987,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { const int64_t n_embd_altup; const int64_t n_altup; const int i_altup_act; - const int n_layer_kv = 20; // number of layers having KV [KV_REUSE] const int n_layer_sparsity = 10; // number of layers using activation sparsity const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95) @@ -10527,7 +11012,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * inp_pos = build_inp_pos(); // TODO: is causal == true correct? might need some changes - auto * inp_attn = build_attn_inp_kv_unified_iswa(); + auto * inp_attn = build_attn_inp_kv_iswa(); // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer] ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs()); @@ -10551,8 +11036,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { // this block is made to be closely resemble Gemma3p5DecoderLayer on python code - const bool has_kv = (il < n_layer_kv); - const float freq_base_l = model.get_rope_freq_base (cparams, il); const float freq_scale_l = model.get_rope_freq_scale(cparams, il); @@ -10572,7 +11055,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens] // self-attention - if (has_kv) { + if (hparams.has_kv(il)) { // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); @@ -10610,9 +11093,9 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); } else { - // no KV layers + // reuse KV cache of earlier layers ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); @@ -10628,7 +11111,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); } cur = build_norm(cur, @@ -10906,8 +11389,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens] all_coefs = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f); // + 1.0 cb(all_coefs, "all_coefs", il); - all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup] - all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] + all_coefs = ggml_transpose(ctx0, all_coefs); // [n_tokens, n_altup] + all_coefs = ggml_cont_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1); ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup] @@ -10918,6 +11401,137 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { } }; +struct llm_build_gemma_embedding_iswa : public llm_graph_context { + llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + if (ubatch.token) { + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + } + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA] + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315 + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); + cb(sa_out, "sa_out", il); + + cur = build_norm(sa_out, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", -1); + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + // TODO: move up next to build_starcoder struct llm_build_starcoder2 : public llm_graph_context { llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { @@ -10934,7 +11548,7 @@ struct llm_build_starcoder2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -10993,7 +11607,7 @@ struct llm_build_starcoder2 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -11299,6 +11913,7 @@ struct llm_graph_context_mamba : public llm_graph_context { // TODO: skip computing output earlier for unused tokens y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + cb(y, "mamba2_y_add_d", il); y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // grouped RMS norm @@ -11420,7 +12035,9 @@ struct llm_build_jamba : public llm_graph_context_mamba { cb(Vcur, "Vcur", il); // No RoPE :) - cur = build_attn(inp_hybrid->get_attn(), model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); + cur = build_attn(inp_hybrid->get_attn(), + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -11503,7 +12120,7 @@ struct llm_build_command_r : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -11578,7 +12195,7 @@ struct llm_build_command_r : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -11650,7 +12267,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified_iswa(); + auto * inp_attn = build_attn_inp_kv_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -11713,7 +12330,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -11785,7 +12402,7 @@ struct llm_build_olmo : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -11844,7 +12461,7 @@ struct llm_build_olmo : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, nullptr, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -11898,6 +12515,7 @@ struct llm_build_olmo : public llm_graph_context { } }; +template struct llm_build_olmo2 : public llm_graph_context { llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -11913,7 +12531,14 @@ struct llm_build_olmo2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -11946,17 +12571,36 @@ struct llm_build_olmo2 : public llm_graph_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( + const bool is_swa = hparams.is_swa(il); + + if (is_swa) { + // For sliding window layers, Olmo3 use regular rope with no yarn rope scaling. + // This is achieved here by setting freq_scale and attn_factor to 1. + // We also set ext_factor to 0 to avoid a few unnecessary computations. + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_ext( + Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -11964,7 +12608,7 @@ struct llm_build_olmo2 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -12042,7 +12686,7 @@ struct llm_build_olmoe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -12097,7 +12741,133 @@ struct llm_build_olmoe : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_llada_moe : public llm_graph_context { + llm_build_llada_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -12168,7 +12938,7 @@ struct llm_build_openelm : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -12230,7 +13000,7 @@ struct llm_build_openelm : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -12299,7 +13069,7 @@ struct llm_build_gptneox : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -12320,9 +13090,7 @@ struct llm_build_gptneox : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -12342,7 +13110,7 @@ struct llm_build_gptneox : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -12445,7 +13213,7 @@ struct llm_build_arctic : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -12492,7 +13260,7 @@ struct llm_build_arctic : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -12583,7 +13351,7 @@ struct llm_build_deepseek : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -12647,7 +13415,7 @@ struct llm_build_deepseek : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -12760,7 +13528,7 @@ struct llm_build_deepseek2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -12875,7 +13643,7 @@ struct llm_build_deepseek2 : public llm_graph_context { // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il); } else { ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); cb(kv, "kv", il); @@ -12909,7 +13677,7 @@ struct llm_build_deepseek2 : public llm_graph_context { // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } } @@ -13007,7 +13775,7 @@ struct llm_build_bitnet : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -13076,7 +13844,7 @@ struct llm_build_bitnet : public llm_graph_context { cur = build_attn(inp_attn, NULL, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cur = build_norm(cur, model.layers[il].attn_sub_norm, NULL, @@ -13199,7 +13967,7 @@ struct llm_build_t5_enc : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo_enc, nullptr, - Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -13271,12 +14039,14 @@ struct llm_build_t5_dec : public llm_graph_context { const int64_t n_outputs_enc = embd_enc->ne[1]; - auto * inp_attn_self = build_attn_inp_kv_unified(); + auto * inp_attn_self = build_attn_inp_kv(); auto * inp_attn_cross = build_attn_inp_cross(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + const int64_t dec_n_layer = hparams.dec_n_layer; + + for (int il = 0; il < dec_n_layer; ++il) { ggml_tensor * inpSA = inpL; // norm @@ -13305,7 +14075,7 @@ struct llm_build_t5_dec : public llm_graph_context { cur = build_attn(inp_attn_self, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -13337,7 +14107,7 @@ struct llm_build_t5_dec : public llm_graph_context { cur = build_attn(inp_attn_cross, model.layers[il].wo_cross, nullptr, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); //ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); @@ -13367,7 +14137,7 @@ struct llm_build_t5_dec : public llm_graph_context { //cb(cur, "kqv_out", il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == dec_n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); } @@ -13388,8 +14158,8 @@ struct llm_build_t5_dec : public llm_graph_context { model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, - model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, - model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate ? LLM_FFN_PAR : LLM_FFN_SEQ, il); cb(cur, "ffn_out", il); } @@ -13436,7 +14206,7 @@ struct llm_build_jais : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -13455,21 +14225,17 @@ struct llm_build_jais : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*cur->nb[0]*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/float(n_embd_head), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -13534,7 +14300,7 @@ struct llm_build_chatglm : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -13568,6 +14334,7 @@ struct llm_build_chatglm : public llm_graph_context { } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); @@ -13577,11 +14344,9 @@ struct llm_build_chatglm : public llm_graph_context { } Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -13601,7 +14366,7 @@ struct llm_build_chatglm : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -13667,7 +14432,7 @@ struct llm_build_glm4 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -13702,6 +14467,7 @@ struct llm_build_glm4 : public llm_graph_context { } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); @@ -13711,11 +14477,9 @@ struct llm_build_glm4 : public llm_graph_context { } Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -13734,7 +14498,7 @@ struct llm_build_glm4 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -13817,7 +14581,7 @@ struct llm_build_glm4_moe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -13883,7 +14647,7 @@ struct llm_build_glm4_moe : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_transformer_layers - 1 && inp_out_ids) { @@ -13977,7 +14741,7 @@ struct llm_build_nemotron : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -14037,7 +14801,7 @@ struct llm_build_nemotron : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -14091,6 +14855,139 @@ struct llm_build_nemotron : public llm_graph_context { } }; +struct llm_build_nemotron_h : public llm_graph_context_mamba { + llm_build_nemotron_h( + const llama_model & model, + const llm_graph_params & params) : + llm_graph_context_mamba(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + ggml_build_forward_expand(gf, inpL); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + if (hparams.is_recurrent(il)) { + // ssm layer // + cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); + } else if (hparams.n_ff(il) == 0) { + // attention layer // + cur = build_attention_layer(cur, inp->get_attn(), model, n_embd_head, il); + } else { + cur = build_ffn_layer(cur, model, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // add residual + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "nemotron_h_block_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + + ggml_tensor * build_attention_layer( + ggml_tensor * cur, + llm_graph_input_attn_kv * inp_attn, + const llama_model & model, + const int64_t n_embd_head, + const int il) { + + // compute Q and K and (optionally) RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + return cur; + } + + ggml_tensor * build_ffn_layer( + ggml_tensor * cur, + const llama_model & model, + const int il) { + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + return cur; + } +}; + struct llm_build_exaone : public llm_graph_context { llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -14106,7 +15003,7 @@ struct llm_build_exaone : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -14168,7 +15065,7 @@ struct llm_build_exaone : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -14238,13 +15135,13 @@ struct llm_build_exaone4 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - using inp_attn_type = std::conditional_t; + using inp_attn_type = std::conditional_t; inp_attn_type * inp_attn = nullptr; if constexpr (iswa) { - inp_attn = build_attn_inp_kv_unified_iswa(); + inp_attn = build_attn_inp_kv_iswa(); } else { - inp_attn = build_attn_inp_kv_unified(); + inp_attn = build_attn_inp_kv(); } ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -14299,7 +15196,7 @@ struct llm_build_exaone4 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } @@ -15127,7 +16024,7 @@ struct llm_build_granite : public llm_graph_context { inp_pos = build_inp_pos(); } - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -15178,12 +16075,12 @@ struct llm_build_granite : public llm_graph_context { } ggml_tensor * build_attention_layer( - ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_attn_kv_unified * inp_attn, - const llama_model & model, - const int64_t n_embd_head, - const int il) { + ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv * inp_attn, + const llama_model & model, + const int64_t n_embd_head, + const int il) { // compute Q and K and (optionally) RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -15234,7 +16131,7 @@ struct llm_build_granite : public llm_graph_context { const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; } @@ -15397,12 +16294,12 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { } ggml_tensor * build_attention_layer( - ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_attn_kv_unified * inp_attn, - const llama_model & model, - const int64_t n_embd_head, - const int il) { + ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv * inp_attn, + const llama_model & model, + const int64_t n_embd_head, + const int il) { // compute Q and K and (optionally) RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -15453,16 +16350,16 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; } ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - ggml_tensor * inpSA, - const llama_model & model, - const int il) { + ggml_tensor * cur, + ggml_tensor * inpSA, + const llama_model & model, + const int il) { // For Granite architectures - scale residual if (hparams.f_residual_scale) { @@ -15553,7 +16450,7 @@ struct llm_build_solar : public llm_graph_context { struct ggml_tensor * inp_pos = build_inp_pos(); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -15640,7 +16537,7 @@ struct llm_build_solar : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -15718,7 +16615,7 @@ struct llm_build_chameleon : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -15797,7 +16694,7 @@ struct llm_build_chameleon : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, nullptr, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -16049,7 +16946,7 @@ struct llm_build_plm : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -16153,7 +17050,7 @@ struct llm_build_plm : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); + q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -16214,7 +17111,7 @@ struct llm_build_bailingmoe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -16276,7 +17173,7 @@ struct llm_build_bailingmoe : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -16363,7 +17260,7 @@ struct llm_build_dots1 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -16416,7 +17313,7 @@ struct llm_build_dots1 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -16513,7 +17410,7 @@ struct llm_build_ernie4_5 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -16571,7 +17468,7 @@ struct llm_build_ernie4_5 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -16643,7 +17540,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -16704,7 +17601,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } @@ -16857,7 +17754,7 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba { ggml_tensor * attn_out = build_attn(inp->get_attn(), model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(attn_out, "attn_out", il); cur = build_norm(inpL, @@ -17017,7 +17914,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba { private: ggml_tensor * build_plamo2_attn_layer( - llm_graph_input_attn_kv_unified * inp, + llm_graph_input_attn_kv * inp, ggml_tensor * inp_pos, ggml_tensor * cur, const llama_model & model, @@ -17033,22 +17930,21 @@ private: const int64_t n_embd_head_q = hparams.n_embd_head_k; const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_head_v = hparams.n_embd_head_v; + int32_t n_head = hparams.n_head(il); int32_t n_head_kv = hparams.n_head_kv(il); const int64_t q_offset = 0; const int64_t k_offset = n_embd_head_q * n_head; const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; - ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv))); + ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv, n_tokens, n_embd_head_v * sizeof(float), qkv->nb[1], v_offset * ggml_element_size(qkv)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens); - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -17067,7 +17963,9 @@ private: ext_factor, attn_factor, beta_fast, beta_slow ); - cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il); + cur = build_attn(inp, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il); } cb(cur, "attn_out", il); @@ -17114,15 +18012,13 @@ private: cb(zx, "mamba_in_proj", il); // {8192, 5, 1, 1} -> {8192, 1, 5, 1} zx = ggml_permute(ctx0, zx, 0, 2, 1, 3); - zx = ggml_cont(ctx0, zx); - zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); + zx = ggml_cont_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); cb(zx, "mamba_in_proj_out", il); // split into z and x // => {head_dim * n_heads, n_seq_tokens, n_seqs} ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx)); - x = ggml_cont(ctx0, x); - x = ggml_reshape_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); + x = ggml_cont_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); // x = ggml_permute(ctx0, x, 0, 2, 1, 3); cb(x, "mamba_x_split", il); @@ -17252,7 +18148,7 @@ struct llm_build_arcee : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -17316,7 +18212,7 @@ struct llm_build_arcee : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -17387,7 +18283,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); @@ -17461,7 +18357,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -17548,7 +18444,7 @@ struct llm_build_hunyuan_dense : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); @@ -17621,7 +18517,7 @@ struct llm_build_hunyuan_dense : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -17686,7 +18582,7 @@ struct llm_build_smollm3 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -17751,7 +18647,7 @@ struct llm_build_smollm3 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -17818,7 +18714,7 @@ struct llm_build_openai_moe_iswa : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified_iswa(); + auto * inp_attn = build_attn_inp_kv_iswa(); for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -17873,9 +18769,9 @@ struct llm_build_openai_moe_iswa : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn_with_sinks(inp_attn, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].attn_sinks, 1.0f/sqrtf(float(n_rot)), il); + Qcur, Kcur, Vcur, nullptr, model.layers[il].attn_sinks, nullptr, 1.0f/sqrtf(float(n_rot)), il); cb(cur, "attn_out", il); } @@ -17951,6 +18847,8 @@ struct llm_build_lfm2 : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + const bool is_moe_layer = il >= static_cast(hparams.n_layer_dense_lead); + auto * prev_cur = cur; cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "model.layers.{}.operator_norm", il); @@ -17965,15 +18863,23 @@ struct llm_build_lfm2 : public llm_graph_context { } cur = ggml_add(ctx0, prev_cur, cur); - cur = ggml_add(ctx0, cur, build_feed_forward(cur, il)); + + auto * ffn_norm_out = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(ffn_norm_out, "model.layers.{}.ffn_norm", il); + + ggml_tensor * ffn_out = is_moe_layer ? + build_moe_feed_forward(ffn_norm_out, il) : + build_dense_feed_forward(ffn_norm_out, il); + cb(ffn_norm_out, "model.layers.{}.ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_out); } cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1); cb(cur, "model.embedding_norm", -1); res->t_embd = cur; - // lm_head is tied with embeddings - cur = build_lora_mm(model.tok_embd, cur); + cur = build_lora_mm(model.output, cur); cb(cur, "lm_head", -1); res->t_logits = cur; @@ -17981,29 +18887,38 @@ struct llm_build_lfm2 : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - ggml_tensor * build_feed_forward(ggml_tensor * cur, - int il) const { - cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "model.layers.{}.ffn_norm", il); + ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, + int il) const { + return build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + static_cast(hparams.expert_gating_func), + il); + } + ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, + int il) const { GGML_ASSERT(!model.layers[il].ffn_up_b); GGML_ASSERT(!model.layers[il].ffn_gate_b); GGML_ASSERT(!model.layers[il].ffn_down_b); - cur = build_ffn(cur, + return build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "model.layers.{}.feed_forward.w2", il); - - return cur; } - ggml_tensor * build_attn_block(ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_attn_kv_unified * inp_attn, - int il) const { + ggml_tensor * build_attn_block(ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv * inp_attn, + int il) const { GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il)); auto const n_embd_head = hparams.n_embd_head_v; auto const n_head_kv = hparams.n_head_kv(il); @@ -18038,7 +18953,7 @@ struct llm_build_lfm2 : public llm_graph_context { ); cur = build_attn(inp_attn, model.layers[il].wo, NULL, - q, k, v, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + q, k, v, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "model.layers.{}.self_attn.out_proj", il); @@ -18115,6 +19030,137 @@ struct llm_build_lfm2 : public llm_graph_context { } }; +struct llm_build_seed_oss : public llm_graph_context { + llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + template struct llm_build_smallthinker : public llm_graph_context{ llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){ @@ -18131,13 +19177,13 @@ struct llm_build_smallthinker : public llm_graph_context{ // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - using inp_attn_type = std::conditional_t; + using inp_attn_type = std::conditional_t; inp_attn_type * inp_attn = nullptr; if constexpr (iswa) { - inp_attn = build_attn_inp_kv_unified_iswa(); + inp_attn = build_attn_inp_kv_iswa(); } else { - inp_attn = build_attn_inp_kv_unified(); + inp_attn = build_attn_inp_kv(); } ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -18182,7 +19228,7 @@ struct llm_build_smallthinker : public llm_graph_context{ cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -18236,6 +19282,291 @@ struct llm_build_smallthinker : public llm_graph_context{ } }; +struct llm_build_grovemoe : public llm_graph_context { + llm_build_grovemoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_chunk_expert = n_expert / hparams.n_group_experts; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * probs = build_lora_mm(model.layers[il].ffn_gate_inp, cur); // [n_expert, n_tokens] + cb(probs, "ffn_moe_logits", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + nullptr, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il, probs); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + // TODO: Only do the expert selection and weights once + moe_out = + build_moe_ffn(cur, + nullptr, + model.layers[il].ffn_up_chexps, + model.layers[il].ffn_gate_chexps, + model.layers[il].ffn_down_chexps, + nullptr, + n_chunk_expert, n_expert_used > n_chunk_expert ? n_chunk_expert : n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il, probs); + cb(moe_out, "ffn_adj_moe_out", il); + + cur = ggml_add(ctx0, cur, ggml_scale(ctx0, moe_out, hparams.expert_group_scale)); + cb(cur, "ffn_final_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_apertus : public llm_graph_context { + llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_pos", il); + cb(Kcur, "Kcur_pos", il); + cb(Vcur, "Vcur_pos", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network with xIELU activation + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // Up projection + ggml_tensor * up = build_lora_mm(model.layers[il].ffn_up, cur); + cb(up, "ffn_up", il); + + float alpha_n_val = hparams.xielu_alpha_n[il]; + float alpha_p_val = hparams.xielu_alpha_p[il]; + float beta_val = hparams.xielu_beta[il]; + float eps_val = hparams.xielu_eps[il]; + + // Apply xIELU activation + ggml_tensor * activated = ggml_xielu(ctx0, up, alpha_n_val, alpha_p_val, beta_val, eps_val); + cb(activated, "ffn_xielu", il); + + // Down projection + cur = build_lora_mm(model.layers[il].ffn_down, activated); + cb(cur, "ffn_down", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -18244,12 +19575,15 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, // switch statement case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_JINA_BERT_V3: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: case LLM_ARCH_WAVTOKENIZER_DEC: + //case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA] case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: + case LLM_ARCH_LLADA_MOE: { res = nullptr; } break; @@ -18260,14 +19594,31 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, if (llm_arch_is_recurrent(arch)) { res = new llama_memory_recurrent( *this, - nullptr, GGML_TYPE_F32, GGML_TYPE_F32, cparams.offload_kqv, std::max((uint32_t) 1, cparams.n_seq_max), - cparams.n_seq_max); + cparams.n_seq_max, + nullptr); } else if (llm_arch_is_hybrid(arch)) { - const auto padding = llama_kv_cache_unified::get_padding(cparams); + + // The main difference between hybrid architectures is the + // layer filters, so pick the right one here + llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; + llama_memory_hybrid::layer_filter_cb filter_recr = nullptr; + if (arch == LLM_ARCH_FALCON_H1) { + filter_attn = [&](int32_t) { return true; }; + filter_recr = [&](int32_t) { return true; }; + } else if (arch == LLM_ARCH_NEMOTRON_H) { + filter_attn = [&](int32_t il) { + return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + }; + filter_recr = [&](int32_t il) { + return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + }; + } + + const auto padding = llama_kv_cache::get_padding(cparams); cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); @@ -18286,10 +19637,10 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* n_seq_max */ cparams.n_seq_max, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, - /* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr, - /* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr); + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); } else { - const auto padding = llama_kv_cache_unified::get_padding(cparams); + const auto padding = llama_kv_cache::get_padding(cparams); uint32_t n_ctx_per_stream = cparams.n_ctx; @@ -18306,10 +19657,22 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + llama_memory_i::layer_reuse_cb reuse = nullptr; + + if (arch == LLM_ARCH_GEMMA3N) { + reuse = [&](int32_t il) { + if (il >= (int32_t) hparams.n_layer_kv_from_start) { + return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); + } + + return -1; + }; + } + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); - res = new llama_kv_cache_unified_iswa( + res = new llama_kv_cache_iswa( *this, params.type_k, params.type_v, @@ -18320,13 +19683,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, n_ctx_per_stream, cparams.n_seq_max, cparams.n_ubatch, - padding); + padding, + nullptr, + reuse); } else { GGML_ASSERT(!hparams.is_swa_any()); - res = new llama_kv_cache_unified( + res = new llama_kv_cache( *this, - nullptr, params.type_k, params.type_v, !cparams.flash_attn, @@ -18336,7 +19700,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, padding, hparams.n_swa, - hparams.swa_type); + hparams.swa_type, + nullptr, + nullptr); } } } @@ -18355,7 +19721,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_LLAMA4: { - llm = std::make_unique(*this, params); + if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) { + llm = std::make_unique(*this, params); + } else { + llm = std::make_unique(*this, params); + } } break; case LLM_ARCH_DECI: { @@ -18383,6 +19753,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_JINA_BERT_V3: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: { @@ -18422,6 +19793,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_LLADA_MOE: + { + llm = std::make_unique(*this, params); + } + break; case LLM_ARCH_QWEN2VL: { llm = std::make_unique(*this, params); @@ -18495,6 +19871,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_GEMMA_EMBEDDING: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_STARCODER2: { llm = std::make_unique(*this, params); @@ -18530,7 +19910,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_OLMO2: { - llm = std::make_unique(*this, params); + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; case LLM_ARCH_OLMOE: { @@ -18599,6 +19983,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_NEMOTRON_H: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_EXAONE: { llm = std::make_unique(*this, params); @@ -18657,6 +20045,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_SEED_OSS: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_DOTS1: { llm = std::make_unique(*this, params); @@ -18694,6 +20086,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: { llm = std::make_unique(*this, params); } break; @@ -18705,6 +20098,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique>(*this, params); } } break; + case LLM_ARCH_GROVEMOE: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_APERTUS: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -18712,9 +20113,16 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + // if the gguf model was converted with --sentence-transformers-dense-modules + // there will be two additional dense projection layers + // dense linear projections are applied after pooling + // TODO: move reranking logic here and generalize + llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + return llm->res->get_gf(); } + // // interface implementation // @@ -18723,7 +20131,7 @@ llama_model_params llama_model_default_params() { llama_model_params result = { /*.devices =*/ nullptr, /*.tensor_buft_overrides =*/ nullptr, - /*.n_gpu_layers =*/ 0, + /*.n_gpu_layers =*/ 999, /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, @@ -18735,13 +20143,9 @@ llama_model_params llama_model_default_params() { /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, + /*.no_host =*/ false, }; -#ifdef GGML_USE_METAL - // note: we usually have plenty of VRAM, so by default offload all layers to the GPU - result.n_gpu_layers = 999; -#endif - return result; } @@ -18833,6 +20237,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_NEMOTRON_H: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values @@ -18873,6 +20278,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GROK: case LLM_ARCH_DBRX: case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V3: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_STABLELM: @@ -18883,6 +20289,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN2MOE: case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3MOE: + case LLM_ARCH_LLADA_MOE: case LLM_ARCH_OLMO2: case LLM_ARCH_OLMOE: case LLM_ARCH_PHI2: @@ -18894,6 +20301,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: + case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: @@ -18908,8 +20316,12 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: case LLM_ARCH_SMALLTHINKER: case LLM_ARCH_GLM4_MOE: + case LLM_ARCH_SEED_OSS: + case LLM_ARCH_GROVEMOE: + case LLM_ARCH_APERTUS: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: @@ -19020,6 +20432,10 @@ bool llama_model_is_recurrent(const llama_model * model) { return llm_arch_is_recurrent(model->arch); } +bool llama_model_is_hybrid(const llama_model * model) { + return llm_arch_is_hybrid(model->arch); +} + bool llama_model_is_diffusion(const llama_model * model) { return llm_arch_is_diffusion(model->arch); } diff --git a/llama/llama.cpp/src/llama-model.h b/llama/llama.cpp/src/llama-model.h index 09964533..ec3fbd33 100644 --- a/llama/llama.cpp/src/llama-model.h +++ b/llama/llama.cpp/src/llama-model.h @@ -7,6 +7,7 @@ #include "llama-memory.h" #include "llama-vocab.h" +#include #include #include #include @@ -28,6 +29,7 @@ enum llm_type { LLM_TYPE_80M, LLM_TYPE_109M, LLM_TYPE_137M, + LLM_TYPE_140M, LLM_TYPE_160M, LLM_TYPE_190M, LLM_TYPE_220M, @@ -36,12 +38,15 @@ enum llm_type { LLM_TYPE_270M, LLM_TYPE_335M, LLM_TYPE_350M, + LLM_TYPE_360M, LLM_TYPE_410M, LLM_TYPE_450M, LLM_TYPE_475M, + LLM_TYPE_558M, LLM_TYPE_700M, LLM_TYPE_770M, LLM_TYPE_780M, + LLM_TYPE_950M, LLM_TYPE_0_3B, LLM_TYPE_0_5B, LLM_TYPE_0_6B, @@ -54,6 +59,7 @@ enum llm_type { LLM_TYPE_1_7B, LLM_TYPE_1_8B, LLM_TYPE_2B, + LLM_TYPE_2_6B, LLM_TYPE_2_8B, LLM_TYPE_2_9B, LLM_TYPE_3B, @@ -76,9 +82,11 @@ enum llm_type { LLM_TYPE_32B, LLM_TYPE_34B, LLM_TYPE_35B, + LLM_TYPE_36B, LLM_TYPE_40B, LLM_TYPE_65B, LLM_TYPE_70B, + LLM_TYPE_120B, LLM_TYPE_142B, LLM_TYPE_236B, LLM_TYPE_290B, @@ -100,6 +108,7 @@ enum llm_type { LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_128E, // llama4 Maverick LLM_TYPE_A13B, + LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, LLM_TYPE_106B_A12B, // GLM-4.5-Air @@ -268,6 +277,11 @@ struct llama_layer { struct ggml_tensor * ffn_down_shexp = nullptr; struct ggml_tensor * ffn_up_shexp = nullptr; + // ff adjugate experts (chexps) + struct ggml_tensor * ffn_gate_chexps = nullptr; + struct ggml_tensor * ffn_down_chexps = nullptr; + struct ggml_tensor * ffn_up_chexps = nullptr; + // ff bias struct ggml_tensor * ffn_gate_b = nullptr; struct ggml_tensor * ffn_down_b = nullptr; // b2 @@ -368,6 +382,12 @@ struct llama_layer { // openai-moe struct ggml_tensor * attn_sinks = nullptr; + // xIELU activation parameters for Apertus + struct ggml_tensor * ffn_act_alpha_n = nullptr; + struct ggml_tensor * ffn_act_alpha_p = nullptr; + struct ggml_tensor * ffn_act_beta = nullptr; + struct ggml_tensor * ffn_act_eps = nullptr; + struct ggml_tensor * bskcn_tv = nullptr; struct llama_layer_posnet posnet; @@ -421,6 +441,12 @@ struct llama_model { std::vector layers; + //Dense linear projections for SentenceTransformers models like embeddinggemma + // For Sentence Transformers models structure see + // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models + struct ggml_tensor * dense_2_out_layers = nullptr; + struct ggml_tensor * dense_3_out_layers = nullptr; + llama_model_params params; // gguf metadata @@ -449,10 +475,12 @@ struct llama_model { std::string desc() const; - size_t size() const; + size_t size() const; // file size size_t n_tensors() const; size_t n_devices() const; + std::map memory_breakdown() const; + // total number of parameters in the model uint64_t n_elements() const; diff --git a/llama/llama.cpp/src/llama-quant.cpp b/llama/llama.cpp/src/llama-quant.cpp index 1d0361cc..97228b2a 100644 --- a/llama/llama.cpp/src/llama-quant.cpp +++ b/llama/llama.cpp/src/llama-quant.cpp @@ -725,7 +725,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // attention layers have a non-zero number of kv heads int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0); if (llama_model_has_encoder(&model)) { - n_attn_layer *= 3; + // now n_attn_layer is the number of attention layers in the encoder + // for each decoder block, there are 2 attention layers + n_attn_layer += 2 * model.hparams.dec_n_layer; } GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected"); } @@ -920,7 +922,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: new_type = tensor->type; new_data = tensor->data; new_size = ggml_nbytes(tensor); - LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0); + LLAMA_LOG_INFO("size = %8.3f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0); } else { const int64_t nelements = ggml_nelements(tensor); @@ -1037,8 +1039,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } close_ofstream(); - LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); - LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); + LLAMA_LOG_INFO("%s: model size = %8.2f MiB\n", __func__, total_size_org/1024.0/1024.0); + LLAMA_LOG_INFO("%s: quant size = %8.2f MiB\n", __func__, total_size_new/1024.0/1024.0); if (qs.n_fallback > 0) { LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n", diff --git a/llama/llama.cpp/src/llama-sampling.cpp b/llama/llama.cpp/src/llama-sampling.cpp index 11f93f42..da34526b 100644 --- a/llama/llama.cpp/src/llama-sampling.cpp +++ b/llama/llama.cpp/src/llama-sampling.cpp @@ -128,6 +128,89 @@ struct ring_buffer { std::vector data; }; +// writes result in res, does not mutate cur +static void llama_token_data_array_partial_sort(const llama_token_data_array & cur, int npartial, std::vector & res) { + static const auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + + constexpr int nbuckets = 128; + constexpr float bucket_low = -10.0f; + constexpr float bucket_high = 10.0f; + constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); + constexpr float bucket_inter = -bucket_low * bucket_scale; + + std::vector bucket_idx; + std::vector histo(nbuckets, 0); + + std::vector bucket_ptrs; + + bucket_idx.reserve(cur.size); + + for (int i = 0; i < (int)cur.size; ++i) { + const float val = cur.data[i].logit; + int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); + ib = std::max(0, std::min(nbuckets - 1, ib)); + bucket_idx.push_back(ib); + ++histo[ib]; + } + int nhave = 0; + int ib = nbuckets - 1; + for ( ; ib >= 0; --ib) { + nhave += histo[ib]; + if (nhave >= npartial) { + break; + } + } + res.resize(nhave); + auto * ptr = res.data(); + bucket_ptrs.reserve(nbuckets - ib); + for (int j = nbuckets - 1; j >= ib; --j) { + bucket_ptrs.push_back(ptr); + ptr += histo[j]; + } + for (int i = 0; i < (int)cur.size; ++i) { + int j = bucket_idx[i]; + if (j >= ib) { + *bucket_ptrs[nbuckets - 1 - j]++ = cur.data[i]; + } + } + + ptr = res.data(); + int ndone = 0; + for (int j = nbuckets - 1; j > ib; --j) { + std::sort(ptr, ptr + histo[j], comp); + ptr += histo[j]; + ndone += histo[j]; + } + std::partial_sort(ptr, ptr + npartial - ndone, ptr + histo[ib], comp); +} + +// reduces the size of cur_p to npartial, keeping only the top npartial elements +static void llama_token_data_array_partial_sort_inplace(llama_token_data_array * cur_p, int npartial) { + static const auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + + if (npartial <= 128) { + std::partial_sort(cur_p->data, cur_p->data + npartial, cur_p->data + cur_p->size, comp); + + cur_p->size = npartial; + cur_p->sorted = true; + + return; + } + + std::vector tmp; + + llama_token_data_array_partial_sort(*cur_p, npartial, tmp); + + std::copy(tmp.data(), tmp.data() + npartial, cur_p->data); + + cur_p->size = npartial; + cur_p->sorted = true; +} + static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) { // iterator for the probabilities #ifdef __GNUC__ @@ -200,18 +283,21 @@ static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) } } -static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { +static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_sort) { GGML_ASSERT(cur_p->size > 0); - // Sort the logits in descending order - if (!cur_p->sorted) { - std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - cur_p->sorted = true; + // Sort the logits in descending order if requested + if (do_sort && !cur_p->sorted) { + llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size); } float max_l = cur_p->data[0].logit; + if (!cur_p->sorted) { + for (size_t i = 1; i < cur_p->size; ++i) { + max_l = std::max(max_l, cur_p->data[i].logit); + } + } + float cum_sum = 0.0f; for (size_t i = 0; i < cur_p->size; ++i) { @@ -226,7 +312,6 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { } static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) { - // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast // if (k >= (int32_t)cur_p->size) { // return; // } @@ -239,64 +324,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) // Sort scores in descending order if (!cur_p->sorted) { - auto comp = [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }; - if (k <= 128) { - std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp); - } else { - constexpr int nbuckets = 128; - constexpr float bucket_low = -10.0f; - constexpr float bucket_high = 10.0f; - constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); - constexpr float bucket_inter = -bucket_low * bucket_scale; - - std::vector bucket_idx(cur_p->size); - std::vector histo(nbuckets, 0); - - for (int i = 0; i < (int)cur_p->size; ++i) { - const float val = cur_p->data[i].logit; - int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); - ib = std::max(0, std::min(nbuckets - 1, ib)); - bucket_idx[i] = ib; - ++histo[ib]; - } - int nhave = 0; - int ib = nbuckets - 1; - for ( ; ib >= 0; --ib) { - nhave += histo[ib]; - if (nhave >= k) { - break; - } - } - std::vector tmp_tokens(nhave); - auto * ptr = tmp_tokens.data(); - std::vector bucket_ptrs; - bucket_ptrs.reserve(nbuckets - ib); - for (int j = nbuckets - 1; j >= ib; --j) { - bucket_ptrs.push_back(ptr); - ptr += histo[j]; - } - for (int i = 0; i < (int)cur_p->size; ++i) { - int j = bucket_idx[i]; - if (j >= ib) { - *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i]; - } - } - - ptr = tmp_tokens.data(); - int ndone = 0; - for (int j = nbuckets - 1; j > ib; --j) { - std::sort(ptr, ptr + histo[j], comp); - ptr += histo[j]; - ndone += histo[j]; - } - std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); - - std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data)); - - } - cur_p->sorted = true; + llama_token_data_array_partial_sort_inplace(cur_p, k); } cur_p->size = k; @@ -576,9 +604,73 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl* static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_dist *) smpl->ctx; - llama_sampler_softmax_impl(cur_p); + // edge cases + if (cur_p->size == 0) { + cur_p->selected = -1; + return; + } + + cur_p->selected = 0; + + if (cur_p->size == 1) { + cur_p->data[0].p = 1.0f; + return; + } + + // max logit for numerical stability + float max_l = cur_p->data[0].logit; + if (!cur_p->sorted) { + for (size_t i = 1; i < cur_p->size; ++i) { + max_l = std::max(max_l, cur_p->data[i].logit); + } + } + + // apply softmax to obtain the probabilities + double sum_cum = 0.0f; + for (size_t i = 0; i < cur_p->size; ++i) { + float p = expf(cur_p->data[i].logit - max_l); + cur_p->data[i].p = p; + sum_cum += p; + } + +#if 1 + // sample from the obtained probabilities and normalize the probs in a single pass + // this is ~3x faster on Mac with full gpt-oss vocab than the version below + // + std::uniform_real_distribution dist(0.0f, 1.0f); + const double rnd = dist(ctx->rng); + + double sum_run = 0.0f; + const double sum_tgt = sum_cum*rnd; + + bool found = false; + for (size_t i = 0; i < cur_p->size; ++i) { + if (!found) { + // accumulate probs until we reach the target sum + sum_run += cur_p->data[i].p; + if (sum_run >= sum_tgt) { + cur_p->selected = i; + found = true; + } + } + + // normalize probs + cur_p->data[i].p /= sum_cum; + } + + // fallback to the last token (don't think this can happen) + assert(found); + if (!found) { + cur_p->selected = cur_p->size - 1; + } +#else + // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= sum_cum; + } cur_p->selected = llama_sample_dist(cur_p, ctx->rng); +#endif } static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { @@ -626,32 +718,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { ); } -// softmax - -static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) { - return "softmax"; -} - -static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { - llama_sampler_softmax_impl(cur_p); -} - -static struct llama_sampler_i llama_sampler_softmax_i = { - /* .name = */ llama_sampler_softmax_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_softmax_apply, - /* .reset = */ nullptr, - /* .clone = */ nullptr, - /* .free = */ nullptr, -}; - -struct llama_sampler * llama_sampler_init_softmax() { - return llama_sampler_init( - /* .iface = */ &llama_sampler_softmax_i, - /* .ctx = */ nullptr - ); -} - // top-k struct llama_sampler_top_k { @@ -663,7 +729,7 @@ static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl } static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_top_k *) smpl->ctx; + auto * ctx = (llama_sampler_top_k *) smpl->ctx; llama_sampler_top_k_impl(cur_p, ctx->k); } @@ -699,6 +765,8 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) { struct llama_sampler_top_p { const float p; const size_t min_keep; + + std::vector buf_sort; }; static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) { @@ -706,20 +774,35 @@ static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl } static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_top_p *) smpl->ctx; + auto * ctx = (llama_sampler_top_p *) smpl->ctx; if (ctx->p >= 1.0f) { return; } - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, false); + + size_t k = cur_p->size; + auto * pdata = cur_p->data; + + auto & buf_sort = ctx->buf_sort; + + // if not sorted, try adaptive top-k sorting + if (!cur_p->sorted && cur_p->size > 1024) { + k = std::min(256, cur_p->size); + llama_token_data_array_partial_sort(*cur_p, k, buf_sort); + pdata = buf_sort.data(); + } else if (!cur_p->sorted) { + // small candidates -> sort inplace + llama_token_data_array_partial_sort_inplace(cur_p, k); + } // Compute the cumulative probabilities float cum_sum = 0.0f; size_t last_idx = cur_p->size; for (size_t i = 0; i < cur_p->size; ++i) { - cum_sum += cur_p->data[i].p; + cum_sum += pdata[i].p; // Check if the running sum is at least p or if we have kept at least min_keep tokens // we set the last index to i+1 to indicate that the current iterate should be included in the set @@ -727,9 +810,21 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d last_idx = i + 1; break; } + + // we exceeded the current top-k heuristic -> increase k and continue + if (!cur_p->sorted && i == k - 1) { + k = cur_p->size; + llama_token_data_array_partial_sort(*cur_p, k, buf_sort); + pdata = buf_sort.data(); + } } // Resize the output vector to keep only the top-p tokens + if (!cur_p->sorted) { + std::copy(buf_sort.data(), buf_sort.data() + last_idx, cur_p->data); + cur_p->sorted = true; + } + cur_p->size = last_idx; } @@ -757,6 +852,7 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { /* .ctx = */ new llama_sampler_top_p { /* .p = */ p, /* .min_keep = */ min_keep, + /* .buf_sort = */ {}, } ); } @@ -773,7 +869,7 @@ static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl } static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_min_p *) smpl->ctx; + auto * ctx = (llama_sampler_min_p *) smpl->ctx; if (ctx->p <= 0.0f || !cur_p->size) { return; @@ -799,7 +895,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d // if we have enough values the operation was a success if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) { - memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); + std::copy(filtered_tokens.begin(), filtered_tokens.end(), cur_p->data); cur_p->size = filtered_tokens.size(); min_p_applied = true; } @@ -809,10 +905,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d if (!min_p_applied) { // Sort the logits in descending order if (!cur_p->sorted) { - std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - cur_p->sorted = true; + llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size); } const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max @@ -869,7 +962,7 @@ static const char * llama_sampler_typical_name(const struct llama_sampler * /*sm } static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_typical *) smpl->ctx; + auto * ctx = (llama_sampler_typical *) smpl->ctx; // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr @@ -878,7 +971,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token } // Compute the softmax of logits and calculate entropy - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, true); float entropy = 0.0f; for (size_t i = 0; i < cur_p->size; ++i) { @@ -1012,7 +1105,7 @@ static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*s } static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx; + auto * ctx = (llama_sampler_temp_ext *) smpl->ctx; if (ctx->delta > 0) { const float min_temp = std::max(0.0f, ctx->temp - ctx->delta); const float max_temp = ctx->temp + ctx->delta; @@ -1027,7 +1120,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke // Calculate maximum possible entropy float max_entropy = -logf(1.0f / cur_p->size); - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, true); // Calculate entropy of the softmax probabilities float entropy = 0.0f; @@ -1121,7 +1214,7 @@ struct llama_sampler_xtc { const uint32_t seed; uint32_t seed_cur; - std::mt19937 rng; + std::mt19937 rng; }; static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) { @@ -1139,17 +1232,20 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data std::uniform_real_distribution distribution(0.0f, 1.0f); float chance = distribution(ctx->rng); - if (chance > ctx->probability) return; + if (chance > ctx->probability) { + return; + } - // in case it's not sorted/recalculated yet - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, true); int pos_last = 0; for (size_t i = 0; i < cur_p->size; ++i) { if (cur_p->data[i].p >= ctx->threshold) { pos_last = i; - } else break; + } else { + break; + } } if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) { @@ -1221,7 +1317,7 @@ struct llama_sampler_mirostat { float mu; - std::mt19937 rng; + std::mt19937 rng; }; static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) { @@ -1231,7 +1327,7 @@ static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*s static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_mirostat *) smpl->ctx; - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, true); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; @@ -1250,7 +1346,8 @@ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_toke float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat); llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); - llama_sampler_softmax_impl(cur_p); + + llama_sampler_softmax_impl(cur_p, true); const int idx = llama_sample_dist(cur_p, ctx->rng); @@ -1336,7 +1433,7 @@ static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, true); // Truncate the words with surprise values greater than mu cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { @@ -1348,7 +1445,7 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t } // Normalize the probabilities of the remaining words - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, true); const int idx = llama_sample_dist(cur_p, ctx->rng); @@ -1540,7 +1637,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0"); } trigger_pattern += ")[\\s\\S]*"; - auto trigger_pattern_c = trigger_pattern.c_str(); + const auto * trigger_pattern_c = trigger_pattern.c_str(); trigger_patterns = &trigger_pattern_c; num_trigger_patterns = 1; } @@ -1748,7 +1845,7 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * } static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; + auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; if (ctx->n <= 0.0f || cur_p->size <= 1) { return; @@ -1780,13 +1877,14 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t } float std = valid_count > 0 ? sqrt(acc/valid_count) : 0; - //apply mask + // apply mask for (size_t i = 0; i < cur_p->size; ++i) { if (cur_p->data[i].logit < max - (ctx->n * std)) { cur_p->data[i].logit = -INFINITY; } } - llama_sampler_softmax_impl(cur_p); + + llama_sampler_softmax_impl(cur_p, true); } static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) { @@ -1991,7 +2089,9 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat { const int last = last_n_repeat - 1; - int rt = 0, lt = 0; + + int rt = 0; + int lt = 0; for (int k = 1; k < last_n_repeat; ++k) { if (k > rt) { @@ -2135,8 +2235,8 @@ static struct llama_sampler_i llama_sampler_dry_i = { /* .free = */ llama_sampler_dry_free, }; -struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { - int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0); +struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { + int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? n_ctx_train : std::max(dry_penalty_last_n, 0); std::unordered_multimap> processed_breakers; const int MAX_CHAR_LEN = 40; const int MAX_SEQ_LEN = 20; @@ -2169,7 +2269,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, return llama_sampler_init( /* .iface = */ &llama_sampler_dry_i, /* .ctx = */ new llama_sampler_dry { - /* .total_context_size = */ context_size, + /* .total_context_size = */ n_ctx_train, /* .dry_multiplier = */ dry_multiplier, /* .dry_base = */ dry_base, /* .dry_allowed_length = */ dry_allowed_length, @@ -2308,7 +2408,7 @@ static const char * llama_sampler_infill_name(const struct llama_sampler * /*smp static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_infill *) smpl->ctx; - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, true); #if defined(GGML_DEBUG_SAMPLER_INFILL) #define LOG_DBG_CUR LLAMA_LOG_DEBUG @@ -2441,8 +2541,13 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ if (n_non_eog == 0) { cur_p->size = 1; cur_p->data[0].id = ctx->vocab->token_eot(); + if (cur_p->data[0].id == LLAMA_TOKEN_NULL) { + cur_p->data[0].id = ctx->vocab->token_eos(); + } cur_p->data[0].logit = 1.0f; + GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL); + return; } diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp index fa388b03..217ede47 100644 --- a/llama/llama.cpp/src/llama-vocab.cpp +++ b/llama/llama.cpp/src/llama-vocab.cpp @@ -347,6 +347,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_JAIS: case LLAMA_VOCAB_PRE_TYPE_TRILLION: + case LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING: regex_exprs = { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; @@ -434,6 +435,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_GROK_2: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -1763,7 +1771,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { const size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx); const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); -#ifdef IS_BIG_ENDIAN +#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ // correct endiannes of data in precompiled_charsmap binary blob uint32_t * xcda_blob_size = (uint32_t *) &precompiled_charsmap[0]; *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); @@ -1944,7 +1952,12 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION; clean_spaces = false; } else if ( - tokenizer_pre == "bailingmoe") { + tokenizer_pre == "granite-docling") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING; + clean_spaces = false; + } else if ( + tokenizer_pre == "bailingmoe" || + tokenizer_pre == "llada-moe") { pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; clean_spaces = false; } else if ( @@ -1963,6 +1976,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "kimi-k2") { pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; clean_spaces = false; + } else if ( + tokenizer_pre == "grok-2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2; + clean_spaces = false; } else { LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__); pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; @@ -2144,6 +2161,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|end|>" || t.first == "" || t.first == "<|endoftext|>" + || t.first == "<|end_of_text|>" // granite || t.first == "" || t.first == "_" || t.first == "<|end▁of▁sentence|>" // DeepSeek @@ -2331,7 +2349,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // @ngxson : quick hack for gpt-oss, always render these tokens for (const auto & t : token_to_id) { - if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>") { + if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>" || t.first == "<|constrain|>") { id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED; } } @@ -2378,6 +2396,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { if (has_return && has_call && has_end) { special_eog_ids.erase(end_id); + id_to_token[end_id].attr = LLAMA_TOKEN_ATTR_USER_DEFINED; LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__); } } @@ -2459,7 +2478,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // set attributes by model/tokenizer/architecture name if (false || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"}) - || _contains_any(general_arch, {"nomic-bert-moe"}) + || _contains_any(general_arch, {"nomic-bert-moe", "jina-bert-v3"}) ) { if (token_to_id.count("") == 0) { LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__); diff --git a/llama/llama.cpp/src/llama-vocab.h b/llama/llama.cpp/src/llama-vocab.h index 61b81242..5e468675 100644 --- a/llama/llama.cpp/src/llama-vocab.h +++ b/llama/llama.cpp/src/llama-vocab.h @@ -8,45 +8,47 @@ // pre-tokenization types enum llama_vocab_pre_type { - LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, - LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, - LLAMA_VOCAB_PRE_TYPE_FALCON = 4, - LLAMA_VOCAB_PRE_TYPE_MPT = 5, - LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, - LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, - LLAMA_VOCAB_PRE_TYPE_REFACT = 8, - LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, - LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, - LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, - LLAMA_VOCAB_PRE_TYPE_OLMO = 12, - LLAMA_VOCAB_PRE_TYPE_DBRX = 13, - LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, - LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, - LLAMA_VOCAB_PRE_TYPE_VIKING = 18, - LLAMA_VOCAB_PRE_TYPE_JAIS = 19, - LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, - LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, - LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, - LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, - LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, - LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, - LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, - LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, - LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, - LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, - LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, - LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, - LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, - LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, - LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + LLAMA_VOCAB_PRE_TYPE_PORO = 15, + LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, + LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, + LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, + LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, + LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, + LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, }; struct LLM_KV; diff --git a/llama/llama.cpp/src/llama.cpp b/llama/llama.cpp/src/llama.cpp index 34906cdb..d821a96a 100644 --- a/llama/llama.cpp/src/llama.cpp +++ b/llama/llama.cpp/src/llama.cpp @@ -25,6 +25,18 @@ // interface implementation // +const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type) { + switch (flash_attn_type) { + case LLAMA_FLASH_ATTN_TYPE_AUTO: + return "auto"; + case LLAMA_FLASH_ATTN_TYPE_DISABLED: + return "disabled"; + case LLAMA_FLASH_ATTN_TYPE_ENABLED: + return "enabled"; + } + GGML_ABORT("fatal error"); +} + struct llama_sampler_chain_params llama_sampler_chain_default_params() { struct llama_sampler_chain_params result = { /*.no_perf =*/ true, @@ -47,6 +59,7 @@ bool llama_supports_mlock(void) { bool llama_supports_gpu_offload(void) { return ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr || + ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU) != nullptr || llama_supports_rpc(); } @@ -71,7 +84,9 @@ void llama_numa_init(enum ggml_numa_strategy numa) { GGML_ASSERT(dev && "CPU backend is not loaded"); auto * reg = ggml_backend_dev_backend_reg(dev); auto * numa_init_fn = (decltype(ggml_numa_init) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_numa_init"); - numa_init_fn(numa); + if (numa_init_fn) { + numa_init_fn(numa); + } } } @@ -170,8 +185,13 @@ static struct llama_model * llama_model_load_from_file_impl( model->devices.push_back(*dev); } } else { + // default device selection + + // build list of available devices + std::vector gpus; + std::vector igpus; std::vector rpc_servers; - // use all available devices + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { ggml_backend_dev_t dev = ggml_backend_dev_get(i); switch (ggml_backend_dev_type(dev)) { @@ -180,19 +200,51 @@ static struct llama_model * llama_model_load_from_file_impl( // skip CPU backends since they are handled separately break; - case GGML_BACKEND_DEVICE_TYPE_GPU: + case GGML_BACKEND_DEVICE_TYPE_GPU: { ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); if (ggml_backend_reg_name(reg) == std::string("RPC")) { rpc_servers.push_back(dev); } else { - model->devices.push_back(dev); + // check if there is already a GPU with the same device id + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + auto it = std::find_if(gpus.begin(), gpus.end(), [&props](ggml_backend_dev_t d) { + ggml_backend_dev_props d_props; + ggml_backend_dev_get_props(d, &d_props); + if (props.device_id && d_props.device_id) { + return strcmp(props.device_id, d_props.device_id) == 0; + } + return false; + }); + + if (it != gpus.end()) { + LLAMA_LOG_INFO("%s: skipping device %s (%s) with id %s - already using device %s (%s) with the same id\n", + __func__, + ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + props.device_id ? props.device_id : "unknown id", + ggml_backend_dev_name(*it), ggml_backend_dev_description(*it)); + } else { + gpus.push_back(dev); + } } break; + } + + case GGML_BACKEND_DEVICE_TYPE_IGPU: + igpus.push_back(dev); + break; } } - // add RPC servers at the front of the list - if (!rpc_servers.empty()) { - model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end()); + + // add RPC servers at the front of the list to minimize network transfers + model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end()); + + // add GPUs + model->devices.insert(model->devices.end(), gpus.begin(), gpus.end()); + + // add integrated GPUs only if no other devices were found + if (model->devices.empty()) { + model->devices.insert(model->devices.end(), igpus.begin(), igpus.end()); } } @@ -213,9 +265,14 @@ static struct llama_model * llama_model_load_from_file_impl( } for (auto * dev : model->devices) { - size_t free, total; // NOLINT - ggml_backend_dev_memory(dev, &free, &total); - LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024); + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + size_t memory_free, memory_total; + ggml_backend_dev_memory(dev, &memory_free, &memory_total); + LLAMA_LOG_INFO("%s: using device %s (%s) (%s) - %zu MiB free\n", __func__, + ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + props.device_id ? props.device_id : "unknown id", + memory_free/1024/1024); } const int status = llama_model_load(path_model, splits, *model, params); diff --git a/llama/llama.cpp/src/unicode.h b/llama/llama.cpp/src/unicode.h index 0a5fa2a7..5bd1362f 100644 --- a/llama/llama.cpp/src/unicode.h +++ b/llama/llama.cpp/src/unicode.h @@ -4,6 +4,7 @@ #include #include +// TODO: reimplement this structure in endian-independent way struct unicode_cpt_flags { enum { UNDEFINED = 0x0001, @@ -15,6 +16,10 @@ struct unicode_cpt_flags { SYMBOL = 0x0040, // regex: \p{S} CONTROL = 0x0080, // regex: \p{C} MASK_CATEGORIES = 0x00FF, + WHITESPACE = 0x0100, + LOWERCASE = 0x0200, + UPPERCASE = 0x0400, + NFD = 0x0800, }; // codepoint type @@ -34,11 +39,49 @@ struct unicode_cpt_flags { // decode from uint16 inline unicode_cpt_flags(const uint16_t flags = 0) { +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ *reinterpret_cast(this) = flags; +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + is_undefined = (flags & UNDEFINED) ? 1 : 0; + is_number = (flags & NUMBER) ? 1 : 0; + is_letter = (flags & LETTER) ? 1 : 0; + is_separator = (flags & SEPARATOR) ? 1 : 0; + is_accent_mark = (flags & ACCENT_MARK) ? 1 : 0; + is_punctuation = (flags & PUNCTUATION) ? 1 : 0; + is_symbol = (flags & SYMBOL) ? 1 : 0; + is_control = (flags & CONTROL) ? 1 : 0; + is_whitespace = (flags & WHITESPACE) ? 1 : 0; + is_lowercase = (flags & LOWERCASE) ? 1 : 0; + is_uppercase = (flags & UPPERCASE) ? 1 : 0; + is_nfd = (flags & NFD) ? 1 : 0; +#else +#error Unexpected or undefined __BYTE_ORDER__ +#endif } inline uint16_t as_uint() const { +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ return *reinterpret_cast(this); +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + uint16_t result = + is_undefined * UNDEFINED + + is_number * NUMBER + + is_letter * LETTER + + is_separator * SEPARATOR + + is_accent_mark * ACCENT_MARK + + is_punctuation * PUNCTUATION + + is_symbol * SYMBOL + + is_control * CONTROL + + is_whitespace * WHITESPACE + + is_lowercase * LOWERCASE + + is_uppercase * UPPERCASE + + is_nfd * NFD + ; + + return result; +#else +#error Unexpected or undefined __BYTE_ORDER__ +#endif } inline uint16_t category_flag() const { diff --git a/llama/llama.cpp/tools/mtmd/clip-impl.h b/llama/llama.cpp/tools/mtmd/clip-impl.h index c8822dcf..7a752385 100644 --- a/llama/llama.cpp/tools/mtmd/clip-impl.h +++ b/llama/llama.cpp/tools/mtmd/clip-impl.h @@ -31,6 +31,7 @@ // vision-specific #define KEY_IMAGE_SIZE "clip.vision.image_size" +#define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size" #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_STD "clip.vision.image_std" @@ -44,6 +45,7 @@ #define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" #define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" #define KEY_MINICPMV_VERSION "clip.minicpmv_version" +#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" // audio-specific #define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins" @@ -81,6 +83,7 @@ #define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" #define TN_IMAGE_NEWLINE "model.image_newline" #define TN_MM_INP_NORM "mm.input_norm.weight" +#define TN_MM_INP_NORM_B "mm.input_norm.bias" #define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 #define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3 #define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3 @@ -132,6 +135,8 @@ enum projector_type { PROJECTOR_TYPE_QWEN2A, PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx PROJECTOR_TYPE_VOXTRAL, + PROJECTOR_TYPE_LFM2, + PROJECTOR_TYPE_KIMIVL, PROJECTOR_TYPE_UNKNOWN, }; @@ -152,6 +157,8 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_QWEN2A, "qwen2a"}, { PROJECTOR_TYPE_QWEN25O, "qwen2.5o"}, { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, + { PROJECTOR_TYPE_LFM2, "lfm2"}, + { PROJECTOR_TYPE_KIMIVL, "kimivl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/llama/llama.cpp/tools/mtmd/clip.cpp b/llama/llama.cpp/tools/mtmd/clip.cpp index f4f69cfc..6699b75a 100644 --- a/llama/llama.cpp/tools/mtmd/clip.cpp +++ b/llama/llama.cpp/tools/mtmd/clip.cpp @@ -183,7 +183,9 @@ struct clip_hparams { int32_t projection_dim; int32_t n_head; int32_t n_layer; - int32_t proj_scale_factor = 0; // idefics3 + // idefics3 + int32_t preproc_image_size = 0; + int32_t proj_scale_factor = 0; float image_mean[3]; float image_std[3]; @@ -214,6 +216,7 @@ struct clip_hparams { // legacy bool has_llava_projector = false; int minicpmv_version = 0; + int32_t minicpmv_query_num = 0; // MiniCPM-V query number }; struct clip_layer { @@ -277,6 +280,7 @@ struct clip_model { // LLaVA projection ggml_tensor * mm_input_norm_w = nullptr; + ggml_tensor * mm_input_norm_b = nullptr; ggml_tensor * mm_0_w = nullptr; ggml_tensor * mm_0_b = nullptr; ggml_tensor * mm_2_w = nullptr; @@ -417,6 +421,7 @@ struct clip_ctx { } if (!backend) { backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr); + backend = backend ? backend : ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU, nullptr); } } @@ -500,11 +505,17 @@ struct clip_graph { ggml_cgraph * build_siglip() { ggml_tensor * inp = build_inp(); + + ggml_tensor * learned_pos_embd = model.position_embeddings; + if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) { + learned_pos_embd = resize_position_embeddings(); + } + ggml_tensor * cur = build_vit( inp, n_patches, NORM_TYPE_NORMAL, hparams.ffn_op, - model.position_embeddings, + learned_pos_embd, nullptr); if (ctx->proj_type() == PROJECTOR_TYPE_GEMMA3) { @@ -513,8 +524,8 @@ struct clip_graph { const int patches_per_image = n_patches_x; const int kernel_size = hparams.proj_scale_factor; - cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); - cur = ggml_reshape_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size); + cur = ggml_transpose(ctx0, cur); + cur = ggml_cont_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size); // doing a pool2d to reduce the number of output tokens cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0); @@ -531,29 +542,27 @@ struct clip_graph { cur); } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) { + // pixel_shuffle // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 - const int scale_factor = model.hparams.proj_scale_factor; - const int n_embd = cur->ne[0]; - const int seq = cur->ne[1]; - const int bsz = 1; // batch size, always 1 for now since we don't support batching - const int height = std::sqrt(seq); - const int width = std::sqrt(seq); - GGML_ASSERT(scale_factor != 0); - cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), - n_embd * scale_factor * scale_factor, - height / scale_factor, - width / scale_factor, - bsz); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur), - n_embd * scale_factor * scale_factor, - seq / (scale_factor * scale_factor), - bsz); - + cur = build_patch_merge_permute(cur, scale_factor); cur = ggml_mul_mat(ctx0, model.projection, cur); + + } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) { + // pixel unshuffle block + const int scale_factor = model.hparams.proj_scale_factor; + cur = build_patch_merge_permute(cur, scale_factor); + + // projection + cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm + cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); + cur = ggml_add(ctx0, cur, model.mm_input_norm_b); + + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + cur = ggml_add(ctx0, cur, model.mm_1_b); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); + cur = ggml_add(ctx0, cur, model.mm_2_b); } else { GGML_ABORT("SigLIP: Unsupported projector type"); } @@ -681,15 +690,15 @@ struct clip_graph { auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); inp = ggml_add(ctx0, inp, inp_1); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] - inp = ggml_reshape_4d( + inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b] + inp = ggml_cont_4d( ctx0, inp, n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); inp = ggml_reshape_4d( ctx0, inp, n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); - inp = ggml_reshape_3d( + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_3d( ctx0, inp, n_embd, n_patches_x * n_patches_y, batch_size); } @@ -879,21 +888,8 @@ struct clip_graph { int n_embd = clip_n_mmproj_embd(ctx); const int d_head = 128; int n_head = n_embd/d_head; - int num_query = 96; - if (ctx->model.hparams.minicpmv_version == 2) { - // MiniCPM-V 2.5 - num_query = 96; - } else if (ctx->model.hparams.minicpmv_version == 3) { - // MiniCPM-V 2.6 - num_query = 64; - } else if (ctx->model.hparams.minicpmv_version == 4) { - // MiniCPM-o 2.6 - num_query = 64; - } else if (ctx->model.hparams.minicpmv_version == 5) { - // MiniCPM-V 4.0 - num_query = 64; - } - + // Use actual config value if available, otherwise fall back to hardcoded values + int num_query = ctx->model.hparams.minicpmv_query_num; ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); @@ -967,14 +963,14 @@ struct clip_graph { GGML_ASSERT(scale_factor > 0); cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_4d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // flatten to 2D - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_2d(ctx0, cur, n_embd * scale_factor * scale_factor, cur->ne[1] * cur->ne[2]); } @@ -1060,14 +1056,14 @@ struct clip_graph { n_patches_y, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_4d(ctx0, cur, n_embd * scale_factor * scale_factor, n_patches_x / scale_factor, n_patches_y / scale_factor, bsz); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + //cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // flatten to 2D - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_2d(ctx0, cur, n_embd * scale_factor * scale_factor, n_patches / scale_factor / scale_factor); cb(cur, "pixel_shuffle", -1); @@ -1092,6 +1088,67 @@ struct clip_graph { return gf; } + ggml_cgraph * build_kimivl() { + // 2D input positions + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + ggml_tensor * learned_pos_embd = resize_position_embeddings(); + + // build ViT with 2D position embeddings + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + // first half is X axis and second half is Y axis + return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); + }; + + ggml_tensor * inp = build_inp(); + ggml_tensor * cur = build_vit( + inp, n_patches, + NORM_TYPE_NORMAL, + hparams.ffn_op, + learned_pos_embd, + add_pos); + + cb(cur, "vit_out", -1); + + { + // patch_merger + const int scale_factor = model.hparams.proj_scale_factor; + cur = build_patch_merge_permute(cur, scale_factor); + + // projection norm + int proj_inp_dim = cur->ne[0]; + cur = ggml_view_2d(ctx0, cur, + n_embd, cur->ne[1] * scale_factor * scale_factor, + ggml_row_size(cur->type, n_embd), 0); + cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm + cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); + cur = ggml_add(ctx0, cur, model.mm_input_norm_b); + cur = ggml_view_2d(ctx0, cur, + proj_inp_dim, cur->ne[1] / scale_factor / scale_factor, + ggml_row_size(cur->type, proj_inp_dim), 0); + cb(cur, "proj_inp_normed", -1); + + // projection mlp + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + cur = ggml_add(ctx0, cur, model.mm_1_b); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); + cur = ggml_add(ctx0, cur, model.mm_2_b); + cb(cur, "proj_out", -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; + } + // this graph is used by llava, granite and glm // due to having embedding_stack (used by granite), we cannot reuse build_vit ggml_cgraph * build_llava() { @@ -1300,8 +1357,8 @@ struct clip_graph { ggml_tensor * block_1 = nullptr; { // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24] - mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3)); - mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]); + mlp_3 = ggml_permute(ctx0, mlp_3, 1, 0, 2, 3); + mlp_3 = ggml_cont_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]); // stride = 1, padding = 1, bias is nullptr block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1); @@ -1406,9 +1463,9 @@ struct clip_graph { mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b); // mlp_2 ne = [2048, 576, 1, 1] // // AVG Pool Layer 2*2, strides = 2 - mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3)); + mlp_2 = ggml_permute(ctx0, mlp_2, 1, 0, 2, 3); // mlp_2 ne = [576, 2048, 1, 1] - mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]); + mlp_2 = ggml_cont_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]); // mlp_2 ne [24, 24, 2048, 1] mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0); // weight ne = [3, 3, 2048, 1] @@ -1428,8 +1485,8 @@ struct clip_graph { // glm projector else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) { size_t gridsz = (size_t)sqrt(embeddings->ne[1]); - embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3)); - embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); + embeddings = ggml_permute(ctx0,embeddings,1,0,2,3); + embeddings = ggml_cont_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1); embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size); embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3)); @@ -1585,6 +1642,29 @@ private: } } + // siglip2 naflex + ggml_tensor * resize_position_embeddings() { + ggml_tensor * pos_embd = model.position_embeddings; + const int height = img.ny / patch_size; + const int width = img.nx / patch_size; + const uint32_t mode = GGML_SCALE_MODE_BILINEAR; + const int n_per_side = (int)std::sqrt(pos_embd->ne[1]); + + GGML_ASSERT(pos_embd); + + if (height == n_per_side && width == n_per_side) { + return pos_embd; + } + + pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_per_side, n_per_side); // -> (n_embd, n_per_side, n_per_side) + pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_per_side, n_per_side, n_embd) + pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, mode); // -> (width, height, n_embd) + pos_embd = ggml_permute(ctx0, pos_embd, 1, 2, 0, 3); // -> (n_embd, width, height) + pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); // -> (n_embd, width * height) + + return pos_embd; + } + // build vision transformer (ViT) cgraph // this function should cover most of the models // if your model has specific features, you should probably duplicate this function @@ -1963,7 +2043,6 @@ private: ggml_row_size(cur->type, n_dim), ggml_row_size(cur->type, n_dim*n_head), n_dim/2 * ggml_element_size(cur)); - second = ggml_cont(ctx0, second); // copy, because ggml_rope don't play well with non-contiguous tensors second = ggml_rope_ext( ctx0, second, @@ -1980,6 +2059,39 @@ private: return cur; } + // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL) + // support dynamic resolution + ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor) { + GGML_ASSERT(scale_factor > 1); + + const int n_embd = cur->ne[0]; + int width = img.nx / patch_size; + int height = img.ny / patch_size; + + // pad width and height to factor + const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width; + const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height; + cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height); + if (pad_width || pad_height) { + cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0); + width += pad_width; + height += pad_height; + } + + // unshuffle h + cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + + // unshuffle w + cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + + cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); + cb(cur, "pixel_shuffle", -1); + + return cur; + } + }; static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) { @@ -1991,6 +2103,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 switch (ctx->proj_type()) { case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_IDEFICS3: + case PROJECTOR_TYPE_LFM2: { res = graph.build_siglip(); } break; @@ -2021,6 +2134,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_whisper_enc(); } break; + case PROJECTOR_TYPE_KIMIVL: + { + res = graph.build_kimivl(); + } break; default: { res = graph.build_llava(); @@ -2148,10 +2265,25 @@ struct clip_model_loader { if (is_vision) { get_u32(KEY_IMAGE_SIZE, hparams.image_size); + get_u32(KEY_PREPROC_IMAGE_SIZE, hparams.preproc_image_size, false); get_u32(KEY_PATCH_SIZE, hparams.patch_size); get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy - + get_u32(KEY_MINICPMV_QUERY_NUM, hparams.minicpmv_query_num, false); + if (hparams.minicpmv_query_num == 0) { + // Fallback to hardcoded values for legacy models + if (hparams.minicpmv_version == 3) { + hparams.minicpmv_query_num = 64; + } else if (hparams.minicpmv_version == 4) { + hparams.minicpmv_query_num = 64; + } else if (hparams.minicpmv_version == 5) { + hparams.minicpmv_query_num = 64; + } else if (hparams.minicpmv_version == 6) { + hparams.minicpmv_query_num = 64; + } else { + hparams.minicpmv_query_num = 96; + } + } } else if (is_audio) { get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins); @@ -2243,6 +2375,7 @@ struct clip_model_loader { } } break; case PROJECTOR_TYPE_IDEFICS3: + case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_INTERNVL: { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false); @@ -2256,6 +2389,12 @@ struct clip_model_loader { hparams.image_size = 1024; get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false); } break; + case PROJECTOR_TYPE_KIMIVL: + { + hparams.rope_theta = 10000.0f; + hparams.warmup_image_size = hparams.patch_size * 8; + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false); + } break; case PROJECTOR_TYPE_GEMMA3: { // default value (used by all model sizes in gemma 3 family) @@ -2420,7 +2559,20 @@ struct clip_model_loader { // some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here // note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check! - if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) { + bool is_ffn_swapped = ( + // only old models need this fix + model.proj_type == PROJECTOR_TYPE_MLP + || model.proj_type == PROJECTOR_TYPE_MLP_NORM + || model.proj_type == PROJECTOR_TYPE_LDP + || model.proj_type == PROJECTOR_TYPE_LDPV2 + || model.proj_type == PROJECTOR_TYPE_QWEN2VL + || model.proj_type == PROJECTOR_TYPE_QWEN25VL + || model.proj_type == PROJECTOR_TYPE_GLM_EDGE + || model.proj_type == PROJECTOR_TYPE_GEMMA3 + || model.proj_type == PROJECTOR_TYPE_IDEFICS3 + || model.proj_type == PROJECTOR_TYPE_MINICPMV + ) && layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd; + if (is_ffn_swapped) { // swap up and down weights ggml_tensor * tmp = layer.ff_up_w; layer.ff_up_w = layer.ff_down_w; @@ -2429,6 +2581,9 @@ struct clip_model_loader { tmp = layer.ff_up_b; layer.ff_up_b = layer.ff_down_b; layer.ff_down_b = tmp; + if (il == 0) { + LOG_WRN("%s: ffn up/down are swapped\n", __func__); + } } } @@ -2546,6 +2701,16 @@ struct clip_model_loader { { model.projection = get_tensor(TN_MM_PROJECTOR); } break; + case PROJECTOR_TYPE_LFM2: + case PROJECTOR_TYPE_KIMIVL: + { + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); + model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + } break; case PROJECTOR_TYPE_PIXTRAL: { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); @@ -2944,7 +3109,7 @@ struct image_manipulation { dst.buf.resize(3 * target_width * target_height); float Cc; - float C[5]; + float C[5] = {}; float d0, d2, d3, a0, a1, a2, a3; int i, j, k, jj; int x, y; @@ -3428,10 +3593,51 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // res_imgs->data[0] = *res; res_imgs->entries.push_back(std::move(img_f32)); return true; - } - else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE + } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) { + // The refined size has two steps: + // 1. Resize w/ aspect-ratio preserving such that the longer side is + // the preprocessor longest size + // 2. Resize w/out preserving aspect ratio such that both sides are + // multiples of image_size (always rounding up) + // + // CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/image_processing_idefics3.py#L737 + const clip_image_size refined_size = image_manipulation::calc_size_preserved_ratio( + original_size, params.image_size, params.preproc_image_size); + + llava_uhd::slice_instructions instructions; + instructions.overview_size = clip_image_size{params.image_size, params.image_size}; + instructions.refined_size = refined_size; + instructions.grid_size = clip_image_size{ + static_cast(std::ceil(static_cast(refined_size.width) / params.image_size)), + static_cast(std::ceil(static_cast(refined_size.height) / params.image_size)), + }; + for (int y = 0; y < refined_size.height; y += params.image_size) { + for (int x = 0; x < refined_size.width; x += params.image_size) { + instructions.slices.push_back(llava_uhd::slice_coordinates{ + /* x */x, + /* y */y, + /* size */clip_image_size{ + std::min(params.image_size, refined_size.width - x), + std::min(params.image_size, refined_size.height - y) + } + }); + } + } + auto imgs = llava_uhd::slice_image(img, instructions); + + // cast and normalize to f32 + for (size_t i = 0; i < imgs.size(); ++i) { + // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp"); + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } + + res_imgs->grid_x = instructions.grid_size.width; + res_imgs->grid_y = instructions.grid_size.height; + return true; + } else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE || ctx->proj_type() == PROJECTOR_TYPE_GEMMA3 - || ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution ) { clip_image_u8 resized_image; @@ -3467,6 +3673,45 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->grid_y = inst.grid_size.height; return true; + } else if ( ctx->proj_type() == PROJECTOR_TYPE_LFM2 + || ctx->proj_type() == PROJECTOR_TYPE_KIMIVL + ) { + GGML_ASSERT(params.proj_scale_factor); + + // smart resize + const int width = img->nx; + const int height = img->ny; + const int total_factor = params.patch_size * params.proj_scale_factor; + constexpr int min_image_tokens = 64; + constexpr int max_image_tokens = 1024; + const float min_pixels = min_image_tokens * total_factor * total_factor; + const float max_pixels = max_image_tokens * total_factor * total_factor; + + auto round_by_factor = [f = total_factor](float x) { return static_cast(std::nearbyintf(x / static_cast(f))) * f; }; + auto ceil_by_factor = [f = total_factor](float x) { return static_cast(std::ceil(x / static_cast(f))) * f; }; + auto floor_by_factor = [f = total_factor](float x) { return static_cast(std::floor(x / static_cast(f))) * f; }; + + int h_bar = std::max(total_factor, round_by_factor(height)); + int w_bar = std::max(total_factor, round_by_factor(width)); + + if (h_bar * w_bar > max_pixels) { + const auto beta = std::sqrt((height * width) / max_pixels); + h_bar = std::max(total_factor, floor_by_factor(height / beta)); + w_bar = std::max(total_factor, floor_by_factor(width / beta)); + } else if (h_bar * w_bar < min_pixels) { + const auto beta = std::sqrt(min_pixels / (height * width)); + h_bar = ceil_by_factor(height * beta); + w_bar = ceil_by_factor(width * beta); + } + + const std::array pad_color = {122, 116, 104}; + + clip_image_u8 resized_img; + image_manipulation::resize_and_pad_image(*img, resized_img, clip_image_size{w_bar, h_bar}, pad_color); + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + return true; } // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) @@ -3506,10 +3751,10 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } return true; - + } else { + GGML_ABORT("Unknown image preprocessing type"); } - GGML_ASSERT(false && "Unknown image preprocessing type"); } ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) { @@ -3573,8 +3818,9 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img) { const auto & params = ctx->model.hparams; - // only for models using fixed size square images - int n_patches_sq = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); + // for models with fixed size image, the input image is already pre-processed and resized to square + int patch_size = params.patch_size; + int n_patches = (img->nx / patch_size) * (img->ny / patch_size); projector_type proj = ctx->proj_type(); @@ -3588,89 +3834,97 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_LDPV2: case PROJECTOR_TYPE_GLM_EDGE: { - n_patches_sq /= 4; + n_patches /= 4; if (ctx->model.mm_glm_tok_boi) { - n_patches_sq += 2; // for BOI and EOI token embeddings + n_patches += 2; // for BOI and EOI token embeddings } } break; case PROJECTOR_TYPE_MINICPMV: { - if (params.minicpmv_version == 2) { - // MiniCPM-V 2.5 - n_patches_sq = 96; - } else if (params.minicpmv_version == 3) { - // MiniCPM-V 2.6 - n_patches_sq = 64; - } else if (params.minicpmv_version == 4) { - // MiniCPM-o 2.6 - n_patches_sq = 64; - } else if (params.minicpmv_version == 5) { - // MiniCPM-V 4.0 - n_patches_sq = 64; + // Use actual config value if available, otherwise fall back to hardcoded values + if (params.minicpmv_query_num > 0) { + n_patches = params.minicpmv_query_num; } else { - GGML_ABORT("Unknown minicpmv version"); + // Fallback to hardcoded values for legacy models + if (params.minicpmv_version == 2) { + n_patches = 96; + } else if (params.minicpmv_version == 3) { + n_patches = 64; + } else if (params.minicpmv_version == 4) { + n_patches = 64; + } else if (params.minicpmv_version == 5) { + // MiniCPM-V 4.0 + n_patches = 64; + } else if (params.minicpmv_version == 6) { + // MiniCPM-V 4.5 + n_patches = 64; + } else { + GGML_ABORT("Unknown minicpmv version"); + } } } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: { - // dynamic size + // dynamic size (2 conv, so double patch size) int patch_size = params.patch_size * 2; int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); - n_patches_sq = x_patch * y_patch; + n_patches = x_patch * y_patch; } break; case PROJECTOR_TYPE_GEMMA3: - { - int n_per_side = params.image_size / params.patch_size; - int n_per_side_2d_pool = n_per_side / params.proj_scale_factor; - n_patches_sq = n_per_side_2d_pool * n_per_side_2d_pool; - } break; case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: + case PROJECTOR_TYPE_LLAMA4: { - // both W and H are divided by proj_scale_factor - n_patches_sq /= (params.proj_scale_factor * params.proj_scale_factor); + // both X and Y are downscaled by the scale factor + int scale_factor = ctx->model.hparams.proj_scale_factor; + n_patches /= (scale_factor * scale_factor); + } break; + case PROJECTOR_TYPE_LFM2: + case PROJECTOR_TYPE_KIMIVL: + { + // dynamic size + int scale_factor = ctx->model.hparams.proj_scale_factor; + int out_patch_size = params.patch_size * scale_factor; + int x_patch = CLIP_ALIGN(img->nx, out_patch_size) / out_patch_size; + int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size; + n_patches = x_patch * y_patch; } break; case PROJECTOR_TYPE_PIXTRAL: { // dynamic size int n_merge = params.spatial_merge_size; - int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1); - int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1); - n_patches_sq = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row - } break; - case PROJECTOR_TYPE_LLAMA4: - { - int scale_factor = ctx->model.hparams.proj_scale_factor; - n_patches_sq /= (scale_factor * scale_factor); + int n_patches_x = img->nx / patch_size / (n_merge > 0 ? n_merge : 1); + int n_patches_y = img->ny / patch_size / (n_merge > 0 ? n_merge : 1); + n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row } break; case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: { - n_patches_sq = img->nx; + n_patches = img->nx; const int proj_stack_factor = ctx->model.hparams.proj_stack_factor; if (ctx->model.audio_has_stack_frames()) { GGML_ASSERT(proj_stack_factor > 0); - const int n_len = CLIP_ALIGN(n_patches_sq, proj_stack_factor); - n_patches_sq = n_len / proj_stack_factor; + const int n_len = CLIP_ALIGN(n_patches, proj_stack_factor); + n_patches = n_len / proj_stack_factor; } // whisper downscales input token by half after conv1d - n_patches_sq /= 2; + n_patches /= 2; if (ctx->model.audio_has_avgpool()) { // divide by 2 because of nn.AvgPool1d(2, stride=2) - n_patches_sq /= 2; + n_patches /= 2; } } break; default: GGML_ABORT("unsupported projector type"); } - return n_patches_sq; + return n_patches; } static std::vector>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector> & pos) { @@ -4019,6 +4273,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("positions", positions); } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_KIMIVL: { // set the 2D positions int n_patches_per_col = image_size_width / patch_size; @@ -4070,6 +4325,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_ULTRAVOX: + case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_VOXTRAL: { // do nothing @@ -4141,7 +4397,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } int clip_n_mmproj_embd(const struct clip_ctx * ctx) { - const auto & hparams = ctx->model.hparams; switch (ctx->model.proj_type) { case PROJECTOR_TYPE_LDP: return ctx->model.mm_model_block_1_block_2_1_b->ne[0]; @@ -4153,20 +4408,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_MLP_NORM: return ctx->model.mm_3_b->ne[0]; case PROJECTOR_TYPE_MINICPMV: - if (hparams.minicpmv_version == 2) { - // MiniCPM-V 2.5 - return 4096; - } else if (hparams.minicpmv_version == 3) { - // MiniCPM-V 2.6 - return 3584; - } else if (hparams.minicpmv_version == 4) { - // MiniCPM-o 2.6 - return 3584; - } else if (hparams.minicpmv_version == 5) { - // MiniCPM-V 4.0 - return 2560; - } - GGML_ABORT("Unknown minicpmv version"); + return ctx->model.mm_model_proj->ne[0]; case PROJECTOR_TYPE_GLM_EDGE: return ctx->model.mm_model_mlp_3_w->ne[1]; case PROJECTOR_TYPE_QWEN2VL: @@ -4185,6 +4427,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_model_proj->ne[1]; case PROJECTOR_TYPE_QWEN2A: return ctx->model.mm_fc_w->ne[1]; + case PROJECTOR_TYPE_LFM2: + case PROJECTOR_TYPE_KIMIVL: + return ctx->model.mm_2_w->ne[1]; default: GGML_ABORT("Unknown projector type"); } diff --git a/llama/llama.cpp/tools/mtmd/clip.h b/llama/llama.cpp/tools/mtmd/clip.h index 08f3efb7..3387cdbd 100644 --- a/llama/llama.cpp/tools/mtmd/clip.h +++ b/llama/llama.cpp/tools/mtmd/clip.h @@ -82,11 +82,6 @@ struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch */ void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img); -bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); - -/** interpret bytes as an image file with length bytes_length, and use the result to populate img */ -bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); - /** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */ bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs ); diff --git a/llama/llama.cpp/tools/mtmd/mtmd.cpp b/llama/llama.cpp/tools/mtmd/mtmd.cpp index 6f70f7f4..35a0d25e 100644 --- a/llama/llama.cpp/tools/mtmd/mtmd.cpp +++ b/llama/llama.cpp/tools/mtmd/mtmd.cpp @@ -76,7 +76,7 @@ enum mtmd_slice_tmpl { MTMD_SLICE_TMPL_MINICPMV_2_5, MTMD_SLICE_TMPL_MINICPMV_2_6, MTMD_SLICE_TMPL_LLAMA4, - // TODO @ngxson : add support for idefics (SmolVLM) + MTMD_SLICE_TMPL_IDEFICS3, }; mtmd_input_text* mtmd_input_text_init(const char * text, bool add_special, bool parse_special) { @@ -124,19 +124,22 @@ struct mtmd_context { // for llava-uhd style models, we need special tokens in-between slices // minicpmv calls them "slices", llama 4 calls them "tiles" mtmd_slice_tmpl slice_tmpl = MTMD_SLICE_TMPL_NONE; - llama_token tok_ov_img_start = LLAMA_TOKEN_NULL; // overview image - llama_token tok_ov_img_end = LLAMA_TOKEN_NULL; // overview image - llama_token tok_slices_start = LLAMA_TOKEN_NULL; // start of all slices - llama_token tok_slices_end = LLAMA_TOKEN_NULL; // end of all slices - llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice start - llama_token tok_sli_img_end = LLAMA_TOKEN_NULL; // single slice end - llama_token tok_sli_img_mid = LLAMA_TOKEN_NULL; // between 2 slices - llama_token tok_row_end = LLAMA_TOKEN_NULL; // end of row + std::vector tok_ov_img_start; // overview image + std::vector tok_ov_img_end; // overview image + std::vector tok_slices_start; // start of all slices + std::vector tok_slices_end; // end of all slices + std::vector tok_sli_img_start; // single slice start + std::vector tok_sli_img_end; // single slice end + std::vector tok_sli_img_mid; // between 2 slices + std::vector tok_row_end; // end of row bool tok_row_end_trail = false; bool ov_img_first = false; bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE + // string template for slice image delimiters with row/col (idefics3) + std::string sli_img_start_tmpl; + // for whisper, we pre-calculate the mel filter bank whisper_preprocessor::whisper_filters w_filters; @@ -207,25 +210,25 @@ struct mtmd_context { // minicpmv 2.5 format: // (overview) (slice) (slice) \n ... slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_5; - tok_ov_img_start = lookup_token(""); - tok_ov_img_end = lookup_token(""); - tok_slices_start = lookup_token(""); - tok_slices_end = lookup_token(""); + tok_ov_img_start = {lookup_token("")}; + tok_ov_img_end = {lookup_token("")}; + tok_slices_start = {lookup_token("")}; + tok_slices_end = {lookup_token("")}; tok_sli_img_start = tok_ov_img_start; tok_sli_img_end = tok_ov_img_end; - tok_row_end = lookup_token("\n"); + tok_row_end = {lookup_token("\n")}; tok_row_end_trail = false; // no trailing end-of-row token ov_img_first = true; - } else if (minicpmv_version == 3 || minicpmv_version == 4 || minicpmv_version == 5) { + } else if (minicpmv_version == 3 || minicpmv_version == 4 || minicpmv_version == 5 || minicpmv_version == 6) { // minicpmv 2.6 format: // (overview) (slice) (slice) \n ... slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_6; - tok_ov_img_start = lookup_token(""); - tok_ov_img_end = lookup_token(""); - tok_sli_img_start = lookup_token(""); - tok_sli_img_end = lookup_token(""); - tok_row_end = lookup_token("\n"); + tok_ov_img_start = {lookup_token("")}; + tok_ov_img_end = {lookup_token("")}; + tok_sli_img_start = {lookup_token("")}; + tok_sli_img_end = {lookup_token("")}; + tok_row_end = {lookup_token("\n")}; tok_row_end_trail = false; // no trailing end-of-row token ov_img_first = true; @@ -240,9 +243,9 @@ struct mtmd_context { // <|image|> (overview) <-- overview image is last // <|image_end|> slice_tmpl = MTMD_SLICE_TMPL_LLAMA4; - tok_ov_img_start = lookup_token("<|image|>"); - tok_sli_img_mid = lookup_token("<|tile_x_separator|>"); - tok_row_end = lookup_token("<|tile_y_separator|>"); + tok_ov_img_start = {lookup_token("<|image|>")}; + tok_sli_img_mid = {lookup_token("<|tile_x_separator|>")}; + tok_row_end = {lookup_token("<|tile_y_separator|>")}; tok_row_end_trail = true; // add trailing end-of-row token ov_img_first = false; // overview image is last } @@ -255,8 +258,11 @@ struct mtmd_context { } else if (proj == PROJECTOR_TYPE_IDEFICS3) { // https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215 - img_beg = ""; - img_end = ""; + slice_tmpl = MTMD_SLICE_TMPL_IDEFICS3; + tok_ov_img_start = {lookup_token("\n\n"), lookup_token(""), lookup_token("")}; + tok_ov_img_end = {lookup_token("")}; + tok_row_end = {lookup_token("\n")}; + sli_img_start_tmpl = ""; } else if (proj == PROJECTOR_TYPE_PIXTRAL) { // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md @@ -514,6 +520,7 @@ struct mtmd_tokenizer { ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6 || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4 + || ctx->slice_tmpl == MTMD_SLICE_TMPL_IDEFICS3 ) { const int n_col = batch_f32.grid_x; const int n_row = batch_f32.grid_y; @@ -527,53 +534,45 @@ struct mtmd_tokenizer { // add overview image (first) if (ctx->ov_img_first) { - if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_start}); - } + add_text(ctx->tok_ov_img_start); cur.entries.emplace_back(std::move(ov_chunk)); - if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_end}); - } + add_text(ctx->tok_ov_img_end); } // add slices (or tiles) if (!chunks.empty()) { GGML_ASSERT((int)chunks.size() == n_row * n_col); - if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_slices_start}); - } + add_text(ctx->tok_slices_start); for (int y = 0; y < n_row; y++) { for (int x = 0; x < n_col; x++) { const bool is_last_in_row = (x == n_col - 1); - if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_sli_img_start}); + if (!ctx->tok_sli_img_start.empty()) { + add_text(ctx->tok_sli_img_start); + } else if (!ctx->sli_img_start_tmpl.empty()) { + // If using a template to preceed a slice image + const size_t sz = std::snprintf(nullptr, 0, ctx->sli_img_start_tmpl.c_str(), y+1, x+1) + 1; + std::unique_ptr buf(new char[sz]); + std::snprintf(buf.get(), sz, ctx->sli_img_start_tmpl.c_str(), y+1, x+1); + add_text(std::string(buf.get(), buf.get() + sz - 1), true); } cur.entries.emplace_back(std::move(chunks[y * n_col + x])); - if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_sli_img_end}); - } - if (!is_last_in_row && ctx->tok_sli_img_mid != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_sli_img_mid}); + add_text(ctx->tok_sli_img_end); + if (!is_last_in_row) { + add_text(ctx->tok_sli_img_mid); } } - if ((y != n_row - 1 || ctx->tok_row_end_trail) && ctx->tok_row_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_row_end}); + if ((y != n_row - 1 || ctx->tok_row_end_trail)) { + add_text(ctx->tok_row_end); } } - if (ctx->tok_slices_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_slices_end}); - } + add_text(ctx->tok_slices_end); } // add overview image (last) if (!ctx->ov_img_first) { - if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_start}); - } + add_text(ctx->tok_ov_img_start); cur.entries.emplace_back(std::move(ov_chunk)); - if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_end}); - } + add_text(ctx->tok_ov_img_end); } } else { @@ -790,7 +789,9 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); bool ok = false; - if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) || clip_is_glm(ctx_clip)) { + if (clip_is_llava(ctx_clip) + || clip_is_minicpmv(ctx_clip) + || clip_is_glm(ctx_clip)) { // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() const auto & entries = image_tokens->batch_f32.entries; for (size_t i = 0; i < entries.size(); i++) { diff --git a/llama/llama.cpp/vendor/miniaudio/miniaudio.h b/llama/llama.cpp/vendor/miniaudio/miniaudio.h index c74bebeb..2f5b9c4e 100644 --- a/llama/llama.cpp/vendor/miniaudio/miniaudio.h +++ b/llama/llama.cpp/vendor/miniaudio/miniaudio.h @@ -1,6 +1,6 @@ /* Audio playback and capture library. Choice of public domain or MIT-0. See license statements at the end of this file. -miniaudio - v0.11.22 - 2025-02-24 +miniaudio - v0.11.24 - TBD David Reid - mackron@gmail.com @@ -12,18 +12,10 @@ GitHub: https://github.com/mackron/miniaudio /* 1. Introduction =============== -To use miniaudio, include "miniaudio.h": - - ```c - #include "miniaudio.h" - ``` - -The implementation is contained in "miniaudio.c". Just compile this like any other source file. You -can include miniaudio.c if you want to compile your project as a single translation unit: - - ```c - #include "miniaudio.c" - ``` +To use miniaudio, just include "miniaudio.h" like any other header and add "miniaudio.c" to your +source tree. If you don't want to add it to your source tree you can compile and link to it like +any other library. Note that ABI compatibility is not guaranteed between versions, even with bug +fix releases, so take care if compiling as a shared object. miniaudio includes both low level and high level APIs. The low level API is good for those who want to do all of their mixing themselves and only require a light weight interface to the underlying @@ -303,7 +295,7 @@ The engine encapsulates both the resource manager and the node graph to create a use high level API. The resource manager and node graph APIs are covered in more later sections of this manual. -The code below shows how you can initialize an engine using it's default configuration. +The code below shows how you can initialize an engine using its default configuration. ```c ma_result result; @@ -391,7 +383,7 @@ Sounds are not started by default. Start a sound with `ma_sound_start()` and sto `ma_sound_stop()`. When a sound is stopped, it is not rewound to the start. Use `ma_sound_seek_to_pcm_frame(&sound, 0)` to seek back to the start of a sound. By default, starting and stopping sounds happens immediately, but sometimes it might be convenient to schedule the sound -the be started and/or stopped at a specific time. This can be done with the following functions: +to be started and/or stopped at a specific time. This can be done with the following functions: ```c ma_sound_set_start_time_in_pcm_frames() @@ -463,6 +455,11 @@ is at the end, use `ma_sound_at_end()`. Looping of a sound can be controlled wit miniaudio should work cleanly out of the box without the need to download or install any dependencies. See below for platform-specific details. +This library has been designed to be added directly to your source tree which is the preferred way +of using it, but you can compile it as a normal library if that's your preference. Be careful if +compiling as a shared object because miniaudio is not ABI compatible between any release, including +bug fix releases. It's recommended you link statically. + Note that GCC and Clang require `-msse2`, `-mavx2`, etc. for SIMD optimizations. If you get errors about undefined references to `__sync_val_compare_and_swap_8`, `__atomic_load_8`, @@ -532,7 +529,7 @@ you'll need to disable run-time linking with `MA_NO_RUNTIME_LINKING` and link wi The Emscripten build emits Web Audio JavaScript directly and should compile cleanly out of the box. You cannot use `-std=c*` compiler flags, nor `-ansi`. -You can enable the use of AudioWorkets by defining `MA_ENABLE_AUDIO_WORKLETS` and then compiling +You can enable the use of AudioWorklets by defining `MA_ENABLE_AUDIO_WORKLETS` and then compiling with the following options: -sAUDIO_WORKLET=1 -sWASM_WORKERS=1 -sASYNCIFY @@ -881,7 +878,7 @@ read data within a certain range of the underlying data. To do this you can use This is useful if you have a sound bank where many sounds are stored in the same file and you want the data source to only play one of those sub-sounds. Note that once the range is set, everything -that takes a position, such as cursors and loop points, should always be relatvie to the start of +that takes a position, such as cursors and loop points, should always be relative to the start of the range. When the range is set, any previously defined loop point will be reset. Custom loop points can also be used with data sources. By default, data sources will loop after @@ -889,7 +886,7 @@ they reach the end of the data source, but if you need to loop at a specific loc the following: ```c - result = ma_data_set_loop_point_in_pcm_frames(pDataSource, loopBegInFrames, loopEndInFrames); + result = ma_data_source_set_loop_point_in_pcm_frames(pDataSource, loopBegInFrames, loopEndInFrames); if (result != MA_SUCCESS) { return result; // Failed to set the loop point. } @@ -3750,7 +3747,7 @@ extern "C" { #define MA_VERSION_MAJOR 0 #define MA_VERSION_MINOR 11 -#define MA_VERSION_REVISION 22 +#define MA_VERSION_REVISION 24 #define MA_VERSION_STRING MA_XSTRINGIFY(MA_VERSION_MAJOR) "." MA_XSTRINGIFY(MA_VERSION_MINOR) "." MA_XSTRINGIFY(MA_VERSION_REVISION) #if defined(_MSC_VER) && !defined(__clang__) @@ -3857,6 +3854,8 @@ typedef ma_uint16 wchar_t; #define MA_SIZE_MAX 0xFFFFFFFF /* When SIZE_MAX is not defined by the standard library just default to the maximum 32-bit unsigned integer. */ #endif +#define MA_UINT64_MAX (((ma_uint64)0xFFFFFFFF << 32) | (ma_uint64)0xFFFFFFFF) /* Weird shifting syntax is for VC6 compatibility. */ + /* Platform/backend detection. */ #if defined(_WIN32) || defined(__COSMOPOLITAN__) @@ -3865,29 +3864,55 @@ typedef ma_uint16 wchar_t; #define MA_WIN32_UWP #elif defined(WINAPI_FAMILY) && (defined(WINAPI_FAMILY_GAMES) && WINAPI_FAMILY == WINAPI_FAMILY_GAMES) #define MA_WIN32_GDK + #elif defined(NXDK) + #define MA_WIN32_NXDK #else #define MA_WIN32_DESKTOP #endif + + /* The original Xbox. */ + #if defined(NXDK) /* <-- Add other Xbox compiler toolchains here, and then add a toolchain-specific define in case we need to discriminate between them later. */ + #define MA_XBOX + + #if defined(NXDK) + #define MA_XBOX_NXDK + #endif + #endif #endif -#if !defined(_WIN32) /* If it's not Win32, assume POSIX. */ +#if defined(__MSDOS__) || defined(MSDOS) || defined(_MSDOS) || defined(__DOS__) + #define MA_DOS + + /* No threading allowed on DOS. */ + #ifndef MA_NO_THREADING + #define MA_NO_THREADING + #endif + + /* No runtime linking allowed on DOS. */ + #ifndef MA_NO_RUNTIME_LINKING + #define MA_NO_RUNTIME_LINKING + #endif +#endif +#if !defined(MA_WIN32) && !defined(MA_DOS) /* If it's not Win32, assume POSIX. */ #define MA_POSIX - /* - Use the MA_NO_PTHREAD_IN_HEADER option at your own risk. This is intentionally undocumented. - You can use this to avoid including pthread.h in the header section. The downside is that it - results in some fixed sized structures being declared for the various types that are used in - miniaudio. The risk here is that these types might be too small for a given platform. This - risk is yours to take and no support will be offered if you enable this option. - */ - #ifndef MA_NO_PTHREAD_IN_HEADER - #include /* Unfortunate #include, but needed for pthread_t, pthread_mutex_t and pthread_cond_t types. */ - typedef pthread_t ma_pthread_t; - typedef pthread_mutex_t ma_pthread_mutex_t; - typedef pthread_cond_t ma_pthread_cond_t; - #else - typedef ma_uintptr ma_pthread_t; - typedef union ma_pthread_mutex_t { char __data[40]; ma_uint64 __alignment; } ma_pthread_mutex_t; - typedef union ma_pthread_cond_t { char __data[48]; ma_uint64 __alignment; } ma_pthread_cond_t; + #if !defined(MA_NO_THREADING) + /* + Use the MA_NO_PTHREAD_IN_HEADER option at your own risk. This is intentionally undocumented. + You can use this to avoid including pthread.h in the header section. The downside is that it + results in some fixed sized structures being declared for the various types that are used in + miniaudio. The risk here is that these types might be too small for a given platform. This + risk is yours to take and no support will be offered if you enable this option. + */ + #ifndef MA_NO_PTHREAD_IN_HEADER + #include /* Unfortunate #include, but needed for pthread_t, pthread_mutex_t and pthread_cond_t types. */ + typedef pthread_t ma_pthread_t; + typedef pthread_mutex_t ma_pthread_mutex_t; + typedef pthread_cond_t ma_pthread_cond_t; + #else + typedef ma_uintptr ma_pthread_t; + typedef union ma_pthread_mutex_t { char __data[40]; ma_uint64 __alignment; } ma_pthread_mutex_t; + typedef union ma_pthread_cond_t { char __data[48]; ma_uint64 __alignment; } ma_pthread_cond_t; + #endif #endif #if defined(__unix__) @@ -3914,8 +3939,11 @@ typedef ma_uint16 wchar_t; #if defined(__PROSPERO__) #define MA_PROSPERO #endif - #if defined(__NX__) - #define MA_NX + #if defined(__3DS__) + #define MA_3DS + #endif + #if defined(__SWITCH__) || defined(__NX__) + #define MA_SWITCH #endif #if defined(__BEOS__) || defined(__HAIKU__) #define MA_BEOS @@ -3925,12 +3953,13 @@ typedef ma_uint16 wchar_t; #endif #endif -#if defined(__has_c_attribute) - #if __has_c_attribute(fallthrough) - #define MA_FALLTHROUGH [[fallthrough]] - #endif +#if !defined(MA_FALLTHROUGH) && defined(__cplusplus) && __cplusplus >= 201703L + #define MA_FALLTHROUGH [[fallthrough]] #endif -#if !defined(MA_FALLTHROUGH) && defined(__has_attribute) && (defined(__clang__) || defined(__GNUC__)) +#if !defined(MA_FALLTHROUGH) && defined(__STDC_VERSION__) && __STDC_VERSION__ >= 202000L + #define MA_FALLTHROUGH [[fallthrough]] +#endif +#if !defined(MA_FALLTHROUGH) && defined(__has_attribute) #if __has_attribute(fallthrough) #define MA_FALLTHROUGH __attribute__((fallthrough)) #endif @@ -3967,7 +3996,7 @@ typedef ma_uint16 wchar_t; #define MA_NO_INLINE __attribute__((noinline)) #else #define MA_INLINE MA_GNUC_INLINE_HINT - #define MA_NO_INLINE __attribute__((noinline)) + #define MA_NO_INLINE #endif #elif defined(__WATCOMC__) #define MA_INLINE __inline @@ -4350,7 +4379,7 @@ typedef struct typedef struct { - ma_int32 state; + ma_uint32 state; } ma_lcg; @@ -6569,7 +6598,7 @@ This section contains the APIs for device playback and capture. Here is where yo ************************************************************************************************************************************************************/ #ifndef MA_NO_DEVICE_IO /* Some backends are only supported on certain platforms. */ -#if defined(MA_WIN32) +#if defined(MA_WIN32) && !defined(MA_XBOX) #define MA_SUPPORT_WASAPI #if defined(MA_WIN32_DESKTOP) /* DirectSound and WinMM backends are only supported on desktops. */ @@ -7426,6 +7455,7 @@ struct ma_context ma_proc snd_pcm_hw_params_set_rate_resample; ma_proc snd_pcm_hw_params_set_rate; ma_proc snd_pcm_hw_params_set_rate_near; + ma_proc snd_pcm_hw_params_set_rate_minmax; ma_proc snd_pcm_hw_params_set_buffer_size_near; ma_proc snd_pcm_hw_params_set_periods_near; ma_proc snd_pcm_hw_params_set_access; @@ -7986,6 +8016,7 @@ struct ma_device /*AAudioStream**/ ma_ptr pStreamPlayback; /*AAudioStream**/ ma_ptr pStreamCapture; ma_mutex rerouteLock; + ma_atomic_bool32 isTearingDown; ma_aaudio_usage usage; ma_aaudio_content_type contentType; ma_aaudio_input_preset inputPreset; @@ -9644,7 +9675,7 @@ Parameters ---------- pBackends (out, optional) A pointer to the buffer that will receive the enabled backends. Set to NULL to retrieve the backend count. Setting - the capacity of the buffer to `MA_BUFFER_COUNT` will guarantee it's large enough for all backends. + the capacity of the buffer to `MA_BACKEND_COUNT` will guarantee it's large enough for all backends. backendCap (in) The capacity of the `pBackends` buffer. @@ -11255,7 +11286,7 @@ typedef struct ma_log* pLog; /* When set to NULL, will use the context's log. */ ma_uint32 listenerCount; /* Must be between 1 and MA_ENGINE_MAX_LISTENERS. */ ma_uint32 channels; /* The number of channels to use when mixing and spatializing. When set to 0, will use the native channel count of the device. */ - ma_uint32 sampleRate; /* The sample rate. When set to 0 will use the native channel count of the device. */ + ma_uint32 sampleRate; /* The sample rate. When set to 0 will use the native sample rate of the device. */ ma_uint32 periodSizeInFrames; /* If set to something other than 0, updates will always be exactly this size. The underlying device may be a different size, but from the perspective of the mixer that won't matter.*/ ma_uint32 periodSizeInMilliseconds; /* Used if periodSizeInFrames is unset. */ ma_uint32 gainSmoothTimeInFrames; /* The number of frames to interpolate the gain of spatialized sounds across. If set to 0, will use gainSmoothTimeInMilliseconds. */ @@ -11419,11 +11450,11 @@ MA_API ma_bool32 ma_sound_is_looping(const ma_sound* pSound); MA_API ma_bool32 ma_sound_at_end(const ma_sound* pSound); MA_API ma_result ma_sound_seek_to_pcm_frame(ma_sound* pSound, ma_uint64 frameIndex); /* Just a wrapper around ma_data_source_seek_to_pcm_frame(). */ MA_API ma_result ma_sound_seek_to_second(ma_sound* pSound, float seekPointInSeconds); /* Abstraction to ma_sound_seek_to_pcm_frame() */ -MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap); -MA_API ma_result ma_sound_get_cursor_in_pcm_frames(ma_sound* pSound, ma_uint64* pCursor); -MA_API ma_result ma_sound_get_length_in_pcm_frames(ma_sound* pSound, ma_uint64* pLength); -MA_API ma_result ma_sound_get_cursor_in_seconds(ma_sound* pSound, float* pCursor); -MA_API ma_result ma_sound_get_length_in_seconds(ma_sound* pSound, float* pLength); +MA_API ma_result ma_sound_get_data_format(const ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap); +MA_API ma_result ma_sound_get_cursor_in_pcm_frames(const ma_sound* pSound, ma_uint64* pCursor); +MA_API ma_result ma_sound_get_length_in_pcm_frames(const ma_sound* pSound, ma_uint64* pLength); +MA_API ma_result ma_sound_get_cursor_in_seconds(const ma_sound* pSound, float* pCursor); +MA_API ma_result ma_sound_get_length_in_seconds(const ma_sound* pSound, float* pLength); MA_API ma_result ma_sound_set_end_callback(ma_sound* pSound, ma_sound_end_proc callback, void* pUserData); MA_API ma_result ma_sound_group_init(ma_engine* pEngine, ma_uint32 flags, ma_sound_group* pParentGroup, ma_sound_group* pGroup); @@ -11544,17 +11575,23 @@ IMPLEMENTATION #endif #if !defined(MA_WIN32) -#include -#include /* select() (used for ma_sleep()). */ -#include + #if !defined(MA_NO_THREADING) + #include + #include /* For pthreads. */ + #endif + + #include /* select() (used for ma_sleep()). */ + #include /* For nanosleep() */ + #include #endif -#ifdef MA_NX -#include /* For nanosleep() */ +/* For fstat(), etc. */ +#if defined(MA_XBOX_NXDK) + #include /* Suggestion for NXDK: Add a sys/stat.h wrapper for compatibility. */ +#else + #include #endif -#include /* For fstat(), etc. */ - #ifdef MA_EMSCRIPTEN #include #endif @@ -11861,7 +11898,7 @@ static MA_INLINE ma_bool32 ma_has_neon(void) #endif #ifndef MA_RESTRICT - #if defined(__clang__) || defined(__GNUC__) || defined(_MSC_VER) + #if defined(__clang__) || defined(_MSC_VER) || (defined(__GNUC__) && (__GNUC__ > 2 || (__GNUC__ == 2 && __GNUC_MINOR__ >= 95))) #define MA_RESTRICT __restrict #else #define MA_RESTRICT @@ -11955,7 +11992,7 @@ static void ma_sleep__posix(ma_uint32 milliseconds) (void)milliseconds; MA_ASSERT(MA_FALSE); /* The Emscripten build should never sleep. */ #else - #if (defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 199309L) || defined(MA_NX) + #if (defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 199309L) || defined(MA_SWITCH) struct timespec ts; ts.tv_sec = milliseconds / 1000; ts.tv_nsec = milliseconds % 1000 * 1000000; @@ -11997,7 +12034,7 @@ static MA_INLINE void ma_yield(void) #endif #endif #else - __asm__ __volatile__ ("pause"); + __asm__ __volatile__ ("rep; nop"); #endif #elif (defined(__arm__) && defined(__ARM_ARCH) && __ARM_ARCH >= 7) || defined(_M_ARM64) || (defined(_M_ARM) && _M_ARM >= 7) || defined(__ARM_ARCH_6K__) || defined(__ARM_ARCH_6T2__) /* ARM */ @@ -12020,7 +12057,7 @@ static MA_INLINE unsigned int ma_disable_denormals(void) { unsigned int prevState; - #if defined(_MSC_VER) + #if defined(_MSC_VER) && !defined(MA_XBOX_NXDK) { /* Older versions of Visual Studio don't support the "safe" versions of _controlfp_s(). I don't @@ -12043,7 +12080,7 @@ static MA_INLINE unsigned int ma_disable_denormals(void) } #elif defined(MA_X86) || defined(MA_X64) { - #if defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__) || defined(__COSMOPOLITAN__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ + #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__) || defined(__COSMOPOLITAN__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ { prevState = _mm_getcsr(); _mm_setcsr(prevState | MA_MM_DENORMALS_ZERO_MASK | MA_MM_FLUSH_ZERO_MASK); @@ -12067,7 +12104,7 @@ static MA_INLINE unsigned int ma_disable_denormals(void) static MA_INLINE void ma_restore_denormals(unsigned int prevState) { - #if defined(_MSC_VER) + #if defined(_MSC_VER) && !defined(MA_XBOX_NXDK) { /* Older versions of Visual Studio do not support _controlfp_s(). See ma_disable_denormals(). */ #if _MSC_VER <= 1200 @@ -12083,7 +12120,7 @@ static MA_INLINE void ma_restore_denormals(unsigned int prevState) } #elif defined(MA_X86) || defined(MA_X64) { - #if defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__) || defined(__COSMOPOLITAN__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ + #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__) || defined(__COSMOPOLITAN__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ { _mm_setcsr(prevState); } @@ -12719,6 +12756,29 @@ MA_API MA_NO_INLINE int ma_strcmp(const char* str1, const char* str2) return ((unsigned char*)str1)[0] - ((unsigned char*)str2)[0]; } +MA_API MA_NO_INLINE int ma_wcscmp(const wchar_t* str1, const wchar_t* str2) +{ + if (str1 == str2) return 0; + + /* These checks differ from the standard implementation. It's not important, but I prefer it just for sanity. */ + if (str1 == NULL) return -1; + if (str2 == NULL) return 1; + + for (;;) { + if (str1[0] == L'\0') { + break; + } + if (str1[0] != str2[0]) { + break; + } + + str1 += 1; + str2 += 1; + } + + return ((unsigned short*)str1)[0] - ((unsigned short*)str2)[0]; +} + MA_API MA_NO_INLINE int ma_strappend(char* dst, size_t dstSize, const char* srcA, const char* srcB) { int result; @@ -12736,6 +12796,22 @@ MA_API MA_NO_INLINE int ma_strappend(char* dst, size_t dstSize, const char* srcA return result; } +MA_API MA_NO_INLINE size_t ma_wcslen(const wchar_t* str) +{ + const wchar_t* end; + + if (str == NULL) { + return 0; + } + + end = str; + while (end[0] != '\0') { + end += 1; + } + + return end - str; +} + MA_API MA_NO_INLINE char* ma_copy_string(const char* src, const ma_allocation_callbacks* pAllocationCallbacks) { size_t sz; @@ -12758,7 +12834,7 @@ MA_API MA_NO_INLINE char* ma_copy_string(const char* src, const ma_allocation_ca MA_API MA_NO_INLINE wchar_t* ma_copy_string_w(const wchar_t* src, const ma_allocation_callbacks* pAllocationCallbacks) { - size_t sz = wcslen(src)+1; + size_t sz = ma_wcslen(src)+1; wchar_t* dst = (wchar_t*)ma_malloc(sz * sizeof(*dst), pAllocationCallbacks); if (dst == NULL) { return NULL; @@ -13189,7 +13265,7 @@ MA_API ma_result ma_fopen(FILE** ppFile, const char* pFilePath, const char* pOpe return MA_INVALID_ARGS; } -#if defined(_MSC_VER) && _MSC_VER >= 1400 +#if (defined(_MSC_VER) && _MSC_VER >= 1400) && !defined(MA_XBOX_NXDK) err = fopen_s(ppFile, pFilePath, pOpenMode); if (err != 0) { return ma_result_from_errno(err); @@ -13231,7 +13307,7 @@ _wfopen() isn't always available in all compilation environments. This can be reviewed as compatibility issues arise. The preference is to use _wfopen_s() and _wfopen() as opposed to the wcsrtombs() fallback, so if you notice your compiler not detecting this properly I'm happy to look at adding support. */ -#if defined(_WIN32) +#if defined(_WIN32) && !defined(MA_XBOX_NXDK) #if defined(_MSC_VER) || defined(__MINGW64__) || (!defined(__STRICT_ANSI__) && !defined(_NO_EXT_KEYS)) #define MA_HAS_WFOPEN #endif @@ -13247,29 +13323,34 @@ MA_API ma_result ma_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_ return MA_INVALID_ARGS; } -#if defined(MA_HAS_WFOPEN) + #if defined(MA_HAS_WFOPEN) { /* Use _wfopen() on Windows. */ - #if defined(_MSC_VER) && _MSC_VER >= 1400 - errno_t err = _wfopen_s(ppFile, pFilePath, pOpenMode); - if (err != 0) { - return ma_result_from_errno(err); + #if defined(_MSC_VER) && _MSC_VER >= 1400 + { + errno_t err = _wfopen_s(ppFile, pFilePath, pOpenMode); + if (err != 0) { + return ma_result_from_errno(err); + } } - #else - *ppFile = _wfopen(pFilePath, pOpenMode); - if (*ppFile == NULL) { - return ma_result_from_errno(errno); + #else + { + *ppFile = _wfopen(pFilePath, pOpenMode); + if (*ppFile == NULL) { + return ma_result_from_errno(errno); + } } - #endif + #endif + (void)pAllocationCallbacks; } -#else - /* - Use fopen() on anything other than Windows. Requires a conversion. This is annoying because fopen() is locale specific. The only real way I can - think of to do this is with wcsrtombs(). Note that wcstombs() is apparently not thread-safe because it uses a static global mbstate_t object for - maintaining state. I've checked this with -std=c89 and it works, but if somebody get's a compiler error I'll look into improving compatibility. - */ + #elif !defined(MA_XBOX_NXDK) && !defined(MA_DOS) /* If your compiler does not support wcsrtombs(), add it here. */ { + /* + Use fopen() on anything other than Windows. Requires a conversion. This is annoying because fopen() is locale specific. The only real way I can + think of to do this is with wcsrtombs(). Note that wcstombs() is apparently not thread-safe because it uses a static global mbstate_t object for + maintaining state. I've checked this with -std=c89 and it works, but if somebody get's a compiler error I'll look into improving compatibility. + */ mbstate_t mbs; size_t lenMB; const wchar_t* pFilePathTemp = pFilePath; @@ -13310,11 +13391,16 @@ MA_API ma_result ma_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_ ma_free(pFilePathMB, pAllocationCallbacks); } + #else + { + /* Getting here means there is no way to open the file with a wide character string. */ + *ppFile = NULL; + } + #endif if (*ppFile == NULL) { return MA_ERROR; } -#endif return MA_SUCCESS; } @@ -13323,7 +13409,7 @@ MA_API ma_result ma_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_ static MA_INLINE void ma_copy_memory_64(void* dst, const void* src, ma_uint64 sizeInBytes) { -#if 0xFFFFFFFFFFFFFFFF <= MA_SIZE_MAX +#if MA_SIZE_MAX > 0xFFFFFFFF MA_COPY_MEMORY(dst, src, (size_t)sizeInBytes); #else while (sizeInBytes > 0) { @@ -13343,7 +13429,7 @@ static MA_INLINE void ma_copy_memory_64(void* dst, const void* src, ma_uint64 si static MA_INLINE void ma_zero_memory_64(void* dst, ma_uint64 sizeInBytes) { -#if 0xFFFFFFFFFFFFFFFF <= MA_SIZE_MAX +#if MA_SIZE_MAX > 0xFFFFFFFF MA_ZERO_MEMORY(dst, (size_t)sizeInBytes); #else while (sizeInBytes > 0) { @@ -13472,6 +13558,18 @@ static ma_result ma_allocation_callbacks_init_copy(ma_allocation_callbacks* pDst Logging **************************************************************************************************************************************************************/ +#ifndef ma_va_copy + #if !defined(_MSC_VER) || _MSC_VER >= 1800 + #if (defined(__GNUC__) && __GNUC__ < 3) + #define ma_va_copy(dst, src) ((dst) = (src)) /* This is untested. Not sure if this is correct for old GCC. */ + #else + #define ma_va_copy(dst, src) va_copy((dst), (src)) + #endif + #else + #define ma_va_copy(dst, src) ((dst) = (src)) + #endif +#endif + MA_API const char* ma_log_level_to_string(ma_uint32 logLevel) { switch (logLevel) @@ -13712,9 +13810,15 @@ MA_API ma_result ma_log_postv(ma_log* pLog, ma_uint32 level, const char* pFormat int length; char pFormattedMessageStack[1024]; char* pFormattedMessageHeap = NULL; + va_list args2; /* First try formatting into our fixed sized stack allocated buffer. If this is too small we'll fallback to a heap allocation. */ - length = vsnprintf(pFormattedMessageStack, sizeof(pFormattedMessageStack), pFormat, args); + ma_va_copy(args2, args); + { + length = vsnprintf(pFormattedMessageStack, sizeof(pFormattedMessageStack), pFormat, args2); + } + va_end(args2); + if (length < 0) { return MA_INVALID_OPERATION; /* An error occurred when trying to convert the buffer. */ } @@ -13755,17 +13859,10 @@ MA_API ma_result ma_log_postv(ma_log* pLog, ma_uint32 level, const char* pFormat char* pFormattedMessage = NULL; va_list args2; - #if _MSC_VER >= 1800 + ma_va_copy(args2, args); { - va_copy(args2, args); + formattedLen = ma_vscprintf(&pLog->allocationCallbacks, pFormat, args2); } - #else - { - args2 = args; - } - #endif - - formattedLen = ma_vscprintf(&pLog->allocationCallbacks, pFormat, args2); va_end(args2); if (formattedLen <= 0) { @@ -13964,7 +14061,7 @@ miniaudio's purposes. #define MA_LCG_A 48271 #define MA_LCG_C 0 -static ma_lcg g_maLCG = {MA_DEFAULT_LCG_SEED}; /* Non-zero initial seed. Use ma_seed() to use an explicit seed. */ +static ma_lcg g_maLCG = {MA_DEFAULT_LCG_SEED}; /* Non-zero initial seed. Use ma_lcg_seed() to use an explicit seed. */ static MA_INLINE void ma_lcg_seed(ma_lcg* pLCG, ma_int32 seed) { @@ -14013,7 +14110,7 @@ static MA_INLINE ma_int32 ma_lcg_rand_range_s32(ma_lcg* pLCG, ma_int32 lo, ma_in } - +#if 0 /* Currently unused. */ static MA_INLINE void ma_seed(ma_int32 seed) { ma_lcg_seed(&g_maLCG, seed); @@ -14038,6 +14135,7 @@ static MA_INLINE float ma_rand_f32(void) { return ma_lcg_rand_f32(&g_maLCG); } +#endif static MA_INLINE float ma_rand_range_f32(float lo, float hi) { @@ -14097,6 +14195,7 @@ Atomics **************************************************************************************************************************************************************/ /* c89atomic.h begin */ #ifndef ma_atomic_h +#define ma_atomic_h #if defined(__cplusplus) extern "C" { #endif @@ -14108,11 +14207,63 @@ extern "C" { #endif #endif typedef int ma_atomic_memory_order; -#define MA_ATOMIC_HAS_8 -#define MA_ATOMIC_HAS_16 -#define MA_ATOMIC_HAS_32 -#define MA_ATOMIC_HAS_64 -#if (defined(_MSC_VER) ) || defined(__WATCOMC__) || defined(__DMC__) +#if !defined(MA_ATOMIC_MODERN_MSVC) && \ + !defined(MA_ATOMIC_LEGACY_MSVC) && \ + !defined(MA_ATOMIC_LEGACY_MSVC_ASM) && \ + !defined(MA_ATOMIC_MODERN_GCC) && \ + !defined(MA_ATOMIC_LEGACY_GCC) && \ + !defined(MA_ATOMIC_LEGACY_GCC_ASM) + #if defined(_MSC_VER) || defined(__WATCOMC__) || defined(__DMC__) || defined(__BORLANDC__) + #if (defined(_MSC_VER) && _MSC_VER > 1600) + #define MA_ATOMIC_MODERN_MSVC + #else + #if defined(MA_X64) + #define MA_ATOMIC_LEGACY_MSVC + #else + #define MA_ATOMIC_LEGACY_MSVC_ASM + #endif + #endif + #elif (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 7))) || defined(__clang__) + #define MA_ATOMIC_MODERN_GCC + #else + #if defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 1)) + #define MA_ATOMIC_LEGACY_GCC + #else + #define MA_ATOMIC_LEGACY_GCC_ASM + #endif + #endif +#endif +#if defined(MA_ATOMIC_MODERN_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC) + #include + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + #define MA_ATOMIC_MSVC_ARM_INTRINSIC_NORETURN(dst, src, order, intrin, ma_atomicType, msvcType) \ + switch (order) \ + { \ + case ma_atomic_memory_order_relaxed: \ + { \ + intrin##_nf((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + case ma_atomic_memory_order_consume: \ + case ma_atomic_memory_order_acquire: \ + { \ + intrin##_acq((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + case ma_atomic_memory_order_release: \ + { \ + intrin##_rel((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + case ma_atomic_memory_order_acq_rel: \ + case ma_atomic_memory_order_seq_cst: \ + default: \ + { \ + intrin((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + } #define MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, intrin, ma_atomicType, msvcType) \ ma_atomicType result; \ switch (order) \ @@ -14138,720 +14289,1501 @@ typedef int ma_atomic_memory_order; } break; \ } \ return result; - #define MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, expected, desired, order, intrin, ma_atomicType, msvcType) \ + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, 1, order, _InterlockedExchange, ma_atomic_flag, long); + } + #else + { + (void)order; + return (ma_atomic_flag)_InterlockedExchange((volatile long*)dst, (long)1); + } + #endif + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC_NORETURN(dst, 0, order, _InterlockedExchange, ma_atomic_flag, long); + } + #else + { + (void)order; + _InterlockedExchange((volatile long*)dst, (long)0); + } + #endif + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + (void)order; + return (ma_uint32)_InterlockedCompareExchange((volatile long*)dst, 0, 0); + } +#endif +#if defined(MA_ATOMIC_LEGACY_MSVC_ASM) + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + ma_atomic_flag result = 0; + (void)order; + __asm { + mov ecx, dst + mov eax, 1 + xchg [ecx], eax + mov result, eax + } + return result; + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov dword ptr [esi], 0 + } + } else { + __asm { + mov esi, dst + mov eax, 0 + xchg [esi], eax + } + } + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + ma_atomic_flag result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov eax, [esi] + mov result, eax + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov eax, [esi] + lock add dword ptr [esp], 0 + mov result, eax + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov eax, [esi] + mov result, eax + lock add dword ptr [esp], 0 + } + } + return result; + } +#endif +#if defined(MA_ATOMIC_MODERN_GCC) + #define ma_atomic_memory_order_relaxed __ATOMIC_RELAXED + #define ma_atomic_memory_order_consume __ATOMIC_CONSUME + #define ma_atomic_memory_order_acquire __ATOMIC_ACQUIRE + #define ma_atomic_memory_order_release __ATOMIC_RELEASE + #define ma_atomic_memory_order_acq_rel __ATOMIC_ACQ_REL + #define ma_atomic_memory_order_seq_cst __ATOMIC_SEQ_CST + typedef ma_uint32 ma_atomic_flag; + #define ma_atomic_flag_test_and_set_explicit(dst, order) __atomic_exchange_n(dst, 1, order) + #define ma_atomic_flag_clear_explicit(dst, order) __atomic_store_n(dst, 0, order) + #define ma_atomic_flag_load_explicit(dst, order) __atomic_load_n(dst, order) +#endif +#if defined(MA_ATOMIC_LEGACY_GCC) + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, 1); + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + if (order > ma_atomic_memory_order_release) { + __sync_synchronize(); + } + __sync_lock_release(dst); + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + (void)order; + return __sync_val_compare_and_swap((ma_atomic_flag*)dst, 0, 0); + } +#endif +#if defined(MA_ATOMIC_LEGACY_GCC_ASM) + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + #if defined(MA_X86) + #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addl $0, (%%esp)" ::: "memory") + #elif defined(MA_X64) + #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addq $0, (%%rsp)" ::: "memory") + #else + #error Unsupported architecture. + #endif + #define MA_ATOMIC_XCHG_GCC_X86(instructionSizeSuffix, result, dst, src) \ + __asm__ __volatile__( \ + "xchg"instructionSizeSuffix" %0, %1" \ + : "=r"(result), \ + "=m"(*dst) \ + : "0"(src), \ + "m"(*dst) \ + : "memory" \ + ) + #define MA_ATOMIC_LOAD_RELAXED_GCC_X86(instructionSizeSuffix, result, dst) \ + __asm__ __volatile__( \ + "mov"instructionSizeSuffix" %1, %0" \ + : "=r"(result) \ + : "m"(*dst) \ + ) + #define MA_ATOMIC_LOAD_RELEASE_GCC_X86(instructionSizeSuffix, result, dst) \ + ma_atomic_thread_fence(ma_atomic_memory_order_release); \ + __asm__ __volatile__( \ + "mov"instructionSizeSuffix" %1, %0" \ + : "=r"(result) \ + : "m"(*dst) \ + : "memory" \ + ) + #define MA_ATOMIC_LOAD_SEQ_CST_GCC_X86(instructionSizeSuffix, result, dst) \ + ma_atomic_thread_fence(ma_atomic_memory_order_seq_cst); \ + __asm__ __volatile__( \ + "mov"instructionSizeSuffix" %1, %0" \ + : "=r"(result) \ + : "m"(*dst) \ + : "memory" \ + ); \ + ma_atomic_thread_fence(ma_atomic_memory_order_seq_cst) + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + ma_atomic_flag result; + #if defined(MA_X86) || defined(MA_X64) + { + (void)order; + MA_ATOMIC_XCHG_GCC_X86("l", result, dst, 1); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__( + "movl $0, %0" + : "=m"(*dst) + ); + } else if (order == ma_atomic_memory_order_release) { + __asm__ __volatile__( + "movl $0, %0" + : "=m"(*dst) + : + : "memory" + ); + } else { + ma_atomic_flag tmp = 0; + __asm__ __volatile__( + "xchgl %0, %1" + : "=r"(tmp), + "=m"(*dst) + : "0"(tmp), + "m"(*dst) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_atomic_flag result; + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("l", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("l", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("l", result, dst); + } + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } +#endif +#define ma_atomic_flag_test_and_set(dst) ma_atomic_flag_test_and_set_explicit(dst, ma_atomic_memory_order_acquire) +#define ma_atomic_flag_clear(dst) ma_atomic_flag_clear_explicit(dst, ma_atomic_memory_order_release) +typedef ma_atomic_flag ma_atomic_spinlock; +static MA_INLINE void ma_atomic_spinlock_lock(volatile ma_atomic_spinlock* pSpinlock) +{ + for (;;) { + if (ma_atomic_flag_test_and_set_explicit(pSpinlock, ma_atomic_memory_order_acquire) == 0) { + break; + } + while (ma_atomic_flag_load_explicit(pSpinlock, ma_atomic_memory_order_relaxed) == 1) { + } + } +} +static MA_INLINE void ma_atomic_spinlock_unlock(volatile ma_atomic_spinlock* pSpinlock) +{ + ma_atomic_flag_clear_explicit(pSpinlock, ma_atomic_memory_order_release); +} +ma_atomic_spinlock ma_atomic_global_lock; +#if defined(MA_ATOMIC_MODERN_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC_ASM) || defined(MA_ATOMIC_LEGACY_GCC) || defined(MA_ATOMIC_LEGACY_GCC_ASM) + #if defined(MA_X64) || (defined(MA_X86) && ((defined(__GNUC__) && defined(__i486__)) || (defined(_M_IX86) && _M_IX86 >= 400))) + #if defined(MA_ATOMIC_LEGACY_MSVC) && defined(MA_X64) + #else + #define MA_ATOMIC_IS_LOCK_FREE_8 1 + #define MA_ATOMIC_IS_LOCK_FREE_16 1 + #endif + #define MA_ATOMIC_IS_LOCK_FREE_32 1 + #if defined(MA_X64) || (defined(MA_X86) && ((defined(__GNUC__) && defined(__i586__)) || (defined(_M_IX86) && _M_IX86 >= 500))) + #define MA_ATOMIC_IS_LOCK_FREE_64 1 + #else + #endif + #else + #endif + #if defined(MA_ARM32) || defined(MA_ARM64) + #define MA_ATOMIC_IS_LOCK_FREE_8 1 + #define MA_ATOMIC_IS_LOCK_FREE_16 1 + #define MA_ATOMIC_IS_LOCK_FREE_32 1 + #if defined(MA_ARM64) || defined(__ARM_ARCH_7A__) || defined(__ARM_ARCH_7R__) || defined(__ARM_ARCH_6K__) || defined(__ARM_ARCH_6Z__) || defined(__ARM_ARCH_6ZK__) + #define MA_ATOMIC_IS_LOCK_FREE_64 1 + #endif + #endif + #if defined(MA_ATOMIC_PPC32) || defined(MA_ATOMIC_PPC64) + #if (defined(__GNUC__) && (__GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 7))) && !defined(__clang__) + #else + #define MA_ATOMIC_IS_LOCK_FREE_8 1 + #define MA_ATOMIC_IS_LOCK_FREE_16 1 + #endif + #define MA_ATOMIC_IS_LOCK_FREE_32 1 + #if defined(MA_ATOMIC_PPC64) + #define MA_ATOMIC_IS_LOCK_FREE_64 1 + #endif + #endif + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_8(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + return 1; + #else + return 0; + #endif + } + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_16(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + return 1; + #else + return 0; + #endif + } + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_32(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + return 1; + #else + return 0; + #endif + } + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_64(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + return 1; + #else + return 0; + #endif + } +#endif +#define MA_ATOMIC_COMPARE_AND_SWAP_LOCK(sizeInBits, dst, expected, replacement) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *dst; \ + if (result == expected) { \ + *dst = replacement; \ + } \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_LOAD_EXPLICIT_LOCK(sizeInBits, ptr, order) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *ptr; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_STORE_EXPLICIT_LOCK(sizeInBits, dst, src, order) \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + *dst = src; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock) +#define MA_ATOMIC_STORE_EXPLICIT_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, src) != oldValue); \ + (void)order +#define MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *dst; \ + *dst = src; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, src) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_ADD_LOCK(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *dst; \ + *dst += src; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_FETCH_ADD_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = oldValue + src; \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_AND_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = (ma_uint##sizeInBits)(oldValue & src); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_OR_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = (ma_uint##sizeInBits)(oldValue | src); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_XOR_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = (ma_uint##sizeInBits)(oldValue ^ src); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#if defined(MA_ATOMIC_MODERN_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC) + #define MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, expected, replacement, order, intrin, ma_atomicType, msvcType) \ ma_atomicType result; \ switch (order) \ { \ case ma_atomic_memory_order_relaxed: \ { \ - result = (ma_atomicType)intrin##_nf((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin##_nf((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ case ma_atomic_memory_order_consume: \ case ma_atomic_memory_order_acquire: \ { \ - result = (ma_atomicType)intrin##_acq((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin##_acq((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ case ma_atomic_memory_order_release: \ { \ - result = (ma_atomicType)intrin##_rel((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin##_rel((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ case ma_atomic_memory_order_acq_rel: \ case ma_atomic_memory_order_seq_cst: \ default: \ { \ - result = (ma_atomicType)intrin((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ } \ return result; - #define ma_atomic_memory_order_relaxed 0 - #define ma_atomic_memory_order_consume 1 - #define ma_atomic_memory_order_acquire 2 - #define ma_atomic_memory_order_release 3 - #define ma_atomic_memory_order_acq_rel 4 - #define ma_atomic_memory_order_seq_cst 5 - #if _MSC_VER < 1600 && defined(MA_X86) - #define MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY - #endif - #if _MSC_VER < 1600 - #undef MA_ATOMIC_HAS_8 - #undef MA_ATOMIC_HAS_16 - #endif - #if !defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #include - #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 desired) - { - ma_uint8 result = 0; - __asm { - mov ecx, dst - mov al, expected - mov dl, desired - lock cmpxchg [ecx], dl - mov result, al - } - return result; - } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 desired) - { - ma_uint16 result = 0; - __asm { - mov ecx, dst - mov ax, expected - mov dx, desired - lock cmpxchg [ecx], dx - mov result, ax - } - return result; - } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 desired) - { - ma_uint32 result = 0; - __asm { - mov ecx, dst - mov eax, expected - mov edx, desired - lock cmpxchg [ecx], edx - mov result, eax - } - return result; - } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 desired) - { - ma_uint32 resultEAX = 0; - ma_uint32 resultEDX = 0; - __asm { - mov esi, dst - mov eax, dword ptr expected - mov edx, dword ptr expected + 4 - mov ebx, dword ptr desired - mov ecx, dword ptr desired + 4 - lock cmpxchg8b qword ptr [esi] - mov resultEAX, eax - mov resultEDX, edx - } - return ((ma_uint64)resultEDX << 32) | resultEAX; - } - #endif + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + #define ma_atomic_compare_and_swap_8( dst, expected, replacement) (ma_uint8 )_InterlockedCompareExchange8((volatile char*)dst, (char)replacement, (char)expected) #else - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_compare_and_swap_8( dst, expected, desired) (ma_uint8 )_InterlockedCompareExchange8((volatile char*)dst, (char)desired, (char)expected) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_compare_and_swap_16(dst, expected, desired) (ma_uint16)_InterlockedCompareExchange16((volatile short*)dst, (short)desired, (short)expected) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_compare_and_swap_32(dst, expected, desired) (ma_uint32)_InterlockedCompareExchange((volatile long*)dst, (long)desired, (long)expected) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_compare_and_swap_64(dst, expected, desired) (ma_uint64)_InterlockedCompareExchange64((volatile ma_int64*)dst, (ma_int64)desired, (ma_int64)expected) - #endif + static MA_INLINE ma_uint8 __stdcall ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); + } #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - ma_uint8 result = 0; - (void)order; - __asm { - mov ecx, dst - mov al, src - lock xchg [ecx], al - mov result, al - } - return result; - } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - ma_uint16 result = 0; - (void)order; - __asm { - mov ecx, dst - mov ax, src - lock xchg [ecx], ax - mov result, ax - } - return result; - } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - ma_uint32 result = 0; - (void)order; - __asm { - mov ecx, dst - mov eax, src - lock xchg [ecx], eax - mov result, eax - } - return result; - } - #endif + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + #define ma_atomic_compare_and_swap_16(dst, expected, replacement) (ma_uint16)_InterlockedCompareExchange16((volatile short*)dst, (short)replacement, (short)expected) #else - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { + static MA_INLINE ma_uint16 __stdcall ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); + } + #endif + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + #define ma_atomic_compare_and_swap_32(dst, expected, replacement) (ma_uint32)_InterlockedCompareExchange((volatile long*)dst, (long)replacement, (long)expected) + #else + static MA_INLINE ma_uint32 __stdcall ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); + } + #endif + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + #define ma_atomic_compare_and_swap_64(dst, expected, replacement) (ma_uint64)_InterlockedCompareExchange64((volatile ma_int64*)dst, (ma_int64)replacement, (ma_int64)expected) + #else + static MA_INLINE ma_uint64 __stdcall ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); + } + #endif + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange8, ma_uint8, char); + { + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange8, ma_uint8, char); + } #else + { + (void)order; + return ma_atomic_compare_and_swap_8((volatile ma_uint8*)ptr, 0, 0); + } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, ptr, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange16, ma_uint16, short); + } + #else + { + (void)order; + return ma_atomic_compare_and_swap_16((volatile ma_uint16*)ptr, 0, 0); + } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, ptr, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange, ma_uint32, long); + } + #else + { + (void)order; + return ma_atomic_compare_and_swap_32((volatile ma_uint32*)ptr, 0, 0); + } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, ptr, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange64, ma_uint64, long long); + } + #else + { + (void)order; + return ma_atomic_compare_and_swap_64((volatile ma_uint64*)ptr, 0, 0); + } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(64, ptr, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange8, ma_uint8, char); + } + #else + { (void)order; return (ma_uint8)_InterlockedExchange8((volatile char*)dst, (char)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange16, ma_uint16, short); + } #else + { (void)order; return (ma_uint16)_InterlockedExchange16((volatile short*)dst, (short)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange, ma_uint32, long); + } #else + { (void)order; return (ma_uint32)_InterlockedExchange((volatile long*)dst, (long)src); - #endif } - #endif - #if defined(MA_ATOMIC_HAS_64) && defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange64, ma_uint64, long long); - #else - (void)order; - return (ma_uint64)_InterlockedExchange64((volatile long long*)dst, (long long)src); #endif - } - #else - #endif - #endif - #if defined(MA_ATOMIC_HAS_64) && !defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - ma_uint64 oldValue; - do { - oldValue = *dst; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; } - #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - ma_uint8 result = 0; - (void)order; - __asm { - mov ecx, dst - mov al, src - lock xadd [ecx], al - mov result, al - } - return result; - } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + #if defined(MA_32BIT) { - ma_uint16 result = 0; - (void)order; - __asm { - mov ecx, dst - mov ax, src - lock xadd [ecx], ax - mov result, ax - } - return result; + MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(64, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - ma_uint32 result = 0; - (void)order; - __asm { - mov ecx, dst - mov eax, src - lock xadd [ecx], eax - mov result, eax - } - return result; - } - #endif - #else - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd8, ma_uint8, char); #else + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange64, ma_uint64, long long); + } + #else + { + (void)order; + return (ma_uint64)_InterlockedExchange64((volatile long long*)dst, (long long)src); + } + #endif + } + #endif + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd8, ma_uint8, char); + } + #else + { (void)order; return (ma_uint8)_InterlockedExchangeAdd8((volatile char*)dst, (char)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd16, ma_uint16, short); + } #else + { (void)order; return (ma_uint16)_InterlockedExchangeAdd16((volatile short*)dst, (short)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd, ma_uint32, long); + } #else + { (void)order; return (ma_uint32)_InterlockedExchangeAdd((volatile long*)dst, (long)src); - #endif } - #endif - #if defined(MA_ATOMIC_HAS_64) && defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd64, ma_uint64, long long); - #else - (void)order; - return (ma_uint64)_InterlockedExchangeAdd64((volatile long long*)dst, (long long)src); #endif - } + } #else - #endif - #endif - #if defined(MA_ATOMIC_HAS_64) && !defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue + src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + #if defined(MA_32BIT) + { + MA_ATOMIC_FETCH_ADD_CAS(64, dst, src, order); + } + #else + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd64, ma_uint64, long long); + } + #else + { + (void)order; + return (ma_uint64)_InterlockedExchangeAdd64((volatile long long*)dst, (long long)src); + } + #endif + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_8(dst, (ma_uint8)(-(ma_int8)src), order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_16(dst, (ma_uint16)(-(ma_int16)src), order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_32(dst, (ma_uint32)(-(ma_int32)src), order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_64(dst, (ma_uint64)(-(ma_int64)src), order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd8, ma_uint8, char); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd16, ma_uint16, short); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd, ma_uint32, long); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd64, ma_uint64, long long); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr8, ma_uint8, char); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr16, ma_uint16, short); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr, ma_uint32, long); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr64, ma_uint64, long long); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor8, ma_uint8, char); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor16, ma_uint16, short); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor, ma_uint32, long); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor64, ma_uint64, long long); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); + } + #endif + } + #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) + #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) + #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) + #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) + #if defined(MA_X64) + #define ma_atomic_thread_fence(order) __faststorefence(), (void)order + #elif defined(MA_ARM64) + #define ma_atomic_thread_fence(order) __dmb(_ARM64_BARRIER_ISH), (void)order + #else + static MA_INLINE void ma_atomic_thread_fence(ma_atomic_memory_order order) + { + volatile ma_uint32 barrier = 0; + ma_atomic_fetch_add_explicit_32(&barrier, 0, order); } #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - static MA_INLINE void __stdcall ma_atomic_thread_fence(ma_atomic_memory_order order) + #define ma_atomic_signal_fence(order) _ReadWriteBarrier(), (void)order +#endif +#if defined(MA_ATOMIC_LEGACY_MSVC_ASM) + static MA_INLINE ma_uint8 __stdcall ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) { + ma_uint8 result = 0; + __asm { + mov ecx, dst + mov al, expected + mov dl, replacement + lock cmpxchg [ecx], dl + mov result, al + } + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + ma_uint16 result = 0; + __asm { + mov ecx, dst + mov ax, expected + mov dx, replacement + lock cmpxchg [ecx], dx + mov result, ax + } + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + ma_uint32 result = 0; + __asm { + mov ecx, dst + mov eax, expected + mov edx, replacement + lock cmpxchg [ecx], edx + mov result, eax + } + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + ma_uint32 resultEAX = 0; + ma_uint32 resultEDX = 0; + __asm { + mov esi, dst + mov eax, dword ptr expected + mov edx, dword ptr expected + 4 + mov ebx, dword ptr replacement + mov ecx, dword ptr replacement + 4 + lock cmpxchg8b qword ptr [esi] + mov resultEAX, eax + mov resultEDX, edx + } + return ((ma_uint64)resultEDX << 32) | resultEAX; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + ma_uint8 result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov al, [esi] + mov result, al + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov al, [esi] + lock add dword ptr [esp], 0 + mov result, al + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov al, [esi] + mov result, al + lock add dword ptr [esp], 0 + } + } + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, dst, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + ma_uint16 result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov ax, [esi] + mov result, ax + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov ax, [esi] + lock add dword ptr [esp], 0 + mov result, ax + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov ax, [esi] + mov result, ax + lock add dword ptr [esp], 0 + } + } + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, dst, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + ma_uint32 result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov eax, [esi] + mov result, eax + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov eax, [esi] + lock add dword ptr [esp], 0 + mov result, eax + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov eax, [esi] + mov result, eax + lock add dword ptr [esp], 0 + } + } + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, dst, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* dst, ma_atomic_memory_order order) + { + (void)order; + return ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, 0, 0); + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov al, src + mov [esi], al + } + } else { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + __asm { + mov esi, dst + mov al, src + xchg [esi], al + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif + } + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov ax, src + mov [esi], ax + } + } else { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + __asm { + mov esi, dst + mov ax, src + xchg [esi], ax + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif + } + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov eax, src + mov [esi], eax + } + } else { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + __asm { + mov esi, dst + mov eax, src + xchg [esi], eax + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif + } + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + MA_ATOMIC_STORE_EXPLICIT_CAS(64, dst, src, order); + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + ma_uint8 result = 0; (void)order; __asm { - lock add [esp], 0 + mov ecx, dst + mov al, src + lock xchg [ecx], al + mov result, al } + return result; } - #else - #if defined(MA_X64) - #define ma_atomic_thread_fence(order) __faststorefence(), (void)order - #elif defined(MA_ARM64) - #define ma_atomic_thread_fence(order) __dmb(_ARM64_BARRIER_ISH), (void)order #else - static MA_INLINE void ma_atomic_thread_fence(ma_atomic_memory_order order) - { - volatile ma_uint32 barrier = 0; - ma_atomic_fetch_add_explicit_32(&barrier, 0, order); + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + ma_uint16 result = 0; + (void)order; + __asm { + mov ecx, dst + mov ax, src + lock xchg [ecx], ax + mov result, ax } - #endif - #endif - #define ma_atomic_compiler_fence() ma_atomic_thread_fence(ma_atomic_memory_order_seq_cst) - #define ma_atomic_signal_fence(order) ma_atomic_thread_fence(order) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) - { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange8, ma_uint8, char); + return result; + } #else - (void)order; - return ma_atomic_compare_and_swap_8((volatile ma_uint8*)ptr, 0, 0); - #endif - } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange16, ma_uint16, short); + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + ma_uint32 result = 0; + (void)order; + __asm { + mov ecx, dst + mov eax, src + xchg [ecx], eax + mov result, eax + } + return result; + } #else - (void)order; - return ma_atomic_compare_and_swap_16((volatile ma_uint16*)ptr, 0, 0); - #endif - } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange, ma_uint32, long); + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(64, dst, src, order); + } #else - (void)order; - return ma_atomic_compare_and_swap_32((volatile ma_uint32*)ptr, 0, 0); - #endif - } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange64, ma_uint64, long long); + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + ma_uint8 result = 0; + (void)order; + __asm { + mov ecx, dst + mov al, src + lock xadd [ecx], al + mov result, al + } + return result; + } #else - (void)order; - return ma_atomic_compare_and_swap_64((volatile ma_uint64*)ptr, 0, 0); + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } #endif - } - #endif - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue - src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); + ma_uint16 result = 0; (void)order; - return oldValue; + __asm { + mov ecx, dst + mov ax, src + lock xadd [ecx], ax + mov result, ax + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue - src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - } - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd8, ma_uint8, char); #else - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue & src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif - } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd16, ma_uint16, short); - #else - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue & src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd, ma_uint32, long); - #else - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); + ma_uint32 result = 0; (void)order; - return oldValue; - #endif + __asm { + mov ecx, dst + mov eax, src + lock xadd [ecx], eax + mov result, eax + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd64, ma_uint64, long long); - #else - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor8, ma_uint8, char); - #else - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + MA_ATOMIC_FETCH_ADD_CAS(64, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor16, ma_uint16, short); - #else - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor, ma_uint32, long); - #else - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); + ma_uint8 result = 0; (void)order; - return oldValue; - #endif + __asm { + mov ecx, dst + mov al, src + neg al + lock xadd [ecx], al + mov result, al + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor64, ma_uint64, long long); - #else - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, (ma_uint8)(-(ma_int8)src), order); } - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr8, ma_uint8, char); - #else - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue | src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); + ma_uint16 result = 0; (void)order; - return oldValue; - #endif + __asm { + mov ecx, dst + mov ax, src + neg ax + lock xadd [ecx], ax + mov result, ax + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr16, ma_uint16, short); - #else - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue | src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, (ma_uint16)(-(ma_int16)src), order); } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr, ma_uint32, long); - #else - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); + ma_uint32 result = 0; (void)order; - return oldValue; - #endif + __asm { + mov ecx, dst + mov eax, src + neg eax + lock xadd [ecx], eax + mov result, eax + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr64, ma_uint64, long long); - #else - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, (ma_uint32)(-(ma_int32)src), order); } - #endif - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_test_and_set_explicit_8( dst, order) ma_atomic_exchange_explicit_8 (dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_test_and_set_explicit_16(dst, order) ma_atomic_exchange_explicit_16(dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_test_and_set_explicit_32(dst, order) ma_atomic_exchange_explicit_32(dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_test_and_set_explicit_64(dst, order) ma_atomic_exchange_explicit_64(dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_clear_explicit_8( dst, order) ma_atomic_store_explicit_8 (dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_clear_explicit_16(dst, order) ma_atomic_store_explicit_16(dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_clear_explicit_32(dst, order) ma_atomic_store_explicit_32(dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_clear_explicit_64(dst, order) ma_atomic_store_explicit_64(dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_8) - typedef ma_uint8 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(ptr, order) (ma_bool32)ma_atomic_test_and_set_explicit_8(ptr, order) - #define ma_atomic_flag_clear_explicit(ptr, order) ma_atomic_clear_explicit_8(ptr, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_8(ptr, order) - #else - typedef ma_uint32 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(ptr, order) (ma_bool32)ma_atomic_test_and_set_explicit_32(ptr, order) - #define ma_atomic_flag_clear_explicit(ptr, order) ma_atomic_clear_explicit_32(ptr, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_32(ptr, order) - #endif -#elif defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 7))) + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_ADD_CAS(64, dst, (ma_uint64)(-(ma_int64)src), order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); + } + static MA_INLINE void __stdcall ma_atomic_thread_fence(ma_atomic_memory_order order) + { + (void)order; + __asm { + lock add dword ptr [esp], 0 + } + } + #define ma_atomic_signal_fence(order) __asm {}; (void)order +#endif +#if defined(MA_ATOMIC_MODERN_GCC) #define MA_ATOMIC_HAS_NATIVE_COMPARE_EXCHANGE - #define MA_ATOMIC_HAS_NATIVE_IS_LOCK_FREE - #define ma_atomic_memory_order_relaxed __ATOMIC_RELAXED - #define ma_atomic_memory_order_consume __ATOMIC_CONSUME - #define ma_atomic_memory_order_acquire __ATOMIC_ACQUIRE - #define ma_atomic_memory_order_release __ATOMIC_RELEASE - #define ma_atomic_memory_order_acq_rel __ATOMIC_ACQ_REL - #define ma_atomic_memory_order_seq_cst __ATOMIC_SEQ_CST - #define ma_atomic_compiler_fence() __asm__ __volatile__("":::"memory") #define ma_atomic_thread_fence(order) __atomic_thread_fence(order) #define ma_atomic_signal_fence(order) __atomic_signal_fence(order) #define ma_atomic_is_lock_free_8(ptr) __atomic_is_lock_free(1, ptr) #define ma_atomic_is_lock_free_16(ptr) __atomic_is_lock_free(2, ptr) #define ma_atomic_is_lock_free_32(ptr) __atomic_is_lock_free(4, ptr) #define ma_atomic_is_lock_free_64(ptr) __atomic_is_lock_free(8, ptr) - #define ma_atomic_test_and_set_explicit_8( dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_test_and_set_explicit_16(dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_test_and_set_explicit_32(dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_test_and_set_explicit_64(dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_clear_explicit_8( dst, order) __atomic_store_n(dst, 0, order) - #define ma_atomic_clear_explicit_16(dst, order) __atomic_store_n(dst, 0, order) - #define ma_atomic_clear_explicit_32(dst, order) __atomic_store_n(dst, 0, order) - #define ma_atomic_clear_explicit_64(dst, order) __atomic_store_n(dst, 0, order) #define ma_atomic_store_explicit_8( dst, src, order) __atomic_store_n(dst, src, order) #define ma_atomic_store_explicit_16(dst, src, order) __atomic_store_n(dst, src, order) #define ma_atomic_store_explicit_32(dst, src, order) __atomic_store_n(dst, src, order) @@ -14864,14 +15796,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_explicit_16(dst, src, order) __atomic_exchange_n(dst, src, order) #define ma_atomic_exchange_explicit_32(dst, src, order) __atomic_exchange_n(dst, src, order) #define ma_atomic_exchange_explicit_64(dst, src, order) __atomic_exchange_n(dst, src, order) - #define ma_atomic_compare_exchange_strong_explicit_8( dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_strong_explicit_16(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_strong_explicit_32(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_strong_explicit_64(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_8( dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_16(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_32(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_64(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) #define ma_atomic_fetch_add_explicit_8( dst, src, order) __atomic_fetch_add(dst, src, order) #define ma_atomic_fetch_add_explicit_16(dst, src, order) __atomic_fetch_add(dst, src, order) #define ma_atomic_fetch_add_explicit_32(dst, src, order) __atomic_fetch_add(dst, src, order) @@ -14892,19 +15824,19 @@ typedef int ma_atomic_memory_order; #define ma_atomic_fetch_and_explicit_16(dst, src, order) __atomic_fetch_and(dst, src, order) #define ma_atomic_fetch_and_explicit_32(dst, src, order) __atomic_fetch_and(dst, src, order) #define ma_atomic_fetch_and_explicit_64(dst, src, order) __atomic_fetch_and(dst, src, order) - static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 desired) + static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } - static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 desired) + static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } - static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 desired) + static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } #if defined(__clang__) @@ -14913,636 +15845,1134 @@ typedef int ma_atomic_memory_order; #pragma clang diagnostic ignored "-Watomic-alignment" #endif #endif - static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 desired) + static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } #if defined(__clang__) #pragma clang diagnostic pop #endif - typedef ma_uint8 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(dst, order) (ma_bool32)__atomic_test_and_set(dst, order) - #define ma_atomic_flag_clear_explicit(dst, order) __atomic_clear(dst, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_8(ptr, order) -#else - #define ma_atomic_memory_order_relaxed 1 - #define ma_atomic_memory_order_consume 2 - #define ma_atomic_memory_order_acquire 3 - #define ma_atomic_memory_order_release 4 - #define ma_atomic_memory_order_acq_rel 5 - #define ma_atomic_memory_order_seq_cst 6 - #define ma_atomic_compiler_fence() __asm__ __volatile__("":::"memory") - #if defined(__GNUC__) +#endif +#if defined(MA_ATOMIC_LEGACY_GCC) || defined(MA_ATOMIC_LEGACY_GCC_ASM) + #define ma_atomic_signal_fence(order) __asm__ __volatile__("":::"memory") + #if defined(MA_ATOMIC_LEGACY_GCC) #define ma_atomic_thread_fence(order) __sync_synchronize(), (void)order - static MA_INLINE ma_uint8 ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) { - if (order > ma_atomic_memory_order_acquire) { - __sync_synchronize(); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + return __sync_val_compare_and_swap(dst, expected, replacement); } - return __sync_lock_test_and_set(dst, src); + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); + } + #endif } - static MA_INLINE ma_uint16 ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) { - ma_uint16 oldValue; - do { - oldValue = *dst; - } while (__sync_val_compare_and_swap(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + return __sync_val_compare_and_swap(dst, expected, replacement); + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); + } + #endif } - static MA_INLINE ma_uint32 ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) { - ma_uint32 oldValue; - do { - oldValue = *dst; - } while (__sync_val_compare_and_swap(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + return __sync_val_compare_and_swap(dst, expected, replacement); + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); + } + #endif } - static MA_INLINE ma_uint64 ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) { - ma_uint64 oldValue; - do { - oldValue = *dst; - } while (__sync_val_compare_and_swap(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + return __sync_val_compare_and_swap(dst, expected, replacement); + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); + } + #endif } - static MA_INLINE ma_uint8 ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return ma_atomic_compare_and_swap_8((ma_uint8*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, ptr, order); + } + #endif } - static MA_INLINE ma_uint16 ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return ma_atomic_compare_and_swap_16((ma_uint16*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, ptr, order); + } + #endif } - static MA_INLINE ma_uint32 ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return ma_atomic_compare_and_swap_32((ma_uint32*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, ptr, order); + } + #endif } - static MA_INLINE ma_uint64 ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); - } - static MA_INLINE ma_uint8 ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_sub(dst, src); - } - static MA_INLINE ma_uint16 ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_sub(dst, src); - } - static MA_INLINE ma_uint32 ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_sub(dst, src); - } - static MA_INLINE ma_uint64 ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_sub(dst, src); - } - static MA_INLINE ma_uint8 ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_or(dst, src); - } - static MA_INLINE ma_uint16 ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_or(dst, src); - } - static MA_INLINE ma_uint32 ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_or(dst, src); - } - static MA_INLINE ma_uint64 ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_or(dst, src); - } - static MA_INLINE ma_uint8 ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_xor(dst, src); - } - static MA_INLINE ma_uint16 ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_xor(dst, src); - } - static MA_INLINE ma_uint32 ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_xor(dst, src); - } - static MA_INLINE ma_uint64 ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_xor(dst, src); - } - static MA_INLINE ma_uint8 ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_and(dst, src); - } - static MA_INLINE ma_uint16 ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_and(dst, src); - } - static MA_INLINE ma_uint32 ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_and(dst, src); - } - static MA_INLINE ma_uint64 ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_and(dst, src); - } - #define ma_atomic_compare_and_swap_8( dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #define ma_atomic_compare_and_swap_16(dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #define ma_atomic_compare_and_swap_32(dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #define ma_atomic_compare_and_swap_64(dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #else - #if defined(MA_X86) - #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addl $0, (%%esp)" ::: "memory", "cc") - #elif defined(MA_X64) - #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addq $0, (%%rsp)" ::: "memory", "cc") - #else - #error Unsupported architecture. Please submit a feature request. - #endif - static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 desired) - { - ma_uint8 result; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; - } - static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 desired) - { - ma_uint16 result; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; - } - static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 desired) - { - ma_uint32 result; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; - } - static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 desired) - { - volatile ma_uint64 result; - #if defined(MA_X86) - ma_uint32 resultEAX; - ma_uint32 resultEDX; - __asm__ __volatile__("push %%ebx; xchg %5, %%ebx; lock; cmpxchg8b %0; pop %%ebx" : "+m"(*dst), "=a"(resultEAX), "=d"(resultEDX) : "a"(expected & 0xFFFFFFFF), "d"(expected >> 32), "r"(desired & 0xFFFFFFFF), "c"(desired >> 32) : "cc"); - result = ((ma_uint64)resultEDX << 32) | resultEAX; - #elif defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return ma_atomic_compare_and_swap_64((ma_uint64*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(64, ptr, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 result = 0; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 result = 0; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 result; - (void)order; - #if defined(MA_X86) - do { - result = *dst; - } while (ma_atomic_compare_and_swap_64(dst, result, src) != result); - #elif defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif } + #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) + #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) + #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) + #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) static MA_INLINE ma_uint8 ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - #if defined(MA_X86) - ma_uint64 oldValue; - ma_uint64 newValue; - (void)order; - do { - oldValue = *dst; - newValue = oldValue + src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - return oldValue; - #elif defined(MA_X64) - ma_uint64 result; - (void)order; - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - return result; - #endif + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue - src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, (ma_uint8)(-(ma_int8)src), order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue - src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, (ma_uint16)(-(ma_int16)src), order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, (ma_uint32)(-(ma_int32)src), order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, (ma_uint64)(-(ma_int64)src), order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue & src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue & src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - } - static MA_INLINE ma_uint8 ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - } - static MA_INLINE ma_uint16 ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - } - static MA_INLINE ma_uint32 ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - } - static MA_INLINE ma_uint64 ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue | src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue | src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); + } + #endif } + static MA_INLINE ma_uint8 ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); + } + #endif + } + #elif defined(MA_ATOMIC_LEGACY_GCC_ASM) + #define MA_ATOMIC_CMPXCHG_GCC_X86(instructionSizeSuffix, result, dst, expected, replacement) \ + __asm__ __volatile__( \ + "lock; cmpxchg"instructionSizeSuffix" %2, %1" \ + : "=a"(result), \ + "=m"(*dst) \ + : "r"(replacement), \ + "0"(expected), \ + "m"(*dst) \ + : "cc", "memory") + #define MA_ATOMIC_XADD_GCC_X86(instructionSizeSuffix, result, dst, src) \ + __asm__ __volatile__( \ + "lock; xadd"instructionSizeSuffix" %0, %1" \ + : "=a"(result), \ + "=m"(*dst) \ + : "0"(src), \ + "m"(*dst) \ + : "cc", "memory") + static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint8 result; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("b", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint16 result; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("w", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint32 result; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("l", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint64 result; + #if defined(MA_X86) + { + ma_uint32 resultEAX; + ma_uint32 resultEDX; + __asm__ __volatile__( + "pushl %%ebx\n" + "movl %4, %%ebx\n" + "lock cmpxchg8b (%%edi)\n" + "popl %%ebx\n" + : "=a"(resultEAX), + "=d"(resultEDX) + : "a"((ma_uint32)(expected & 0xFFFFFFFF)), + "d"((ma_uint32)(expected >> 32)), + "r"((ma_uint32)(replacement & 0xFFFFFFFF)), + "c"((ma_uint32)(replacement >> 32)), + "D"(dst) + : "memory", "cc"); + result = ((ma_uint64)resultEDX << 32) | resultEAX; + } + #elif defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("q", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint8 result; + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("b", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("b", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("b", result, dst); + } + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, dst, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint16 result; + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("w", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("w", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("w", result, dst); + } + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, dst, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint32 result; + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("l", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("l", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("l", result, dst); + } + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, dst, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint64 result; + #if defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("q", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("q", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("q", result, dst); + } + } + #elif defined(MA_X86) + { + (void)order; + return ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, 0, 0); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(64, dst, order); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint8 result; + (void)order; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("b", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint16 result; + (void)order; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("w", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint32 result; + (void)order; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("l", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint64 result; + (void)order; + #if defined(MA_X86) + { + MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(64, dst, src, order); + } + #elif defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("q", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE void ma_atomic_store_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movb %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgb %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif + } + static MA_INLINE void ma_atomic_store_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movw %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgw %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif + } + static MA_INLINE void ma_atomic_store_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movl %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgl %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif + } + static MA_INLINE void ma_atomic_store_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movq %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgq %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_CAS(64, dst, src, order); + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_uint8 result; + (void)order; + MA_ATOMIC_XADD_GCC_X86("b", result, dst, src); + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_uint16 result; + (void)order; + MA_ATOMIC_XADD_GCC_X86("w", result, dst, src); + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_uint32 result; + (void)order; + MA_ATOMIC_XADD_GCC_X86("l", result, dst, src); + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) + { + MA_ATOMIC_FETCH_ADD_CAS(64, dst, src, order); + } + #elif defined(MA_X64) + { + ma_uint64 result; + MA_ATOMIC_XADD_GCC_X86("q", result, dst, src); + (void)order; + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_8(dst, (ma_uint8)(-(ma_int8)src), order); + } + static MA_INLINE ma_uint16 ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_16(dst, (ma_uint16)(-(ma_int16)src), order); + } + static MA_INLINE ma_uint32 ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_32(dst, (ma_uint32)(-(ma_int32)src), order); + } + static MA_INLINE ma_uint64 ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_64(dst, (ma_uint64)(-(ma_int64)src), order); + } + static MA_INLINE ma_uint8 ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); + } + static MA_INLINE ma_uint8 ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); + } + static MA_INLINE ma_uint8 ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); + } + #else + #error Unsupported compiler. #endif - #define ma_atomic_signal_fence(order) ma_atomic_thread_fence(order) - static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_8((ma_uint8*)ptr, 0, 0); - } - static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_16((ma_uint16*)ptr, 0, 0); - } - static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_32((ma_uint32*)ptr, 0, 0); - } - static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_64((ma_uint64*)ptr, 0, 0); - } - #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) - #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) - #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) - #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) - #define ma_atomic_test_and_set_explicit_8( dst, order) ma_atomic_exchange_explicit_8 (dst, 1, order) - #define ma_atomic_test_and_set_explicit_16(dst, order) ma_atomic_exchange_explicit_16(dst, 1, order) - #define ma_atomic_test_and_set_explicit_32(dst, order) ma_atomic_exchange_explicit_32(dst, 1, order) - #define ma_atomic_test_and_set_explicit_64(dst, order) ma_atomic_exchange_explicit_64(dst, 1, order) - #define ma_atomic_clear_explicit_8( dst, order) ma_atomic_store_explicit_8 (dst, 0, order) - #define ma_atomic_clear_explicit_16(dst, order) ma_atomic_store_explicit_16(dst, 0, order) - #define ma_atomic_clear_explicit_32(dst, order) ma_atomic_store_explicit_32(dst, 0, order) - #define ma_atomic_clear_explicit_64(dst, order) ma_atomic_store_explicit_64(dst, 0, order) - typedef ma_uint8 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(ptr, order) (ma_bool32)ma_atomic_test_and_set_explicit_8(ptr, order) - #define ma_atomic_flag_clear_explicit(ptr, order) ma_atomic_clear_explicit_8(ptr, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_8(ptr, order) #endif #if !defined(MA_ATOMIC_HAS_NATIVE_COMPARE_EXCHANGE) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_8(volatile ma_uint8* dst, ma_uint8* expected, ma_uint8 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint8 expectedValue; - ma_uint8 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_8(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_8(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_8(expected, result, failureOrder); - return 0; - } - } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_16(volatile ma_uint16* dst, ma_uint16* expected, ma_uint16 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint16 expectedValue; - ma_uint16 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_16(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_16(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_16(expected, result, failureOrder); - return 0; - } - } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_32(volatile ma_uint32* dst, ma_uint32* expected, ma_uint32 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint32 expectedValue; - ma_uint32 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_32(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_32(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_32(expected, result, failureOrder); - return 0; - } - } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_64(volatile ma_uint64* dst, volatile ma_uint64* expected, ma_uint64 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint64 expectedValue; - ma_uint64 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_64(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_64(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_64(expected, result, failureOrder); - return 0; - } - } - #endif - #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8 (dst, expected, desired, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, desired, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, desired, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, desired, successOrder, failureOrder) -#endif -#if !defined(MA_ATOMIC_HAS_NATIVE_IS_LOCK_FREE) - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_8(volatile void* ptr) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_8(volatile ma_uint8* dst, ma_uint8* expected, ma_uint8 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - (void)ptr; - return 1; - } - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_16(volatile void* ptr) - { - (void)ptr; - return 1; - } - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_32(volatile void* ptr) - { - (void)ptr; - return 1; - } - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_64(volatile void* ptr) - { - (void)ptr; - #if defined(MA_64BIT) - return 1; - #else - #if defined(MA_X86) || defined(MA_X64) + ma_uint8 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_8(dst, *expected, replacement); + if (result == *expected) { return 1; - #else + } else { + *expected = result; return 0; - #endif - #endif + } } + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_16(volatile ma_uint16* dst, ma_uint16* expected, ma_uint16 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + { + ma_uint16 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_16(dst, *expected, replacement); + if (result == *expected) { + return 1; + } else { + *expected = result; + return 0; + } + } + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_32(volatile ma_uint32* dst, ma_uint32* expected, ma_uint32 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + { + ma_uint32 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_32(dst, *expected, replacement); + if (result == *expected) { + return 1; + } else { + *expected = result; + return 0; + } + } + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_64(volatile ma_uint64* dst, volatile ma_uint64* expected, ma_uint64 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + { + ma_uint64 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_64(dst, *expected, replacement); + if (result == *expected) { + return 1; + } else { + *expected = result; + return 0; + } + } + #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8 (dst, expected, replacement, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, replacement, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, replacement, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, replacement, successOrder, failureOrder) #endif #if defined(MA_64BIT) static MA_INLINE ma_bool32 ma_atomic_is_lock_free_ptr(volatile void** ptr) @@ -15561,17 +16991,17 @@ typedef int ma_atomic_memory_order; { return (void*)ma_atomic_exchange_explicit_64((volatile ma_uint64*)dst, (ma_uint64)src, order); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_strong_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_strong_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_weak_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_weak_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder); } - static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* desired) + static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* replacement) { - return (void*)ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, (ma_uint64)expected, (ma_uint64)desired); + return (void*)ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, (ma_uint64)expected, (ma_uint64)replacement); } #elif defined(MA_32BIT) static MA_INLINE ma_bool32 ma_atomic_is_lock_free_ptr(volatile void** ptr) @@ -15590,36 +17020,26 @@ typedef int ma_atomic_memory_order; { return (void*)ma_atomic_exchange_explicit_32((volatile ma_uint32*)dst, (ma_uint32)src, order); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_strong_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_strong_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_weak_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_weak_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder); } - static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* desired) + static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* replacement) { - return (void*)ma_atomic_compare_and_swap_32((volatile ma_uint32*)dst, (ma_uint32)expected, (ma_uint32)desired); + return (void*)ma_atomic_compare_and_swap_32((volatile ma_uint32*)dst, (ma_uint32)expected, (ma_uint32)replacement); } #else #error Unsupported architecture. #endif -#define ma_atomic_flag_test_and_set(ptr) ma_atomic_flag_test_and_set_explicit(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_flag_clear(ptr) ma_atomic_flag_clear_explicit(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_store_ptr(dst, src) ma_atomic_store_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_load_ptr(ptr) ma_atomic_load_explicit_ptr((volatile void**)ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_exchange_ptr(dst, src) ma_atomic_exchange_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_ptr(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_ptr((volatile void**)dst, (void**)expected, (void*)desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_ptr(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_ptr((volatile void**)dst, (void**)expected, (void*)desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_8( ptr) ma_atomic_test_and_set_explicit_8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_16(ptr) ma_atomic_test_and_set_explicit_16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_32(ptr) ma_atomic_test_and_set_explicit_32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_64(ptr) ma_atomic_test_and_set_explicit_64(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_8( ptr) ma_atomic_clear_explicit_8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_16(ptr) ma_atomic_clear_explicit_16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_32(ptr) ma_atomic_clear_explicit_32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_64(ptr) ma_atomic_clear_explicit_64(ptr, ma_atomic_memory_order_seq_cst) +#define ma_atomic_store_ptr(dst, src) ma_atomic_store_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) +#define ma_atomic_load_ptr(ptr) ma_atomic_load_explicit_ptr((volatile void**)ptr, ma_atomic_memory_order_seq_cst) +#define ma_atomic_exchange_ptr(dst, src) ma_atomic_exchange_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_ptr(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_ptr((volatile void**)dst, (void**)expected, (void*)replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_ptr(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_ptr((volatile void**)dst, (void**)expected, (void*)replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_8( dst, src) ma_atomic_store_explicit_8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_16(dst, src) ma_atomic_store_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_32(dst, src) ma_atomic_store_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15632,14 +17052,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_16(dst, src) ma_atomic_exchange_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_32(dst, src) ma_atomic_exchange_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_64(dst, src) ma_atomic_exchange_explicit_64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_8( dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_16(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_32(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_64(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_8( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_16( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_32( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_64( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_8( dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_16(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_32(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_64(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_8( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_16( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_32( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_64( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_8( dst, src) ma_atomic_fetch_add_explicit_8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_16(dst, src) ma_atomic_fetch_add_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_32(dst, src) ma_atomic_fetch_add_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15660,14 +17080,6 @@ typedef int ma_atomic_memory_order; #define ma_atomic_fetch_and_16(dst, src) ma_atomic_fetch_and_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_32(dst, src) ma_atomic_fetch_and_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_64(dst, src) ma_atomic_fetch_and_explicit_64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_explicit_i8( ptr, order) (ma_int8 )ma_atomic_test_and_set_explicit_8( (ma_uint8* )ptr, order) -#define ma_atomic_test_and_set_explicit_i16(ptr, order) (ma_int16)ma_atomic_test_and_set_explicit_16((ma_uint16*)ptr, order) -#define ma_atomic_test_and_set_explicit_i32(ptr, order) (ma_int32)ma_atomic_test_and_set_explicit_32((ma_uint32*)ptr, order) -#define ma_atomic_test_and_set_explicit_i64(ptr, order) (ma_int64)ma_atomic_test_and_set_explicit_64((ma_uint64*)ptr, order) -#define ma_atomic_clear_explicit_i8( ptr, order) ma_atomic_clear_explicit_8( (ma_uint8* )ptr, order) -#define ma_atomic_clear_explicit_i16(ptr, order) ma_atomic_clear_explicit_16((ma_uint16*)ptr, order) -#define ma_atomic_clear_explicit_i32(ptr, order) ma_atomic_clear_explicit_32((ma_uint32*)ptr, order) -#define ma_atomic_clear_explicit_i64(ptr, order) ma_atomic_clear_explicit_64((ma_uint64*)ptr, order) #define ma_atomic_store_explicit_i8( dst, src, order) ma_atomic_store_explicit_8( (ma_uint8* )dst, (ma_uint8 )src, order) #define ma_atomic_store_explicit_i16(dst, src, order) ma_atomic_store_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_store_explicit_i32(dst, src, order) ma_atomic_store_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) @@ -15680,14 +17092,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_explicit_i16(dst, src, order) (ma_int16)ma_atomic_exchange_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_exchange_explicit_i32(dst, src, order) (ma_int32)ma_atomic_exchange_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) #define ma_atomic_exchange_explicit_i64(dst, src, order) (ma_int64)ma_atomic_exchange_explicit_64((ma_uint64*)dst, (ma_uint64)src, order) -#define ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder) #define ma_atomic_fetch_add_explicit_i8( dst, src, order) (ma_int8 )ma_atomic_fetch_add_explicit_8( (ma_uint8* )dst, (ma_uint8 )src, order) #define ma_atomic_fetch_add_explicit_i16(dst, src, order) (ma_int16)ma_atomic_fetch_add_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_fetch_add_explicit_i32(dst, src, order) (ma_int32)ma_atomic_fetch_add_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) @@ -15708,14 +17120,6 @@ typedef int ma_atomic_memory_order; #define ma_atomic_fetch_and_explicit_i16(dst, src, order) (ma_int16)ma_atomic_fetch_and_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_fetch_and_explicit_i32(dst, src, order) (ma_int32)ma_atomic_fetch_and_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) #define ma_atomic_fetch_and_explicit_i64(dst, src, order) (ma_int64)ma_atomic_fetch_and_explicit_64((ma_uint64*)dst, (ma_uint64)src, order) -#define ma_atomic_test_and_set_i8( ptr) ma_atomic_test_and_set_explicit_i8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_i16(ptr) ma_atomic_test_and_set_explicit_i16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_i32(ptr) ma_atomic_test_and_set_explicit_i32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_i64(ptr) ma_atomic_test_and_set_explicit_i64(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i8( ptr) ma_atomic_clear_explicit_i8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i16(ptr) ma_atomic_clear_explicit_i16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i32(ptr) ma_atomic_clear_explicit_i32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i64(ptr) ma_atomic_clear_explicit_i64(ptr, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_i8( dst, src) ma_atomic_store_explicit_i8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_i16(dst, src) ma_atomic_store_explicit_i16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_i32(dst, src) ma_atomic_store_explicit_i32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15728,14 +17132,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_i16(dst, src) ma_atomic_exchange_explicit_i16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_i32(dst, src) ma_atomic_exchange_explicit_i32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_i64(dst, src) ma_atomic_exchange_explicit_i64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i8( dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i16(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i32(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i64(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i8( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i16(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i32(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i64(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i8( dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i16(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i32(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i64(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i8( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i16(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i32(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i64(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_i8( dst, src) ma_atomic_fetch_add_explicit_i8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_i16(dst, src) ma_atomic_fetch_add_explicit_i16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_i32(dst, src) ma_atomic_fetch_add_explicit_i32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15812,28 +17216,28 @@ static MA_INLINE double ma_atomic_exchange_explicit_f64(volatile double* dst, do r.i = ma_atomic_exchange_explicit_64((volatile ma_uint64*)dst, x.i, order); return r.f; } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f32(volatile float* dst, float* expected, float desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f32(volatile float* dst, float* expected, float replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if32 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_strong_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, d.i, successOrder, failureOrder); } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f64(volatile double* dst, double* expected, double desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f64(volatile double* dst, double* expected, double replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if64 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_strong_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, d.i, successOrder, failureOrder); } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f32(volatile float* dst, float* expected, float desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f32(volatile float* dst, float* expected, float replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if32 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_weak_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, d.i, successOrder, failureOrder); } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f64(volatile double* dst, double* expected, double desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f64(volatile double* dst, double* expected, double replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if64 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_weak_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, d.i, successOrder, failureOrder); } static MA_INLINE float ma_atomic_fetch_add_explicit_f32(volatile float* dst, float src, ma_atomic_memory_order order) @@ -15924,10 +17328,10 @@ static MA_INLINE double ma_atomic_fetch_and_explicit_f64(volatile double* dst, d #define ma_atomic_load_f64(ptr) (double)ma_atomic_load_explicit_f64(ptr, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_f32(dst, src) (float )ma_atomic_exchange_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_f64(dst, src) (double)ma_atomic_exchange_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_f32(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_f32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_f64(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_f64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_f32(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_f32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_f64(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_f64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_f32(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_f32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_f64(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_f64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_f32(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_f32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_f64(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_f64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_f32(dst, src) ma_atomic_fetch_add_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_f64(dst, src) ma_atomic_fetch_add_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_sub_f32(dst, src) ma_atomic_fetch_sub_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15938,39 +17342,24 @@ static MA_INLINE double ma_atomic_fetch_and_explicit_f64(volatile double* dst, d #define ma_atomic_fetch_xor_f64(dst, src) ma_atomic_fetch_xor_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_f32(dst, src) ma_atomic_fetch_and_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_f64(dst, src) ma_atomic_fetch_and_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) -static MA_INLINE float ma_atomic_compare_and_swap_f32(volatile float* dst, float expected, float desired) +static MA_INLINE float ma_atomic_compare_and_swap_f32(volatile float* dst, float expected, float replacement) { ma_atomic_if32 r; ma_atomic_if32 e, d; e.f = expected; - d.f = desired; + d.f = replacement; r.i = ma_atomic_compare_and_swap_32((volatile ma_uint32*)dst, e.i, d.i); return r.f; } -static MA_INLINE double ma_atomic_compare_and_swap_f64(volatile double* dst, double expected, double desired) +static MA_INLINE double ma_atomic_compare_and_swap_f64(volatile double* dst, double expected, double replacement) { ma_atomic_if64 r; ma_atomic_if64 e, d; e.f = expected; - d.f = desired; + d.f = replacement; r.i = ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, e.i, d.i); return r.f; } -typedef ma_atomic_flag ma_atomic_spinlock; -static MA_INLINE void ma_atomic_spinlock_lock(volatile ma_atomic_spinlock* pSpinlock) -{ - for (;;) { - if (ma_atomic_flag_test_and_set_explicit(pSpinlock, ma_atomic_memory_order_acquire) == 0) { - break; - } - while (ma_atomic_flag_load_explicit(pSpinlock, ma_atomic_memory_order_relaxed) == 1) { - } - } -} -static MA_INLINE void ma_atomic_spinlock_unlock(volatile ma_atomic_spinlock* pSpinlock) -{ - ma_atomic_flag_clear_explicit(pSpinlock, ma_atomic_memory_order_release); -} #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))) #pragma GCC diagnostic pop #endif @@ -16176,7 +17565,7 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority int result; pthread_attr_t* pAttr = NULL; -#if !defined(__EMSCRIPTEN__) && !defined(__3DS__) +#if !defined(MA_EMSCRIPTEN) && !defined(MA_3DS) && !defined(MA_SWITCH) /* Try setting the thread priority. It's not critical if anything fails here. */ pthread_attr_t attr; if (pthread_attr_init(&attr) == 0) { @@ -16208,9 +17597,18 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority } #endif - if (stackSize > 0) { - pthread_attr_setstacksize(&attr, stackSize); + #if defined(_POSIX_THREAD_ATTR_STACKSIZE) && _POSIX_THREAD_ATTR_STACKSIZE >= 0 + { + if (stackSize > 0) { + pthread_attr_setstacksize(&attr, stackSize); + } } + #else + { + (void)stackSize; /* Suppress unused parameter warning. */ + } + #endif + if (scheduler != -1) { int priorityMin = sched_get_priority_min(scheduler); @@ -16267,6 +17665,21 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority } if (result != 0) { + /* + There have been reports that attempting to create a realtime thread can sometimes fail. In this case, + fall back to a normal priority thread. + + I'm including a compile-time option here to disable this functionality for those who have a hard + requirement on realtime threads and would rather an explicit failure. + */ + #ifndef MA_NO_PTHREAD_REALTIME_PRIORITY_FALLBACK + { + if(result == EPERM && priority == ma_thread_priority_realtime) { + return ma_thread_create__posix(pThread, ma_thread_priority_normal, stackSize, entryProc, pData); + } + } + #endif + return ma_result_from_errno(result); } @@ -16538,7 +17951,7 @@ static ma_result ma_event_signal__win32(ma_event* pEvent) static ma_result ma_semaphore_init__win32(int initialValue, ma_semaphore* pSemaphore) { - *pSemaphore = CreateSemaphoreW(NULL, (LONG)initialValue, LONG_MAX, NULL); + *pSemaphore = CreateSemaphore(NULL, (LONG)initialValue, LONG_MAX, NULL); if (*pSemaphore == NULL) { return ma_result_from_GetLastError(GetLastError()); } @@ -17432,10 +18845,12 @@ static MA_INLINE ma_uint16 ma_job_extract_slot(ma_uint64 toc) return (ma_uint16)(toc & 0x0000FFFF); } +#if 0 /* Currently unused, but might make use of this later. */ static MA_INLINE ma_uint16 ma_job_extract_code(ma_uint64 toc) { return (ma_uint16)((toc & 0xFFFF0000) >> 16); } +#endif static MA_INLINE ma_uint64 ma_job_toc_to_allocation(ma_uint64 toc) { @@ -17900,6 +19315,13 @@ MA_API ma_result ma_job_queue_next(ma_job_queue* pQueue, ma_job* pJob) Dynamic Linking *******************************************************************************/ +/* Disable run-time linking on certain backends and platforms. */ +#ifndef MA_NO_RUNTIME_LINKING + #if defined(MA_EMSCRIPTEN) || defined(MA_ORBIS) || defined(MA_PROSPERO) || defined(MA_SWITCH) || defined(MA_DOS) + #define MA_NO_RUNTIME_LINKING + #endif +#endif + #ifdef MA_POSIX /* No need for dlfcn.h if we're not using runtime linking. */ #ifndef MA_NO_RUNTIME_LINKING @@ -17909,104 +19331,124 @@ Dynamic Linking MA_API ma_handle ma_dlopen(ma_log* pLog, const char* filename) { -#ifndef MA_NO_RUNTIME_LINKING - ma_handle handle; + #ifndef MA_NO_RUNTIME_LINKING + { + ma_handle handle; - ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading library: %s\n", filename); + ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading library: %s\n", filename); - #ifdef MA_WIN32 - /* From MSDN: Desktop applications cannot use LoadPackagedLibrary; if a desktop application calls this function it fails with APPMODEL_ERROR_NO_PACKAGE.*/ - #if !defined(MA_WIN32_UWP) || !(defined(WINAPI_FAMILY) && ((defined(WINAPI_FAMILY_PHONE_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PHONE_APP))) - handle = (ma_handle)LoadLibraryA(filename); + #ifdef MA_WIN32 + /* From MSDN: Desktop applications cannot use LoadPackagedLibrary; if a desktop application calls this function it fails with APPMODEL_ERROR_NO_PACKAGE.*/ + #if !defined(MA_WIN32_UWP) || !(defined(WINAPI_FAMILY) && ((defined(WINAPI_FAMILY_PHONE_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PHONE_APP))) + handle = (ma_handle)LoadLibraryA(filename); + #else + /* *sigh* It appears there is no ANSI version of LoadPackagedLibrary()... */ + WCHAR filenameW[4096]; + if (MultiByteToWideChar(CP_UTF8, 0, filename, -1, filenameW, sizeof(filenameW)) == 0) { + handle = NULL; + } else { + handle = (ma_handle)LoadPackagedLibrary(filenameW, 0); + } + #endif #else - /* *sigh* It appears there is no ANSI version of LoadPackagedLibrary()... */ - WCHAR filenameW[4096]; - if (MultiByteToWideChar(CP_UTF8, 0, filename, -1, filenameW, sizeof(filenameW)) == 0) { - handle = NULL; - } else { - handle = (ma_handle)LoadPackagedLibrary(filenameW, 0); - } + handle = (ma_handle)dlopen(filename, RTLD_NOW); #endif - #else - handle = (ma_handle)dlopen(filename, RTLD_NOW); - #endif - /* - I'm not considering failure to load a library an error nor a warning because seamlessly falling through to a lower-priority - backend is a deliberate design choice. Instead I'm logging it as an informational message. - */ - if (handle == NULL) { - ma_log_postf(pLog, MA_LOG_LEVEL_INFO, "Failed to load library: %s\n", filename); + /* + I'm not considering failure to load a library an error nor a warning because seamlessly falling through to a lower-priority + backend is a deliberate design choice. Instead I'm logging it as an informational message. + */ + if (handle == NULL) { + ma_log_postf(pLog, MA_LOG_LEVEL_INFO, "Failed to load library: %s\n", filename); + } + + return handle; } - - return handle; -#else - /* Runtime linking is disabled. */ - (void)pLog; - (void)filename; - return NULL; -#endif + #else + { + /* Runtime linking is disabled. */ + (void)pLog; + (void)filename; + return NULL; + } + #endif } MA_API void ma_dlclose(ma_log* pLog, ma_handle handle) { -#ifndef MA_NO_RUNTIME_LINKING - #ifdef MA_WIN32 - FreeLibrary((HMODULE)handle); - #else - /* Hack for Android bug (see https://github.com/android/ndk/issues/360). Calling dlclose() pre-API 28 may segfault. */ - #if !defined(MA_ANDROID) || (defined(__ANDROID_API__) && __ANDROID_API__ >= 28) + #ifndef MA_NO_RUNTIME_LINKING + { + #ifdef MA_WIN32 { - dlclose((void*)handle); + FreeLibrary((HMODULE)handle); } #else { - (void)handle; + /* Hack for Android bug (see https://github.com/android/ndk/issues/360). Calling dlclose() pre-API 28 may segfault. */ + #if !defined(MA_ANDROID) || (defined(__ANDROID_API__) && __ANDROID_API__ >= 28) + { + dlclose((void*)handle); + } + #else + { + (void)handle; + } + #endif } #endif - #endif - (void)pLog; -#else - /* Runtime linking is disabled. */ - (void)pLog; - (void)handle; -#endif + (void)pLog; + } + #else + { + /* Runtime linking is disabled. */ + (void)pLog; + (void)handle; + } + #endif } MA_API ma_proc ma_dlsym(ma_log* pLog, ma_handle handle, const char* symbol) { -#ifndef MA_NO_RUNTIME_LINKING - ma_proc proc; + #ifndef MA_NO_RUNTIME_LINKING + { + ma_proc proc; - ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading symbol: %s\n", symbol); + ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading symbol: %s\n", symbol); -#ifdef _WIN32 - proc = (ma_proc)GetProcAddress((HMODULE)handle, symbol); -#else -#if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Wpedantic" -#endif - proc = (ma_proc)dlsym((void*)handle, symbol); -#if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) - #pragma GCC diagnostic pop -#endif -#endif + #ifdef _WIN32 + { + proc = (ma_proc)GetProcAddress((HMODULE)handle, symbol); + } + #else + { + #if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wpedantic" + #endif + proc = (ma_proc)dlsym((void*)handle, symbol); + #if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) + #pragma GCC diagnostic pop + #endif + } + #endif - if (proc == NULL) { - ma_log_postf(pLog, MA_LOG_LEVEL_WARNING, "Failed to load symbol: %s\n", symbol); + if (proc == NULL) { + ma_log_postf(pLog, MA_LOG_LEVEL_WARNING, "Failed to load symbol: %s\n", symbol); + } + + (void)pLog; /* It's possible for pContext to be unused. */ + return proc; } - - (void)pLog; /* It's possible for pContext to be unused. */ - return proc; -#else - /* Runtime linking is disabled. */ - (void)pLog; - (void)handle; - (void)symbol; - return NULL; -#endif + #else + { + /* Runtime linking is disabled. */ + (void)pLog; + (void)handle; + (void)symbol; + return NULL; + } + #endif } @@ -18020,13 +19462,6 @@ DEVICE I/O ************************************************************************************************************************************************************* ************************************************************************************************************************************************************/ -/* Disable run-time linking on certain backends and platforms. */ -#ifndef MA_NO_RUNTIME_LINKING - #if defined(MA_EMSCRIPTEN) || defined(MA_ORBIS) || defined(MA_PROSPERO) - #define MA_NO_RUNTIME_LINKING - #endif -#endif - #ifdef MA_APPLE #include #endif @@ -18039,12 +19474,6 @@ DEVICE I/O #ifdef MA_POSIX #include - #include - - /* No need for dlfcn.h if we're not using runtime linking. */ - #ifndef MA_NO_RUNTIME_LINKING - #include - #endif #endif /* This must be set to at least 26. */ @@ -18299,7 +19728,7 @@ MA_API ma_bool32 ma_is_loopback_supported(ma_backend backend) -#if defined(MA_WIN32) +#if defined(MA_WIN32) && !defined(MA_XBOX) /* WASAPI error codes. */ #define MA_AUDCLNT_E_NOT_INITIALIZED ((HRESULT)0x88890001) #define MA_AUDCLNT_E_ALREADY_INITIALIZED ((HRESULT)0x88890002) @@ -18514,6 +19943,11 @@ typedef LONG (WINAPI * MA_PFN_RegCloseKey)(HKEY hKey); typedef LONG (WINAPI * MA_PFN_RegQueryValueExA)(HKEY hKey, const char* lpValueName, DWORD* lpReserved, DWORD* lpType, BYTE* lpData, DWORD* lpcbData); #endif /* MA_WIN32_DESKTOP */ +static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_PCM = {0x00000001, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; +static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_IEEE_FLOAT = {0x00000003, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; +/*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_ALAW = {0x00000006, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ +/*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_MULAW = {0x00000007, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ + MA_API size_t ma_strlen_WCHAR(const WCHAR* str) { size_t len = 0; @@ -18577,7 +20011,7 @@ Timing *******************************************************************************/ #if defined(MA_WIN32) && !defined(MA_POSIX) static LARGE_INTEGER g_ma_TimerFrequency; /* <-- Initialized to zero since it's static. */ - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { LARGE_INTEGER counter; @@ -18589,7 +20023,7 @@ Timing pTimer->counter = counter.QuadPart; } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { LARGE_INTEGER counter; if (!QueryPerformanceCounter(&counter)) { @@ -18600,7 +20034,7 @@ Timing } #elif defined(MA_APPLE) && (MAC_OS_X_VERSION_MIN_REQUIRED < 101200) static ma_uint64 g_ma_TimerFrequency = 0; - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { mach_timebase_info_data_t baseTime; mach_timebase_info(&baseTime); @@ -18609,7 +20043,7 @@ Timing pTimer->counter = mach_absolute_time(); } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { ma_uint64 newTimeCounter = mach_absolute_time(); ma_uint64 oldTimeCounter = pTimer->counter; @@ -18634,7 +20068,7 @@ Timing #define MA_CLOCK_ID CLOCK_REALTIME #endif - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { struct timespec newTime; clock_gettime(MA_CLOCK_ID, &newTime); @@ -18642,7 +20076,7 @@ Timing pTimer->counter = (newTime.tv_sec * 1000000000) + newTime.tv_nsec; } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { ma_uint64 newTimeCounter; ma_uint64 oldTimeCounter; @@ -18656,7 +20090,7 @@ Timing return (newTimeCounter - oldTimeCounter) / 1000000000.0; } #else - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { struct timeval newTime; gettimeofday(&newTime, NULL); @@ -18664,7 +20098,7 @@ Timing pTimer->counter = (newTime.tv_sec * 1000000) + newTime.tv_usec; } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { ma_uint64 newTimeCounter; ma_uint64 oldTimeCounter; @@ -19248,14 +20682,6 @@ static MA_INLINE void ma_device__set_state(ma_device* pDevice, ma_device_state n } -#if defined(MA_WIN32) - static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_PCM = {0x00000001, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; - static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_IEEE_FLOAT = {0x00000003, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; - /*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_ALAW = {0x00000006, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ - /*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_MULAW = {0x00000007, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ -#endif - - MA_API ma_uint32 ma_get_format_priority_index(ma_format format) /* Lower = better. */ { @@ -19967,7 +21393,7 @@ static ma_result ma_context_init__null(ma_context* pContext, const ma_context_co WIN32 COMMON *******************************************************************************/ -#if defined(MA_WIN32) +#if defined(MA_WIN32) && !defined(MA_XBOX) #if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) #define ma_CoInitializeEx(pContext, pvReserved, dwCoInit) ((pContext->win32.CoInitializeEx) ? ((MA_PFN_CoInitializeEx)pContext->win32.CoInitializeEx)(pvReserved, dwCoInit) : ((MA_PFN_CoInitialize)pContext->win32.CoInitialize)(pvReserved)) #define ma_CoUninitialize(pContext) ((MA_PFN_CoUninitialize)pContext->win32.CoUninitialize)() @@ -19982,7 +21408,7 @@ WIN32 COMMON #define ma_PropVariantClear(pContext, pvar) PropVariantClear(pvar) #endif -#if !defined(MAXULONG_PTR) && !defined(__WATCOMC__) +#if !defined(MAXULONG_PTR) && !defined(__WATCOMC__) && !defined(MA_XBOX_NXDK) typedef size_t DWORD_PTR; #endif @@ -20409,11 +21835,21 @@ typedef enum MA_AudioCategory_Other = 0 /* <-- miniaudio is only caring about Other. */ } MA_AUDIO_STREAM_CATEGORY; +typedef enum +{ + MA_AUDCLNT_STREAMOPTIONS_NONE, + MA_AUDCLNT_STREAMOPTIONS_RAW, + MA_AUDCLNT_STREAMOPTIONS_MATCH_FORMAT, + MA_AUDCLNT_STREAMOPTIONS_AMBISONICS, + MA_AUDCLNT_STREAMOPTIONS_POST_VOLUME_LOOPBACK +} MA_AUDCLNT_STREAMOPTIONS; + typedef struct { ma_uint32 cbSize; BOOL bIsOffload; MA_AUDIO_STREAM_CATEGORY eCategory; + MA_AUDCLNT_STREAMOPTIONS Options; } ma_AudioClientProperties; /* IUnknown */ @@ -21588,6 +23024,7 @@ static ma_result ma_context_get_MMDevice__wasapi(ma_context* pContext, ma_device { ma_IMMDeviceEnumerator* pDeviceEnumerator; HRESULT hr; + HRESULT CoInitializeResult; MA_ASSERT(pContext != NULL); MA_ASSERT(ppMMDevice != NULL); @@ -21601,12 +23038,17 @@ static ma_result ma_context_get_MMDevice__wasapi(ma_context* pContext, ma_device The community has reported that this seems to fix the crash. There are future plans to move all WASAPI operation over to a single thread to make everything safer, but in the meantime while we wait for that to come online I'm happy enough to use this hack instead. + + CoUninitialize should only be called if we successfully initialized. S_OK and S_FALSE both mean that we need to + call CoUninitialize since the internal ref count was increased. RPC_E_CHANGED_MODE means that CoInitializeEx was + called with a different COINIT value, and we don't call CoUninitialize in that case. Other errors are possible, + so we check for S_OK and S_FALSE specifically. */ - ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); + CoInitializeResult = ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); { hr = ma_CoCreateInstance(pContext, &MA_CLSID_MMDeviceEnumerator, NULL, CLSCTX_ALL, &MA_IID_IMMDeviceEnumerator, (void**)&pDeviceEnumerator); - } - ma_CoUninitialize(pContext); + } + if (CoInitializeResult == S_OK || CoInitializeResult == S_FALSE) { ma_CoUninitialize(pContext); } if (FAILED(hr)) { /* <-- This is checking the call above to ma_CoCreateInstance(). */ ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[WASAPI] Failed to create IMMDeviceEnumerator.\n"); @@ -21950,7 +23392,7 @@ static ma_result ma_context_get_IAudioClient__wasapi(ma_context* pContext, ma_de pActivationParams = &activationParams; /* When requesting a specific device ID we need to use a special device ID. */ - MA_COPY_MEMORY(virtualDeviceID.wasapi, MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK, (wcslen(MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK) + 1) * sizeof(wchar_t)); /* +1 for the null terminator. */ + MA_COPY_MEMORY(virtualDeviceID.wasapi, MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK, (ma_wcslen(MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK) + 1) * sizeof(wchar_t)); /* +1 for the null terminator. */ pDeviceID = &virtualDeviceID; } else { pActivationParams = NULL; /* No activation parameters required. */ @@ -26679,6 +28121,9 @@ typedef snd_pcm_channel_area_t ma_snd_pcm_channel_area_t; typedef snd_pcm_chmap_t ma_snd_pcm_chmap_t; typedef snd_pcm_state_t ma_snd_pcm_state_t; +/* snd_pcm_state_t */ +#define MA_SND_PCM_STATE_XRUN SND_PCM_STATE_XRUN + /* snd_pcm_stream_t */ #define MA_SND_PCM_STREAM_PLAYBACK SND_PCM_STREAM_PLAYBACK #define MA_SND_PCM_STREAM_CAPTURE SND_PCM_STREAM_CAPTURE @@ -26874,6 +28319,7 @@ typedef int (* ma_snd_pcm_hw_params_set_channels_minmax_proc) ( typedef int (* ma_snd_pcm_hw_params_set_rate_resample_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int val); typedef int (* ma_snd_pcm_hw_params_set_rate_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int val, int dir); typedef int (* ma_snd_pcm_hw_params_set_rate_near_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int *val, int *dir); +typedef int (* ma_snd_pcm_hw_params_set_rate_minmax_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int *min, int *mindir, unsigned int *max, int *maxdir); typedef int (* ma_snd_pcm_hw_params_set_buffer_size_near_proc)(ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, ma_snd_pcm_uframes_t *val); typedef int (* ma_snd_pcm_hw_params_set_periods_near_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int *val, int *dir); typedef int (* ma_snd_pcm_hw_params_set_access_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, ma_snd_pcm_access_t _access); @@ -28640,8 +30086,9 @@ static ma_result ma_context_init__alsa(ma_context* pContext, const ma_context_co ma_snd_pcm_hw_params_get_format_mask_proc _snd_pcm_hw_params_get_format_mask = snd_pcm_hw_params_get_format_mask; ma_snd_pcm_hw_params_set_channels_proc _snd_pcm_hw_params_set_channels = snd_pcm_hw_params_set_channels; ma_snd_pcm_hw_params_set_channels_near_proc _snd_pcm_hw_params_set_channels_near = snd_pcm_hw_params_set_channels_near; + ma_snd_pcm_hw_params_set_channels_minmax_proc _snd_pcm_hw_params_set_channels_minmax = snd_pcm_hw_params_set_channels_minmax; ma_snd_pcm_hw_params_set_rate_resample_proc _snd_pcm_hw_params_set_rate_resample = snd_pcm_hw_params_set_rate_resample; - ma_snd_pcm_hw_params_set_rate_near _snd_pcm_hw_params_set_rate = snd_pcm_hw_params_set_rate; + ma_snd_pcm_hw_params_set_rate_proc _snd_pcm_hw_params_set_rate = snd_pcm_hw_params_set_rate; ma_snd_pcm_hw_params_set_rate_near_proc _snd_pcm_hw_params_set_rate_near = snd_pcm_hw_params_set_rate_near; ma_snd_pcm_hw_params_set_rate_minmax_proc _snd_pcm_hw_params_set_rate_minmax = snd_pcm_hw_params_set_rate_minmax; ma_snd_pcm_hw_params_set_buffer_size_near_proc _snd_pcm_hw_params_set_buffer_size_near = snd_pcm_hw_params_set_buffer_size_near; @@ -28693,9 +30140,9 @@ static ma_result ma_context_init__alsa(ma_context* pContext, const ma_context_co ma_snd_pcm_info_proc _snd_pcm_info = snd_pcm_info; ma_snd_pcm_info_sizeof_proc _snd_pcm_info_sizeof = snd_pcm_info_sizeof; ma_snd_pcm_info_get_name_proc _snd_pcm_info_get_name = snd_pcm_info_get_name; - ma_snd_pcm_poll_descriptors _snd_pcm_poll_descriptors = snd_pcm_poll_descriptors; - ma_snd_pcm_poll_descriptors_count _snd_pcm_poll_descriptors_count = snd_pcm_poll_descriptors_count; - ma_snd_pcm_poll_descriptors_revents _snd_pcm_poll_descriptors_revents = snd_pcm_poll_descriptors_revents; + ma_snd_pcm_poll_descriptors_proc _snd_pcm_poll_descriptors = snd_pcm_poll_descriptors; + ma_snd_pcm_poll_descriptors_count_proc _snd_pcm_poll_descriptors_count = snd_pcm_poll_descriptors_count; + ma_snd_pcm_poll_descriptors_revents_proc _snd_pcm_poll_descriptors_revents = snd_pcm_poll_descriptors_revents; ma_snd_config_update_free_global_proc _snd_config_update_free_global = snd_config_update_free_global; pContext->alsa.snd_pcm_open = (ma_proc)_snd_pcm_open; @@ -28711,6 +30158,7 @@ static ma_result ma_context_init__alsa(ma_context* pContext, const ma_context_co pContext->alsa.snd_pcm_hw_params_set_rate_resample = (ma_proc)_snd_pcm_hw_params_set_rate_resample; pContext->alsa.snd_pcm_hw_params_set_rate = (ma_proc)_snd_pcm_hw_params_set_rate; pContext->alsa.snd_pcm_hw_params_set_rate_near = (ma_proc)_snd_pcm_hw_params_set_rate_near; + pContext->alsa.snd_pcm_hw_params_set_rate_minmax = (ma_proc)_snd_pcm_hw_params_set_rate_minmax; pContext->alsa.snd_pcm_hw_params_set_buffer_size_near = (ma_proc)_snd_pcm_hw_params_set_buffer_size_near; pContext->alsa.snd_pcm_hw_params_set_periods_near = (ma_proc)_snd_pcm_hw_params_set_periods_near; pContext->alsa.snd_pcm_hw_params_set_access = (ma_proc)_snd_pcm_hw_params_set_access; @@ -29436,7 +30884,7 @@ typedef void (* ma_pa_threaded_mainloop_unlock_proc) ( typedef void (* ma_pa_threaded_mainloop_wait_proc) (ma_pa_threaded_mainloop* m); typedef void (* ma_pa_threaded_mainloop_signal_proc) (ma_pa_threaded_mainloop* m, int wait_for_accept); typedef void (* ma_pa_threaded_mainloop_accept_proc) (ma_pa_threaded_mainloop* m); -typedef int (* ma_pa_threaded_mainloop_get_retval_proc) (ma_pa_threaded_mainloop* m); +typedef int (* ma_pa_threaded_mainloop_get_retval_proc) (const ma_pa_threaded_mainloop* m); typedef ma_pa_mainloop_api* (* ma_pa_threaded_mainloop_get_api_proc) (ma_pa_threaded_mainloop* m); typedef int (* ma_pa_threaded_mainloop_in_thread_proc) (ma_pa_threaded_mainloop* m); typedef void (* ma_pa_threaded_mainloop_set_name_proc) (ma_pa_threaded_mainloop* m, const char* name); @@ -29445,13 +30893,13 @@ typedef void (* ma_pa_context_unref_proc) ( typedef int (* ma_pa_context_connect_proc) (ma_pa_context* c, const char* server, ma_pa_context_flags_t flags, const ma_pa_spawn_api* api); typedef void (* ma_pa_context_disconnect_proc) (ma_pa_context* c); typedef void (* ma_pa_context_set_state_callback_proc) (ma_pa_context* c, ma_pa_context_notify_cb_t cb, void* userdata); -typedef ma_pa_context_state_t (* ma_pa_context_get_state_proc) (ma_pa_context* c); +typedef ma_pa_context_state_t (* ma_pa_context_get_state_proc) (const ma_pa_context* c); typedef ma_pa_operation* (* ma_pa_context_get_sink_info_list_proc) (ma_pa_context* c, ma_pa_sink_info_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_context_get_source_info_list_proc) (ma_pa_context* c, ma_pa_source_info_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_context_get_sink_info_by_name_proc) (ma_pa_context* c, const char* name, ma_pa_sink_info_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_context_get_source_info_by_name_proc)(ma_pa_context* c, const char* name, ma_pa_source_info_cb_t cb, void* userdata); typedef void (* ma_pa_operation_unref_proc) (ma_pa_operation* o); -typedef ma_pa_operation_state_t (* ma_pa_operation_get_state_proc) (ma_pa_operation* o); +typedef ma_pa_operation_state_t (* ma_pa_operation_get_state_proc) (const ma_pa_operation* o); typedef ma_pa_channel_map* (* ma_pa_channel_map_init_extend_proc) (ma_pa_channel_map* m, unsigned channels, ma_pa_channel_map_def_t def); typedef int (* ma_pa_channel_map_valid_proc) (const ma_pa_channel_map* m); typedef int (* ma_pa_channel_map_compatible_proc) (const ma_pa_channel_map* m, const ma_pa_sample_spec* ss); @@ -29460,12 +30908,12 @@ typedef void (* ma_pa_stream_unref_proc) ( typedef int (* ma_pa_stream_connect_playback_proc) (ma_pa_stream* s, const char* dev, const ma_pa_buffer_attr* attr, ma_pa_stream_flags_t flags, const ma_pa_cvolume* volume, ma_pa_stream* sync_stream); typedef int (* ma_pa_stream_connect_record_proc) (ma_pa_stream* s, const char* dev, const ma_pa_buffer_attr* attr, ma_pa_stream_flags_t flags); typedef int (* ma_pa_stream_disconnect_proc) (ma_pa_stream* s); -typedef ma_pa_stream_state_t (* ma_pa_stream_get_state_proc) (ma_pa_stream* s); +typedef ma_pa_stream_state_t (* ma_pa_stream_get_state_proc) (const ma_pa_stream* s); typedef const ma_pa_sample_spec* (* ma_pa_stream_get_sample_spec_proc) (ma_pa_stream* s); typedef const ma_pa_channel_map* (* ma_pa_stream_get_channel_map_proc) (ma_pa_stream* s); typedef const ma_pa_buffer_attr* (* ma_pa_stream_get_buffer_attr_proc) (ma_pa_stream* s); typedef ma_pa_operation* (* ma_pa_stream_set_buffer_attr_proc) (ma_pa_stream* s, const ma_pa_buffer_attr* attr, ma_pa_stream_success_cb_t cb, void* userdata); -typedef const char* (* ma_pa_stream_get_device_name_proc) (ma_pa_stream* s); +typedef const char* (* ma_pa_stream_get_device_name_proc) (const ma_pa_stream* s); typedef void (* ma_pa_stream_set_write_callback_proc) (ma_pa_stream* s, ma_pa_stream_request_cb_t cb, void* userdata); typedef void (* ma_pa_stream_set_read_callback_proc) (ma_pa_stream* s, ma_pa_stream_request_cb_t cb, void* userdata); typedef void (* ma_pa_stream_set_suspended_callback_proc) (ma_pa_stream* s, ma_pa_stream_notify_cb_t cb, void* userdata); @@ -29473,15 +30921,15 @@ typedef void (* ma_pa_stream_set_moved_callback_proc) ( typedef int (* ma_pa_stream_is_suspended_proc) (const ma_pa_stream* s); typedef ma_pa_operation* (* ma_pa_stream_flush_proc) (ma_pa_stream* s, ma_pa_stream_success_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_stream_drain_proc) (ma_pa_stream* s, ma_pa_stream_success_cb_t cb, void* userdata); -typedef int (* ma_pa_stream_is_corked_proc) (ma_pa_stream* s); +typedef int (* ma_pa_stream_is_corked_proc) (const ma_pa_stream* s); typedef ma_pa_operation* (* ma_pa_stream_cork_proc) (ma_pa_stream* s, int b, ma_pa_stream_success_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_stream_trigger_proc) (ma_pa_stream* s, ma_pa_stream_success_cb_t cb, void* userdata); typedef int (* ma_pa_stream_begin_write_proc) (ma_pa_stream* s, void** data, size_t* nbytes); typedef int (* ma_pa_stream_write_proc) (ma_pa_stream* s, const void* data, size_t nbytes, ma_pa_free_cb_t free_cb, int64_t offset, ma_pa_seek_mode_t seek); typedef int (* ma_pa_stream_peek_proc) (ma_pa_stream* s, const void** data, size_t* nbytes); typedef int (* ma_pa_stream_drop_proc) (ma_pa_stream* s); -typedef size_t (* ma_pa_stream_writable_size_proc) (ma_pa_stream* s); -typedef size_t (* ma_pa_stream_readable_size_proc) (ma_pa_stream* s); +typedef size_t (* ma_pa_stream_writable_size_proc) (const ma_pa_stream* s); +typedef size_t (* ma_pa_stream_readable_size_proc) (const ma_pa_stream* s); typedef struct { @@ -29777,7 +31225,7 @@ static ma_result ma_init_pa_mainloop_and_pa_context__pulse(ma_context* pContext, } /* Now we need to connect to the context. Everything is asynchronous so we need to wait for it to connect before returning. */ - result = ma_result_from_pulse(((ma_pa_context_connect_proc)pContext->pulse.pa_context_connect)((ma_pa_context*)pPulseContext, pServerName, (tryAutoSpawn) ? 0 : MA_PA_CONTEXT_NOAUTOSPAWN, NULL)); + result = ma_result_from_pulse(((ma_pa_context_connect_proc)pContext->pulse.pa_context_connect)((ma_pa_context*)pPulseContext, pServerName, (tryAutoSpawn) ? MA_PA_CONTEXT_NOFLAGS : MA_PA_CONTEXT_NOAUTOSPAWN, NULL)); if (result != MA_SUCCESS) { ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[PulseAudio] Failed to connect PulseAudio context."); ((ma_pa_mainloop_free_proc)pContext->pulse.pa_mainloop_free)((ma_pa_mainloop*)(pMainLoop)); @@ -30510,7 +31958,7 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi const ma_pa_buffer_attr* pActualAttr = NULL; const ma_pa_channel_map* pActualChannelMap = NULL; ma_uint32 iChannel; - ma_pa_stream_flags_t streamFlags; + int streamFlags; MA_ASSERT(pDevice != NULL); MA_ZERO_OBJECT(&pDevice->pulse); @@ -30568,8 +32016,13 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi ss.channels = pDescriptorCapture->channels; } + /* PulseAudio has a maximum channel count of 32. We'll get a crash if this is exceeded. */ + if (ss.channels > 32) { + ss.channels = 32; + } + /* Use a default channel map. */ - ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, pConfig->pulse.channelMap); + ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, (ma_pa_channel_map_def_t)pConfig->pulse.channelMap); /* Use the requested sample rate if one was specified. */ if (pDescriptorCapture->sampleRate != 0) { @@ -30626,7 +32079,7 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi streamFlags |= MA_PA_STREAM_DONT_MOVE; } - error = ((ma_pa_stream_connect_record_proc)pDevice->pContext->pulse.pa_stream_connect_record)((ma_pa_stream*)pDevice->pulse.pStreamCapture, devCapture, &attr, streamFlags); + error = ((ma_pa_stream_connect_record_proc)pDevice->pContext->pulse.pa_stream_connect_record)((ma_pa_stream*)pDevice->pulse.pStreamCapture, devCapture, &attr, (ma_pa_stream_flags_t)streamFlags); if (error != MA_PA_OK) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[PulseAudio] Failed to connect PulseAudio capture stream."); result = ma_result_from_pulse(error); @@ -30720,8 +32173,13 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi ss.channels = pDescriptorPlayback->channels; } + /* PulseAudio has a maximum channel count of 32. We'll get a crash if this is exceeded. */ + if (ss.channels > 32) { + ss.channels = 32; + } + /* Use a default channel map. */ - ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, pConfig->pulse.channelMap); + ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, (ma_pa_channel_map_def_t)pConfig->pulse.channelMap); /* Use the requested sample rate if one was specified. */ @@ -30783,7 +32241,7 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi streamFlags |= MA_PA_STREAM_DONT_MOVE; } - error = ((ma_pa_stream_connect_playback_proc)pDevice->pContext->pulse.pa_stream_connect_playback)((ma_pa_stream*)pDevice->pulse.pStreamPlayback, devPlayback, &attr, streamFlags, NULL, NULL); + error = ((ma_pa_stream_connect_playback_proc)pDevice->pContext->pulse.pa_stream_connect_playback)((ma_pa_stream*)pDevice->pulse.pStreamPlayback, devPlayback, &attr, (ma_pa_stream_flags_t)streamFlags, NULL, NULL); if (error != MA_PA_OK) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[PulseAudio] Failed to connect PulseAudio playback stream."); result = ma_result_from_pulse(error); @@ -31338,6 +32796,7 @@ typedef JackProcessCallback ma_JackProcessCallback; typedef JackBufferSizeCallback ma_JackBufferSizeCallback; typedef JackShutdownCallback ma_JackShutdownCallback; #define MA_JACK_DEFAULT_AUDIO_TYPE JACK_DEFAULT_AUDIO_TYPE +#define ma_JackNullOption JackNullOption #define ma_JackNoStartServer JackNoStartServer #define ma_JackPortIsInput JackPortIsInput #define ma_JackPortIsOutput JackPortIsOutput @@ -31352,6 +32811,7 @@ typedef int (* ma_JackProcessCallback) (ma_jack_nframes_t nframes, void* arg) typedef int (* ma_JackBufferSizeCallback)(ma_jack_nframes_t nframes, void* arg); typedef void (* ma_JackShutdownCallback) (void* arg); #define MA_JACK_DEFAULT_AUDIO_TYPE "32 bit float mono audio" +#define ma_JackNullOption 0 #define ma_JackNoStartServer 1 #define ma_JackPortIsInput 1 #define ma_JackPortIsOutput 2 @@ -31392,7 +32852,7 @@ static ma_result ma_context_open_client__jack(ma_context* pContext, ma_jack_clie maxClientNameSize = ((ma_jack_client_name_size_proc)pContext->jack.jack_client_name_size)(); /* Includes null terminator. */ ma_strncpy_s(clientName, ma_min(sizeof(clientName), maxClientNameSize), (pContext->jack.pClientName != NULL) ? pContext->jack.pClientName : "miniaudio", (size_t)-1); - pClient = ((ma_jack_client_open_proc)pContext->jack.jack_client_open)(clientName, (pContext->jack.tryStartServer) ? 0 : ma_JackNoStartServer, &status, NULL); + pClient = ((ma_jack_client_open_proc)pContext->jack.jack_client_open)(clientName, (pContext->jack.tryStartServer) ? ma_JackNullOption : ma_JackNoStartServer, &status, NULL); if (pClient == NULL) { return MA_FAILED_TO_OPEN_BACKEND_DEVICE; } @@ -36994,7 +38454,7 @@ OSS Backend #define MA_OSS_DEFAULT_DEVICE_NAME "/dev/dsp" -static int ma_open_temp_device__oss() +static int ma_open_temp_device__oss(void) { /* The OSS sample code uses "/dev/mixer" as the device for getting system properties so I'm going to do the same. */ int fd = open("/dev/mixer", O_RDONLY, 0); @@ -37834,25 +39294,30 @@ static void ma_stream_error_callback__aaudio(ma_AAudioStream* pStream, void* pUs (void)error; ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] ERROR CALLBACK: error=%d, AAudioStream_getState()=%d\n", error, ((MA_PFN_AAudioStream_getState)pDevice->pContext->aaudio.AAudioStream_getState)(pStream)); + /* When we get an error, we'll assume that the stream is in an erroneous state and needs to be restarted. From the documentation, we cannot do this from the error callback. Therefore we are going to use an event thread for the AAudio backend to do this cleanly and safely. */ - job = ma_job_init(MA_JOB_TYPE_DEVICE_AAUDIO_REROUTE); - job.data.device.aaudio.reroute.pDevice = pDevice; - - if (pStream == pDevice->aaudio.pStreamCapture) { - job.data.device.aaudio.reroute.deviceType = ma_device_type_capture; + if (ma_atomic_bool32_get(&pDevice->aaudio.isTearingDown)) { + ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Device Disconnected. Tearing down device.\n"); } else { - job.data.device.aaudio.reroute.deviceType = ma_device_type_playback; - } - - result = ma_device_job_thread_post(&pDevice->pContext->aaudio.jobThread, &job); - if (result != MA_SUCCESS) { - ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Device Disconnected. Failed to post job for rerouting.\n"); - return; + job = ma_job_init(MA_JOB_TYPE_DEVICE_AAUDIO_REROUTE); + job.data.device.aaudio.reroute.pDevice = pDevice; + + if (pStream == pDevice->aaudio.pStreamCapture) { + job.data.device.aaudio.reroute.deviceType = ma_device_type_capture; + } else { + job.data.device.aaudio.reroute.deviceType = ma_device_type_playback; + } + + result = ma_device_job_thread_post(&pDevice->pContext->aaudio.jobThread, &job); + if (result != MA_SUCCESS) { + ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Device Disconnected. Failed to post job for rerouting.\n"); + return; + } } } @@ -38169,7 +39634,7 @@ static ma_result ma_close_streams__aaudio(ma_device* pDevice) { MA_ASSERT(pDevice != NULL); - /* When re-routing, streams may have been closed and never re-opened. Hence the extra checks below. */ + /* When rerouting, streams may have been closed and never re-opened. Hence the extra checks below. */ if (pDevice->type == ma_device_type_capture || pDevice->type == ma_device_type_duplex) { ma_close_stream__aaudio(pDevice->pContext, (ma_AAudioStream*)pDevice->aaudio.pStreamCapture); pDevice->aaudio.pStreamCapture = NULL; @@ -38186,6 +39651,12 @@ static ma_result ma_device_uninit__aaudio(ma_device* pDevice) { MA_ASSERT(pDevice != NULL); + /* + Note: Closing the streams may cause a timeout error, which would then trigger rerouting in our error callback. + We must not schedule a reroute when device is getting destroyed. + */ + ma_atomic_bool32_set(&pDevice->aaudio.isTearingDown, MA_TRUE); + /* Wait for any rerouting to finish before attempting to close the streams. */ ma_mutex_lock(&pDevice->aaudio.rerouteLock); { @@ -38193,7 +39664,7 @@ static ma_result ma_device_uninit__aaudio(ma_device* pDevice) } ma_mutex_unlock(&pDevice->aaudio.rerouteLock); - /* Destroy re-routing lock. */ + /* Destroy rerouting lock. */ ma_mutex_uninit(&pDevice->aaudio.rerouteLock); return MA_SUCCESS; @@ -38429,17 +39900,22 @@ static ma_result ma_device_stop__aaudio(ma_device* pDevice) static ma_result ma_device_reinit__aaudio(ma_device* pDevice, ma_device_type deviceType) { + const ma_int32 maxAttempts = 4; /* Reasonable retry limit. */ + ma_result result; - int32_t retries = 0; + ma_int32 iAttempt; MA_ASSERT(pDevice != NULL); - /* - TODO: Stop retrying if main thread is about to uninit device. - */ - ma_mutex_lock(&pDevice->aaudio.rerouteLock); - { -error_disconnected: + /* We got disconnected! Retry a few times, until we find a connected device! */ + iAttempt = 0; + while (iAttempt++ < maxAttempts) { + /* Device tearing down? No need to reroute! */ + if (ma_atomic_bool32_get(&pDevice->aaudio.isTearingDown)) { + result = MA_SUCCESS; /* Caller should continue as normal. */ + break; + } + /* The first thing to do is close the streams. */ ma_close_streams__aaudio(pDevice); @@ -38495,14 +39971,16 @@ error_disconnected: result = ma_device_init_streams__aaudio(pDevice, &deviceConfig, &descriptorPlayback, &descriptorCapture); if (result != MA_SUCCESS) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_WARNING, "[AAudio] Failed to create stream after route change."); - goto done; + /* Reroute failed! */ + break; } result = ma_device_post_init(pDevice, deviceType, &descriptorPlayback, &descriptorCapture); if (result != MA_SUCCESS) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_WARNING, "[AAudio] Failed to initialize device after route change."); ma_close_streams__aaudio(pDevice); - goto done; + /* Reroute failed! */ + break; } /* We'll only ever do this in response to a reroute. */ @@ -38513,26 +39991,23 @@ error_disconnected: if (pDevice->aaudio.noAutoStartAfterReroute == MA_FALSE) { result = ma_device_start__aaudio(pDevice); if (result != MA_SUCCESS) { - /* We got disconnected! Retry a few times, until we find a connected device! */ - retries += 1; - if (retries <= 3) { - ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change, retrying(%d)", retries); - goto error_disconnected; + if (iAttempt < maxAttempts) { + ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change, retrying(%d)", iAttempt); + } else { + ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change, giving up."); } - ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change."); - goto done; } } else { - ma_device_stop(pDevice); /* Do a full device stop so we set internal state correctly. */ + ma_device_stop(pDevice); /* Do a full device stop so we set internal state correctly. */ } } - - result = MA_SUCCESS; - } -done: - /* Re-routing done */ - ma_mutex_unlock(&pDevice->aaudio.rerouteLock); + if (result == MA_SUCCESS) { + /* Reroute successful! */ + break; + } + } + return result; } @@ -38698,7 +40173,7 @@ static ma_result ma_context_init__aaudio(ma_context* pContext, const ma_context_ static ma_result ma_job_process__device__aaudio_reroute(ma_job* pJob) { - ma_result result; + ma_result result = MA_SUCCESS; ma_device* pDevice; MA_ASSERT(pJob != NULL); @@ -38706,19 +40181,22 @@ static ma_result ma_job_process__device__aaudio_reroute(ma_job* pJob) pDevice = (ma_device*)pJob->data.device.aaudio.reroute.pDevice; MA_ASSERT(pDevice != NULL); - /* Here is where we need to reroute the device. To do this we need to uninitialize the stream and reinitialize it. */ - result = ma_device_reinit__aaudio(pDevice, (ma_device_type)pJob->data.device.aaudio.reroute.deviceType); - if (result != MA_SUCCESS) { - /* - Getting here means we failed to reroute the device. The best thing I can think of here is to - just stop the device. - */ - ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[AAudio] Stopping device due to reroute failure."); - ma_device_stop(pDevice); - return result; + ma_mutex_lock(&pDevice->aaudio.rerouteLock); + { + /* Here is where we need to reroute the device. To do this we need to uninitialize the stream and reinitialize it. */ + result = ma_device_reinit__aaudio(pDevice, (ma_device_type)pJob->data.device.aaudio.reroute.deviceType); + if (result != MA_SUCCESS) { + /* + Getting here means we failed to reroute the device. The best thing I can think of here is to + just stop the device. + */ + ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[AAudio] Stopping device due to reroute failure."); + ma_device_stop(pDevice); + } } + ma_mutex_unlock(&pDevice->aaudio.rerouteLock); - return MA_SUCCESS; + return result; } #else /* Getting here means there is no AAudio backend so we need a no-op job implementation. */ @@ -40782,8 +42260,8 @@ static ma_result ma_context_uninit__webaudio(ma_context* pContext) /* Remove the global miniaudio object from window if there are no more references to it. */ EM_ASM({ if (typeof(window.miniaudio) !== 'undefined') { - miniaudio.unlock_event_types.map(function(event_type) { - document.removeEventListener(event_type, miniaudio.unlock, true); + window.miniaudio.unlock_event_types.map(function(event_type) { + document.removeEventListener(event_type, window.miniaudio.unlock, true); }); window.miniaudio.referenceCount -= 1; @@ -41236,13 +42714,13 @@ MA_API ma_result ma_device_post_init(ma_device* pDevice, ma_device_type deviceTy static ma_thread_result MA_THREADCALL ma_worker_thread(void* pData) { ma_device* pDevice = (ma_device*)pData; -#ifdef MA_WIN32 +#if defined(MA_WIN32) && !defined(MA_XBOX) HRESULT CoInitializeResult; #endif MA_ASSERT(pDevice != NULL); -#ifdef MA_WIN32 +#if defined(MA_WIN32) && !defined(MA_XBOX) CoInitializeResult = ma_CoInitializeEx(pDevice->pContext, NULL, MA_COINIT_VALUE); #endif @@ -41333,8 +42811,8 @@ static ma_thread_result MA_THREADCALL ma_worker_thread(void* pData) ma_event_signal(&pDevice->stopEvent); } -#ifdef MA_WIN32 - if (CoInitializeResult == S_OK) { +#if defined(MA_WIN32) && !defined(MA_XBOX) + if (CoInitializeResult == S_OK || CoInitializeResult == S_FALSE) { ma_CoUninitialize(pDevice->pContext); } #endif @@ -41358,67 +42836,92 @@ static ma_bool32 ma_device__is_initialized(ma_device* pDevice) static ma_result ma_context_uninit_backend_apis__win32(ma_context* pContext) { /* For some reason UWP complains when CoUninitialize() is called. I'm just not going to call it on UWP. */ -#if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) - if (pContext->win32.CoInitializeResult == S_OK) { - ma_CoUninitialize(pContext); + #if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) + { + /* TODO: Remove this once the new single threaded backend system is in place in 0.12. */ + #if !defined(MA_XBOX) + { + if (pContext->win32.CoInitializeResult == S_OK || pContext->win32.CoInitializeResult == S_FALSE) { + ma_CoUninitialize(pContext); /* TODO: Remove this once the new single threaded backend system is in place in 0.12. */ + } + } + #endif + + #if defined(MA_WIN32_DESKTOP) + ma_dlclose(ma_context_get_log(pContext), pContext->win32.hUser32DLL); + ma_dlclose(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL); + #endif + + ma_dlclose(ma_context_get_log(pContext), pContext->win32.hOle32DLL); + } + #else + { + (void)pContext; } - - #if defined(MA_WIN32_DESKTOP) - ma_dlclose(ma_context_get_log(pContext), pContext->win32.hUser32DLL); - ma_dlclose(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL); #endif - ma_dlclose(ma_context_get_log(pContext), pContext->win32.hOle32DLL); -#else - (void)pContext; -#endif - return MA_SUCCESS; } static ma_result ma_context_init_backend_apis__win32(ma_context* pContext) { -#if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) - #if defined(MA_WIN32_DESKTOP) - /* User32.dll */ - pContext->win32.hUser32DLL = ma_dlopen(ma_context_get_log(pContext), "user32.dll"); - if (pContext->win32.hUser32DLL == NULL) { + /* + TODO: Reassess all of this stuff and move everything to the relevant backends. For example, I think + GetForegroundWindow() and GetDesktopWindow() are only used by the DirectSound backend. + */ + #if (defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK)) && !defined(MA_XBOX) + { + #if defined(MA_WIN32_DESKTOP) + { + /* User32.dll */ + pContext->win32.hUser32DLL = ma_dlopen(ma_context_get_log(pContext), "user32.dll"); + if (pContext->win32.hUser32DLL == NULL) { + return MA_FAILED_TO_INIT_BACKEND; + } + + pContext->win32.GetForegroundWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetForegroundWindow"); + pContext->win32.GetDesktopWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetDesktopWindow"); + + + /* Advapi32.dll */ + pContext->win32.hAdvapi32DLL = ma_dlopen(ma_context_get_log(pContext), "advapi32.dll"); + if (pContext->win32.hAdvapi32DLL == NULL) { + return MA_FAILED_TO_INIT_BACKEND; + } + + pContext->win32.RegOpenKeyExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegOpenKeyExA"); + pContext->win32.RegCloseKey = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegCloseKey"); + pContext->win32.RegQueryValueExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegQueryValueExA"); + } + #endif + + /* Ole32.dll */ + pContext->win32.hOle32DLL = ma_dlopen(ma_context_get_log(pContext), "ole32.dll"); + if (pContext->win32.hOle32DLL == NULL) { return MA_FAILED_TO_INIT_BACKEND; } - pContext->win32.GetForegroundWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetForegroundWindow"); - pContext->win32.GetDesktopWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetDesktopWindow"); - - - /* Advapi32.dll */ - pContext->win32.hAdvapi32DLL = ma_dlopen(ma_context_get_log(pContext), "advapi32.dll"); - if (pContext->win32.hAdvapi32DLL == NULL) { - return MA_FAILED_TO_INIT_BACKEND; - } - - pContext->win32.RegOpenKeyExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegOpenKeyExA"); - pContext->win32.RegCloseKey = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegCloseKey"); - pContext->win32.RegQueryValueExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegQueryValueExA"); + pContext->win32.CoInitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitialize"); + pContext->win32.CoInitializeEx = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitializeEx"); + pContext->win32.CoUninitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoUninitialize"); + pContext->win32.CoCreateInstance = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoCreateInstance"); + pContext->win32.CoTaskMemFree = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoTaskMemFree"); + pContext->win32.PropVariantClear = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "PropVariantClear"); + pContext->win32.StringFromGUID2 = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "StringFromGUID2"); + } + #else + { + (void)pContext; /* Unused. */ + } #endif - /* Ole32.dll */ - pContext->win32.hOle32DLL = ma_dlopen(ma_context_get_log(pContext), "ole32.dll"); - if (pContext->win32.hOle32DLL == NULL) { - return MA_FAILED_TO_INIT_BACKEND; + /* TODO: Remove this once the new single threaded backend system is in place in 0.12. */ + #if !defined(MA_XBOX) + { + pContext->win32.CoInitializeResult = ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); } + #endif - pContext->win32.CoInitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitialize"); - pContext->win32.CoInitializeEx = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitializeEx"); - pContext->win32.CoUninitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoUninitialize"); - pContext->win32.CoCreateInstance = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoCreateInstance"); - pContext->win32.CoTaskMemFree = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoTaskMemFree"); - pContext->win32.PropVariantClear = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "PropVariantClear"); - pContext->win32.StringFromGUID2 = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "StringFromGUID2"); -#else - (void)pContext; /* Unused. */ -#endif - - pContext->win32.CoInitializeResult = ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); return MA_SUCCESS; } #else @@ -44016,7 +45519,7 @@ static MA_INLINE void ma_pcm_s16_to_s32__reference(void* dst, const void* src, m ma_uint64 i; for (i = 0; i < count; i += 1) { - dst_s32[i] = src_s16[i] << 16; + dst_s32[i] = (ma_int32)src_s16[i] << 16; } (void)ditherMode; @@ -56408,8 +57911,12 @@ MA_API size_t ma_channel_map_to_string(const ma_channel* pChannelMap, ma_uint32 } /* Null terminate. Don't increment the length here. */ - if (pBufferOut != NULL && bufferCap > len + 1) { - pBufferOut[len] = '\0'; + if (pBufferOut != NULL) { + if (bufferCap > len) { + pBufferOut[len] = '\0'; + } else if (bufferCap > 0) { + pBufferOut[bufferCap - 1] = '\0'; + } } return len; @@ -56620,7 +58127,7 @@ MA_API ma_result ma_rb_init_ex(size_t subbufferSizeInBytes, size_t subbufferCoun Here is where we allocate our own buffer. We always want to align this to MA_SIMD_ALIGNMENT for future SIMD optimization opportunity. To do this we need to make sure the stride is a multiple of MA_SIMD_ALIGNMENT. */ - pRB->subbufferStrideInBytes = (pRB->subbufferSizeInBytes + (MA_SIMD_ALIGNMENT-1)) & ~MA_SIMD_ALIGNMENT; + pRB->subbufferStrideInBytes = ma_align(pRB->subbufferSizeInBytes, MA_SIMD_ALIGNMENT); bufferSizeInBytes = (size_t)pRB->subbufferCount*pRB->subbufferStrideInBytes; pRB->pBuffer = ma_aligned_malloc(bufferSizeInBytes, MA_SIMD_ALIGNMENT, &pRB->allocationCallbacks); @@ -59515,7 +61022,7 @@ MA_API ma_result ma_vfs_info(ma_vfs* pVFS, ma_vfs_file file, ma_file_info* pInfo } -#if !defined(MA_USE_WIN32_FILEIO) && (defined(MA_WIN32) && defined(MA_WIN32_DESKTOP) && !defined(MA_NO_WIN32_FILEIO) && !defined(MA_POSIX)) +#if !defined(MA_USE_WIN32_FILEIO) && (defined(MA_WIN32) && (defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_NXDK)) && !defined(MA_NO_WIN32_FILEIO) && !defined(MA_POSIX)) #define MA_USE_WIN32_FILEIO #endif @@ -59592,25 +61099,34 @@ static ma_result ma_default_vfs_open__win32(ma_vfs* pVFS, const char* pFilePath, static ma_result ma_default_vfs_open_w__win32(ma_vfs* pVFS, const wchar_t* pFilePath, ma_uint32 openMode, ma_vfs_file* pFile) { - HANDLE hFile; - DWORD dwDesiredAccess; - DWORD dwShareMode; - DWORD dwCreationDisposition; + #if !defined(MA_XBOX_NXDK) + { + HANDLE hFile; + DWORD dwDesiredAccess; + DWORD dwShareMode; + DWORD dwCreationDisposition; - (void)pVFS; + (void)pVFS; - /* Load some Win32 symbols dynamically so we can dynamically check for the existence of SetFilePointerEx. */ - ma_win32_fileio_init(); + /* Load some Win32 symbols dynamically so we can dynamically check for the existence of SetFilePointerEx. */ + ma_win32_fileio_init(); - ma_default_vfs__get_open_settings_win32(openMode, &dwDesiredAccess, &dwShareMode, &dwCreationDisposition); + ma_default_vfs__get_open_settings_win32(openMode, &dwDesiredAccess, &dwShareMode, &dwCreationDisposition); - hFile = CreateFileW(pFilePath, dwDesiredAccess, dwShareMode, NULL, dwCreationDisposition, FILE_ATTRIBUTE_NORMAL, NULL); - if (hFile == INVALID_HANDLE_VALUE) { - return ma_result_from_GetLastError(GetLastError()); + hFile = CreateFileW(pFilePath, dwDesiredAccess, dwShareMode, NULL, dwCreationDisposition, FILE_ATTRIBUTE_NORMAL, NULL); + if (hFile == INVALID_HANDLE_VALUE) { + return ma_result_from_GetLastError(GetLastError()); + } + + *pFile = hFile; + return MA_SUCCESS; } - - *pFile = hFile; - return MA_SUCCESS; + #else + { + /* No CreateFileW() available. */ + return MA_NOT_IMPLEMENTED; + } + #endif } static ma_result ma_default_vfs_close__win32(ma_vfs* pVFS, ma_vfs_file file) @@ -59781,19 +61297,28 @@ static ma_result ma_default_vfs_tell__win32(ma_vfs* pVFS, ma_vfs_file file, ma_i static ma_result ma_default_vfs_info__win32(ma_vfs* pVFS, ma_vfs_file file, ma_file_info* pInfo) { - BY_HANDLE_FILE_INFORMATION fi; - BOOL result; - (void)pVFS; - result = GetFileInformationByHandle((HANDLE)file, &fi); - if (result == 0) { - return ma_result_from_GetLastError(GetLastError()); + #if !defined(MA_XBOX_NXDK) + { + BY_HANDLE_FILE_INFORMATION fi; + BOOL result; + + result = GetFileInformationByHandle((HANDLE)file, &fi); + if (result == 0) { + return ma_result_from_GetLastError(GetLastError()); + } + + pInfo->sizeInBytes = ((ma_uint64)fi.nFileSizeHigh << 32) | ((ma_uint64)fi.nFileSizeLow); + + return MA_SUCCESS; } - - pInfo->sizeInBytes = ((ma_uint64)fi.nFileSizeHigh << 32) | ((ma_uint64)fi.nFileSizeLow); - - return MA_SUCCESS; + #else + { + /* GetFileInformationByHandle() is unavailable. */ + return MA_NOT_IMPLEMENTED; + } + #endif } #else static ma_result ma_default_vfs_open__stdio(ma_vfs* pVFS, const char* pFilePath, ma_uint32 openMode, ma_vfs_file* pFile) @@ -60131,6 +61656,8 @@ static ma_result ma_default_vfs_tell(ma_vfs* pVFS, ma_vfs_file file, ma_int64* p static ma_result ma_default_vfs_info(ma_vfs* pVFS, ma_vfs_file file, ma_file_info* pInfo) { + ma_result result; + if (pInfo == NULL) { return MA_INVALID_ARGS; } @@ -60142,10 +61669,43 @@ static ma_result ma_default_vfs_info(ma_vfs* pVFS, ma_vfs_file file, ma_file_inf } #if defined(MA_USE_WIN32_FILEIO) - return ma_default_vfs_info__win32(pVFS, file, pInfo); + result = ma_default_vfs_info__win32(pVFS, file, pInfo); #else - return ma_default_vfs_info__stdio(pVFS, file, pInfo); + result = ma_default_vfs_info__stdio(pVFS, file, pInfo); #endif + + if (result == MA_NOT_IMPLEMENTED) { + /* Not implemented. Fall back to seek/tell/seek. */ + ma_result result; + ma_int64 cursor; + ma_int64 sizeInBytes; + + result = ma_default_vfs_tell(pVFS, file, &cursor); + if (result != MA_SUCCESS) { + return result; + } + + result = ma_default_vfs_seek(pVFS, file, 0, ma_seek_origin_end); + if (result != MA_SUCCESS) { + return result; + } + + result = ma_default_vfs_tell(pVFS, file, &sizeInBytes); + if (result != MA_SUCCESS) { + return result; + } + + pInfo->sizeInBytes = sizeInBytes; + + result = ma_default_vfs_seek(pVFS, file, cursor, ma_seek_origin_start); + if (result != MA_SUCCESS) { + return result; + } + + MA_ASSERT(result == MA_SUCCESS); + } + + return result; } @@ -60333,8 +61893,8 @@ extern "C" { #define MA_DR_WAV_STRINGIFY(x) #x #define MA_DR_WAV_XSTRINGIFY(x) MA_DR_WAV_STRINGIFY(x) #define MA_DR_WAV_VERSION_MAJOR 0 -#define MA_DR_WAV_VERSION_MINOR 13 -#define MA_DR_WAV_VERSION_REVISION 18 +#define MA_DR_WAV_VERSION_MINOR 14 +#define MA_DR_WAV_VERSION_REVISION 1 #define MA_DR_WAV_VERSION_STRING MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MAJOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MINOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_REVISION) #include #define MA_DR_WAVE_FORMAT_PCM 0x1 @@ -60350,8 +61910,9 @@ MA_API void ma_dr_wav_version(ma_uint32* pMajor, ma_uint32* pMinor, ma_uint32* p MA_API const char* ma_dr_wav_version_string(void); typedef enum { - ma_dr_wav_seek_origin_start, - ma_dr_wav_seek_origin_current + MA_DR_WAV_SEEK_SET, + MA_DR_WAV_SEEK_CUR, + MA_DR_WAV_SEEK_END } ma_dr_wav_seek_origin; typedef enum { @@ -60388,6 +61949,7 @@ MA_API ma_uint16 ma_dr_wav_fmt_get_format(const ma_dr_wav_fmt* pFMT); typedef size_t (* ma_dr_wav_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); typedef size_t (* ma_dr_wav_write_proc)(void* pUserData, const void* pData, size_t bytesToWrite); typedef ma_bool32 (* ma_dr_wav_seek_proc)(void* pUserData, int offset, ma_dr_wav_seek_origin origin); +typedef ma_bool32 (* ma_dr_wav_tell_proc)(void* pUserData, ma_int64* pCursor); typedef ma_uint64 (* ma_dr_wav_chunk_proc)(void* pChunkUserData, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pReadSeekUserData, const ma_dr_wav_chunk_header* pChunkHeader, ma_dr_wav_container container, const ma_dr_wav_fmt* pFMT); typedef struct { @@ -60432,6 +61994,11 @@ typedef enum ma_dr_wav_metadata_type_list_info_genre = 1 << 15, ma_dr_wav_metadata_type_list_info_album = 1 << 16, ma_dr_wav_metadata_type_list_info_tracknumber = 1 << 17, + ma_dr_wav_metadata_type_list_info_location = 1 << 18, + ma_dr_wav_metadata_type_list_info_organization = 1 << 19, + ma_dr_wav_metadata_type_list_info_keywords = 1 << 20, + ma_dr_wav_metadata_type_list_info_medium = 1 << 21, + ma_dr_wav_metadata_type_list_info_description = 1 << 22, ma_dr_wav_metadata_type_list_all_info_strings = ma_dr_wav_metadata_type_list_info_software | ma_dr_wav_metadata_type_list_info_copyright | ma_dr_wav_metadata_type_list_info_title @@ -60440,7 +62007,12 @@ typedef enum | ma_dr_wav_metadata_type_list_info_date | ma_dr_wav_metadata_type_list_info_genre | ma_dr_wav_metadata_type_list_info_album - | ma_dr_wav_metadata_type_list_info_tracknumber, + | ma_dr_wav_metadata_type_list_info_tracknumber + | ma_dr_wav_metadata_type_list_info_location + | ma_dr_wav_metadata_type_list_info_organization + | ma_dr_wav_metadata_type_list_info_keywords + | ma_dr_wav_metadata_type_list_info_medium + | ma_dr_wav_metadata_type_list_info_description, ma_dr_wav_metadata_type_list_all_adtl = ma_dr_wav_metadata_type_list_label | ma_dr_wav_metadata_type_list_note | ma_dr_wav_metadata_type_list_labelled_cue_region, @@ -60457,8 +62029,8 @@ typedef struct { ma_uint32 cuePointId; ma_uint32 type; - ma_uint32 firstSampleByteOffset; - ma_uint32 lastSampleByteOffset; + ma_uint32 firstSampleOffset; + ma_uint32 lastSampleOffset; ma_uint32 sampleFraction; ma_uint32 playCount; } ma_dr_wav_smpl_loop; @@ -60493,7 +62065,7 @@ typedef struct ma_uint8 dataChunkId[4]; ma_uint32 chunkStart; ma_uint32 blockStart; - ma_uint32 sampleByteOffset; + ma_uint32 sampleOffset; } ma_dr_wav_cue_point; typedef struct { @@ -60595,6 +62167,7 @@ typedef struct ma_dr_wav_read_proc onRead; ma_dr_wav_write_proc onWrite; ma_dr_wav_seek_proc onSeek; + ma_dr_wav_tell_proc onTell; void* pUserData; ma_allocation_callbacks allocationCallbacks; ma_dr_wav_container container; @@ -60637,9 +62210,9 @@ typedef struct ma_bool8 isUnsigned; } aiff; } ma_dr_wav; -MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, ma_dr_wav_chunk_proc onChunk, void* pReadSeekTellUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_write(ma_dr_wav* pWav, const ma_dr_wav_data_format* pFormat, ma_dr_wav_write_proc onWrite, ma_dr_wav_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_write_sequential(ma_dr_wav* pWav, const ma_dr_wav_data_format* pFormat, ma_uint64 totalSampleCount, ma_dr_wav_write_proc onWrite, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_write_sequential_pcm_frames(ma_dr_wav* pWav, const ma_dr_wav_data_format* pFormat, ma_uint64 totalPCMFrameCount, ma_dr_wav_write_proc onWrite, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); @@ -60711,9 +62284,9 @@ MA_API ma_bool32 ma_dr_wav_init_memory_write(ma_dr_wav* pWav, void** ppData, siz MA_API ma_bool32 ma_dr_wav_init_memory_write_sequential(ma_dr_wav* pWav, void** ppData, size_t* pDataSize, const ma_dr_wav_data_format* pFormat, ma_uint64 totalSampleCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_memory_write_sequential_pcm_frames(ma_dr_wav* pWav, void** ppData, size_t* pDataSize, const ma_dr_wav_data_format* pFormat, ma_uint64 totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_WAV_NO_CONVERSION_API -MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_WAV_NO_STDIO MA_API ma_int16* ma_dr_wav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); MA_API float* ma_dr_wav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); @@ -60753,8 +62326,8 @@ extern "C" { #define MA_DR_FLAC_STRINGIFY(x) #x #define MA_DR_FLAC_XSTRINGIFY(x) MA_DR_FLAC_STRINGIFY(x) #define MA_DR_FLAC_VERSION_MAJOR 0 -#define MA_DR_FLAC_VERSION_MINOR 12 -#define MA_DR_FLAC_VERSION_REVISION 43 +#define MA_DR_FLAC_VERSION_MINOR 13 +#define MA_DR_FLAC_VERSION_REVISION 2 #define MA_DR_FLAC_VERSION_STRING MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_MAJOR) "." MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_MINOR) "." MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_REVISION) #include #if defined(_MSC_VER) && _MSC_VER >= 1700 @@ -60817,8 +62390,9 @@ typedef enum } ma_dr_flac_container; typedef enum { - ma_dr_flac_seek_origin_start, - ma_dr_flac_seek_origin_current + MA_DR_FLAC_SEEK_SET, + MA_DR_FLAC_SEEK_CUR, + MA_DR_FLAC_SEEK_END } ma_dr_flac_seek_origin; typedef struct { @@ -60841,8 +62415,9 @@ typedef struct typedef struct { ma_uint32 type; - const void* pRawData; ma_uint32 rawDataSize; + ma_uint64 rawDataOffset; + const void* pRawData; union { ma_dr_flac_streaminfo streaminfo; @@ -60888,12 +62463,14 @@ typedef struct ma_uint32 colorDepth; ma_uint32 indexColorCount; ma_uint32 pictureDataSize; + ma_uint64 pictureDataOffset; const ma_uint8* pPictureData; } picture; } data; } ma_dr_flac_metadata; typedef size_t (* ma_dr_flac_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); typedef ma_bool32 (* ma_dr_flac_seek_proc)(void* pUserData, int offset, ma_dr_flac_seek_origin origin); +typedef ma_bool32 (* ma_dr_flac_tell_proc)(void* pUserData, ma_int64* pCursor); typedef void (* ma_dr_flac_meta_proc)(void* pUserData, ma_dr_flac_metadata* pMetadata); typedef struct { @@ -60905,6 +62482,7 @@ typedef struct { ma_dr_flac_read_proc onRead; ma_dr_flac_seek_proc onSeek; + ma_dr_flac_tell_proc onTell; void* pUserData; size_t unalignedByteCount; ma_dr_flac_cache_t unalignedCache; @@ -60964,10 +62542,10 @@ typedef struct ma_dr_flac_bs bs; ma_uint8 pExtraData[1]; } ma_dr_flac; -MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); MA_API void ma_dr_flac_close(ma_dr_flac* pFlac); MA_API ma_uint64 ma_dr_flac_read_pcm_frames_s32(ma_dr_flac* pFlac, ma_uint64 framesToRead, ma_int32* pBufferOut); MA_API ma_uint64 ma_dr_flac_read_pcm_frames_s16(ma_dr_flac* pFlac, ma_uint64 framesToRead, ma_int16* pBufferOut); @@ -60981,9 +62559,9 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_with_metadata_w(const wchar_t* pFileName #endif MA_API ma_dr_flac* ma_dr_flac_open_memory(const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_dr_flac* ma_dr_flac_open_memory_with_metadata(const void* pData, size_t dataSize, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_FLAC_NO_STDIO MA_API ma_int32* ma_dr_flac_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_int16* ma_dr_flac_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); @@ -61031,6 +62609,12 @@ MA_API ma_bool32 ma_dr_flac_next_cuesheet_track(ma_dr_flac_cuesheet_track_iterat #endif /* MA_NO_FLAC */ #if !defined(MA_NO_MP3) && !defined(MA_NO_DECODING) +#ifndef MA_DR_MP3_NO_SIMD + #if (defined(MA_NO_NEON) && defined(MA_ARM)) || (defined(MA_NO_SSE2) && (defined(MA_X86) || defined(MA_X64))) + #define MA_DR_MP3_NO_SIMD + #endif +#endif + /* dr_mp3_h begin */ #ifndef ma_dr_mp3_h #define ma_dr_mp3_h @@ -61040,31 +62624,57 @@ extern "C" { #define MA_DR_MP3_STRINGIFY(x) #x #define MA_DR_MP3_XSTRINGIFY(x) MA_DR_MP3_STRINGIFY(x) #define MA_DR_MP3_VERSION_MAJOR 0 -#define MA_DR_MP3_VERSION_MINOR 6 -#define MA_DR_MP3_VERSION_REVISION 40 +#define MA_DR_MP3_VERSION_MINOR 7 +#define MA_DR_MP3_VERSION_REVISION 2 #define MA_DR_MP3_VERSION_STRING MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_MAJOR) "." MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_MINOR) "." MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_REVISION) #include #define MA_DR_MP3_MAX_PCM_FRAMES_PER_MP3_FRAME 1152 #define MA_DR_MP3_MAX_SAMPLES_PER_FRAME (MA_DR_MP3_MAX_PCM_FRAMES_PER_MP3_FRAME*2) MA_API void ma_dr_mp3_version(ma_uint32* pMajor, ma_uint32* pMinor, ma_uint32* pRevision); MA_API const char* ma_dr_mp3_version_string(void); +#define MA_DR_MP3_MAX_BITRESERVOIR_BYTES 511 +#define MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE 2304 +#define MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE typedef struct { - int frame_bytes, channels, hz, layer, bitrate_kbps; + int frame_bytes, channels, sample_rate, layer, bitrate_kbps; } ma_dr_mp3dec_frame_info; typedef struct +{ + const ma_uint8 *buf; + int pos, limit; +} ma_dr_mp3_bs; +typedef struct +{ + const ma_uint8 *sfbtab; + ma_uint16 part_23_length, big_values, scalefac_compress; + ma_uint8 global_gain, block_type, mixed_block_flag, n_long_sfb, n_short_sfb; + ma_uint8 table_select[3], region_count[3], subblock_gain[3]; + ma_uint8 preflag, scalefac_scale, count1_table, scfsi; +} ma_dr_mp3_L3_gr_info; +typedef struct +{ + ma_dr_mp3_bs bs; + ma_uint8 maindata[MA_DR_MP3_MAX_BITRESERVOIR_BYTES + MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES]; + ma_dr_mp3_L3_gr_info gr_info[4]; + float grbuf[2][576], scf[40], syn[18 + 15][2*32]; + ma_uint8 ist_pos[2][39]; +} ma_dr_mp3dec_scratch; +typedef struct { float mdct_overlap[2][9*32], qmf_state[15*2*32]; int reserv, free_format_bytes; ma_uint8 header[4], reserv_buf[511]; + ma_dr_mp3dec_scratch scratch; } ma_dr_mp3dec; MA_API void ma_dr_mp3dec_init(ma_dr_mp3dec *dec); MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int mp3_bytes, void *pcm, ma_dr_mp3dec_frame_info *info); MA_API void ma_dr_mp3dec_f32_to_s16(const float *in, ma_int16 *out, size_t num_samples); typedef enum { - ma_dr_mp3_seek_origin_start, - ma_dr_mp3_seek_origin_current + MA_DR_MP3_SEEK_SET, + MA_DR_MP3_SEEK_CUR, + MA_DR_MP3_SEEK_END } ma_dr_mp3_seek_origin; typedef struct { @@ -61073,8 +62683,24 @@ typedef struct ma_uint16 mp3FramesToDiscard; ma_uint16 pcmFramesToDiscard; } ma_dr_mp3_seek_point; +typedef enum +{ + MA_DR_MP3_METADATA_TYPE_ID3V1, + MA_DR_MP3_METADATA_TYPE_ID3V2, + MA_DR_MP3_METADATA_TYPE_APE, + MA_DR_MP3_METADATA_TYPE_XING, + MA_DR_MP3_METADATA_TYPE_VBRI +} ma_dr_mp3_metadata_type; +typedef struct +{ + ma_dr_mp3_metadata_type type; + const void* pRawData; + size_t rawDataSize; +} ma_dr_mp3_metadata; typedef size_t (* ma_dr_mp3_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); typedef ma_bool32 (* ma_dr_mp3_seek_proc)(void* pUserData, int offset, ma_dr_mp3_seek_origin origin); +typedef ma_bool32 (* ma_dr_mp3_tell_proc)(void* pUserData, ma_int64* pCursor); +typedef void (* ma_dr_mp3_meta_proc)(void* pUserData, const ma_dr_mp3_metadata* pMetadata); typedef struct { ma_uint32 channels; @@ -61087,7 +62713,9 @@ typedef struct ma_uint32 sampleRate; ma_dr_mp3_read_proc onRead; ma_dr_mp3_seek_proc onSeek; + ma_dr_mp3_meta_proc onMeta; void* pUserData; + void* pUserDataMeta; ma_allocation_callbacks allocationCallbacks; ma_uint32 mp3FrameChannels; ma_uint32 mp3FrameSampleRate; @@ -61096,13 +62724,20 @@ typedef struct ma_uint8 pcmFrames[sizeof(float)*MA_DR_MP3_MAX_SAMPLES_PER_FRAME]; ma_uint64 currentPCMFrame; ma_uint64 streamCursor; + ma_uint64 streamLength; + ma_uint64 streamStartOffset; ma_dr_mp3_seek_point* pSeekPoints; ma_uint32 seekPointCount; + ma_uint32 delayInPCMFrames; + ma_uint32 paddingInPCMFrames; + ma_uint64 totalPCMFrameCount; + ma_bool32 isVBR; + ma_bool32 isCBR; size_t dataSize; size_t dataCapacity; size_t dataConsumed; ma_uint8* pData; - ma_bool32 atEnd : 1; + ma_bool32 atEnd; struct { const ma_uint8* pData; @@ -61110,9 +62745,12 @@ typedef struct size_t currentReadPos; } memory; } ma_dr_mp3; -MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, ma_dr_mp3_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_mp3_init_memory_with_metadata(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_MP3_NO_STDIO +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata(ma_dr_mp3* pMP3, const char* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_mp3_init_file(ma_dr_mp3* pMP3, const char* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_mp3_init_file_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks); #endif @@ -61125,8 +62763,8 @@ MA_API ma_uint64 ma_dr_mp3_get_mp3_frame_count(ma_dr_mp3* pMP3); MA_API ma_bool32 ma_dr_mp3_get_mp3_and_pcm_frame_count(ma_dr_mp3* pMP3, ma_uint64* pMP3FrameCount, ma_uint64* pPCMFrameCount); MA_API ma_bool32 ma_dr_mp3_calculate_seek_points(ma_dr_mp3* pMP3, ma_uint32* pSeekPointCount, ma_dr_mp3_seek_point* pSeekPoints); MA_API ma_bool32 ma_dr_mp3_bind_seek_table(ma_dr_mp3* pMP3, ma_uint32 seekPointCount, ma_dr_mp3_seek_point* pSeekPoints); -MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API float* ma_dr_mp3_open_memory_and_read_pcm_frames_f32(const void* pData, size_t dataSize, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_int16* ma_dr_mp3_open_memory_and_read_pcm_frames_s16(const void* pData, size_t dataSize, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_MP3_NO_STDIO @@ -61679,8 +63317,10 @@ static ma_bool32 ma_wav_dr_callback__seek(void* pUserData, int offset, ma_dr_wav MA_ASSERT(pWav != NULL); maSeekOrigin = ma_seek_origin_start; - if (origin == ma_dr_wav_seek_origin_current) { - maSeekOrigin = ma_seek_origin_current; + if (origin == MA_DR_WAV_SEEK_CUR) { + maSeekOrigin = ma_seek_origin_current; + } else if (origin == MA_DR_WAV_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; } result = pWav->onSeek(pWav->pReadSeekTellUserData, offset, maSeekOrigin); @@ -61690,6 +63330,26 @@ static ma_bool32 ma_wav_dr_callback__seek(void* pUserData, int offset, ma_dr_wav return MA_TRUE; } + +static ma_bool32 ma_wav_dr_callback__tell(void* pUserData, ma_int64* pCursor) +{ + ma_wav* pWav = (ma_wav*)pUserData; + ma_result result; + + MA_ASSERT(pWav != NULL); + MA_ASSERT(pCursor != NULL); + + if (pWav->onTell == NULL) { + return MA_FALSE; /* Not implemented. */ + } + + result = pWav->onTell(pWav->pReadSeekTellUserData, pCursor); + if (result != MA_SUCCESS) { + return MA_FALSE; /* Failed to tell. */ + } + + return MA_TRUE; +} #endif static ma_result ma_wav_init_internal(const ma_decoding_backend_config* pConfig, ma_wav* pWav) @@ -61784,7 +63444,7 @@ MA_API ma_result ma_wav_init(ma_read_proc onRead, ma_seek_proc onSeek, ma_tell_p { ma_bool32 wavResult; - wavResult = ma_dr_wav_init(&pWav->dr, ma_wav_dr_callback__read, ma_wav_dr_callback__seek, pWav, pAllocationCallbacks); + wavResult = ma_dr_wav_init(&pWav->dr, ma_wav_dr_callback__read, ma_wav_dr_callback__seek, ma_wav_dr_callback__tell, pWav, pAllocationCallbacks); if (wavResult != MA_TRUE) { return MA_INVALID_FILE; } @@ -62363,8 +64023,10 @@ static ma_bool32 ma_flac_dr_callback__seek(void* pUserData, int offset, ma_dr_fl MA_ASSERT(pFlac != NULL); maSeekOrigin = ma_seek_origin_start; - if (origin == ma_dr_flac_seek_origin_current) { - maSeekOrigin = ma_seek_origin_current; + if (origin == MA_DR_FLAC_SEEK_CUR) { + maSeekOrigin = ma_seek_origin_current; + } else if (origin == MA_DR_FLAC_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; } result = pFlac->onSeek(pFlac->pReadSeekTellUserData, offset, maSeekOrigin); @@ -62374,6 +64036,26 @@ static ma_bool32 ma_flac_dr_callback__seek(void* pUserData, int offset, ma_dr_fl return MA_TRUE; } + +static ma_bool32 ma_flac_dr_callback__tell(void* pUserData, ma_int64* pCursor) +{ + ma_flac* pFlac = (ma_flac*)pUserData; + ma_result result; + + MA_ASSERT(pFlac != NULL); + MA_ASSERT(pCursor != NULL); + + if (pFlac->onTell == NULL) { + return MA_FALSE; /* Not implemented. */ + } + + result = pFlac->onTell(pFlac->pReadSeekTellUserData, pCursor); + if (result != MA_SUCCESS) { + return MA_FALSE; /* Failed to tell. */ + } + + return MA_TRUE; +} #endif static ma_result ma_flac_init_internal(const ma_decoding_backend_config* pConfig, ma_flac* pFlac) @@ -62425,7 +64107,7 @@ MA_API ma_result ma_flac_init(ma_read_proc onRead, ma_seek_proc onSeek, ma_tell_ #if !defined(MA_NO_FLAC) { - pFlac->dr = ma_dr_flac_open(ma_flac_dr_callback__read, ma_flac_dr_callback__seek, pFlac, pAllocationCallbacks); + pFlac->dr = ma_dr_flac_open(ma_flac_dr_callback__read, ma_flac_dr_callback__seek, ma_flac_dr_callback__tell, pFlac, pAllocationCallbacks); if (pFlac->dr == NULL) { return MA_INVALID_FILE; } @@ -62986,9 +64668,12 @@ static ma_bool32 ma_mp3_dr_callback__seek(void* pUserData, int offset, ma_dr_mp3 MA_ASSERT(pMP3 != NULL); - maSeekOrigin = ma_seek_origin_start; - if (origin == ma_dr_mp3_seek_origin_current) { - maSeekOrigin = ma_seek_origin_current; + if (origin == MA_DR_MP3_SEEK_SET) { + maSeekOrigin = ma_seek_origin_start; + } else if (origin == MA_DR_MP3_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; + } else { + maSeekOrigin = ma_seek_origin_current; } result = pMP3->onSeek(pMP3->pReadSeekTellUserData, offset, maSeekOrigin); @@ -62998,6 +64683,21 @@ static ma_bool32 ma_mp3_dr_callback__seek(void* pUserData, int offset, ma_dr_mp3 return MA_TRUE; } + +static ma_bool32 ma_mp3_dr_callback__tell(void* pUserData, ma_int64* pCursor) +{ + ma_mp3* pMP3 = (ma_mp3*)pUserData; + ma_result result; + + MA_ASSERT(pMP3 != NULL); + + result = pMP3->onTell(pMP3->pReadSeekTellUserData, pCursor); + if (result != MA_SUCCESS) { + return MA_FALSE; + } + + return MA_TRUE; +} #endif static ma_result ma_mp3_init_internal(const ma_decoding_backend_config* pConfig, ma_mp3* pMP3) @@ -63098,7 +64798,7 @@ MA_API ma_result ma_mp3_init(ma_read_proc onRead, ma_seek_proc onSeek, ma_tell_p { ma_bool32 mp3Result; - mp3Result = ma_dr_mp3_init(&pMP3->dr, ma_mp3_dr_callback__read, ma_mp3_dr_callback__seek, pMP3, pAllocationCallbacks); + mp3Result = ma_dr_mp3_init(&pMP3->dr, ma_mp3_dr_callback__read, ma_mp3_dr_callback__seek, ma_mp3_dr_callback__tell, NULL, pMP3, pAllocationCallbacks); if (mp3Result != MA_TRUE) { return MA_INVALID_FILE; } @@ -64997,14 +66697,16 @@ static ma_bool32 ma_path_extension_equal_w(const wchar_t* path, const wchar_t* e ext1 = extension; ext2 = ma_path_extension_w(path); -#if defined(_MSC_VER) || defined(__WATCOMC__) || defined(__DMC__) - return _wcsicmp(ext1, ext2) == 0; -#else - /* - I'm not aware of a wide character version of strcasecmp(). I'm therefore converting the extensions to multibyte strings and comparing those. This - isn't the most efficient way to do it, but it should work OK. - */ + #if (defined(_MSC_VER) || defined(__WATCOMC__) || defined(__DMC__)) && !defined(MA_XBOX_NXDK) { + return _wcsicmp(ext1, ext2) == 0; + } + #elif !defined(MA_XBOX_NXDK) && !defined(MA_DOS) + { + /* + I'm not aware of a wide character version of strcasecmp(). I'm therefore converting the extensions to multibyte strings and comparing those. This + isn't the most efficient way to do it, but it should work OK. + */ char ext1MB[4096]; char ext2MB[4096]; const wchar_t* pext1 = ext1; @@ -65024,7 +66726,13 @@ static ma_bool32 ma_path_extension_equal_w(const wchar_t* path, const wchar_t* e return strcasecmp(ext1MB, ext2MB) == 0; } -#endif + #else + { + /* Getting here means we don't have a way to do a case-sensitive comparison for wide strings. Fall back to a simple case-sensitive comparison. */ + /* TODO: Implement our own wchar_t-to-char conversion routine and then use the char* version for comparing. */ + return ma_wcscmp(ext1, ext2) == 0; + } + #endif } #endif /* MA_HAS_PATH_API */ @@ -66119,10 +67827,18 @@ static ma_bool32 ma_encoder__internal_on_seek_wav(void* pUserData, int offset, m { ma_encoder* pEncoder = (ma_encoder*)pUserData; ma_result result; + ma_seek_origin maSeekOrigin; MA_ASSERT(pEncoder != NULL); - result = pEncoder->onSeek(pEncoder, offset, (origin == ma_dr_wav_seek_origin_start) ? ma_seek_origin_start : ma_seek_origin_current); + maSeekOrigin = ma_seek_origin_start; + if (origin == MA_DR_WAV_SEEK_CUR) { + maSeekOrigin = ma_seek_origin_current; + } else if (origin == MA_DR_WAV_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; + } + + result = pEncoder->onSeek(pEncoder, offset, maSeekOrigin); if (result != MA_SUCCESS) { return MA_FALSE; } else { @@ -67644,7 +69360,7 @@ static MA_INLINE ma_uint32 ma_hash_getblock(const ma_uint32* blocks, int i) ma_uint32 block; /* Try silencing a sanitization warning about unaligned access by doing a memcpy() instead of assignment. */ - MA_COPY_MEMORY(&block, ma_offset_ptr(blocks, i * sizeof(block)), sizeof(block)); + MA_COPY_MEMORY(&block, ma_offset_ptr(blocks, i * (int) sizeof(block)), sizeof(block)); if (ma_is_little_endian()) { return block; @@ -67720,7 +69436,7 @@ static ma_uint32 ma_hash_string_32(const char* str) static ma_uint32 ma_hash_string_w_32(const wchar_t* str) { - return ma_hash_32(str, (int)wcslen(str) * sizeof(*str), MA_DEFAULT_HASH_SEED); + return ma_hash_32(str, (int)ma_wcslen(str) * sizeof(*str), MA_DEFAULT_HASH_SEED); } @@ -67880,6 +69596,7 @@ static MA_INLINE ma_resource_manager_data_buffer_node* ma_resource_manager_data_ return ma_resource_manager_data_buffer_node_find_min(pDataBufferNode->pChildHi); } +#if 0 /* Currently unused, but might make use of this later. */ static MA_INLINE ma_resource_manager_data_buffer_node* ma_resource_manager_data_buffer_node_find_inorder_predecessor(ma_resource_manager_data_buffer_node* pDataBufferNode) { MA_ASSERT(pDataBufferNode != NULL); @@ -67887,6 +69604,7 @@ static MA_INLINE ma_resource_manager_data_buffer_node* ma_resource_manager_data_ return ma_resource_manager_data_buffer_node_find_max(pDataBufferNode->pChildLo); } +#endif static ma_result ma_resource_manager_data_buffer_node_remove(ma_resource_manager* pResourceManager, ma_resource_manager_data_buffer_node* pDataBufferNode) { @@ -69009,16 +70727,19 @@ static ma_result ma_resource_manager_data_buffer_node_acquire_critical_section(m /* Failed to post job. Probably ran out of memory. */ ma_log_postf(ma_resource_manager_get_log(pResourceManager), MA_LOG_LEVEL_ERROR, "Failed to post MA_JOB_TYPE_RESOURCE_MANAGER_LOAD_DATA_BUFFER_NODE job. %s.\n", ma_result_description(result)); - /* - Fences were acquired before posting the job, but since the job was not able to - be posted, we need to make sure we release them so nothing gets stuck waiting. - */ - if (pInitFence != NULL) { ma_fence_release(pInitFence); } - if (pDoneFence != NULL) { ma_fence_release(pDoneFence); } - if ((flags & MA_RESOURCE_MANAGER_DATA_SOURCE_FLAG_WAIT_INIT) != 0) { ma_resource_manager_inline_notification_uninit(pInitNotification); } else { + /* + Fences were acquired before posting the job, but since the job was not able to + be posted, we need to make sure we release them so nothing gets stuck waiting. + + In the WAIT_INIT case, these will have already been released in ma_job_process() + so we should only release fences in this branch. + */ + if (pInitFence != NULL) { ma_fence_release(pInitFence); } + if (pDoneFence != NULL) { ma_fence_release(pDoneFence); } + /* These will have been freed by the job thread, but with WAIT_INIT they will already have happened since the job has already been handled. */ ma_free(pFilePathCopy, &pResourceManager->config.allocationCallbacks); ma_free(pFilePathWCopy, &pResourceManager->config.allocationCallbacks); @@ -76674,7 +78395,7 @@ static ma_result ma_sound_init_from_data_source_internal(ma_engine* pEngine, con } if (pConfig->loopPointBegInPCMFrames != 0 || pConfig->loopPointEndInPCMFrames != ~((ma_uint64)0)) { - ma_data_source_set_range_in_pcm_frames(ma_sound_get_data_source(pSound), pConfig->loopPointBegInPCMFrames, pConfig->loopPointEndInPCMFrames); + ma_data_source_set_loop_point_in_pcm_frames(ma_sound_get_data_source(pSound), pConfig->loopPointBegInPCMFrames, pConfig->loopPointEndInPCMFrames); } ma_sound_set_looping(pSound, pConfig->isLooping || ((pConfig->flags & MA_SOUND_FLAG_LOOPING) != 0)); @@ -76736,6 +78457,7 @@ MA_API ma_result ma_sound_init_from_file_internal(ma_engine* pEngine, const ma_s result = ma_resource_manager_data_source_init_ex(pEngine->pResourceManager, &resourceManagerDataSourceConfig, pSound->pResourceManagerDataSource); if (result != MA_SUCCESS) { + ma_free(pSound->pResourceManagerDataSource, &pEngine->allocationCallbacks); goto done; } @@ -77541,7 +79263,12 @@ MA_API ma_uint64 ma_sound_get_time_in_pcm_frames(const ma_sound* pSound) MA_API ma_uint64 ma_sound_get_time_in_milliseconds(const ma_sound* pSound) { - return ma_sound_get_time_in_pcm_frames(pSound) * 1000 / ma_engine_get_sample_rate(ma_sound_get_engine(pSound)); + ma_uint32 sampleRate = ma_engine_get_sample_rate(ma_sound_get_engine(pSound)); + if (sampleRate == 0) { + return 0; /* Prevent a division by zero. */ + } + + return ma_sound_get_time_in_pcm_frames(pSound) * 1000 / sampleRate; } MA_API void ma_sound_set_looping(ma_sound* pSound, ma_bool32 isLooping) @@ -77625,7 +79352,7 @@ MA_API ma_result ma_sound_seek_to_second(ma_sound* pSound, float seekPointInSeco return ma_sound_seek_to_pcm_frame(pSound, frameIndex); } -MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap) +MA_API ma_result ma_sound_get_data_format(const ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap) { if (pSound == NULL) { return MA_INVALID_ARGS; @@ -77658,7 +79385,7 @@ MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, } } -MA_API ma_result ma_sound_get_cursor_in_pcm_frames(ma_sound* pSound, ma_uint64* pCursor) +MA_API ma_result ma_sound_get_cursor_in_pcm_frames(const ma_sound* pSound, ma_uint64* pCursor) { ma_uint64 seekTarget; @@ -77680,7 +79407,7 @@ MA_API ma_result ma_sound_get_cursor_in_pcm_frames(ma_sound* pSound, ma_uint64* } } -MA_API ma_result ma_sound_get_length_in_pcm_frames(ma_sound* pSound, ma_uint64* pLength) +MA_API ma_result ma_sound_get_length_in_pcm_frames(const ma_sound* pSound, ma_uint64* pLength) { if (pSound == NULL) { return MA_INVALID_ARGS; @@ -77694,7 +79421,7 @@ MA_API ma_result ma_sound_get_length_in_pcm_frames(ma_sound* pSound, ma_uint64* return ma_data_source_get_length_in_pcm_frames(pSound->pDataSource, pLength); } -MA_API ma_result ma_sound_get_cursor_in_seconds(ma_sound* pSound, float* pCursor) +MA_API ma_result ma_sound_get_cursor_in_seconds(const ma_sound* pSound, float* pCursor) { ma_result result; ma_uint64 cursorInPCMFrames; @@ -77720,7 +79447,7 @@ MA_API ma_result ma_sound_get_cursor_in_seconds(ma_sound* pSound, float* pCursor return MA_SUCCESS; } -MA_API ma_result ma_sound_get_length_in_seconds(ma_sound* pSound, float* pLength) +MA_API ma_result ma_sound_get_length_in_seconds(const ma_sound* pSound, float* pLength) { if (pSound == NULL) { return MA_INVALID_ARGS; @@ -78539,12 +80266,12 @@ MA_PRIVATE ma_bool32 ma_dr_wav__seek_forward(ma_dr_wav_seek_proc onSeek, ma_uint ma_uint64 bytesRemainingToSeek = offset; while (bytesRemainingToSeek > 0) { if (bytesRemainingToSeek > 0x7FFFFFFF) { - if (!onSeek(pUserData, 0x7FFFFFFF, ma_dr_wav_seek_origin_current)) { + if (!onSeek(pUserData, 0x7FFFFFFF, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } bytesRemainingToSeek -= 0x7FFFFFFF; } else { - if (!onSeek(pUserData, (int)bytesRemainingToSeek, ma_dr_wav_seek_origin_current)) { + if (!onSeek(pUserData, (int)bytesRemainingToSeek, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } bytesRemainingToSeek = 0; @@ -78555,17 +80282,17 @@ MA_PRIVATE ma_bool32 ma_dr_wav__seek_forward(ma_dr_wav_seek_proc onSeek, ma_uint MA_PRIVATE ma_bool32 ma_dr_wav__seek_from_start(ma_dr_wav_seek_proc onSeek, ma_uint64 offset, void* pUserData) { if (offset <= 0x7FFFFFFF) { - return onSeek(pUserData, (int)offset, ma_dr_wav_seek_origin_start); + return onSeek(pUserData, (int)offset, MA_DR_WAV_SEEK_SET); } - if (!onSeek(pUserData, 0x7FFFFFFF, ma_dr_wav_seek_origin_start)) { + if (!onSeek(pUserData, 0x7FFFFFFF, MA_DR_WAV_SEEK_SET)) { return MA_FALSE; } offset -= 0x7FFFFFFF; for (;;) { if (offset <= 0x7FFFFFFF) { - return onSeek(pUserData, (int)offset, ma_dr_wav_seek_origin_current); + return onSeek(pUserData, (int)offset, MA_DR_WAV_SEEK_CUR); } - if (!onSeek(pUserData, 0x7FFFFFFF, ma_dr_wav_seek_origin_current)) { + if (!onSeek(pUserData, 0x7FFFFFFF, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } offset -= 0x7FFFFFFF; @@ -78588,7 +80315,7 @@ MA_PRIVATE ma_bool32 ma_dr_wav__on_seek(ma_dr_wav_seek_proc onSeek, void* pUserD if (!onSeek(pUserData, offset, origin)) { return MA_FALSE; } - if (origin == ma_dr_wav_seek_origin_start) { + if (origin == MA_DR_WAV_SEEK_SET) { *pCursor = offset; } else { *pCursor += offset; @@ -78707,12 +80434,12 @@ MA_PRIVATE ma_uint64 ma_dr_wav__read_smpl_to_metadata_obj(ma_dr_wav__metadata_pa ma_uint8 smplLoopData[MA_DR_WAV_SMPL_LOOP_BYTES]; bytesJustRead = ma_dr_wav__metadata_parser_read(pParser, smplLoopData, sizeof(smplLoopData), &totalBytesRead); if (bytesJustRead == sizeof(smplLoopData)) { - pMetadata->data.smpl.pLoops[iSampleLoop].cuePointId = ma_dr_wav_bytes_to_u32(smplLoopData + 0); - pMetadata->data.smpl.pLoops[iSampleLoop].type = ma_dr_wav_bytes_to_u32(smplLoopData + 4); - pMetadata->data.smpl.pLoops[iSampleLoop].firstSampleByteOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 8); - pMetadata->data.smpl.pLoops[iSampleLoop].lastSampleByteOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 12); - pMetadata->data.smpl.pLoops[iSampleLoop].sampleFraction = ma_dr_wav_bytes_to_u32(smplLoopData + 16); - pMetadata->data.smpl.pLoops[iSampleLoop].playCount = ma_dr_wav_bytes_to_u32(smplLoopData + 20); + pMetadata->data.smpl.pLoops[iSampleLoop].cuePointId = ma_dr_wav_bytes_to_u32(smplLoopData + 0); + pMetadata->data.smpl.pLoops[iSampleLoop].type = ma_dr_wav_bytes_to_u32(smplLoopData + 4); + pMetadata->data.smpl.pLoops[iSampleLoop].firstSampleOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 8); + pMetadata->data.smpl.pLoops[iSampleLoop].lastSampleOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 12); + pMetadata->data.smpl.pLoops[iSampleLoop].sampleFraction = ma_dr_wav_bytes_to_u32(smplLoopData + 16); + pMetadata->data.smpl.pLoops[iSampleLoop].playCount = ma_dr_wav_bytes_to_u32(smplLoopData + 20); } else { break; } @@ -78756,7 +80483,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav__read_cue_to_metadata_obj(ma_dr_wav__metadata_par pMetadata->data.cue.pCuePoints[iCuePoint].dataChunkId[3] = cuePointData[11]; pMetadata->data.cue.pCuePoints[iCuePoint].chunkStart = ma_dr_wav_bytes_to_u32(cuePointData + 12); pMetadata->data.cue.pCuePoints[iCuePoint].blockStart = ma_dr_wav_bytes_to_u32(cuePointData + 16); - pMetadata->data.cue.pCuePoints[iCuePoint].sampleByteOffset = ma_dr_wav_bytes_to_u32(cuePointData + 20); + pMetadata->data.cue.pCuePoints[iCuePoint].sampleOffset = ma_dr_wav_bytes_to_u32(cuePointData + 20); } else { break; } @@ -79096,7 +80823,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse if (pParser->stage == ma_dr_wav__metadata_parser_stage_count) { ma_uint8 buffer[4]; size_t bytesJustRead; - if (!pParser->onSeek(pParser->pReadSeekUserData, 28, ma_dr_wav_seek_origin_current)) { + if (!pParser->onSeek(pParser->pReadSeekUserData, 28, MA_DR_WAV_SEEK_CUR)) { return bytesRead; } bytesRead += 28; @@ -79191,7 +80918,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse return bytesRead; } allocSizeNeeded += ma_dr_wav__strlen(buffer) + 1; - allocSizeNeeded += (size_t)pChunkHeader->sizeInBytes - MA_DR_WAV_BEXT_BYTES; + allocSizeNeeded += (size_t)pChunkHeader->sizeInBytes - MA_DR_WAV_BEXT_BYTES + 1; ma_dr_wav__metadata_request_extra_memory_for_stage_2(pParser, allocSizeNeeded, 1); pParser->metadataCount += 1; } else { @@ -79274,6 +81001,16 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_album); } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_tracknumber, "ITRK")) { subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_tracknumber); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_location, "IARL")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_location); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_organization, "ICMS")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_organization); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_keywords, "IKEY")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_keywords); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_medium, "IMED")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_medium); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_description, "ISBJ")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_description); } else if ((allowedMetadataTypes & ma_dr_wav_metadata_type_unknown) != 0) { subchunkBytesRead = ma_dr_wav__metadata_process_unknown_chunk(pParser, subchunkId, subchunkDataSize, listType); } @@ -79281,13 +81018,13 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse MA_DR_WAV_ASSERT(subchunkBytesRead <= subchunkDataSize); if (subchunkBytesRead < subchunkDataSize) { ma_uint64 bytesToSeek = subchunkDataSize - subchunkBytesRead; - if (!pParser->onSeek(pParser->pReadSeekUserData, (int)bytesToSeek, ma_dr_wav_seek_origin_current)) { + if (!pParser->onSeek(pParser->pReadSeekUserData, (int)bytesToSeek, MA_DR_WAV_SEEK_CUR)) { break; } bytesRead += bytesToSeek; } if ((subchunkDataSize % 2) == 1) { - if (!pParser->onSeek(pParser->pReadSeekUserData, 1, ma_dr_wav_seek_origin_current)) { + if (!pParser->onSeek(pParser->pReadSeekUserData, 1, MA_DR_WAV_SEEK_CUR)) { break; } bytesRead += 1; @@ -79324,7 +81061,7 @@ MA_API ma_uint16 ma_dr_wav_fmt_get_format(const ma_dr_wav_fmt* pFMT) return ma_dr_wav_bytes_to_u16(pFMT->subFormat); } } -MA_PRIVATE ma_bool32 ma_dr_wav_preinit(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pReadSeekUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_PRIVATE ma_bool32 ma_dr_wav_preinit(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pReadSeekTellUserData, const ma_allocation_callbacks* pAllocationCallbacks) { if (pWav == NULL || onRead == NULL || onSeek == NULL) { return MA_FALSE; @@ -79332,7 +81069,8 @@ MA_PRIVATE ma_bool32 ma_dr_wav_preinit(ma_dr_wav* pWav, ma_dr_wav_read_proc onRe MA_DR_WAV_ZERO_MEMORY(pWav, sizeof(*pWav)); pWav->onRead = onRead; pWav->onSeek = onSeek; - pWav->pUserData = pReadSeekUserData; + pWav->onTell = onTell; + pWav->pUserData = pReadSeekTellUserData; pWav->allocationCallbacks = ma_dr_wav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) { return MA_FALSE; @@ -79546,14 +81284,14 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p fmt.channelMask = ma_dr_wav_bytes_to_u32_ex(fmtext + 2, pWav->container); ma_dr_wav_bytes_to_guid(fmtext + 6, fmt.subFormat); } else { - if (pWav->onSeek(pWav->pUserData, fmt.extendedSize, ma_dr_wav_seek_origin_current) == MA_FALSE) { + if (pWav->onSeek(pWav->pUserData, fmt.extendedSize, MA_DR_WAV_SEEK_CUR) == MA_FALSE) { return MA_FALSE; } } cursor += fmt.extendedSize; bytesReadSoFar += fmt.extendedSize; } - if (pWav->onSeek(pWav->pUserData, (int)(header.sizeInBytes - bytesReadSoFar), ma_dr_wav_seek_origin_current) == MA_FALSE) { + if (pWav->onSeek(pWav->pUserData, (int)(header.sizeInBytes - bytesReadSoFar), MA_DR_WAV_SEEK_CUR) == MA_FALSE) { return MA_FALSE; } cursor += (header.sizeInBytes - bytesReadSoFar); @@ -79704,15 +81442,26 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p return MA_FALSE; } offset = ma_dr_wav_bytes_to_u32_ex(offsetAndBlockSizeData + 0, pWav->container); - if (ma_dr_wav__seek_forward(pWav->onSeek, offset, pWav->pUserData) == MA_FALSE) { - return MA_FALSE; - } - cursor += offset; - pWav->dataChunkDataPos = cursor; + pWav->dataChunkDataPos = cursor + offset; dataChunkSize = chunkSize; - if (sequential || !isProcessingMetadata) { - break; + if (dataChunkSize > offset) { + dataChunkSize -= offset; } else { + dataChunkSize = 0; + } + if (sequential) { + if (foundChunk_fmt) { + if (ma_dr_wav__seek_forward(pWav->onSeek, offset, pWav->pUserData) == MA_FALSE) { + return MA_FALSE; + } + cursor += offset; + break; + } else { + return MA_FALSE; + } + } else { + chunkSize += header.paddingSize; + chunkSize -= sizeof(offsetAndBlockSizeData); if (ma_dr_wav__seek_forward(pWav->onSeek, chunkSize, pWav->pUserData) == MA_FALSE) { break; } @@ -79776,6 +81525,17 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p pWav->pMetadata = metadataParser.pMetadata; pWav->metadataCount = metadataParser.metadataCount; } + if (pWav->onTell != NULL && pWav->onSeek != NULL) { + if (pWav->onSeek(pWav->pUserData, 0, MA_DR_WAV_SEEK_END) == MA_TRUE) { + ma_int64 fileSize; + if (pWav->onTell(pWav->pUserData, &fileSize)) { + if (dataChunkSize + pWav->dataChunkDataPos > (ma_uint64)fileSize) { + dataChunkSize = (ma_uint64)fileSize - pWav->dataChunkDataPos; + } + } + } else { + } + } if (dataChunkSize == 0xFFFFFFFF && (pWav->container == ma_dr_wav_container_riff || pWav->container == ma_dr_wav_container_rifx) && pWav->isSequentialWrite == MA_FALSE) { dataChunkSize = 0; for (;;) { @@ -79795,8 +81555,14 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p pWav->sampleRate = fmt.sampleRate; pWav->channels = fmt.channels; pWav->bitsPerSample = fmt.bitsPerSample; - pWav->bytesRemaining = dataChunkSize; pWav->translatedFormatTag = translatedFormatTag; + if (!ma_dr_wav__is_compressed_format_tag(translatedFormatTag)) { + ma_uint32 bytesPerFrame = ma_dr_wav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame > 0) { + dataChunkSize -= (dataChunkSize % bytesPerFrame); + } + } + pWav->bytesRemaining = dataChunkSize; pWav->dataChunkDataSize = dataChunkSize; if (sampleCountFromFactChunk != 0) { pWav->totalPCMFrameCount = sampleCountFromFactChunk; @@ -79851,20 +81617,20 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p #endif return MA_TRUE; } -MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_wav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0, pAllocationCallbacks); + return ma_dr_wav_init_ex(pWav, onRead, onSeek, onTell, NULL, pUserData, NULL, 0, pAllocationCallbacks); } -MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, ma_dr_wav_chunk_proc onChunk, void* pReadSeekTellUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) { - if (!ma_dr_wav_preinit(pWav, onRead, onSeek, pReadSeekUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, onRead, onSeek, onTell, pReadSeekTellUserData, pAllocationCallbacks)) { return MA_FALSE; } return ma_dr_wav_init__internal(pWav, onChunk, pChunkUserData, flags); } -MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) { - if (!ma_dr_wav_preinit(pWav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return MA_FALSE; } return ma_dr_wav_init__internal(pWav, NULL, NULL, flags | MA_DR_WAV_WITH_METADATA); @@ -80026,8 +81792,8 @@ MA_PRIVATE size_t ma_dr_wav__write_or_count_metadata(ma_dr_wav* pWav, ma_dr_wav_ for (iLoop = 0; iLoop < pMetadata->data.smpl.sampleLoopCount; ++iLoop) { bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].cuePointId); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].type); - bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].firstSampleByteOffset); - bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].lastSampleByteOffset); + bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].firstSampleOffset); + bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].lastSampleOffset); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].sampleFraction); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].playCount); } @@ -80061,7 +81827,7 @@ MA_PRIVATE size_t ma_dr_wav__write_or_count_metadata(ma_dr_wav* pWav, ma_dr_wav_ bytesWritten += ma_dr_wav__write_or_count(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].dataChunkId, 4); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].chunkStart); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].blockStart); - bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].sampleByteOffset); + bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].sampleOffset); } } break; case ma_dr_wav_metadata_type_acid: @@ -80147,15 +81913,20 @@ MA_PRIVATE size_t ma_dr_wav__write_or_count_metadata(ma_dr_wav* pWav, ma_dr_wav_ if (pMetadata->type & ma_dr_wav_metadata_type_list_all_info_strings) { const char* pID = NULL; switch (pMetadata->type) { - case ma_dr_wav_metadata_type_list_info_software: pID = "ISFT"; break; - case ma_dr_wav_metadata_type_list_info_copyright: pID = "ICOP"; break; - case ma_dr_wav_metadata_type_list_info_title: pID = "INAM"; break; - case ma_dr_wav_metadata_type_list_info_artist: pID = "IART"; break; - case ma_dr_wav_metadata_type_list_info_comment: pID = "ICMT"; break; - case ma_dr_wav_metadata_type_list_info_date: pID = "ICRD"; break; - case ma_dr_wav_metadata_type_list_info_genre: pID = "IGNR"; break; - case ma_dr_wav_metadata_type_list_info_album: pID = "IPRD"; break; - case ma_dr_wav_metadata_type_list_info_tracknumber: pID = "ITRK"; break; + case ma_dr_wav_metadata_type_list_info_software: pID = "ISFT"; break; + case ma_dr_wav_metadata_type_list_info_copyright: pID = "ICOP"; break; + case ma_dr_wav_metadata_type_list_info_title: pID = "INAM"; break; + case ma_dr_wav_metadata_type_list_info_artist: pID = "IART"; break; + case ma_dr_wav_metadata_type_list_info_comment: pID = "ICMT"; break; + case ma_dr_wav_metadata_type_list_info_date: pID = "ICRD"; break; + case ma_dr_wav_metadata_type_list_info_genre: pID = "IGNR"; break; + case ma_dr_wav_metadata_type_list_info_album: pID = "IPRD"; break; + case ma_dr_wav_metadata_type_list_info_tracknumber: pID = "ITRK"; break; + case ma_dr_wav_metadata_type_list_info_location: pID = "IARL"; break; + case ma_dr_wav_metadata_type_list_info_organization: pID = "ICMS"; break; + case ma_dr_wav_metadata_type_list_info_keywords: pID = "IKEY"; break; + case ma_dr_wav_metadata_type_list_info_medium: pID = "IMED"; break; + case ma_dr_wav_metadata_type_list_info_description: pID = "ISBJ"; break; default: break; } MA_DR_WAV_ASSERT(pID != NULL); @@ -80370,7 +82141,7 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init_write__internal(ma_dr_wav* pWav, const ma_dr } pWav->dataChunkDataSizeTargetWrite = initialDataChunkSize; if (pFormat->container == ma_dr_wav_container_riff) { - ma_uint32 chunkSizeRIFF = 28 + (ma_uint32)initialDataChunkSize; + ma_uint32 chunkSizeRIFF = 36 + (ma_uint32)initialDataChunkSize; runningPos += ma_dr_wav__write(pWav, "RIFF", 4); runningPos += ma_dr_wav__write_u32ne_to_le(pWav, chunkSizeRIFF); runningPos += ma_dr_wav__write(pWav, "WAVE", 4); @@ -80493,7 +82264,31 @@ MA_PRIVATE size_t ma_dr_wav__on_write_stdio(void* pUserData, const void* pData, } MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_stdio(void* pUserData, int offset, ma_dr_wav_seek_origin origin) { - return fseek((FILE*)pUserData, offset, (origin == ma_dr_wav_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; + int whence = SEEK_SET; + if (origin == MA_DR_WAV_SEEK_CUR) { + whence = SEEK_CUR; + } else if (origin == MA_DR_WAV_SEEK_END) { + whence = SEEK_END; + } + return fseek((FILE*)pUserData, offset, whence) == 0; +} +MA_PRIVATE ma_bool32 ma_dr_wav__on_tell_stdio(void* pUserData, ma_int64* pCursor) +{ + FILE* pFileStdio = (FILE*)pUserData; + ma_int64 result; + MA_DR_WAV_ASSERT(pFileStdio != NULL); + MA_DR_WAV_ASSERT(pCursor != NULL); +#if defined(_WIN32) && !defined(NXDK) + #if defined(_MSC_VER) && _MSC_VER > 1200 + result = _ftelli64(pFileStdio); + #else + result = ftell(pFileStdio); + #endif +#else + result = ftell(pFileStdio); +#endif + *pCursor = result; + return MA_TRUE; } MA_API ma_bool32 ma_dr_wav_init_file(ma_dr_wav* pWav, const char* filename, const ma_allocation_callbacks* pAllocationCallbacks) { @@ -80502,7 +82297,7 @@ MA_API ma_bool32 ma_dr_wav_init_file(ma_dr_wav* pWav, const char* filename, cons MA_PRIVATE ma_bool32 ma_dr_wav_init_file__internal_FILE(ma_dr_wav* pWav, FILE* pFile, ma_dr_wav_chunk_proc onChunk, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) { ma_bool32 result; - result = ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_stdio, ma_dr_wav__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + result = ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_stdio, ma_dr_wav__on_seek_stdio, ma_dr_wav__on_tell_stdio, (void*)pFile, pAllocationCallbacks); if (result != MA_TRUE) { fclose(pFile); return result; @@ -80639,25 +82434,27 @@ MA_PRIVATE size_t ma_dr_wav__on_read_memory(void* pUserData, void* pBufferOut, s MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_memory(void* pUserData, int offset, ma_dr_wav_seek_origin origin) { ma_dr_wav* pWav = (ma_dr_wav*)pUserData; + ma_int64 newCursor; MA_DR_WAV_ASSERT(pWav != NULL); - if (origin == ma_dr_wav_seek_origin_current) { - if (offset > 0) { - if (pWav->memoryStream.currentReadPos + offset > pWav->memoryStream.dataSize) { - return MA_FALSE; - } - } else { - if (pWav->memoryStream.currentReadPos < (size_t)-offset) { - return MA_FALSE; - } - } - pWav->memoryStream.currentReadPos += offset; + newCursor = pWav->memoryStream.currentReadPos; + if (origin == MA_DR_WAV_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_WAV_SEEK_CUR) { + newCursor = (ma_int64)pWav->memoryStream.currentReadPos; + } else if (origin == MA_DR_WAV_SEEK_END) { + newCursor = (ma_int64)pWav->memoryStream.dataSize; } else { - if ((ma_uint32)offset <= pWav->memoryStream.dataSize) { - pWav->memoryStream.currentReadPos = offset; - } else { - return MA_FALSE; - } + MA_DR_WAV_ASSERT(!"Invalid seek origin"); + return MA_FALSE; } + newCursor += offset; + if (newCursor < 0) { + return MA_FALSE; + } + if ((size_t)newCursor > pWav->memoryStream.dataSize) { + return MA_FALSE; + } + pWav->memoryStream.currentReadPos = (size_t)newCursor; return MA_TRUE; } MA_PRIVATE size_t ma_dr_wav__on_write_memory(void* pUserData, const void* pDataIn, size_t bytesToWrite) @@ -80691,25 +82488,35 @@ MA_PRIVATE size_t ma_dr_wav__on_write_memory(void* pUserData, const void* pDataI MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_memory_write(void* pUserData, int offset, ma_dr_wav_seek_origin origin) { ma_dr_wav* pWav = (ma_dr_wav*)pUserData; + ma_int64 newCursor; MA_DR_WAV_ASSERT(pWav != NULL); - if (origin == ma_dr_wav_seek_origin_current) { - if (offset > 0) { - if (pWav->memoryStreamWrite.currentWritePos + offset > pWav->memoryStreamWrite.dataSize) { - offset = (int)(pWav->memoryStreamWrite.dataSize - pWav->memoryStreamWrite.currentWritePos); - } - } else { - if (pWav->memoryStreamWrite.currentWritePos < (size_t)-offset) { - offset = -(int)pWav->memoryStreamWrite.currentWritePos; - } - } - pWav->memoryStreamWrite.currentWritePos += offset; + newCursor = pWav->memoryStreamWrite.currentWritePos; + if (origin == MA_DR_WAV_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_WAV_SEEK_CUR) { + newCursor = (ma_int64)pWav->memoryStreamWrite.currentWritePos; + } else if (origin == MA_DR_WAV_SEEK_END) { + newCursor = (ma_int64)pWav->memoryStreamWrite.dataSize; } else { - if ((ma_uint32)offset <= pWav->memoryStreamWrite.dataSize) { - pWav->memoryStreamWrite.currentWritePos = offset; - } else { - pWav->memoryStreamWrite.currentWritePos = pWav->memoryStreamWrite.dataSize; - } + MA_DR_WAV_ASSERT(!"Invalid seek origin"); + return MA_INVALID_ARGS; } + newCursor += offset; + if (newCursor < 0) { + return MA_FALSE; + } + if ((size_t)newCursor > pWav->memoryStreamWrite.dataSize) { + return MA_FALSE; + } + pWav->memoryStreamWrite.currentWritePos = (size_t)newCursor; + return MA_TRUE; +} +MA_PRIVATE ma_bool32 ma_dr_wav__on_tell_memory(void* pUserData, ma_int64* pCursor) +{ + ma_dr_wav* pWav = (ma_dr_wav*)pUserData; + MA_DR_WAV_ASSERT(pWav != NULL); + MA_DR_WAV_ASSERT(pCursor != NULL); + *pCursor = (ma_int64)pWav->memoryStream.currentReadPos; return MA_TRUE; } MA_API ma_bool32 ma_dr_wav_init_memory(ma_dr_wav* pWav, const void* data, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) @@ -80721,7 +82528,7 @@ MA_API ma_bool32 ma_dr_wav_init_memory_ex(ma_dr_wav* pWav, const void* data, siz if (data == NULL || dataSize == 0) { return MA_FALSE; } - if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, pWav, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, ma_dr_wav__on_tell_memory, pWav, pAllocationCallbacks)) { return MA_FALSE; } pWav->memoryStream.data = (const ma_uint8*)data; @@ -80734,7 +82541,7 @@ MA_API ma_bool32 ma_dr_wav_init_memory_with_metadata(ma_dr_wav* pWav, const void if (data == NULL || dataSize == 0) { return MA_FALSE; } - if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, pWav, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, ma_dr_wav__on_tell_memory, pWav, pAllocationCallbacks)) { return MA_FALSE; } pWav->memoryStream.data = (const ma_uint8*)data; @@ -80793,30 +82600,30 @@ MA_API ma_result ma_dr_wav_uninit(ma_dr_wav* pWav) } if (pWav->onSeek && !pWav->isSequentialWrite) { if (pWav->container == ma_dr_wav_container_riff) { - if (pWav->onSeek(pWav->pUserData, 4, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, 4, MA_DR_WAV_SEEK_SET)) { ma_uint32 riffChunkSize = ma_dr_wav__riff_chunk_size_riff(pWav->dataChunkDataSize, pWav->pMetadata, pWav->metadataCount); ma_dr_wav__write_u32ne_to_le(pWav, riffChunkSize); } - if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 4, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 4, MA_DR_WAV_SEEK_SET)) { ma_uint32 dataChunkSize = ma_dr_wav__data_chunk_size_riff(pWav->dataChunkDataSize); ma_dr_wav__write_u32ne_to_le(pWav, dataChunkSize); } } else if (pWav->container == ma_dr_wav_container_w64) { - if (pWav->onSeek(pWav->pUserData, 16, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, 16, MA_DR_WAV_SEEK_SET)) { ma_uint64 riffChunkSize = ma_dr_wav__riff_chunk_size_w64(pWav->dataChunkDataSize); ma_dr_wav__write_u64ne_to_le(pWav, riffChunkSize); } - if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 8, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 8, MA_DR_WAV_SEEK_SET)) { ma_uint64 dataChunkSize = ma_dr_wav__data_chunk_size_w64(pWav->dataChunkDataSize); ma_dr_wav__write_u64ne_to_le(pWav, dataChunkSize); } } else if (pWav->container == ma_dr_wav_container_rf64) { int ds64BodyPos = 12 + 8; - if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 0, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 0, MA_DR_WAV_SEEK_SET)) { ma_uint64 riffChunkSize = ma_dr_wav__riff_chunk_size_rf64(pWav->dataChunkDataSize, pWav->pMetadata, pWav->metadataCount); ma_dr_wav__write_u64ne_to_le(pWav, riffChunkSize); } - if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 8, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 8, MA_DR_WAV_SEEK_SET)) { ma_uint64 dataChunkSize = ma_dr_wav__data_chunk_size_rf64(pWav->dataChunkDataSize); ma_dr_wav__write_u64ne_to_le(pWav, dataChunkSize); } @@ -80863,7 +82670,7 @@ MA_API size_t ma_dr_wav_read_raw(ma_dr_wav* pWav, size_t bytesToRead, void* pBuf if (bytesToSeek > 0x7FFFFFFF) { bytesToSeek = 0x7FFFFFFF; } - if (pWav->onSeek(pWav->pUserData, (int)bytesToSeek, ma_dr_wav_seek_origin_current) == MA_FALSE) { + if (pWav->onSeek(pWav->pUserData, (int)bytesToSeek, MA_DR_WAV_SEEK_CUR) == MA_FALSE) { break; } bytesRead += bytesToSeek; @@ -80962,7 +82769,7 @@ MA_PRIVATE ma_bool32 ma_dr_wav_seek_to_first_pcm_frame(ma_dr_wav* pWav) if (pWav->onWrite != NULL) { return MA_FALSE; } - if (!pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos, ma_dr_wav_seek_origin_start)) { + if (!pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos, MA_DR_WAV_SEEK_SET)) { return MA_FALSE; } if (ma_dr_wav__is_compressed_format_tag(pWav->translatedFormatTag)) { @@ -81043,7 +82850,7 @@ MA_API ma_bool32 ma_dr_wav_seek_to_pcm_frame(ma_dr_wav* pWav, ma_uint64 targetFr } while (offset > 0) { int offset32 = ((offset > INT_MAX) ? INT_MAX : (int)offset); - if (!pWav->onSeek(pWav->pUserData, offset32, ma_dr_wav_seek_origin_current)) { + if (!pWav->onSeek(pWav->pUserData, offset32, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } pWav->readCursorInPCMFrames += offset32 / bytesPerFrame; @@ -81169,12 +82976,12 @@ MA_API ma_uint64 ma_dr_wav_write_pcm_frames(ma_dr_wav* pWav, ma_uint64 framesToW MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_uint64 framesToRead, ma_int16* pBufferOut) { ma_uint64 totalFramesRead = 0; - static ma_int32 adaptationTable[] = { + static const ma_int32 adaptationTable[] = { 230, 230, 230, 230, 307, 409, 512, 614, 768, 614, 512, 409, 307, 230, 230, 230 }; - static ma_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 }; - static ma_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 }; + static const ma_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 }; + static const ma_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 }; MA_DR_WAV_ASSERT(pWav != NULL); MA_DR_WAV_ASSERT(framesToRead > 0); while (pWav->readCursorInPCMFrames < pWav->totalPCMFrameCount) { @@ -81307,11 +83114,11 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint { ma_uint64 totalFramesRead = 0; ma_uint32 iChannel; - static ma_int32 indexTable[16] = { + static const ma_int32 indexTable[16] = { -1, -1, -1, -1, 2, 4, 6, 8, -1, -1, -1, -1, 2, 4, 6, 8 }; - static ma_int32 stepTable[89] = { + static const ma_int32 stepTable[89] = { 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 19, 21, 23, 25, 28, 31, 34, 37, 41, 45, 50, 55, 60, 66, 73, 80, 88, 97, 107, 118, @@ -81334,7 +83141,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint } pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); if (header[2] >= ma_dr_wav_countof(stepTable)) { - pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, ma_dr_wav_seek_origin_current); + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, MA_DR_WAV_SEEK_CUR); pWav->ima.bytesRemainingInBlock = 0; return totalFramesRead; } @@ -81349,7 +83156,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint } pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); if (header[2] >= ma_dr_wav_countof(stepTable) || header[6] >= ma_dr_wav_countof(stepTable)) { - pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, ma_dr_wav_seek_origin_current); + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, MA_DR_WAV_SEEK_CUR); pWav->ima.bytesRemainingInBlock = 0; return totalFramesRead; } @@ -81424,7 +83231,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint return totalFramesRead; } #ifndef MA_DR_WAV_NO_CONVERSION_API -static unsigned short g_ma_dr_wavAlawTable[256] = { +static const unsigned short ma_dr_wav_gAlawTable[256] = { 0xEA80, 0xEB80, 0xE880, 0xE980, 0xEE80, 0xEF80, 0xEC80, 0xED80, 0xE280, 0xE380, 0xE080, 0xE180, 0xE680, 0xE780, 0xE480, 0xE580, 0xF540, 0xF5C0, 0xF440, 0xF4C0, 0xF740, 0xF7C0, 0xF640, 0xF6C0, 0xF140, 0xF1C0, 0xF040, 0xF0C0, 0xF340, 0xF3C0, 0xF240, 0xF2C0, 0xAA00, 0xAE00, 0xA200, 0xA600, 0xBA00, 0xBE00, 0xB200, 0xB600, 0x8A00, 0x8E00, 0x8200, 0x8600, 0x9A00, 0x9E00, 0x9200, 0x9600, @@ -81442,7 +83249,7 @@ static unsigned short g_ma_dr_wavAlawTable[256] = { 0x0560, 0x0520, 0x05E0, 0x05A0, 0x0460, 0x0420, 0x04E0, 0x04A0, 0x0760, 0x0720, 0x07E0, 0x07A0, 0x0660, 0x0620, 0x06E0, 0x06A0, 0x02B0, 0x0290, 0x02F0, 0x02D0, 0x0230, 0x0210, 0x0270, 0x0250, 0x03B0, 0x0390, 0x03F0, 0x03D0, 0x0330, 0x0310, 0x0370, 0x0350 }; -static unsigned short g_ma_dr_wavMulawTable[256] = { +static const unsigned short ma_dr_wav_gMulawTable[256] = { 0x8284, 0x8684, 0x8A84, 0x8E84, 0x9284, 0x9684, 0x9A84, 0x9E84, 0xA284, 0xA684, 0xAA84, 0xAE84, 0xB284, 0xB684, 0xBA84, 0xBE84, 0xC184, 0xC384, 0xC584, 0xC784, 0xC984, 0xCB84, 0xCD84, 0xCF84, 0xD184, 0xD384, 0xD584, 0xD784, 0xD984, 0xDB84, 0xDD84, 0xDF84, 0xE104, 0xE204, 0xE304, 0xE404, 0xE504, 0xE604, 0xE704, 0xE804, 0xE904, 0xEA04, 0xEB04, 0xEC04, 0xED04, 0xEE04, 0xEF04, 0xF004, @@ -81462,11 +83269,11 @@ static unsigned short g_ma_dr_wavMulawTable[256] = { }; static MA_INLINE ma_int16 ma_dr_wav__alaw_to_s16(ma_uint8 sampleIn) { - return (short)g_ma_dr_wavAlawTable[sampleIn]; + return (short)ma_dr_wav_gAlawTable[sampleIn]; } static MA_INLINE ma_int16 ma_dr_wav__mulaw_to_s16(ma_uint8 sampleIn) { - return (short)g_ma_dr_wavMulawTable[sampleIn]; + return (short)ma_dr_wav_gMulawTable[sampleIn]; } MA_PRIVATE void ma_dr_wav__pcm_to_s16(ma_int16* pOut, const ma_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) { @@ -82625,7 +84432,7 @@ MA_PRIVATE ma_int32* ma_dr_wav__read_pcm_frames_and_close_s32(ma_dr_wav* pWav, u } return pSampleData; } -MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_wav wav; if (channelsOut) { @@ -82637,12 +84444,12 @@ MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRe if (totalFrameCountOut) { *totalFrameCountOut = 0; } - if (!ma_dr_wav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_init(&wav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_wav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); } -MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_wav wav; if (channelsOut) { @@ -82654,12 +84461,12 @@ MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, if (totalFrameCountOut) { *totalFrameCountOut = 0; } - if (!ma_dr_wav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_init(&wav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_wav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); } -MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_wav wav; if (channelsOut) { @@ -82671,7 +84478,7 @@ MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRe if (totalFrameCountOut) { *totalFrameCountOut = 0; } - if (!ma_dr_wav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_init(&wav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_wav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); @@ -84106,23 +85913,23 @@ static ma_bool32 ma_dr_flac__seek_to_byte(ma_dr_flac_bs* bs, ma_uint64 offsetFro MA_DR_FLAC_ASSERT(offsetFromStart > 0); if (offsetFromStart > 0x7FFFFFFF) { ma_uint64 bytesRemaining = offsetFromStart; - if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_start)) { + if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } bytesRemaining -= 0x7FFFFFFF; while (bytesRemaining > 0x7FFFFFFF) { - if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_current)) { + if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } bytesRemaining -= 0x7FFFFFFF; } if (bytesRemaining > 0) { - if (!bs->onSeek(bs->pUserData, (int)bytesRemaining, ma_dr_flac_seek_origin_current)) { + if (!bs->onSeek(bs->pUserData, (int)bytesRemaining, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } } else { - if (!bs->onSeek(bs->pUserData, (int)offsetFromStart, ma_dr_flac_seek_origin_start)) { + if (!bs->onSeek(bs->pUserData, (int)offsetFromStart, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } } @@ -86600,6 +88407,7 @@ typedef struct { ma_dr_flac_read_proc onRead; ma_dr_flac_seek_proc onSeek; + ma_dr_flac_tell_proc onTell; ma_dr_flac_meta_proc onMeta; ma_dr_flac_container container; void* pUserData; @@ -86728,11 +88536,12 @@ static void ma_dr_flac__free_from_callbacks(void* p, const ma_allocation_callbac pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData); } } -static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, void* pUserData, void* pUserDataMD, ma_uint64* pFirstFramePos, ma_uint64* pSeektablePos, ma_uint32* pSeekpointCount, ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, void* pUserData, void* pUserDataMD, ma_uint64* pFirstFramePos, ma_uint64* pSeektablePos, ma_uint32* pSeekpointCount, ma_allocation_callbacks* pAllocationCallbacks) { ma_uint64 runningFilePos = 42; ma_uint64 seektablePos = 0; ma_uint32 seektableSize = 0; + (void)onTell; for (;;) { ma_dr_flac_metadata metadata; ma_uint8 isLastBlock = 0; @@ -86743,8 +88552,9 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea } runningFilePos += 4; metadata.type = blockType; - metadata.pRawData = NULL; metadata.rawDataSize = 0; + metadata.rawDataOffset = runningFilePos; + metadata.pRawData = NULL; switch (blockType) { case MA_DR_FLAC_METADATA_BLOCK_TYPE_APPLICATION: @@ -86944,53 +88754,123 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea return MA_FALSE; } if (onMeta) { - void* pRawData; - const char* pRunningData; - const char* pRunningDataEnd; - pRawData = ma_dr_flac__malloc_from_callbacks(blockSize, pAllocationCallbacks); - if (pRawData == NULL) { + ma_bool32 result = MA_TRUE; + ma_uint32 blockSizeRemaining = blockSize; + char* pMime = NULL; + char* pDescription = NULL; + void* pPictureData = NULL; + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.type, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.type = ma_dr_flac__be2host_32(metadata.data.picture.type); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.mimeLength, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.mimeLength = ma_dr_flac__be2host_32(metadata.data.picture.mimeLength); + pMime = (char*)ma_dr_flac__malloc_from_callbacks(metadata.data.picture.mimeLength + 1, pAllocationCallbacks); + if (pMime == NULL) { + result = MA_FALSE; + goto done_flac; + } + if (blockSizeRemaining < metadata.data.picture.mimeLength || onRead(pUserData, pMime, metadata.data.picture.mimeLength) != metadata.data.picture.mimeLength) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= metadata.data.picture.mimeLength; + pMime[metadata.data.picture.mimeLength] = '\0'; + metadata.data.picture.mime = (const char*)pMime; + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.descriptionLength, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.descriptionLength = ma_dr_flac__be2host_32(metadata.data.picture.descriptionLength); + pDescription = (char*)ma_dr_flac__malloc_from_callbacks(metadata.data.picture.descriptionLength + 1, pAllocationCallbacks); + if (pDescription == NULL) { + result = MA_FALSE; + goto done_flac; + } + if (blockSizeRemaining < metadata.data.picture.descriptionLength || onRead(pUserData, pDescription, metadata.data.picture.descriptionLength) != metadata.data.picture.descriptionLength) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= metadata.data.picture.descriptionLength; + pDescription[metadata.data.picture.descriptionLength] = '\0'; + metadata.data.picture.description = (const char*)pDescription; + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.width, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.width = ma_dr_flac__be2host_32(metadata.data.picture.width); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.height, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.height = ma_dr_flac__be2host_32(metadata.data.picture.height); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.colorDepth, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.colorDepth = ma_dr_flac__be2host_32(metadata.data.picture.colorDepth); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.indexColorCount, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.indexColorCount = ma_dr_flac__be2host_32(metadata.data.picture.indexColorCount); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.pictureDataSize, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.pictureDataSize = ma_dr_flac__be2host_32(metadata.data.picture.pictureDataSize); + if (blockSizeRemaining < metadata.data.picture.pictureDataSize) { + result = MA_FALSE; + goto done_flac; + } + metadata.data.picture.pictureDataOffset = runningFilePos + (blockSize - blockSizeRemaining); + #ifndef MA_DR_FLAC_NO_PICTURE_METADATA_MALLOC + pPictureData = ma_dr_flac__malloc_from_callbacks(metadata.data.picture.pictureDataSize, pAllocationCallbacks); + if (pPictureData != NULL) { + if (onRead(pUserData, pPictureData, metadata.data.picture.pictureDataSize) != metadata.data.picture.pictureDataSize) { + result = MA_FALSE; + goto done_flac; + } + } else + #endif + { + if (!onSeek(pUserData, metadata.data.picture.pictureDataSize, MA_DR_FLAC_SEEK_CUR)) { + result = MA_FALSE; + goto done_flac; + } + } + blockSizeRemaining -= metadata.data.picture.pictureDataSize; + metadata.data.picture.pPictureData = (const ma_uint8*)pPictureData; + if (metadata.data.picture.pictureDataOffset != 0 || metadata.data.picture.pPictureData != NULL) { + onMeta(pUserDataMD, &metadata); + } else { + } + done_flac: + ma_dr_flac__free_from_callbacks(pMime, pAllocationCallbacks); + ma_dr_flac__free_from_callbacks(pDescription, pAllocationCallbacks); + ma_dr_flac__free_from_callbacks(pPictureData, pAllocationCallbacks); + if (result != MA_TRUE) { return MA_FALSE; } - if (onRead(pUserData, pRawData, blockSize) != blockSize) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; - } - metadata.pRawData = pRawData; - metadata.rawDataSize = blockSize; - pRunningData = (const char*)pRawData; - pRunningDataEnd = (const char*)pRawData + blockSize; - metadata.data.picture.type = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.mimeLength = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - if ((pRunningDataEnd - pRunningData) - 24 < (ma_int64)metadata.data.picture.mimeLength) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; - } - metadata.data.picture.mime = pRunningData; pRunningData += metadata.data.picture.mimeLength; - metadata.data.picture.descriptionLength = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - if ((pRunningDataEnd - pRunningData) - 20 < (ma_int64)metadata.data.picture.descriptionLength) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; - } - metadata.data.picture.description = pRunningData; pRunningData += metadata.data.picture.descriptionLength; - metadata.data.picture.width = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.height = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.colorDepth = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.indexColorCount = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.pictureDataSize = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.pPictureData = (const ma_uint8*)pRunningData; - if (pRunningDataEnd - pRunningData < (ma_int64)metadata.data.picture.pictureDataSize) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; - } - onMeta(pUserDataMD, &metadata); - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); } } break; case MA_DR_FLAC_METADATA_BLOCK_TYPE_PADDING: { if (onMeta) { metadata.data.padding.unused = 0; - if (!onSeek(pUserData, blockSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { isLastBlock = MA_TRUE; } else { onMeta(pUserDataMD, &metadata); @@ -87000,7 +88880,7 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea case MA_DR_FLAC_METADATA_BLOCK_TYPE_INVALID: { if (onMeta) { - if (!onSeek(pUserData, blockSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { isLastBlock = MA_TRUE; } } @@ -87009,12 +88889,15 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea { if (onMeta) { void* pRawData = ma_dr_flac__malloc_from_callbacks(blockSize, pAllocationCallbacks); - if (pRawData == NULL) { - return MA_FALSE; - } - if (onRead(pUserData, pRawData, blockSize) != blockSize) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; + if (pRawData != NULL) { + if (onRead(pUserData, pRawData, blockSize) != blockSize) { + ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); + return MA_FALSE; + } + } else { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { + return MA_FALSE; + } } metadata.pRawData = pRawData; metadata.rawDataSize = blockSize; @@ -87024,7 +88907,7 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea } break; } if (onMeta == NULL && blockSize > 0) { - if (!onSeek(pUserData, blockSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { isLastBlock = MA_TRUE; } } @@ -87288,6 +89171,7 @@ typedef struct { ma_dr_flac_read_proc onRead; ma_dr_flac_seek_proc onSeek; + ma_dr_flac_tell_proc onTell; void* pUserData; ma_uint64 currentBytePos; ma_uint64 firstBytePos; @@ -87306,29 +89190,29 @@ static size_t ma_dr_flac_oggbs__read_physical(ma_dr_flac_oggbs* oggbs, void* buf } static ma_bool32 ma_dr_flac_oggbs__seek_physical(ma_dr_flac_oggbs* oggbs, ma_uint64 offset, ma_dr_flac_seek_origin origin) { - if (origin == ma_dr_flac_seek_origin_start) { + if (origin == MA_DR_FLAC_SEEK_SET) { if (offset <= 0x7FFFFFFF) { - if (!oggbs->onSeek(oggbs->pUserData, (int)offset, ma_dr_flac_seek_origin_start)) { + if (!oggbs->onSeek(oggbs->pUserData, (int)offset, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } oggbs->currentBytePos = offset; return MA_TRUE; } else { - if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_start)) { + if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } oggbs->currentBytePos = offset; - return ma_dr_flac_oggbs__seek_physical(oggbs, offset - 0x7FFFFFFF, ma_dr_flac_seek_origin_current); + return ma_dr_flac_oggbs__seek_physical(oggbs, offset - 0x7FFFFFFF, MA_DR_FLAC_SEEK_CUR); } } else { while (offset > 0x7FFFFFFF) { - if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_current)) { + if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } oggbs->currentBytePos += 0x7FFFFFFF; offset -= 0x7FFFFFFF; } - if (!oggbs->onSeek(oggbs->pUserData, (int)offset, ma_dr_flac_seek_origin_current)) { + if (!oggbs->onSeek(oggbs->pUserData, (int)offset, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } oggbs->currentBytePos += offset; @@ -87354,7 +89238,7 @@ static ma_bool32 ma_dr_flac_oggbs__goto_next_page(ma_dr_flac_oggbs* oggbs, ma_dr continue; } if (header.serialNumber != oggbs->serialNumber) { - if (pageBodySize > 0 && !ma_dr_flac_oggbs__seek_physical(oggbs, pageBodySize, ma_dr_flac_seek_origin_current)) { + if (pageBodySize > 0 && !ma_dr_flac_oggbs__seek_physical(oggbs, pageBodySize, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } continue; @@ -87416,7 +89300,7 @@ static ma_bool32 ma_dr_flac_oggbs__seek_to_next_packet(ma_dr_flac_oggbs* oggbs) } bytesToEndOfPacketOrPage += segmentSize; } - ma_dr_flac_oggbs__seek_physical(oggbs, bytesToEndOfPacketOrPage, ma_dr_flac_seek_origin_current); + ma_dr_flac_oggbs__seek_physical(oggbs, bytesToEndOfPacketOrPage, MA_DR_FLAC_SEEK_CUR); oggbs->bytesRemainingInPage -= bytesToEndOfPacketOrPage; if (atEndOfPage) { if (!ma_dr_flac_oggbs__goto_next_page(oggbs)) { @@ -87469,36 +89353,44 @@ static ma_bool32 ma_dr_flac__on_seek_ogg(void* pUserData, int offset, ma_dr_flac int bytesSeeked = 0; MA_DR_FLAC_ASSERT(oggbs != NULL); MA_DR_FLAC_ASSERT(offset >= 0); - if (origin == ma_dr_flac_seek_origin_start) { - if (!ma_dr_flac_oggbs__seek_physical(oggbs, (int)oggbs->firstBytePos, ma_dr_flac_seek_origin_start)) { + if (origin == MA_DR_FLAC_SEEK_SET) { + if (!ma_dr_flac_oggbs__seek_physical(oggbs, (int)oggbs->firstBytePos, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_fail_on_crc_mismatch)) { return MA_FALSE; } - return ma_dr_flac__on_seek_ogg(pUserData, offset, ma_dr_flac_seek_origin_current); - } - MA_DR_FLAC_ASSERT(origin == ma_dr_flac_seek_origin_current); - while (bytesSeeked < offset) { - int bytesRemainingToSeek = offset - bytesSeeked; - MA_DR_FLAC_ASSERT(bytesRemainingToSeek >= 0); - if (oggbs->bytesRemainingInPage >= (size_t)bytesRemainingToSeek) { - bytesSeeked += bytesRemainingToSeek; - (void)bytesSeeked; - oggbs->bytesRemainingInPage -= bytesRemainingToSeek; - break; - } - if (oggbs->bytesRemainingInPage > 0) { - bytesSeeked += (int)oggbs->bytesRemainingInPage; - oggbs->bytesRemainingInPage = 0; - } - MA_DR_FLAC_ASSERT(bytesRemainingToSeek > 0); - if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_fail_on_crc_mismatch)) { - return MA_FALSE; + return ma_dr_flac__on_seek_ogg(pUserData, offset, MA_DR_FLAC_SEEK_CUR); + } else if (origin == MA_DR_FLAC_SEEK_CUR) { + while (bytesSeeked < offset) { + int bytesRemainingToSeek = offset - bytesSeeked; + MA_DR_FLAC_ASSERT(bytesRemainingToSeek >= 0); + if (oggbs->bytesRemainingInPage >= (size_t)bytesRemainingToSeek) { + bytesSeeked += bytesRemainingToSeek; + (void)bytesSeeked; + oggbs->bytesRemainingInPage -= bytesRemainingToSeek; + break; + } + if (oggbs->bytesRemainingInPage > 0) { + bytesSeeked += (int)oggbs->bytesRemainingInPage; + oggbs->bytesRemainingInPage = 0; + } + MA_DR_FLAC_ASSERT(bytesRemainingToSeek > 0); + if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_fail_on_crc_mismatch)) { + return MA_FALSE; + } } + } else if (origin == MA_DR_FLAC_SEEK_END) { + return MA_FALSE; } return MA_TRUE; } +static ma_bool32 ma_dr_flac__on_tell_ogg(void* pUserData, ma_int64* pCursor) +{ + (void)pUserData; + (void)pCursor; + return MA_FALSE; +} static ma_bool32 ma_dr_flac_ogg__seek_to_pcm_frame(ma_dr_flac* pFlac, ma_uint64 pcmFrameIndex) { ma_dr_flac_oggbs* oggbs = (ma_dr_flac_oggbs*)pFlac->_oggbs; @@ -87515,7 +89407,7 @@ static ma_bool32 ma_dr_flac_ogg__seek_to_pcm_frame(ma_dr_flac* pFlac, ma_uint64 runningGranulePosition = 0; for (;;) { if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_recover_on_crc_mismatch)) { - ma_dr_flac_oggbs__seek_physical(oggbs, originalBytePos, ma_dr_flac_seek_origin_start); + ma_dr_flac_oggbs__seek_physical(oggbs, originalBytePos, MA_DR_FLAC_SEEK_SET); return MA_FALSE; } runningFrameBytePos = oggbs->currentBytePos - ma_dr_flac_ogg__get_page_header_size(&oggbs->currentPageHeader) - oggbs->pageDataSize; @@ -87534,7 +89426,7 @@ static ma_bool32 ma_dr_flac_ogg__seek_to_pcm_frame(ma_dr_flac* pFlac, ma_uint64 } } } - if (!ma_dr_flac_oggbs__seek_physical(oggbs, runningFrameBytePos, ma_dr_flac_seek_origin_start)) { + if (!ma_dr_flac_oggbs__seek_physical(oggbs, runningFrameBytePos, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_recover_on_crc_mismatch)) { @@ -87629,7 +89521,7 @@ static ma_bool32 ma_dr_flac__init_private__ogg(ma_dr_flac_init_info* pInit, ma_d if (mappingVersion[0] != 1) { return MA_FALSE; } - if (!onSeek(pUserData, 2, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, 2, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } if (onRead(pUserData, sig, 4) != 4) { @@ -87674,17 +89566,17 @@ static ma_bool32 ma_dr_flac__init_private__ogg(ma_dr_flac_init_info* pInit, ma_d return MA_FALSE; } } else { - if (!onSeek(pUserData, bytesRemainingInPage, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, bytesRemainingInPage, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } } else { - if (!onSeek(pUserData, bytesRemainingInPage, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, bytesRemainingInPage, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } } else { - if (!onSeek(pUserData, pageBodySize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, pageBodySize, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } @@ -87698,7 +89590,7 @@ static ma_bool32 ma_dr_flac__init_private__ogg(ma_dr_flac_init_info* pInit, ma_d return MA_TRUE; } #endif -static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD) +static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD) { ma_bool32 relaxed; ma_uint8 id[4]; @@ -87708,12 +89600,14 @@ static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_fla MA_DR_FLAC_ZERO_MEMORY(pInit, sizeof(*pInit)); pInit->onRead = onRead; pInit->onSeek = onSeek; + pInit->onTell = onTell; pInit->onMeta = onMeta; pInit->container = container; pInit->pUserData = pUserData; pInit->pUserDataMD = pUserDataMD; pInit->bs.onRead = onRead; pInit->bs.onSeek = onSeek; + pInit->bs.onTell = onTell; pInit->bs.pUserData = pUserData; ma_dr_flac__reset_cache(&pInit->bs); relaxed = container != ma_dr_flac_container_unknown; @@ -87736,7 +89630,7 @@ static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_fla if (flags & 0x10) { headerSize += 10; } - if (!onSeek(pUserData, headerSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, headerSize, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } pInit->runningFilePos += headerSize; @@ -87779,7 +89673,7 @@ static void ma_dr_flac__init_from_info(ma_dr_flac* pFlac, const ma_dr_flac_init_ pFlac->totalPCMFrameCount = pInit->totalPCMFrameCount; pFlac->container = pInit->container; } -static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac_init_info init; ma_uint32 allocationSize; @@ -87794,7 +89688,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on ma_allocation_callbacks allocationCallbacks; ma_dr_flac* pFlac; ma_dr_flac__init_cpu_caps(); - if (!ma_dr_flac__init_private(&init, onRead, onSeek, onMeta, container, pUserData, pUserDataMD)) { + if (!ma_dr_flac__init_private(&init, onRead, onSeek, onTell, onMeta, container, pUserData, pUserDataMD)) { return NULL; } if (pAllocationCallbacks != NULL) { @@ -87827,6 +89721,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on MA_DR_FLAC_ZERO_MEMORY(pOggbs, sizeof(*pOggbs)); pOggbs->onRead = onRead; pOggbs->onSeek = onSeek; + pOggbs->onTell = onTell; pOggbs->pUserData = pUserData; pOggbs->currentBytePos = init.oggFirstBytePos; pOggbs->firstBytePos = init.oggFirstBytePos; @@ -87841,15 +89736,17 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on if (init.hasMetadataBlocks) { ma_dr_flac_read_proc onReadOverride = onRead; ma_dr_flac_seek_proc onSeekOverride = onSeek; + ma_dr_flac_tell_proc onTellOverride = onTell; void* pUserDataOverride = pUserData; #ifndef MA_DR_FLAC_NO_OGG if (init.container == ma_dr_flac_container_ogg) { onReadOverride = ma_dr_flac__on_read_ogg; onSeekOverride = ma_dr_flac__on_seek_ogg; + onTellOverride = ma_dr_flac__on_tell_ogg; pUserDataOverride = (void*)pOggbs; } #endif - if (!ma_dr_flac__read_and_decode_metadata(onReadOverride, onSeekOverride, onMeta, pUserDataOverride, pUserDataMD, &firstFramePos, &seektablePos, &seekpointCount, &allocationCallbacks)) { + if (!ma_dr_flac__read_and_decode_metadata(onReadOverride, onSeekOverride, onTellOverride, onMeta, pUserDataOverride, pUserDataMD, &firstFramePos, &seektablePos, &seekpointCount, &allocationCallbacks)) { #ifndef MA_DR_FLAC_NO_OGG ma_dr_flac__free_from_callbacks(pOggbs, &allocationCallbacks); #endif @@ -87875,6 +89772,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on pOggbs = NULL; pFlac->bs.onRead = ma_dr_flac__on_read_ogg; pFlac->bs.onSeek = ma_dr_flac__on_seek_ogg; + pFlac->bs.onTell = ma_dr_flac__on_tell_ogg; pFlac->bs.pUserData = (void*)pInternalOggbs; pFlac->_oggbs = (void*)pInternalOggbs; } @@ -87894,7 +89792,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on pFlac->pSeekpoints = (ma_dr_flac_seekpoint*)((ma_uint8*)pFlac->pDecodedSamples + decodedSamplesAllocationSize); MA_DR_FLAC_ASSERT(pFlac->bs.onSeek != NULL); MA_DR_FLAC_ASSERT(pFlac->bs.onRead != NULL); - if (pFlac->bs.onSeek(pFlac->bs.pUserData, (int)seektablePos, ma_dr_flac_seek_origin_start)) { + if (pFlac->bs.onSeek(pFlac->bs.pUserData, (int)seektablePos, MA_DR_FLAC_SEEK_SET)) { ma_uint32 iSeekpoint; for (iSeekpoint = 0; iSeekpoint < seekpointCount; iSeekpoint += 1) { if (pFlac->bs.onRead(pFlac->bs.pUserData, pFlac->pSeekpoints + iSeekpoint, MA_DR_FLAC_SEEKPOINT_SIZE_IN_BYTES) == MA_DR_FLAC_SEEKPOINT_SIZE_IN_BYTES) { @@ -87907,7 +89805,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on break; } } - if (!pFlac->bs.onSeek(pFlac->bs.pUserData, (int)pFlac->firstFLACFramePosInBytes, ma_dr_flac_seek_origin_start)) { + if (!pFlac->bs.onSeek(pFlac->bs.pUserData, (int)pFlac->firstFLACFramePosInBytes, MA_DR_FLAC_SEEK_SET)) { ma_dr_flac__free_from_callbacks(pFlac, &allocationCallbacks); return NULL; } @@ -87950,8 +89848,31 @@ static size_t ma_dr_flac__on_read_stdio(void* pUserData, void* bufferOut, size_t } static ma_bool32 ma_dr_flac__on_seek_stdio(void* pUserData, int offset, ma_dr_flac_seek_origin origin) { - MA_DR_FLAC_ASSERT(offset >= 0); - return fseek((FILE*)pUserData, offset, (origin == ma_dr_flac_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; + int whence = SEEK_SET; + if (origin == MA_DR_FLAC_SEEK_CUR) { + whence = SEEK_CUR; + } else if (origin == MA_DR_FLAC_SEEK_END) { + whence = SEEK_END; + } + return fseek((FILE*)pUserData, offset, whence) == 0; +} +static ma_bool32 ma_dr_flac__on_tell_stdio(void* pUserData, ma_int64* pCursor) +{ + FILE* pFileStdio = (FILE*)pUserData; + ma_int64 result; + MA_DR_FLAC_ASSERT(pFileStdio != NULL); + MA_DR_FLAC_ASSERT(pCursor != NULL); +#if defined(_WIN32) && !defined(NXDK) + #if defined(_MSC_VER) && _MSC_VER > 1200 + result = _ftelli64(pFileStdio); + #else + result = ftell(pFileStdio); + #endif +#else + result = ftell(pFileStdio); +#endif + *pCursor = result; + return MA_TRUE; } MA_API ma_dr_flac* ma_dr_flac_open_file(const char* pFileName, const ma_allocation_callbacks* pAllocationCallbacks) { @@ -87960,7 +89881,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file(const char* pFileName, const ma_allocati if (ma_fopen(&pFile, pFileName, "rb") != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, (void*)pFile, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return NULL; @@ -87975,7 +89896,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_w(const wchar_t* pFileName, const ma_all if (ma_wfopen(&pFile, pFileName, L"rb", pAllocationCallbacks) != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, (void*)pFile, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return NULL; @@ -87990,7 +89911,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_with_metadata(const char* pFileName, ma_ if (ma_fopen(&pFile, pFileName, "rb") != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return pFlac; @@ -88005,7 +89926,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_with_metadata_w(const wchar_t* pFileName if (ma_wfopen(&pFile, pFileName, L"rb", pAllocationCallbacks) != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return pFlac; @@ -88033,24 +89954,34 @@ static size_t ma_dr_flac__on_read_memory(void* pUserData, void* bufferOut, size_ static ma_bool32 ma_dr_flac__on_seek_memory(void* pUserData, int offset, ma_dr_flac_seek_origin origin) { ma_dr_flac__memory_stream* memoryStream = (ma_dr_flac__memory_stream*)pUserData; + ma_int64 newCursor; MA_DR_FLAC_ASSERT(memoryStream != NULL); - MA_DR_FLAC_ASSERT(offset >= 0); - if (offset > (ma_int64)memoryStream->dataSize) { + if (origin == MA_DR_FLAC_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_FLAC_SEEK_CUR) { + newCursor = (ma_int64)memoryStream->currentReadPos; + } else if (origin == MA_DR_FLAC_SEEK_END) { + newCursor = (ma_int64)memoryStream->dataSize; + } else { + MA_DR_FLAC_ASSERT(!"Invalid seek origin"); return MA_FALSE; } - if (origin == ma_dr_flac_seek_origin_current) { - if (memoryStream->currentReadPos + offset <= memoryStream->dataSize) { - memoryStream->currentReadPos += offset; - } else { - return MA_FALSE; - } - } else { - if ((ma_uint32)offset <= memoryStream->dataSize) { - memoryStream->currentReadPos = offset; - } else { - return MA_FALSE; - } + newCursor += offset; + if (newCursor < 0) { + return MA_FALSE; } + if ((size_t)newCursor > memoryStream->dataSize) { + return MA_FALSE; + } + memoryStream->currentReadPos = (size_t)newCursor; + return MA_TRUE; +} +static ma_bool32 ma_dr_flac__on_tell_memory(void* pUserData, ma_int64* pCursor) +{ + ma_dr_flac__memory_stream* memoryStream = (ma_dr_flac__memory_stream*)pUserData; + MA_DR_FLAC_ASSERT(memoryStream != NULL); + MA_DR_FLAC_ASSERT(pCursor != NULL); + *pCursor = (ma_int64)memoryStream->currentReadPos; return MA_TRUE; } MA_API ma_dr_flac* ma_dr_flac_open_memory(const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) @@ -88060,7 +89991,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_memory(const void* pData, size_t dataSize, co memoryStream.data = (const ma_uint8*)pData; memoryStream.dataSize = dataSize; memoryStream.currentReadPos = 0; - pFlac = ma_dr_flac_open(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, &memoryStream, pAllocationCallbacks); + pFlac = ma_dr_flac_open(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, ma_dr_flac__on_tell_memory, &memoryStream, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } @@ -88085,7 +90016,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_memory_with_metadata(const void* pData, size_ memoryStream.data = (const ma_uint8*)pData; memoryStream.dataSize = dataSize; memoryStream.currentReadPos = 0; - pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, onMeta, ma_dr_flac_container_unknown, &memoryStream, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, ma_dr_flac__on_tell_memory, onMeta, ma_dr_flac_container_unknown, &memoryStream, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } @@ -88103,21 +90034,21 @@ MA_API ma_dr_flac* ma_dr_flac_open_memory_with_metadata(const void* pData, size_ } return pFlac; } -MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, NULL, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, NULL, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); } -MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, NULL, container, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, NULL, container, pUserData, pUserData, pAllocationCallbacks); } -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onMeta, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, onMeta, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); } -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onMeta, container, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, onMeta, container, pUserData, pUserData, pAllocationCallbacks); } MA_API void ma_dr_flac_close(ma_dr_flac* pFlac) { @@ -90410,7 +92341,7 @@ on_error: MA_DR_FLAC_DEFINE_FULL_READ_AND_CLOSE(s32, ma_int32) MA_DR_FLAC_DEFINE_FULL_READ_AND_CLOSE(s16, ma_int16) MA_DR_FLAC_DEFINE_FULL_READ_AND_CLOSE(f32, float) -MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac* pFlac; if (channelsOut) { @@ -90422,13 +92353,13 @@ MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc on if (totalPCMFrameCountOut) { *totalPCMFrameCountOut = 0; } - pFlac = ma_dr_flac_open(onRead, onSeek, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open(onRead, onSeek, onTell, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } return ma_dr_flac__full_read_and_close_s32(pFlac, channelsOut, sampleRateOut, totalPCMFrameCountOut); } -MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac* pFlac; if (channelsOut) { @@ -90440,13 +92371,13 @@ MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc on if (totalPCMFrameCountOut) { *totalPCMFrameCountOut = 0; } - pFlac = ma_dr_flac_open(onRead, onSeek, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open(onRead, onSeek, onTell, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } return ma_dr_flac__full_read_and_close_s16(pFlac, channelsOut, sampleRateOut, totalPCMFrameCountOut); } -MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac* pFlac; if (channelsOut) { @@ -90458,7 +92389,7 @@ MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRea if (totalPCMFrameCountOut) { *totalPCMFrameCountOut = 0; } - pFlac = ma_dr_flac_open(onRead, onSeek, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open(onRead, onSeek, onTell, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } @@ -90680,12 +92611,9 @@ MA_API const char* ma_dr_mp3_version_string(void) #define MA_DR_MP3_NO_SIMD #endif #define MA_DR_MP3_OFFSET_PTR(p, offset) ((void*)((ma_uint8*)(p) + (offset))) -#define MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE 2304 #ifndef MA_DR_MP3_MAX_FRAME_SYNC_MATCHES #define MA_DR_MP3_MAX_FRAME_SYNC_MATCHES 10 #endif -#define MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE -#define MA_DR_MP3_MAX_BITRESERVOIR_BYTES 511 #define MA_DR_MP3_SHORT_BLOCK_TYPE 2 #define MA_DR_MP3_STOP_BLOCK_TYPE 3 #define MA_DR_MP3_MODE_MONO 3 @@ -90735,7 +92663,7 @@ MA_API const char* ma_dr_mp3_version_string(void) #define MA_DR_MP3_VMUL_S(x, s) _mm_mul_ps(x, _mm_set1_ps(s)) #define MA_DR_MP3_VREV(x) _mm_shuffle_ps(x, x, _MM_SHUFFLE(0, 1, 2, 3)) typedef __m128 ma_dr_mp3_f4; -#if defined(_MSC_VER) || defined(MA_DR_MP3_ONLY_SIMD) +#if (defined(_MSC_VER) || defined(MA_DR_MP3_ONLY_SIMD)) && !defined(__clang__) #define ma_dr_mp3_cpuid __cpuid #else static __inline__ __attribute__((always_inline)) void ma_dr_mp3_cpuid(int CPUInfo[], const int InfoType) @@ -90851,11 +92779,6 @@ static __inline__ __attribute__((always_inline)) ma_int32 ma_dr_mp3_clip_int16_a #define MA_DR_MP3_FREE(p) free((p)) #endif typedef struct -{ - const ma_uint8 *buf; - int pos, limit; -} ma_dr_mp3_bs; -typedef struct { float scf[3*64]; ma_uint8 total_bands, stereo_bands, bitalloc[64], scfcod[64]; @@ -90864,22 +92787,6 @@ typedef struct { ma_uint8 tab_offset, code_tab_width, band_count; } ma_dr_mp3_L12_subband_alloc; -typedef struct -{ - const ma_uint8 *sfbtab; - ma_uint16 part_23_length, big_values, scalefac_compress; - ma_uint8 global_gain, block_type, mixed_block_flag, n_long_sfb, n_short_sfb; - ma_uint8 table_select[3], region_count[3], subblock_gain[3]; - ma_uint8 preflag, scalefac_scale, count1_table, scfsi; -} ma_dr_mp3_L3_gr_info; -typedef struct -{ - ma_dr_mp3_bs bs; - ma_uint8 maindata[MA_DR_MP3_MAX_BITRESERVOIR_BYTES + MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES]; - ma_dr_mp3_L3_gr_info gr_info[4]; - float grbuf[2][576], scf[40], syn[18 + 15][2*32]; - ma_uint8 ist_pos[2][39]; -} ma_dr_mp3dec_scratch; static void ma_dr_mp3_bs_init(ma_dr_mp3_bs *bs, const ma_uint8 *data, int bytes) { bs->buf = data; @@ -91262,6 +93169,10 @@ static float ma_dr_mp3_L3_ldexp_q2(float y, int exp_q2) } while ((exp_q2 -= e) > 0); return y; } +#if (defined(__GNUC__) && (__GNUC__ >= 13)) && !defined(__clang__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstringop-overflow" +#endif static void ma_dr_mp3_L3_decode_scalefactors(const ma_uint8 *hdr, ma_uint8 *ist_pos, ma_dr_mp3_bs *bs, const ma_dr_mp3_L3_gr_info *gr, float *scf, int ch) { static const ma_uint8 g_scf_partitions[3][28] = { @@ -91320,7 +93231,10 @@ static void ma_dr_mp3_L3_decode_scalefactors(const ma_uint8 *hdr, ma_uint8 *ist_ scf[i] = ma_dr_mp3_L3_ldexp_q2(gain, iscf[i] << scf_shift); } } -static const float g_ma_dr_mp3_pow43[129 + 16] = { +#if (defined(__GNUC__) && (__GNUC__ >= 13)) && !defined(__clang__) + #pragma GCC diagnostic pop +#endif +static const float ma_dr_mp3_g_pow43[129 + 16] = { 0,-1,-2.519842f,-4.326749f,-6.349604f,-8.549880f,-10.902724f,-13.390518f,-16.000000f,-18.720754f,-21.544347f,-24.463781f,-27.473142f,-30.567351f,-33.741992f,-36.993181f, 0,1,2.519842f,4.326749f,6.349604f,8.549880f,10.902724f,13.390518f,16.000000f,18.720754f,21.544347f,24.463781f,27.473142f,30.567351f,33.741992f,36.993181f,40.317474f,43.711787f,47.173345f,50.699631f,54.288352f,57.937408f,61.644865f,65.408941f,69.227979f,73.100443f,77.024898f,81.000000f,85.024491f,89.097188f,93.216975f,97.382800f,101.593667f,105.848633f,110.146801f,114.487321f,118.869381f,123.292209f,127.755065f,132.257246f,136.798076f,141.376907f,145.993119f,150.646117f,155.335327f,160.060199f,164.820202f,169.614826f,174.443577f,179.305980f,184.201575f,189.129918f,194.090580f,199.083145f,204.107210f,209.162385f,214.248292f,219.364564f,224.510845f,229.686789f,234.892058f,240.126328f,245.389280f,250.680604f,256.000000f,261.347174f,266.721841f,272.123723f,277.552547f,283.008049f,288.489971f,293.998060f,299.532071f,305.091761f,310.676898f,316.287249f,321.922592f,327.582707f,333.267377f,338.976394f,344.709550f,350.466646f,356.247482f,362.051866f,367.879608f,373.730522f,379.604427f,385.501143f,391.420496f,397.362314f,403.326427f,409.312672f,415.320884f,421.350905f,427.402579f,433.475750f,439.570269f,445.685987f,451.822757f,457.980436f,464.158883f,470.357960f,476.577530f,482.817459f,489.077615f,495.357868f,501.658090f,507.978156f,514.317941f,520.677324f,527.056184f,533.454404f,539.871867f,546.308458f,552.764065f,559.238575f,565.731879f,572.243870f,578.774440f,585.323483f,591.890898f,598.476581f,605.080431f,611.702349f,618.342238f,625.000000f,631.675540f,638.368763f,645.079578f }; @@ -91330,7 +93244,7 @@ static float ma_dr_mp3_L3_pow_43(int x) int sign, mult = 256; if (x < 129) { - return g_ma_dr_mp3_pow43[16 + x]; + return ma_dr_mp3_g_pow43[16 + x]; } if (x < 1024) { @@ -91339,7 +93253,7 @@ static float ma_dr_mp3_L3_pow_43(int x) } sign = 2*x & 64; frac = (float)((x & 63) - sign) / ((x & ~63) + sign); - return g_ma_dr_mp3_pow43[16 + ((x + sign) >> 6)]*(1.f + frac*((4.f/3) + frac*(2.f/9)))*mult; + return ma_dr_mp3_g_pow43[16 + ((x + sign) >> 6)]*(1.f + frac*((4.f/3) + frac*(2.f/9)))*mult; } static void ma_dr_mp3_L3_huffman(float *dst, ma_dr_mp3_bs *bs, const ma_dr_mp3_L3_gr_info *gr_info, const float *scf, int layer3gr_limit) { @@ -91409,7 +93323,7 @@ static void ma_dr_mp3_L3_huffman(float *dst, ma_dr_mp3_bs *bs, const ma_dr_mp3_L *dst = one*ma_dr_mp3_L3_pow_43(lsb)*((ma_int32)bs_cache < 0 ? -1: 1); } else { - *dst = g_ma_dr_mp3_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; + *dst = ma_dr_mp3_g_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; } MA_DR_MP3_FLUSH_BITS(lsb ? 1 : 0); } @@ -91437,7 +93351,7 @@ static void ma_dr_mp3_L3_huffman(float *dst, ma_dr_mp3_bs *bs, const ma_dr_mp3_L for (j = 0; j < 2; j++, dst++, leaf >>= 4) { int lsb = leaf & 0x0F; - *dst = g_ma_dr_mp3_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; + *dst = ma_dr_mp3_g_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; MA_DR_MP3_FLUSH_BITS(lsb ? 1 : 0); } MA_DR_MP3_CHECK_BITS; @@ -92245,7 +94159,6 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int int i = 0, igr, frame_size = 0, success = 1; const ma_uint8 *hdr; ma_dr_mp3_bs bs_frame[1]; - ma_dr_mp3dec_scratch scratch; if (mp3_bytes > 4 && dec->header[0] == 0xff && ma_dr_mp3_hdr_compare(dec->header, mp3)) { frame_size = ma_dr_mp3_hdr_frame_bytes(mp3, dec->free_format_bytes) + ma_dr_mp3_hdr_padding(mp3); @@ -92268,7 +94181,7 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int MA_DR_MP3_COPY_MEMORY(dec->header, hdr, MA_DR_MP3_HDR_SIZE); info->frame_bytes = i + frame_size; info->channels = MA_DR_MP3_HDR_IS_MONO(hdr) ? 1 : 2; - info->hz = ma_dr_mp3_hdr_sample_rate_hz(hdr); + info->sample_rate = ma_dr_mp3_hdr_sample_rate_hz(hdr); info->layer = 4 - MA_DR_MP3_HDR_GET_LAYER(hdr); info->bitrate_kbps = ma_dr_mp3_hdr_bitrate_kbps(hdr); ma_dr_mp3_bs_init(bs_frame, hdr + MA_DR_MP3_HDR_SIZE, frame_size - MA_DR_MP3_HDR_SIZE); @@ -92278,23 +94191,23 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int } if (info->layer == 3) { - int main_data_begin = ma_dr_mp3_L3_read_side_info(bs_frame, scratch.gr_info, hdr); + int main_data_begin = ma_dr_mp3_L3_read_side_info(bs_frame, dec->scratch.gr_info, hdr); if (main_data_begin < 0 || bs_frame->pos > bs_frame->limit) { ma_dr_mp3dec_init(dec); return 0; } - success = ma_dr_mp3_L3_restore_reservoir(dec, bs_frame, &scratch, main_data_begin); + success = ma_dr_mp3_L3_restore_reservoir(dec, bs_frame, &dec->scratch, main_data_begin); if (success && pcm != NULL) { for (igr = 0; igr < (MA_DR_MP3_HDR_TEST_MPEG1(hdr) ? 2 : 1); igr++, pcm = MA_DR_MP3_OFFSET_PTR(pcm, sizeof(ma_dr_mp3d_sample_t)*576*info->channels)) { - MA_DR_MP3_ZERO_MEMORY(scratch.grbuf[0], 576*2*sizeof(float)); - ma_dr_mp3_L3_decode(dec, &scratch, scratch.gr_info + igr*info->channels, info->channels); - ma_dr_mp3d_synth_granule(dec->qmf_state, scratch.grbuf[0], 18, info->channels, (ma_dr_mp3d_sample_t*)pcm, scratch.syn[0]); + MA_DR_MP3_ZERO_MEMORY(dec->scratch.grbuf[0], 576*2*sizeof(float)); + ma_dr_mp3_L3_decode(dec, &dec->scratch, dec->scratch.gr_info + igr*info->channels, info->channels); + ma_dr_mp3d_synth_granule(dec->qmf_state, dec->scratch.grbuf[0], 18, info->channels, (ma_dr_mp3d_sample_t*)pcm, dec->scratch.syn[0]); } } - ma_dr_mp3_L3_save_reservoir(dec, &scratch); + ma_dr_mp3_L3_save_reservoir(dec, &dec->scratch); } else { #ifdef MA_DR_MP3_ONLY_MP3 @@ -92305,15 +94218,15 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int return ma_dr_mp3_hdr_frame_samples(hdr); } ma_dr_mp3_L12_read_scale_info(hdr, bs_frame, sci); - MA_DR_MP3_ZERO_MEMORY(scratch.grbuf[0], 576*2*sizeof(float)); + MA_DR_MP3_ZERO_MEMORY(dec->scratch.grbuf[0], 576*2*sizeof(float)); for (i = 0, igr = 0; igr < 3; igr++) { - if (12 == (i += ma_dr_mp3_L12_dequantize_granule(scratch.grbuf[0] + i, bs_frame, sci, info->layer | 1))) + if (12 == (i += ma_dr_mp3_L12_dequantize_granule(dec->scratch.grbuf[0] + i, bs_frame, sci, info->layer | 1))) { i = 0; - ma_dr_mp3_L12_apply_scf_384(sci, sci->scf + igr, scratch.grbuf[0]); - ma_dr_mp3d_synth_granule(dec->qmf_state, scratch.grbuf[0], 12, info->channels, (ma_dr_mp3d_sample_t*)pcm, scratch.syn[0]); - MA_DR_MP3_ZERO_MEMORY(scratch.grbuf[0], 576*2*sizeof(float)); + ma_dr_mp3_L12_apply_scf_384(sci, sci->scf + igr, dec->scratch.grbuf[0]); + ma_dr_mp3d_synth_granule(dec->qmf_state, dec->scratch.grbuf[0], 12, info->channels, (ma_dr_mp3d_sample_t*)pcm, dec->scratch.syn[0]); + MA_DR_MP3_ZERO_MEMORY(dec->scratch.grbuf[0], 576*2*sizeof(float)); pcm = MA_DR_MP3_OFFSET_PTR(pcm, sizeof(ma_dr_mp3d_sample_t)*384*info->channels); } if (bs_frame->pos > bs_frame->limit) @@ -92491,19 +94404,41 @@ static ma_allocation_callbacks ma_dr_mp3_copy_allocation_callbacks_or_defaults(c } static size_t ma_dr_mp3__on_read(ma_dr_mp3* pMP3, void* pBufferOut, size_t bytesToRead) { - size_t bytesRead = pMP3->onRead(pMP3->pUserData, pBufferOut, bytesToRead); + size_t bytesRead; + MA_DR_MP3_ASSERT(pMP3 != NULL); + MA_DR_MP3_ASSERT(pMP3->onRead != NULL); + if (bytesToRead == 0) { + return 0; + } + bytesRead = pMP3->onRead(pMP3->pUserData, pBufferOut, bytesToRead); pMP3->streamCursor += bytesRead; return bytesRead; } +static size_t ma_dr_mp3__on_read_clamped(ma_dr_mp3* pMP3, void* pBufferOut, size_t bytesToRead) +{ + MA_DR_MP3_ASSERT(pMP3 != NULL); + MA_DR_MP3_ASSERT(pMP3->onRead != NULL); + if (pMP3->streamLength == MA_UINT64_MAX) { + return ma_dr_mp3__on_read(pMP3, pBufferOut, bytesToRead); + } else { + ma_uint64 bytesRemaining; + bytesRemaining = (pMP3->streamLength - pMP3->streamCursor); + if (bytesToRead > bytesRemaining) { + bytesToRead = (size_t)bytesRemaining; + } + return ma_dr_mp3__on_read(pMP3, pBufferOut, bytesToRead); + } +} static ma_bool32 ma_dr_mp3__on_seek(ma_dr_mp3* pMP3, int offset, ma_dr_mp3_seek_origin origin) { MA_DR_MP3_ASSERT(offset >= 0); + MA_DR_MP3_ASSERT(origin == MA_DR_MP3_SEEK_SET || origin == MA_DR_MP3_SEEK_CUR); if (!pMP3->onSeek(pMP3->pUserData, offset, origin)) { return MA_FALSE; } - if (origin == ma_dr_mp3_seek_origin_start) { + if (origin == MA_DR_MP3_SEEK_SET) { pMP3->streamCursor = (ma_uint64)offset; - } else { + } else{ pMP3->streamCursor += offset; } return MA_TRUE; @@ -92513,18 +94448,18 @@ static ma_bool32 ma_dr_mp3__on_seek_64(ma_dr_mp3* pMP3, ma_uint64 offset, ma_dr_ if (offset <= 0x7FFFFFFF) { return ma_dr_mp3__on_seek(pMP3, (int)offset, origin); } - if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, ma_dr_mp3_seek_origin_start)) { + if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, MA_DR_MP3_SEEK_SET)) { return MA_FALSE; } offset -= 0x7FFFFFFF; while (offset > 0) { if (offset <= 0x7FFFFFFF) { - if (!ma_dr_mp3__on_seek(pMP3, (int)offset, ma_dr_mp3_seek_origin_current)) { + if (!ma_dr_mp3__on_seek(pMP3, (int)offset, MA_DR_MP3_SEEK_CUR)) { return MA_FALSE; } offset = 0; } else { - if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, ma_dr_mp3_seek_origin_current)) { + if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, MA_DR_MP3_SEEK_CUR)) { return MA_FALSE; } offset -= 0x7FFFFFFF; @@ -92532,7 +94467,18 @@ static ma_bool32 ma_dr_mp3__on_seek_64(ma_dr_mp3* pMP3, ma_uint64 offset, ma_dr_ } return MA_TRUE; } -static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames) +static void ma_dr_mp3__on_meta(ma_dr_mp3* pMP3, ma_dr_mp3_metadata_type type, const void* pRawData, size_t rawDataSize) +{ + if (pMP3->onMeta) { + ma_dr_mp3_metadata metadata; + MA_DR_MP3_ZERO_OBJECT(&metadata); + metadata.type = type; + metadata.pRawData = pRawData; + metadata.rawDataSize = rawDataSize; + pMP3->onMeta(pMP3->pUserDataMeta, &metadata); + } +} +static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames, ma_dr_mp3dec_frame_info* pMP3FrameInfo, const ma_uint8** ppMP3FrameData) { ma_uint32 pcmFramesRead = 0; MA_DR_MP3_ASSERT(pMP3 != NULL); @@ -92559,7 +94505,7 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d pMP3->pData = pNewData; pMP3->dataCapacity = newDataCap; } - bytesRead = ma_dr_mp3__on_read(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); + bytesRead = ma_dr_mp3__on_read_clamped(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); if (bytesRead == 0) { if (pMP3->dataSize == 0) { pMP3->atEnd = MA_TRUE; @@ -92578,16 +94524,20 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d return 0; } pcmFramesRead = ma_dr_mp3dec_decode_frame(&pMP3->decoder, pMP3->pData + pMP3->dataConsumed, (int)pMP3->dataSize, pPCMFrames, &info); - if (info.frame_bytes > 0) { - pMP3->dataConsumed += (size_t)info.frame_bytes; - pMP3->dataSize -= (size_t)info.frame_bytes; - } + pMP3->dataConsumed += (size_t)info.frame_bytes; + pMP3->dataSize -= (size_t)info.frame_bytes; if (pcmFramesRead > 0) { pcmFramesRead = ma_dr_mp3_hdr_frame_samples(pMP3->decoder.header); pMP3->pcmFramesConsumedInMP3Frame = 0; pMP3->pcmFramesRemainingInMP3Frame = pcmFramesRead; pMP3->mp3FrameChannels = info.channels; - pMP3->mp3FrameSampleRate = info.hz; + pMP3->mp3FrameSampleRate = info.sample_rate; + if (pMP3FrameInfo != NULL) { + *pMP3FrameInfo = info; + } + if (ppMP3FrameData != NULL) { + *ppMP3FrameData = pMP3->pData + pMP3->dataConsumed - (size_t)info.frame_bytes; + } break; } else if (info.frame_bytes == 0) { size_t bytesRead; @@ -92604,7 +94554,7 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d pMP3->pData = pNewData; pMP3->dataCapacity = newDataCap; } - bytesRead = ma_dr_mp3__on_read(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); + bytesRead = ma_dr_mp3__on_read_clamped(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); if (bytesRead == 0) { pMP3->atEnd = MA_TRUE; return 0; @@ -92614,7 +94564,7 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d }; return pcmFramesRead; } -static ma_uint32 ma_dr_mp3_decode_next_frame_ex__memory(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames) +static ma_uint32 ma_dr_mp3_decode_next_frame_ex__memory(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames, ma_dr_mp3dec_frame_info* pMP3FrameInfo, const ma_uint8** ppMP3FrameData) { ma_uint32 pcmFramesRead = 0; ma_dr_mp3dec_frame_info info; @@ -92630,36 +94580,44 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__memory(ma_dr_mp3* pMP3, ma_dr_m pMP3->pcmFramesConsumedInMP3Frame = 0; pMP3->pcmFramesRemainingInMP3Frame = pcmFramesRead; pMP3->mp3FrameChannels = info.channels; - pMP3->mp3FrameSampleRate = info.hz; + pMP3->mp3FrameSampleRate = info.sample_rate; + if (pMP3FrameInfo != NULL) { + *pMP3FrameInfo = info; + } + if (ppMP3FrameData != NULL) { + *ppMP3FrameData = pMP3->memory.pData + pMP3->memory.currentReadPos; + } break; } else if (info.frame_bytes > 0) { pMP3->memory.currentReadPos += (size_t)info.frame_bytes; + pMP3->streamCursor += (size_t)info.frame_bytes; } else { break; } } pMP3->memory.currentReadPos += (size_t)info.frame_bytes; + pMP3->streamCursor += (size_t)info.frame_bytes; return pcmFramesRead; } -static ma_uint32 ma_dr_mp3_decode_next_frame_ex(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames) +static ma_uint32 ma_dr_mp3_decode_next_frame_ex(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames, ma_dr_mp3dec_frame_info* pMP3FrameInfo, const ma_uint8** ppMP3FrameData) { if (pMP3->memory.pData != NULL && pMP3->memory.dataSize > 0) { - return ma_dr_mp3_decode_next_frame_ex__memory(pMP3, pPCMFrames); + return ma_dr_mp3_decode_next_frame_ex__memory(pMP3, pPCMFrames, pMP3FrameInfo, ppMP3FrameData); } else { - return ma_dr_mp3_decode_next_frame_ex__callbacks(pMP3, pPCMFrames); + return ma_dr_mp3_decode_next_frame_ex__callbacks(pMP3, pPCMFrames, pMP3FrameInfo, ppMP3FrameData); } } static ma_uint32 ma_dr_mp3_decode_next_frame(ma_dr_mp3* pMP3) { MA_DR_MP3_ASSERT(pMP3 != NULL); - return ma_dr_mp3_decode_next_frame_ex(pMP3, (ma_dr_mp3d_sample_t*)pMP3->pcmFrames); + return ma_dr_mp3_decode_next_frame_ex(pMP3, (ma_dr_mp3d_sample_t*)pMP3->pcmFrames, NULL, NULL); } #if 0 static ma_uint32 ma_dr_mp3_seek_next_frame(ma_dr_mp3* pMP3) { ma_uint32 pcmFrameCount; MA_DR_MP3_ASSERT(pMP3 != NULL); - pcmFrameCount = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFrameCount = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFrameCount == 0) { return 0; } @@ -92669,33 +94627,249 @@ static ma_uint32 ma_dr_mp3_seek_next_frame(ma_dr_mp3* pMP3) return pcmFrameCount; } #endif -static ma_bool32 ma_dr_mp3_init_internal(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_mp3_init_internal(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, ma_dr_mp3_meta_proc onMeta, void* pUserData, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) { + ma_dr_mp3dec_frame_info firstFrameInfo; + const ma_uint8* pFirstFrameData; + ma_uint32 firstFramePCMFrameCount; + ma_uint32 detectedMP3FrameCount = 0xFFFFFFFF; MA_DR_MP3_ASSERT(pMP3 != NULL); MA_DR_MP3_ASSERT(onRead != NULL); ma_dr_mp3dec_init(&pMP3->decoder); pMP3->onRead = onRead; pMP3->onSeek = onSeek; + pMP3->onMeta = onMeta; pMP3->pUserData = pUserData; + pMP3->pUserDataMeta = pUserDataMeta; pMP3->allocationCallbacks = ma_dr_mp3_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); if (pMP3->allocationCallbacks.onFree == NULL || (pMP3->allocationCallbacks.onMalloc == NULL && pMP3->allocationCallbacks.onRealloc == NULL)) { return MA_FALSE; } - if (ma_dr_mp3_decode_next_frame(pMP3) == 0) { + pMP3->streamCursor = 0; + pMP3->streamLength = MA_UINT64_MAX; + pMP3->streamStartOffset = 0; + pMP3->delayInPCMFrames = 0; + pMP3->paddingInPCMFrames = 0; + pMP3->totalPCMFrameCount = MA_UINT64_MAX; + #if 1 + if (onSeek != NULL && onTell != NULL) { + if (onSeek(pUserData, 0, MA_DR_MP3_SEEK_END)) { + ma_int64 streamLen; + int streamEndOffset = 0; + if (onTell(pUserData, &streamLen)) { + if (streamLen > 128) { + char id3[3]; + if (onSeek(pUserData, streamEndOffset - 128, MA_DR_MP3_SEEK_END)) { + if (onRead(pUserData, id3, 3) == 3 && id3[0] == 'T' && id3[1] == 'A' && id3[2] == 'G') { + streamEndOffset -= 128; + streamLen -= 128; + if (onMeta != NULL) { + ma_uint8 tag[128]; + tag[0] = 'T'; tag[1] = 'A'; tag[2] = 'G'; + if (onRead(pUserData, tag + 3, 125) == 125) { + ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_ID3V1, tag, 128); + } + } + } else { + } + } else { + } + } else { + } + if (streamLen > 32) { + char ape[32]; + if (onSeek(pUserData, streamEndOffset - 32, MA_DR_MP3_SEEK_END)) { + if (onRead(pUserData, ape, 32) == 32 && ape[0] == 'A' && ape[1] == 'P' && ape[2] == 'E' && ape[3] == 'T' && ape[4] == 'A' && ape[5] == 'G' && ape[6] == 'E' && ape[7] == 'X') { + ma_uint32 tagSize = + ((ma_uint32)ape[24] << 0) | + ((ma_uint32)ape[25] << 8) | + ((ma_uint32)ape[26] << 16) | + ((ma_uint32)ape[27] << 24); + streamEndOffset -= 32 + tagSize; + streamLen -= 32 + tagSize; + if (onMeta != NULL) { + if (onSeek(pUserData, streamEndOffset, MA_DR_MP3_SEEK_END)) { + size_t apeTagSize = (size_t)tagSize + 32; + ma_uint8* pTagData = (ma_uint8*)ma_dr_mp3_malloc(apeTagSize, pAllocationCallbacks); + if (pTagData != NULL) { + if (onRead(pUserData, pTagData, apeTagSize) == apeTagSize) { + ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_APE, pTagData, apeTagSize); + } + ma_dr_mp3_free(pTagData, pAllocationCallbacks); + } + } + } + } + } + } else { + } + if (!onSeek(pUserData, 0, MA_DR_MP3_SEEK_SET)) { + return MA_FALSE; + } + pMP3->streamLength = (ma_uint64)streamLen; + if (pMP3->memory.pData != NULL) { + pMP3->memory.dataSize = (size_t)pMP3->streamLength; + } + } else { + if (!onSeek(pUserData, 0, MA_DR_MP3_SEEK_SET)) { + return MA_FALSE; + } + } + } else { + } + } else { + } + #endif + #if 1 + { + char header[10]; + if (onRead(pUserData, header, 10) == 10) { + if (header[0] == 'I' && header[1] == 'D' && header[2] == '3') { + ma_uint32 tagSize = + (((ma_uint32)header[6] & 0x7F) << 21) | + (((ma_uint32)header[7] & 0x7F) << 14) | + (((ma_uint32)header[8] & 0x7F) << 7) | + (((ma_uint32)header[9] & 0x7F) << 0); + if (header[5] & 0x10) { + tagSize += 10; + } + if (onMeta != NULL) { + size_t tagSizeWithHeader = 10 + tagSize; + ma_uint8* pTagData = (ma_uint8*)ma_dr_mp3_malloc(tagSizeWithHeader, pAllocationCallbacks); + if (pTagData != NULL) { + MA_DR_MP3_COPY_MEMORY(pTagData, header, 10); + if (onRead(pUserData, pTagData + 10, tagSize) == tagSize) { + ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_ID3V2, pTagData, tagSizeWithHeader); + } + ma_dr_mp3_free(pTagData, pAllocationCallbacks); + } + } else { + if (onSeek != NULL) { + if (!onSeek(pUserData, tagSize, MA_DR_MP3_SEEK_CUR)) { + return MA_FALSE; + } + } else { + char discard[1024]; + while (tagSize > 0) { + size_t bytesToRead = tagSize; + if (bytesToRead > sizeof(discard)) { + bytesToRead = sizeof(discard); + } + if (onRead(pUserData, discard, bytesToRead) != bytesToRead) { + return MA_FALSE; + } + tagSize -= (ma_uint32)bytesToRead; + } + } + } + pMP3->streamStartOffset += 10 + tagSize; + pMP3->streamCursor = pMP3->streamStartOffset; + } else { + if (onSeek != NULL) { + if (!onSeek(pUserData, 0, MA_DR_MP3_SEEK_SET)) { + return MA_FALSE; + } + } else { + } + } + } else { + return MA_FALSE; + } + } + #endif + firstFramePCMFrameCount = ma_dr_mp3_decode_next_frame_ex(pMP3, (ma_dr_mp3d_sample_t*)pMP3->pcmFrames, &firstFrameInfo, &pFirstFrameData); + if (firstFramePCMFrameCount > 0) { + MA_DR_MP3_ASSERT(pFirstFrameData != NULL); + #if 1 + MA_DR_MP3_ASSERT(firstFrameInfo.frame_bytes > 0); + { + ma_dr_mp3_bs bs; + ma_dr_mp3_L3_gr_info grInfo[4]; + const ma_uint8* pTagData = pFirstFrameData; + ma_dr_mp3_bs_init(&bs, pFirstFrameData + MA_DR_MP3_HDR_SIZE, firstFrameInfo.frame_bytes - MA_DR_MP3_HDR_SIZE); + if (MA_DR_MP3_HDR_IS_CRC(pFirstFrameData)) { + ma_dr_mp3_bs_get_bits(&bs, 16); + } + if (ma_dr_mp3_L3_read_side_info(&bs, grInfo, pFirstFrameData) >= 0) { + ma_bool32 isXing = MA_FALSE; + ma_bool32 isInfo = MA_FALSE; + const ma_uint8* pTagDataBeg; + pTagDataBeg = pFirstFrameData + MA_DR_MP3_HDR_SIZE + (bs.pos/8); + pTagData = pTagDataBeg; + isXing = (pTagData[0] == 'X' && pTagData[1] == 'i' && pTagData[2] == 'n' && pTagData[3] == 'g'); + isInfo = (pTagData[0] == 'I' && pTagData[1] == 'n' && pTagData[2] == 'f' && pTagData[3] == 'o'); + if (isXing || isInfo) { + ma_uint32 bytes = 0; + ma_uint32 flags = pTagData[7]; + pTagData += 8; + if (flags & 0x01) { + detectedMP3FrameCount = (ma_uint32)pTagData[0] << 24 | (ma_uint32)pTagData[1] << 16 | (ma_uint32)pTagData[2] << 8 | (ma_uint32)pTagData[3]; + pTagData += 4; + } + if (flags & 0x02) { + bytes = (ma_uint32)pTagData[0] << 24 | (ma_uint32)pTagData[1] << 16 | (ma_uint32)pTagData[2] << 8 | (ma_uint32)pTagData[3]; + (void)bytes; + pTagData += 4; + } + if (flags & 0x04) { + pTagData += 100; + } + if (flags & 0x08) { + pTagData += 4; + } + if (pTagData[0]) { + pTagData += 21; + if (pTagData - pFirstFrameData + 14 < firstFrameInfo.frame_bytes) { + int delayInPCMFrames; + int paddingInPCMFrames; + delayInPCMFrames = (( (ma_uint32)pTagData[0] << 4) | ((ma_uint32)pTagData[1] >> 4)) + (528 + 1); + paddingInPCMFrames = ((((ma_uint32)pTagData[1] & 0xF) << 8) | ((ma_uint32)pTagData[2] )) - (528 + 1); + if (paddingInPCMFrames < 0) { + paddingInPCMFrames = 0; + } + pMP3->delayInPCMFrames = (ma_uint32)delayInPCMFrames; + pMP3->paddingInPCMFrames = (ma_uint32)paddingInPCMFrames; + } + } + if (isXing) { + pMP3->isVBR = MA_TRUE; + } else if (isInfo) { + pMP3->isCBR = MA_TRUE; + } + if (onMeta != NULL) { + ma_dr_mp3_metadata_type metadataType = isXing ? MA_DR_MP3_METADATA_TYPE_XING : MA_DR_MP3_METADATA_TYPE_VBRI; + size_t tagDataSize; + tagDataSize = (size_t)firstFrameInfo.frame_bytes; + tagDataSize -= (size_t)(pTagDataBeg - pFirstFrameData); + ma_dr_mp3__on_meta(pMP3, metadataType, pTagDataBeg, tagDataSize); + } + pMP3->pcmFramesRemainingInMP3Frame = 0; + pMP3->streamStartOffset += (ma_uint32)(firstFrameInfo.frame_bytes); + pMP3->streamCursor = pMP3->streamStartOffset; + ma_dr_mp3dec_init(&pMP3->decoder); + } + } else { + } + } + #endif + } else { ma_dr_mp3__free_from_callbacks(pMP3->pData, &pMP3->allocationCallbacks); return MA_FALSE; } + if (detectedMP3FrameCount != 0xFFFFFFFF) { + pMP3->totalPCMFrameCount = detectedMP3FrameCount * firstFramePCMFrameCount; + } pMP3->channels = pMP3->mp3FrameChannels; pMP3->sampleRate = pMP3->mp3FrameSampleRate; return MA_TRUE; } -MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, ma_dr_mp3_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { if (pMP3 == NULL || onRead == NULL) { return MA_FALSE; } MA_DR_MP3_ZERO_OBJECT(pMP3); - return ma_dr_mp3_init_internal(pMP3, onRead, onSeek, pUserData, pAllocationCallbacks); + return ma_dr_mp3_init_internal(pMP3, onRead, onSeek, onTell, onMeta, pUserData, pUserData, pAllocationCallbacks); } static size_t ma_dr_mp3__on_read_memory(void* pUserData, void* pBufferOut, size_t bytesToRead) { @@ -92716,29 +94890,40 @@ static size_t ma_dr_mp3__on_read_memory(void* pUserData, void* pBufferOut, size_ static ma_bool32 ma_dr_mp3__on_seek_memory(void* pUserData, int byteOffset, ma_dr_mp3_seek_origin origin) { ma_dr_mp3* pMP3 = (ma_dr_mp3*)pUserData; + ma_int64 newCursor; MA_DR_MP3_ASSERT(pMP3 != NULL); - if (origin == ma_dr_mp3_seek_origin_current) { - if (byteOffset > 0) { - if (pMP3->memory.currentReadPos + byteOffset > pMP3->memory.dataSize) { - byteOffset = (int)(pMP3->memory.dataSize - pMP3->memory.currentReadPos); - } - } else { - if (pMP3->memory.currentReadPos < (size_t)-byteOffset) { - byteOffset = -(int)pMP3->memory.currentReadPos; - } - } - pMP3->memory.currentReadPos += byteOffset; + newCursor = pMP3->memory.currentReadPos; + if (origin == MA_DR_MP3_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_MP3_SEEK_CUR) { + newCursor = (ma_int64)pMP3->memory.currentReadPos; + } else if (origin == MA_DR_MP3_SEEK_END) { + newCursor = (ma_int64)pMP3->memory.dataSize; } else { - if ((ma_uint32)byteOffset <= pMP3->memory.dataSize) { - pMP3->memory.currentReadPos = byteOffset; - } else { - pMP3->memory.currentReadPos = pMP3->memory.dataSize; - } + MA_DR_MP3_ASSERT(!"Invalid seek origin"); + return MA_FALSE; } + newCursor += byteOffset; + if (newCursor < 0) { + return MA_FALSE; + } + if ((size_t)newCursor > pMP3->memory.dataSize) { + return MA_FALSE; + } + pMP3->memory.currentReadPos = (size_t)newCursor; return MA_TRUE; } -MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_mp3__on_tell_memory(void* pUserData, ma_int64* pCursor) { + ma_dr_mp3* pMP3 = (ma_dr_mp3*)pUserData; + MA_DR_MP3_ASSERT(pMP3 != NULL); + MA_DR_MP3_ASSERT(pCursor != NULL); + *pCursor = (ma_int64)pMP3->memory.currentReadPos; + return MA_TRUE; +} +MA_API ma_bool32 ma_dr_mp3_init_memory_with_metadata(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) +{ + ma_bool32 result; if (pMP3 == NULL) { return MA_FALSE; } @@ -92749,7 +94934,21 @@ MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_ pMP3->memory.pData = (const ma_uint8*)pData; pMP3->memory.dataSize = dataSize; pMP3->memory.currentReadPos = 0; - return ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_memory, ma_dr_mp3__on_seek_memory, pMP3, pAllocationCallbacks); + result = ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_memory, ma_dr_mp3__on_seek_memory, ma_dr_mp3__on_tell_memory, onMeta, pMP3, pUserDataMeta, pAllocationCallbacks); + if (result == MA_FALSE) { + return MA_FALSE; + } + if (pMP3->streamLength <= (ma_uint64)MA_SIZE_MAX) { + pMP3->memory.dataSize = (size_t)pMP3->streamLength; + } + if (pMP3->streamStartOffset > (ma_uint64)MA_SIZE_MAX) { + return MA_FALSE; + } + return MA_TRUE; +} +MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) +{ + return ma_dr_mp3_init_memory_with_metadata(pMP3, pData, dataSize, NULL, NULL, pAllocationCallbacks); } #ifndef MA_DR_MP3_NO_STDIO #include @@ -92760,36 +94959,76 @@ static size_t ma_dr_mp3__on_read_stdio(void* pUserData, void* pBufferOut, size_t } static ma_bool32 ma_dr_mp3__on_seek_stdio(void* pUserData, int offset, ma_dr_mp3_seek_origin origin) { - return fseek((FILE*)pUserData, offset, (origin == ma_dr_mp3_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; + int whence = SEEK_SET; + if (origin == MA_DR_MP3_SEEK_CUR) { + whence = SEEK_CUR; + } else if (origin == MA_DR_MP3_SEEK_END) { + whence = SEEK_END; + } + return fseek((FILE*)pUserData, offset, whence) == 0; } -MA_API ma_bool32 ma_dr_mp3_init_file(ma_dr_mp3* pMP3, const char* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_mp3__on_tell_stdio(void* pUserData, ma_int64* pCursor) +{ + FILE* pFileStdio = (FILE*)pUserData; + ma_int64 result; + MA_DR_MP3_ASSERT(pFileStdio != NULL); + MA_DR_MP3_ASSERT(pCursor != NULL); +#if defined(_WIN32) && !defined(NXDK) + #if defined(_MSC_VER) && _MSC_VER > 1200 + result = _ftelli64(pFileStdio); + #else + result = ftell(pFileStdio); + #endif +#else + result = ftell(pFileStdio); +#endif + *pCursor = result; + return MA_TRUE; +} +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata(ma_dr_mp3* pMP3, const char* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) { ma_bool32 result; FILE* pFile; + if (pMP3 == NULL) { + return MA_FALSE; + } + MA_DR_MP3_ZERO_OBJECT(pMP3); if (ma_fopen(&pFile, pFilePath, "rb") != MA_SUCCESS) { return MA_FALSE; } - result = ma_dr_mp3_init(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + result = ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, ma_dr_mp3__on_tell_stdio, onMeta, (void*)pFile, pUserDataMeta, pAllocationCallbacks); if (result != MA_TRUE) { fclose(pFile); return result; } return MA_TRUE; } -MA_API ma_bool32 ma_dr_mp3_init_file_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) { ma_bool32 result; FILE* pFile; + if (pMP3 == NULL) { + return MA_FALSE; + } + MA_DR_MP3_ZERO_OBJECT(pMP3); if (ma_wfopen(&pFile, pFilePath, L"rb", pAllocationCallbacks) != MA_SUCCESS) { return MA_FALSE; } - result = ma_dr_mp3_init(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + result = ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, ma_dr_mp3__on_tell_stdio, onMeta, (void*)pFile, pUserDataMeta, pAllocationCallbacks); if (result != MA_TRUE) { fclose(pFile); return result; } return MA_TRUE; } +MA_API ma_bool32 ma_dr_mp3_init_file(ma_dr_mp3* pMP3, const char* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +{ + return ma_dr_mp3_init_file_with_metadata(pMP3, pFilePath, NULL, NULL, pAllocationCallbacks); +} +MA_API ma_bool32 ma_dr_mp3_init_file_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +{ + return ma_dr_mp3_init_file_with_metadata_w(pMP3, pFilePath, NULL, NULL, pAllocationCallbacks); +} #endif MA_API void ma_dr_mp3_uninit(ma_dr_mp3* pMP3) { @@ -92859,17 +95098,38 @@ static ma_uint64 ma_dr_mp3_read_pcm_frames_raw(ma_dr_mp3* pMP3, ma_uint64 frames MA_DR_MP3_ASSERT(pMP3 != NULL); MA_DR_MP3_ASSERT(pMP3->onRead != NULL); while (framesToRead > 0) { - ma_uint32 framesToConsume = (ma_uint32)MA_DR_MP3_MIN(pMP3->pcmFramesRemainingInMP3Frame, framesToRead); + ma_uint32 framesToConsume; + if (pMP3->currentPCMFrame < pMP3->delayInPCMFrames) { + ma_uint32 framesToSkip = (ma_uint32)MA_DR_MP3_MIN(pMP3->pcmFramesRemainingInMP3Frame, pMP3->delayInPCMFrames - pMP3->currentPCMFrame); + pMP3->currentPCMFrame += framesToSkip; + pMP3->pcmFramesConsumedInMP3Frame += framesToSkip; + pMP3->pcmFramesRemainingInMP3Frame -= framesToSkip; + } + framesToConsume = (ma_uint32)MA_DR_MP3_MIN(pMP3->pcmFramesRemainingInMP3Frame, framesToRead); + if (pMP3->totalPCMFrameCount != MA_UINT64_MAX && pMP3->totalPCMFrameCount > pMP3->paddingInPCMFrames) { + if (pMP3->currentPCMFrame < (pMP3->totalPCMFrameCount - pMP3->paddingInPCMFrames)) { + ma_uint64 framesRemainigToPadding = (pMP3->totalPCMFrameCount - pMP3->paddingInPCMFrames) - pMP3->currentPCMFrame; + if (framesToConsume > framesRemainigToPadding) { + framesToConsume = (ma_uint32)framesRemainigToPadding; + } + } else { + break; + } + } if (pBufferOut != NULL) { - #if defined(MA_DR_MP3_FLOAT_OUTPUT) - float* pFramesOutF32 = (float*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(float) * totalFramesRead * pMP3->channels); - float* pFramesInF32 = (float*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(float) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); - MA_DR_MP3_COPY_MEMORY(pFramesOutF32, pFramesInF32, sizeof(float) * framesToConsume * pMP3->channels); - #else - ma_int16* pFramesOutS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(ma_int16) * totalFramesRead * pMP3->channels); - ma_int16* pFramesInS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(ma_int16) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); - MA_DR_MP3_COPY_MEMORY(pFramesOutS16, pFramesInS16, sizeof(ma_int16) * framesToConsume * pMP3->channels); - #endif + #if defined(MA_DR_MP3_FLOAT_OUTPUT) + { + float* pFramesOutF32 = (float*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(float) * totalFramesRead * pMP3->channels); + float* pFramesInF32 = (float*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(float) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); + MA_DR_MP3_COPY_MEMORY(pFramesOutF32, pFramesInF32, sizeof(float) * framesToConsume * pMP3->channels); + } + #else + { + ma_int16* pFramesOutS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(ma_int16) * totalFramesRead * pMP3->channels); + ma_int16* pFramesInS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(ma_int16) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); + MA_DR_MP3_COPY_MEMORY(pFramesOutS16, pFramesInS16, sizeof(ma_int16) * framesToConsume * pMP3->channels); + } + #endif } pMP3->currentPCMFrame += framesToConsume; pMP3->pcmFramesConsumedInMP3Frame += framesToConsume; @@ -92879,6 +95139,9 @@ static ma_uint64 ma_dr_mp3_read_pcm_frames_raw(ma_dr_mp3* pMP3, ma_uint64 frames if (framesToRead == 0) { break; } + if (pMP3->totalPCMFrameCount != MA_UINT64_MAX && pMP3->totalPCMFrameCount > pMP3->paddingInPCMFrames && pMP3->currentPCMFrame >= (pMP3->totalPCMFrameCount - pMP3->paddingInPCMFrames)) { + break; + } MA_DR_MP3_ASSERT(pMP3->pcmFramesRemainingInMP3Frame == 0); if (ma_dr_mp3_decode_next_frame(pMP3) == 0) { break; @@ -92958,7 +95221,7 @@ static ma_bool32 ma_dr_mp3_seek_to_start_of_stream(ma_dr_mp3* pMP3) { MA_DR_MP3_ASSERT(pMP3 != NULL); MA_DR_MP3_ASSERT(pMP3->onSeek != NULL); - if (!ma_dr_mp3__on_seek(pMP3, 0, ma_dr_mp3_seek_origin_start)) { + if (!ma_dr_mp3__on_seek_64(pMP3, pMP3->streamStartOffset, MA_DR_MP3_SEEK_SET)) { return MA_FALSE; } ma_dr_mp3_reset(pMP3); @@ -93024,7 +95287,7 @@ static ma_bool32 ma_dr_mp3_seek_to_pcm_frame__seek_table(ma_dr_mp3* pMP3, ma_uin seekPoint.mp3FramesToDiscard = 0; seekPoint.pcmFramesToDiscard = 0; } - if (!ma_dr_mp3__on_seek_64(pMP3, seekPoint.seekPosInBytes, ma_dr_mp3_seek_origin_start)) { + if (!ma_dr_mp3__on_seek_64(pMP3, seekPoint.seekPosInBytes, MA_DR_MP3_SEEK_SET)) { return MA_FALSE; } ma_dr_mp3_reset(pMP3); @@ -93035,7 +95298,7 @@ static ma_bool32 ma_dr_mp3_seek_to_pcm_frame__seek_table(ma_dr_mp3* pMP3, ma_uin if (iMP3Frame == seekPoint.mp3FramesToDiscard-1) { pPCMFrames = (ma_dr_mp3d_sample_t*)pMP3->pcmFrames; } - pcmFramesRead = ma_dr_mp3_decode_next_frame_ex(pMP3, pPCMFrames); + pcmFramesRead = ma_dr_mp3_decode_next_frame_ex(pMP3, pPCMFrames, NULL, NULL); if (pcmFramesRead == 0) { return MA_FALSE; } @@ -93077,7 +95340,7 @@ MA_API ma_bool32 ma_dr_mp3_get_mp3_and_pcm_frame_count(ma_dr_mp3* pMP3, ma_uint6 totalMP3FrameCount = 0; for (;;) { ma_uint32 pcmFramesInCurrentMP3Frame; - pcmFramesInCurrentMP3Frame = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFramesInCurrentMP3Frame = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFramesInCurrentMP3Frame == 0) { break; } @@ -93101,10 +95364,26 @@ MA_API ma_bool32 ma_dr_mp3_get_mp3_and_pcm_frame_count(ma_dr_mp3* pMP3, ma_uint6 MA_API ma_uint64 ma_dr_mp3_get_pcm_frame_count(ma_dr_mp3* pMP3) { ma_uint64 totalPCMFrameCount; - if (!ma_dr_mp3_get_mp3_and_pcm_frame_count(pMP3, NULL, &totalPCMFrameCount)) { + if (pMP3 == NULL) { return 0; } - return totalPCMFrameCount; + if (pMP3->totalPCMFrameCount != MA_UINT64_MAX) { + totalPCMFrameCount = pMP3->totalPCMFrameCount; + if (totalPCMFrameCount >= pMP3->delayInPCMFrames) { + totalPCMFrameCount -= pMP3->delayInPCMFrames; + } else { + } + if (totalPCMFrameCount >= pMP3->paddingInPCMFrames) { + totalPCMFrameCount -= pMP3->paddingInPCMFrames; + } else { + } + return totalPCMFrameCount; + } else { + if (!ma_dr_mp3_get_mp3_and_pcm_frame_count(pMP3, NULL, &totalPCMFrameCount)) { + return 0; + } + return totalPCMFrameCount; + } } MA_API ma_uint64 ma_dr_mp3_get_mp3_frame_count(ma_dr_mp3* pMP3) { @@ -93174,7 +95453,7 @@ MA_API ma_bool32 ma_dr_mp3_calculate_seek_points(ma_dr_mp3* pMP3, ma_uint32* pSe MA_DR_MP3_ASSERT(pMP3->streamCursor >= pMP3->dataSize); mp3FrameInfo[iMP3Frame].bytePos = pMP3->streamCursor - pMP3->dataSize; mp3FrameInfo[iMP3Frame].pcmFrameIndex = runningPCMFrameCount; - pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFramesInCurrentMP3FrameIn == 0) { return MA_FALSE; } @@ -93198,7 +95477,7 @@ MA_API ma_bool32 ma_dr_mp3_calculate_seek_points(ma_dr_mp3* pMP3, ma_uint32* pSe } mp3FrameInfo[MA_DR_MP3_COUNTOF(mp3FrameInfo)-1].bytePos = pMP3->streamCursor - pMP3->dataSize; mp3FrameInfo[MA_DR_MP3_COUNTOF(mp3FrameInfo)-1].pcmFrameIndex = runningPCMFrameCount; - pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFramesInCurrentMP3FrameIn == 0) { pSeekPoints[iSeekPoint].seekPosInBytes = mp3FrameInfo[0].bytePos; pSeekPoints[iSeekPoint].pcmFrameIndex = nextTargetPCMFrame; @@ -93336,18 +95615,18 @@ static ma_int16* ma_dr_mp3__full_read_and_close_s16(ma_dr_mp3* pMP3, ma_dr_mp3_c } return pFrames; } -MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_mp3 mp3; - if (!ma_dr_mp3_init(&mp3, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_mp3_init(&mp3, onRead, onSeek, onTell, NULL, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_mp3__full_read_and_close_f32(&mp3, pConfig, pTotalFrameCount); } -MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_mp3 mp3; - if (!ma_dr_mp3_init(&mp3, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_mp3_init(&mp3, onRead, onSeek, onTell, NULL, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_mp3__full_read_and_close_s16(&mp3, pConfig, pTotalFrameCount); diff --git a/llama/llama.go b/llama/llama.go index 88672a03..c995b3ea 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -42,6 +42,7 @@ import ( _ "github.com/ollama/ollama/llama/llama.cpp/common" _ "github.com/ollama/ollama/llama/llama.cpp/src" _ "github.com/ollama/ollama/llama/llama.cpp/tools/mtmd" + "github.com/ollama/ollama/ml" ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" ) @@ -62,16 +63,21 @@ func BackendInit() { C.llama_backend_init() } -func EnumerateGPUs() []string { - var ids []string +func EnumerateGPUs() []ml.DeviceID { + var ids []ml.DeviceID for i := range C.ggml_backend_dev_count() { device := C.ggml_backend_dev_get(i) - if C.ggml_backend_dev_type(device) == C.GGML_BACKEND_DEVICE_TYPE_GPU { + switch C.ggml_backend_dev_type(device) { + case C.GGML_BACKEND_DEVICE_TYPE_GPU, + C.GGML_BACKEND_DEVICE_TYPE_IGPU: var props C.struct_ggml_backend_dev_props C.ggml_backend_dev_get_props(device, &props) - ids = append(ids, C.GoString(props.id)) + ids = append(ids, ml.DeviceID{ + ID: C.GoString(props.id), + Library: C.GoString(props.library), + }) } } @@ -112,7 +118,11 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla params.n_threads = C.int(threads) params.n_threads_batch = params.n_threads params.embeddings = C.bool(true) - params.flash_attn = C.bool(flashAttention) + if flashAttention { + params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_ENABLED + } else { + params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_DISABLED + } params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType)) params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType)) @@ -496,7 +506,12 @@ func (c *MtmdContext) Free() { C.mtmd_free(c.c) } -func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, error) { +type MtmdChunk struct { + Embed []float32 + Tokens []int +} + +func (c *MtmdContext) MultimodalTokenize(llamaContext *Context, data []byte) ([]MtmdChunk, error) { // Initialize the input chunks pointer ic := C.mtmd_input_chunks_init() defer C.mtmd_input_chunks_free(ic) @@ -515,35 +530,51 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, } nChunks := C.mtmd_input_chunks_size(ic) numEmbed := llamaContext.Model().NEmbd() - embed := make([][]float32, 0) + outChunks := make([]MtmdChunk, 0) for i := range int(nChunks) { chunk := C.mtmd_input_chunks_get(ic, C.size_t(i)) numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk)) slog.Debug("chunk tokens", "index", i, "numTokens", numTokens) - // Encode the chunk - if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) { - return nil, errors.New("unable to encode mtmd image chunk") - } + if C.mtmd_input_chunk_get_type(chunk) == C.MTMD_INPUT_CHUNK_TYPE_TEXT { + // If this is a text chunk, add the tokens + cNumTokens := C.size_t(0) + cTokens := C.mtmd_input_chunk_get_tokens_text(chunk, &cNumTokens) + cTokensArr := unsafe.Slice(cTokens, int(cNumTokens)) + tokens := make([]int, int(cNumTokens)) + for j := range int(cNumTokens) { + tokens[j] = int(cTokensArr[j]) + } + outChunks = append(outChunks, MtmdChunk{Tokens: tokens}) + } else { + // Otherwise, encode the image chunk to embeddings - // Get the embeddings for this chunk - chunkEmbed := make([][]float32, numTokens) - chunkEmbd := C.mtmd_get_output_embd(c.c) - if nil == chunkEmbd { - continue - } + // Encode the chunk + if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) { + return nil, errors.New("unable to encode mtmd image chunk") + } - // Extend the embedding array for each token - s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed) - rows := make([]float32, len(s)) - copy(rows, s) - for i := range numTokens { - chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed] + // Get the embeddings for this chunk + chunkEmbed := make([][]float32, numTokens) + chunkEmbd := C.mtmd_get_output_embd(c.c) + if nil == chunkEmbd { + return nil, errors.New("no mtmd image embedding") + } + + // Extend the embedding array for each token + s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed) + rows := make([]float32, len(s)) + copy(rows, s) + for i := range numTokens { + chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed] + } + for _, e := range chunkEmbed { + outChunks = append(outChunks, MtmdChunk{Embed: e}) + } } - embed = append(embed, chunkEmbed...) } - slog.Debug("image embeddings", "totalEmbeddings", len(embed)) - return embed, nil + slog.Debug("image tokenization chunks", "totalChunks", len(outChunks)) + return outChunks, nil } func (c *Context) Synchronize() { diff --git a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch index d62331d0..bcd60fb6 100644 --- a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch +++ b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch @@ -15,18 +15,18 @@ problem. ggml/src/ggml-backend.cpp | 9 +++++++-- ggml/src/ggml-cann/ggml-cann.cpp | 2 ++ ggml/src/ggml-cuda/ggml-cuda.cu | 3 +++ - ggml/src/ggml-metal/ggml-metal.m | 1 + + ggml/src/ggml-metal/ggml-metal.cpp | 2 ++ ggml/src/ggml-opencl/ggml-opencl.cpp | 1 + ggml/src/ggml-rpc/ggml-rpc.cpp | 1 + ggml/src/ggml-sycl/ggml-sycl.cpp | 3 +++ ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 ++ - 8 files changed, 20 insertions(+), 2 deletions(-) + 8 files changed, 21 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 1b9d29e9..97f47abd 100644 +index ff9135fe..8ba86f82 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp -@@ -107,7 +107,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { +@@ -113,7 +113,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { if (buffer->iface.free_buffer != NULL) { buffer->iface.free_buffer(buffer); } @@ -34,7 +34,7 @@ index 1b9d29e9..97f47abd 100644 } size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { -@@ -529,6 +528,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) +@@ -586,6 +585,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) free(ctx->buffers); free(ctx); @@ -42,9 +42,9 @@ index 1b9d29e9..97f47abd 100644 } static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { -@@ -1890,6 +1890,11 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { - +@@ -2075,6 +2075,11 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); ggml_aligned_free(buffer->context, buffer->size); + delete buffer; +} @@ -54,7 +54,7 @@ index 1b9d29e9..97f47abd 100644 } static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { -@@ -1937,7 +1942,7 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { +@@ -2127,7 +2132,7 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { }; static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = { @@ -64,10 +64,10 @@ index 1b9d29e9..97f47abd 100644 /* .init_tensor = */ NULL, // no initialization required /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp -index cf575b36..ca1addfa 100755 +index ad1adba6..7d44f74f 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp -@@ -826,6 +826,7 @@ static void ggml_backend_cann_buffer_free_buffer( +@@ -843,6 +843,7 @@ static void ggml_backend_cann_buffer_free_buffer( ggml_backend_cann_buffer_context* ctx = (ggml_backend_cann_buffer_context*)buffer->context; delete ctx; @@ -75,7 +75,7 @@ index cf575b36..ca1addfa 100755 } /** -@@ -1572,6 +1573,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf +@@ -1630,6 +1631,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf */ static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) { ACL_CHECK(aclrtFreeHost(buffer->context)); @@ -84,7 +84,7 @@ index cf575b36..ca1addfa 100755 /** diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index d9110491..37ee2a6d 100644 +index 856e9de2..c0b1e4c1 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -567,6 +567,7 @@ struct ggml_backend_cuda_buffer_context { @@ -111,23 +111,31 @@ index d9110491..37ee2a6d 100644 } static void * ggml_cuda_host_malloc(size_t size) { -diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m -index cb8eff4a..7bccc7bf 100644 ---- a/ggml/src/ggml-metal/ggml-metal.m -+++ b/ggml/src/ggml-metal/ggml-metal.m -@@ -6032,6 +6032,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) - } +diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp +index 7afc881f..bf096227 100644 +--- a/ggml/src/ggml-metal/ggml-metal.cpp ++++ b/ggml/src/ggml-metal/ggml-metal.cpp +@@ -25,6 +25,7 @@ static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t b + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); - free(ctx); -+ free(buffer); + ggml_metal_buffer_free(ctx); ++ delete buffer; } - static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) { + static void * ggml_backend_metal_buffer_shared_get_base(ggml_backend_buffer_t buffer) { +@@ -99,6 +100,7 @@ static void ggml_backend_metal_buffer_private_free_buffer(ggml_backend_buffer_t + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_free(ctx); ++ delete buffer; + } + + static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) { diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp -index 8ba1e00d..8163e8dc 100644 +index 79d21487..38c75018 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp -@@ -2745,6 +2745,7 @@ struct ggml_backend_opencl_buffer_context { +@@ -3212,6 +3212,7 @@ struct ggml_backend_opencl_buffer_context { static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; delete ctx; @@ -136,10 +144,10 @@ index 8ba1e00d..8163e8dc 100644 static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) { diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp -index df6ba540..2e395968 100644 +index aad48d62..a46c0f52 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp -@@ -486,6 +486,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { +@@ -528,6 +528,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0); RPC_STATUS_ASSERT(status); delete ctx; @@ -148,10 +156,10 @@ index df6ba540..2e395968 100644 static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp -index 3992dad0..67503951 100644 +index 45b8c216..4ec9a592 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp -@@ -331,6 +331,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try { +@@ -334,6 +334,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try { ggml_sycl_set_device(ctx->device); delete ctx; @@ -159,7 +167,7 @@ index 3992dad0..67503951 100644 } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ -@@ -792,6 +793,7 @@ struct ggml_backend_sycl_split_buffer_context { +@@ -795,6 +796,7 @@ struct ggml_backend_sycl_split_buffer_context { static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; delete ctx; @@ -167,7 +175,7 @@ index 3992dad0..67503951 100644 } static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) { -@@ -1134,6 +1136,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_ +@@ -1137,6 +1139,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_ static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_sycl_host_free(buffer->context); @@ -176,10 +184,10 @@ index 3992dad0..67503951 100644 static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index 4070e248..394a2839 100644 +index 3cd89c71..ed83236f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -@@ -10209,6 +10209,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { +@@ -11600,6 +11600,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; ggml_vk_destroy_buffer(ctx->dev_buffer); delete ctx; @@ -187,7 +195,7 @@ index 4070e248..394a2839 100644 } static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { -@@ -10352,6 +10353,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe +@@ -11743,6 +11744,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); ggml_vk_host_free(vk_instance.devices[0], buffer->context); diff --git a/llama/patches/0002-pretokenizer.patch b/llama/patches/0002-pretokenizer.patch index 15dcbc6c..aacb1566 100644 --- a/llama/patches/0002-pretokenizer.patch +++ b/llama/patches/0002-pretokenizer.patch @@ -10,10 +10,10 @@ logs instead of throwing an error 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index f7e03e70..8ebe11cf 100644 +index 7fffd171..0b6edaf4 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp -@@ -1804,16 +1804,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { +@@ -1812,16 +1812,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { if (type == LLAMA_VOCAB_TYPE_BPE) { add_space_prefix = false; clean_spaces = true; @@ -31,8 +31,8 @@ index f7e03e70..8ebe11cf 100644 pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if ( tokenizer_pre == "llama3" || -@@ -1975,7 +1966,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { - pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; +@@ -1992,7 +1983,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { + pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2; clean_spaces = false; } else { - throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); diff --git a/llama/patches/0003-clip-unicode.patch b/llama/patches/0003-clip-unicode.patch index 548a0da3..3ba3742b 100644 --- a/llama/patches/0003-clip-unicode.patch +++ b/llama/patches/0003-clip-unicode.patch @@ -10,7 +10,7 @@ filesystems for paths that include wide characters 1 file changed, 39 insertions(+) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp -index 20c21733..f4f69cfc 100644 +index 98e68af2..6699b75a 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -28,6 +28,19 @@ @@ -33,7 +33,7 @@ index 20c21733..f4f69cfc 100644 struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; enum ffn_op_type { -@@ -2597,7 +2610,29 @@ struct clip_model_loader { +@@ -2762,7 +2775,29 @@ struct clip_model_loader { { std::vector read_buf; @@ -63,7 +63,7 @@ index 20c21733..f4f69cfc 100644 if (!fin) { throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str())); } -@@ -2624,7 +2659,11 @@ struct clip_model_loader { +@@ -2789,7 +2824,11 @@ struct clip_model_loader { ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); } } diff --git a/llama/patches/0004-solar-pro.patch b/llama/patches/0004-solar-pro.patch index e2ece004..631cba2a 100644 --- a/llama/patches/0004-solar-pro.patch +++ b/llama/patches/0004-solar-pro.patch @@ -9,16 +9,16 @@ adds support for the Solar Pro architecture src/llama-arch.h | 3 + src/llama-hparams.cpp | 8 ++ src/llama-hparams.h | 5 + - src/llama-model-loader.cpp | 1 + + src/llama-model-loader.cpp | 2 +- src/llama-model.cpp | 207 +++++++++++++++++++++++++++++++++++++ src/llama-model.h | 3 + - 7 files changed, 248 insertions(+) + 7 files changed, 248 insertions(+), 1 deletion(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp -index 18dcc6dd..4b285646 100644 +index 869e4dcc..9f6b6ad2 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp -@@ -78,6 +78,7 @@ static const std::map LLM_ARCH_NAMES = { +@@ -81,6 +81,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_GRANITE_HYBRID, "granitehybrid" }, { LLM_ARCH_CHAMELEON, "chameleon" }, @@ -26,15 +26,15 @@ index 18dcc6dd..4b285646 100644 { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_PLM, "plm" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" }, -@@ -164,6 +165,7 @@ static const std::map LLM_KV_NAMES = { - { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, - { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, +@@ -179,6 +180,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, + { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, + { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, -@@ -1794,6 +1796,24 @@ static const std::map> LLM_TENSOR_N +@@ -1893,6 +1895,24 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, }, }, @@ -59,7 +59,7 @@ index 18dcc6dd..4b285646 100644 { LLM_ARCH_WAVTOKENIZER_DEC, { -@@ -2219,6 +2239,7 @@ static const std::map LLM_TENSOR_INFOS = { +@@ -2429,6 +2449,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, @@ -68,10 +68,10 @@ index 18dcc6dd..4b285646 100644 {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h -index 7af587e7..3ea994c7 100644 +index c3ae7165..dc7a362a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h -@@ -82,6 +82,7 @@ enum llm_arch { +@@ -85,6 +85,7 @@ enum llm_arch { LLM_ARCH_GRANITE_MOE, LLM_ARCH_GRANITE_HYBRID, LLM_ARCH_CHAMELEON, @@ -79,15 +79,15 @@ index 7af587e7..3ea994c7 100644 LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_PLM, LLM_ARCH_BAILINGMOE, -@@ -168,6 +169,7 @@ enum llm_kv { - LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, - LLM_KV_ATTENTION_SLIDING_WINDOW, +@@ -183,6 +184,7 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, + LLM_KV_ATTENTION_OUTPUT_SCALE, + LLM_KV_ATTENTION_TEMPERATURE_LENGTH, + LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, -@@ -394,6 +396,7 @@ enum llm_tensor { +@@ -432,6 +434,7 @@ enum llm_tensor { LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, @@ -96,10 +96,10 @@ index 7af587e7..3ea994c7 100644 LLM_TENSOR_CONVNEXT_DW, LLM_TENSOR_CONVNEXT_NORM, diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp -index 7a06368d..35fc054f 100644 +index db65d69e..b6bf6bbf 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp -@@ -146,6 +146,14 @@ uint32_t llama_hparams::n_pos_per_embd() const { +@@ -151,6 +151,14 @@ uint32_t llama_hparams::n_pos_per_embd() const { return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1; } @@ -115,10 +115,10 @@ index 7a06368d..35fc054f 100644 if (il < n_layer) { return swa_layers[il]; diff --git a/src/llama-hparams.h b/src/llama-hparams.h -index bd231224..29bd9056 100644 +index 4e7f73ec..80582728 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h -@@ -62,6 +62,8 @@ struct llama_hparams { +@@ -64,6 +64,8 @@ struct llama_hparams { std::array n_head_kv_arr; std::array n_ff_arr; @@ -127,7 +127,7 @@ index bd231224..29bd9056 100644 uint32_t n_layer_dense_lead = 0; uint32_t n_lora_q = 0; uint32_t n_lora_kv = 0; -@@ -220,6 +222,9 @@ struct llama_hparams { +@@ -248,6 +250,9 @@ struct llama_hparams { uint32_t n_pos_per_embd() const; @@ -135,25 +135,26 @@ index bd231224..29bd9056 100644 + bool n_bskcn(uint32_t n, uint32_t il) const; + bool is_swa(uint32_t il) const; - }; + bool has_kv(uint32_t il) const; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp -index f71c40f8..7eab9b68 100644 +index aa3a65f8..ee303bd5 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp -@@ -465,6 +465,7 @@ namespace GGUFMeta { - // TODO: this is not very clever - figure out something better +@@ -466,7 +466,7 @@ namespace GGUFMeta { template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); +- + template bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required); llama_model_loader::llama_model_loader( const std::string & fname, diff --git a/src/llama-model.cpp b/src/llama-model.cpp -index 58ca7df7..280129e1 100644 +index 36d495d6..74e1d162 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp -@@ -1706,6 +1706,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { +@@ -1865,6 +1865,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; @@ -175,7 +176,7 @@ index 58ca7df7..280129e1 100644 case LLM_ARCH_WAVTOKENIZER_DEC: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); -@@ -4793,6 +4808,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) { +@@ -5170,6 +5185,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -210,7 +211,7 @@ index 58ca7df7..280129e1 100644 layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); -@@ -15495,6 +15538,165 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { +@@ -16392,6 +16435,165 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { } }; @@ -229,7 +230,7 @@ index 58ca7df7..280129e1 100644 + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) -+ auto * inp_attn = build_attn_inp_kv_unified(); ++ auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + @@ -316,7 +317,7 @@ index 58ca7df7..280129e1 100644 + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, -+ Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); ++ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + @@ -376,7 +377,7 @@ index 58ca7df7..280129e1 100644 // ref: https://github.com/facebookresearch/chameleon // based on the original build_llama() function, changes: // * qk-norm -@@ -18439,6 +18641,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { +@@ -19827,6 +20029,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; @@ -387,7 +388,7 @@ index 58ca7df7..280129e1 100644 case LLM_ARCH_WAVTOKENIZER_DEC: { llm = std::make_unique(*this, params); -@@ -18652,6 +18858,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { +@@ -20057,6 +20263,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_CHAMELEON: @@ -396,10 +397,10 @@ index 58ca7df7..280129e1 100644 case LLM_ARCH_NEO_BERT: case LLM_ARCH_SMOLLM3: diff --git a/src/llama-model.h b/src/llama-model.h -index 6fcd74d5..09964533 100644 +index 7f48662f..ec3fbd33 100644 --- a/src/llama-model.h +++ b/src/llama-model.h -@@ -70,6 +70,7 @@ enum llm_type { +@@ -76,6 +76,7 @@ enum llm_type { LLM_TYPE_15B, LLM_TYPE_16B, LLM_TYPE_20B, @@ -407,9 +408,9 @@ index 6fcd74d5..09964533 100644 LLM_TYPE_27B, LLM_TYPE_30B, LLM_TYPE_32B, -@@ -367,6 +368,8 @@ struct llama_layer { - // openai-moe - struct ggml_tensor * attn_sinks = nullptr; +@@ -387,6 +388,8 @@ struct llama_layer { + struct ggml_tensor * ffn_act_beta = nullptr; + struct ggml_tensor * ffn_act_eps = nullptr; + struct ggml_tensor * bskcn_tv = nullptr; + diff --git a/llama/patches/0005-fix-deepseek-deseret-regex.patch b/llama/patches/0005-fix-deepseek-deseret-regex.patch index 1f8b5542..127fcc37 100644 --- a/llama/patches/0005-fix-deepseek-deseret-regex.patch +++ b/llama/patches/0005-fix-deepseek-deseret-regex.patch @@ -12,7 +12,7 @@ regex 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index 8ebe11cf..c011008f 100644 +index 0b6edaf4..3de95c67 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -299,7 +299,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { diff --git a/llama/patches/0006-maintain-ordering-for-rules-for-grammar.patch b/llama/patches/0006-maintain-ordering-for-rules-for-grammar.patch index 17bd3989..a923f137 100644 --- a/llama/patches/0006-maintain-ordering-for-rules-for-grammar.patch +++ b/llama/patches/0006-maintain-ordering-for-rules-for-grammar.patch @@ -8,10 +8,10 @@ Subject: [PATCH] maintain ordering for rules for grammar 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp -index 637891f5..98b8280f 100644 +index db1f0b23..f4de7e34 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp -@@ -307,7 +307,7 @@ private: +@@ -308,7 +308,7 @@ private: friend std::string build_grammar(const std::function & cb, const common_grammar_options & options); std::function _fetch_json; bool _dotall; diff --git a/llama/patches/0007-sort-devices-by-score.patch b/llama/patches/0007-sort-devices-by-score.patch index fa3522f9..22a084e8 100644 --- a/llama/patches/0007-sort-devices-by-score.patch +++ b/llama/patches/0007-sort-devices-by-score.patch @@ -11,10 +11,10 @@ with the fastest acceleration is loaded 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp -index 6c315137..3040b2aa 100644 +index 136afec7..f794d9cf 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp -@@ -162,7 +162,7 @@ struct ggml_backend_reg_entry { +@@ -175,7 +175,7 @@ struct ggml_backend_reg_entry { struct ggml_backend_registry { std::vector backends; @@ -23,7 +23,7 @@ index 6c315137..3040b2aa 100644 ggml_backend_registry() { #ifdef GGML_USE_CUDA -@@ -207,7 +207,7 @@ struct ggml_backend_registry { +@@ -223,7 +223,7 @@ struct ggml_backend_registry { } } @@ -32,7 +32,7 @@ index 6c315137..3040b2aa 100644 if (!reg) { return; } -@@ -218,15 +218,20 @@ struct ggml_backend_registry { +@@ -234,15 +234,20 @@ struct ggml_backend_registry { #endif backends.push_back({ reg, std::move(handle) }); for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) { @@ -56,7 +56,7 @@ index 6c315137..3040b2aa 100644 } ggml_backend_reg_t load_backend(const fs::path & path, bool silent) { -@@ -270,7 +275,7 @@ struct ggml_backend_registry { +@@ -286,7 +291,7 @@ struct ggml_backend_registry { GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_str(path).c_str()); @@ -65,7 +65,7 @@ index 6c315137..3040b2aa 100644 return reg; } -@@ -293,7 +298,7 @@ struct ggml_backend_registry { +@@ -309,7 +314,7 @@ struct ggml_backend_registry { // remove devices devices.erase( std::remove_if(devices.begin(), devices.end(), @@ -74,7 +74,7 @@ index 6c315137..3040b2aa 100644 devices.end()); // remove backend -@@ -351,7 +356,7 @@ size_t ggml_backend_dev_count() { +@@ -367,7 +372,7 @@ size_t ggml_backend_dev_count() { ggml_backend_dev_t ggml_backend_dev_get(size_t index) { GGML_ASSERT(index < ggml_backend_dev_count()); diff --git a/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch b/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch index aa64e1ed..43fc8a0b 100644 --- a/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch +++ b/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch @@ -8,10 +8,10 @@ Subject: [PATCH] add phony target ggml-cpu for all cpu variants 1 file changed, 2 insertions(+) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt -index 177fb282..f5a5079a 100644 +index 892c2331..09fdf5fc 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt -@@ -304,6 +304,7 @@ function(ggml_add_cpu_backend_variant tag_name) +@@ -310,6 +310,7 @@ function(ggml_add_cpu_backend_variant tag_name) endif() ggml_add_cpu_backend_variant_impl(${tag_name}) @@ -19,7 +19,7 @@ index 177fb282..f5a5079a 100644 endfunction() ggml_add_backend(CPU) -@@ -314,6 +315,7 @@ if (GGML_CPU_ALL_VARIANTS) +@@ -320,6 +321,7 @@ if (GGML_CPU_ALL_VARIANTS) elseif (GGML_CPU_ARM_ARCH) message(FATAL_ERROR "Cannot use both GGML_CPU_ARM_ARCH and GGML_CPU_ALL_VARIANTS") endif() diff --git a/llama/patches/0009-remove-amx.patch b/llama/patches/0009-remove-amx.patch index cde880ce..6b0b90f3 100644 --- a/llama/patches/0009-remove-amx.patch +++ b/llama/patches/0009-remove-amx.patch @@ -9,10 +9,10 @@ disable amx as it reduces performance on some systems 1 file changed, 4 deletions(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt -index f5a5079a..5158acd6 100644 +index 09fdf5fc..0609c650 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt -@@ -324,10 +324,6 @@ if (GGML_CPU_ALL_VARIANTS) +@@ -330,10 +330,6 @@ if (GGML_CPU_ALL_VARIANTS) ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512) ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI) ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI) diff --git a/llama/patches/0010-fix-string-arr-kv-loading.patch b/llama/patches/0010-fix-string-arr-kv-loading.patch index b6cf2e91..29a31349 100644 --- a/llama/patches/0010-fix-string-arr-kv-loading.patch +++ b/llama/patches/0010-fix-string-arr-kv-loading.patch @@ -25,7 +25,7 @@ index 79ee2020..3efb22f0 100644 // get ith C string from array with given key_id GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i); diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp -index 53504399..0f71d5f3 100644 +index 8cc4ef1c..d950dbdf 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -805,10 +805,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id @@ -53,10 +53,10 @@ index 53504399..0f71d5f3 100644 } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index c011008f..fa388b03 100644 +index 3de95c67..217ede47 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp -@@ -1760,9 +1760,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { +@@ -1768,9 +1768,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); if (precompiled_charsmap_keyidx != -1) { const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx); @@ -66,4 +66,4 @@ index c011008f..fa388b03 100644 + const size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx); const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); - #ifdef IS_BIG_ENDIAN + #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ diff --git a/llama/patches/0011-ollama-debug-tensor.patch b/llama/patches/0011-ollama-debug-tensor.patch index 5dcd6ee0..21edb8ba 100644 --- a/llama/patches/0011-ollama-debug-tensor.patch +++ b/llama/patches/0011-ollama-debug-tensor.patch @@ -8,7 +8,7 @@ Subject: [PATCH] ollama debug tensor 1 file changed, 6 insertions(+) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c -index d89cd8f4..a5689c18 100644 +index ba2a36d9..99509b0c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -15,6 +15,8 @@ @@ -20,7 +20,7 @@ index d89cd8f4..a5689c18 100644 #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) -@@ -2858,6 +2860,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { +@@ -2887,6 +2889,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { ggml_compute_forward(¶ms, node); diff --git a/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch b/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch index 2d373123..b4ad69cf 100644 --- a/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch +++ b/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch @@ -184,10 +184,10 @@ index f8c291de..2a3a62db 100644 const char * grammar_root, bool lazy, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp -index bfbf5fa2..11f93f42 100644 +index 55d2e355..da34526b 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp -@@ -1466,7 +1466,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { +@@ -1563,7 +1563,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); } @@ -196,7 +196,7 @@ index bfbf5fa2..11f93f42 100644 ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); -@@ -1548,7 +1548,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( +@@ -1645,7 +1645,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( /* .vocab = */ vocab, /* .grammar_str = */ grammar_str, /* .grammar_root = */ grammar_root, diff --git a/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch b/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch index 7e821c1e..f87c8c38 100644 --- a/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch +++ b/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch @@ -4,17 +4,18 @@ Date: Thu, 1 May 2025 13:45:12 -0700 Subject: [PATCH] add argsort and cuda copy for i32 --- - ggml/src/ggml-cpu/ops.cpp | 43 +++++++++++++ - ggml/src/ggml-cuda/argsort.cu | 102 ++++++++++++++++++++++++++++++- - ggml/src/ggml-cuda/cpy-utils.cuh | 6 ++ - ggml/src/ggml-cuda/cpy.cu | 43 +++++++++++++ - 4 files changed, 192 insertions(+), 2 deletions(-) + ggml/src/ggml-cpu/ops.cpp | 43 +++++++++++ + ggml/src/ggml-cuda/argsort.cu | 102 ++++++++++++++++++++++++++- + ggml/src/ggml-cuda/cpy-utils.cuh | 6 ++ + ggml/src/ggml-cuda/cpy.cu | 43 +++++++++++ + ggml/src/ggml-metal/ggml-metal.metal | 64 +++++++++++++++++ + 5 files changed, 256 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp -index 854f1c2b..a2924757 100644 +index 1c43865f..31478dd8 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp -@@ -8146,6 +8146,45 @@ static void ggml_compute_forward_argsort_f32( +@@ -7889,6 +7889,45 @@ static void ggml_compute_forward_argsort_f32( } } @@ -60,7 +61,7 @@ index 854f1c2b..a2924757 100644 void ggml_compute_forward_argsort( const ggml_compute_params * params, ggml_tensor * dst) { -@@ -8157,6 +8196,10 @@ void ggml_compute_forward_argsort( +@@ -7900,6 +7939,10 @@ void ggml_compute_forward_argsort( { ggml_compute_forward_argsort_f32(params, dst); } break; @@ -196,12 +197,12 @@ index 607ded85..53b02634 100644 + } } diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh -index 410c12b7..b8e9e107 100644 +index e621cb98..597c0c8b 100644 --- a/ggml/src/ggml-cuda/cpy-utils.cuh +++ b/ggml/src/ggml-cuda/cpy-utils.cuh -@@ -223,3 +223,9 @@ template +@@ -215,3 +215,9 @@ template static __device__ void cpy_1_flt(const char * cxi, char * cdsti) { - convert_flt((const src_t *)cxi, (dst_t *)cdsti); + *(dst_t *) cdsti = ggml_cuda_cast(*(const src_t *) cxi); } + +static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) { @@ -210,10 +211,10 @@ index 410c12b7..b8e9e107 100644 + *dst = *src; +} diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu -index f9bb0256..9c3774e5 100644 +index 746f4396..911220e9 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu -@@ -278,6 +278,47 @@ static void ggml_cpy_f32_iq4_nl_cuda( +@@ -277,6 +277,47 @@ static void ggml_cpy_f32_iq4_nl_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } @@ -261,7 +262,7 @@ index f9bb0256..9c3774e5 100644 void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); -@@ -369,6 +410,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg +@@ -372,6 +413,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); @@ -270,3 +271,80 @@ index f9bb0256..9c3774e5 100644 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { +diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal +index 74a9aa99..375a0c7f 100644 +--- a/ggml/src/ggml-metal/ggml-metal.metal ++++ b/ggml/src/ggml-metal/ggml-metal.metal +@@ -4346,8 +4346,72 @@ kernel void kernel_argsort_f32_i32( + } + } + ++typedef void (i32_argsort_t)( ++ constant ggml_metal_kargs_argsort & args, ++ device const int32_t * x, ++ device int32_t * dst, ++ threadgroup int32_t * shared_values [[threadgroup(0)]], ++ uint3 tgpig[[threadgroup_position_in_grid]], ++ uint3 tpitg[[thread_position_in_threadgroup]]); ++ ++template ++kernel void kernel_argsort_i32_i32( ++ constant ggml_metal_kargs_argsort & args, ++ device const int32_t * x, ++ device int32_t * dst, ++ threadgroup int32_t * shared_values [[threadgroup(0)]], ++ uint3 tgpig[[threadgroup_position_in_grid]], ++ uint3 tpitg[[thread_position_in_threadgroup]]) { ++ // bitonic sort ++ int col = tpitg[0]; ++ int row = tgpig[1]; ++ ++ if (col >= args.ncols_pad) return; ++ ++ device const int32_t * x_row = x + row * args.ncols; ++ threadgroup int32_t * dst_row = shared_values; ++ ++ // initialize indices ++ dst_row[col] = col; ++ ++ threadgroup_barrier(mem_flags::mem_threadgroup); ++ ++ for (int k = 2; k <= args.ncols_pad; k *= 2) { ++ for (int j = k / 2; j > 0; j /= 2) { ++ int ixj = col ^ j; ++ if (ixj > col) { ++ if ((col & k) == 0) { ++ if (dst_row[col] >= args.ncols || ++ (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ? ++ x_row[dst_row[col]] > x_row[dst_row[ixj]] : ++ x_row[dst_row[col]] < x_row[dst_row[ixj]])) ++ ) { ++ SWAP(dst_row[col], dst_row[ixj]); ++ } ++ } else { ++ if (dst_row[ixj] >= args.ncols || ++ (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ? ++ x_row[dst_row[col]] < x_row[dst_row[ixj]] : ++ x_row[dst_row[col]] > x_row[dst_row[ixj]])) ++ ) { ++ SWAP(dst_row[col], dst_row[ixj]); ++ } ++ } ++ } ++ threadgroup_barrier(mem_flags::mem_threadgroup); ++ } ++ } ++ ++ // copy the result to dst without the padding ++ if (col < args.ncols) { ++ dst[row * args.ncols + col] = dst_row[col]; ++ } ++} ++ + template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; + template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; ++template [[host_name("kernel_argsort_i32_i32_asc")]] kernel i32_argsort_t kernel_argsort_i32_i32; ++template [[host_name("kernel_argsort_i32_i32_desc")]] kernel i32_argsort_t kernel_argsort_i32_i32; + + kernel void kernel_leaky_relu_f32( + constant ggml_metal_kargs_leaky_relu & args, diff --git a/llama/patches/0014-graph-memory-reporting-on-failure.patch b/llama/patches/0014-graph-memory-reporting-on-failure.patch index 26fe8a8e..a3f0fc70 100644 --- a/llama/patches/0014-graph-memory-reporting-on-failure.patch +++ b/llama/patches/0014-graph-memory-reporting-on-failure.patch @@ -4,60 +4,50 @@ Date: Fri, 18 Apr 2025 15:58:19 -0700 Subject: [PATCH] graph memory reporting on failure --- - ggml/include/ggml-alloc.h | 6 ++++++ - ggml/include/ggml-backend.h | 6 ++++++ - ggml/src/ggml-alloc.c | 38 +++++++++++++++++++++++++++++++++---- - ggml/src/ggml-backend.cpp | 10 ++++++++++ - 4 files changed, 56 insertions(+), 4 deletions(-) + ggml/include/ggml-alloc.h | 1 + + ggml/include/ggml-backend.h | 1 + + ggml/src/ggml-alloc.c | 34 +++++++++++++++++++++++++++++++--- + ggml/src/ggml-backend.cpp | 7 +++++++ + 4 files changed, 40 insertions(+), 3 deletions(-) diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h -index 2cb150fd..781b1e10 100644 +index 2cb150fd..7ab3f019 100644 --- a/ggml/include/ggml-alloc.h +++ b/ggml/include/ggml-alloc.h -@@ -66,6 +66,12 @@ GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph +@@ -65,6 +65,7 @@ GGML_API bool ggml_gallocr_reserve_n( + GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph); GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id); ++GGML_API size_t ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id); -+struct ggml_allocr_buffer_status { -+ size_t size; -+ bool allocated; -+}; -+GGML_API struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id); -+ // Utils // Create a buffer and allocate all the tensors in a ggml_context - GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index a2977ea2..8a91b381 100644 +index f1b74078..c54ff98b 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h -@@ -304,6 +304,12 @@ extern "C" { +@@ -318,6 +318,7 @@ extern "C" { - GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); + GGML_API ggml_backend_buffer_type_t ggml_backend_sched_get_buffer_type(ggml_backend_sched_t sched, ggml_backend_t backend); + GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); ++ GGML_API size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); -+ struct ggml_backend_buffer_status { -+ size_t size; -+ bool allocated; -+ }; -+ GGML_API struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); -+ GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); - diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c -index 8b6e6028..41c8c4a2 100644 +index 929bc448..eee9d3b1 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c -@@ -350,6 +350,7 @@ struct node_alloc { +@@ -486,6 +486,7 @@ struct node_alloc { struct ggml_gallocr { ggml_backend_buffer_type_t * bufts; // [n_buffers] - ggml_backend_buffer_t * buffers; // [n_buffers] + struct vbuffer ** buffers; // [n_buffers] + size_t *buffer_sizes; // [n_buffers] struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers] int n_buffers; -@@ -373,6 +374,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs - galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t)); +@@ -509,6 +510,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs + galloc->buffers = calloc(n_bufs, sizeof(struct vbuffer *)); GGML_ASSERT(galloc->buffers != NULL); + galloc->buffer_sizes = calloc(n_bufs, sizeof(size_t)); @@ -66,7 +56,7 @@ index 8b6e6028..41c8c4a2 100644 galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *)); GGML_ASSERT(galloc->buf_tallocs != NULL); -@@ -439,6 +443,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { +@@ -576,6 +580,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { ggml_hash_set_free(&galloc->hash_set); free(galloc->hash_values); free(galloc->bufts); @@ -74,7 +64,7 @@ index 8b6e6028..41c8c4a2 100644 free(galloc->buffers); free(galloc->buf_tallocs); free(galloc->node_allocs); -@@ -734,6 +739,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c +@@ -869,6 +874,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c } } @@ -83,23 +73,21 @@ index 8b6e6028..41c8c4a2 100644 // reallocate buffers if needed for (int i = 0; i < galloc->n_buffers; i++) { // if the buffer type is used multiple times, we reuse the same buffer -@@ -755,15 +762,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c +@@ -898,14 +905,19 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c - ggml_backend_buffer_free(galloc->buffers[i]); - galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size); + ggml_vbuffer_free(galloc->buffers[i]); + galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); - if (galloc->buffers[i] == NULL) { + if (galloc->buffers[i]) { -+ galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]); -+ ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); ++ galloc->buffer_sizes[i] = ggml_vbuffer_size(galloc->buffers[i]); + } else { GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size); - return false; + galloc->buffer_sizes[i] = new_size; + success = false; } -- ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); + } else { -+ galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]); ++ galloc->buffer_sizes[i] = ggml_vbuffer_size(galloc->buffers[i]); } } @@ -108,11 +96,11 @@ index 8b6e6028..41c8c4a2 100644 } bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) { -@@ -920,6 +932,24 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { - return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]); +@@ -1060,6 +1072,22 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { + return ggml_vbuffer_size(galloc->buffers[buffer_id]); } -+struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) { ++size_t ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) { + GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers); + + for (int i = 0; i < buffer_id; i++) { @@ -121,36 +109,31 @@ index 8b6e6028..41c8c4a2 100644 + // (See above.) However, we need a different check because multiple buffers might be NULL in our + // case and we still want to know the attempted size. + -+ struct ggml_allocr_buffer_status status = {0, true}; -+ return status; ++ return 0; + } + } + -+ struct ggml_allocr_buffer_status status = {galloc->buffer_sizes[buffer_id], galloc->buffers[buffer_id] != NULL}; -+ return status; ++ return galloc->buffer_sizes[buffer_id]; +} + // utils static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 97f47abd..eded0291 100644 +index 8ba86f82..cb2b9956 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp -@@ -1631,6 +1631,16 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe +@@ -1809,6 +1809,13 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe return ggml_gallocr_get_buffer_size(sched->galloc, backend_index); } -+struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { ++size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { + int backend_index = ggml_backend_sched_backend_id(sched, backend); + GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); + -+ struct ggml_allocr_buffer_status allocr_status = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); -+ struct ggml_backend_buffer_status status = {allocr_status.size, allocr_status.allocated}; -+ -+ return status; ++ return ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); +} + void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { + GGML_ASSERT(sched); int backend_index = ggml_backend_sched_backend_id(sched, backend); - GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); diff --git a/llama/patches/0015-ggml-Export-GPU-UUIDs.patch b/llama/patches/0015-ggml-Export-GPU-UUIDs.patch index 22e1a724..b58d23d9 100644 --- a/llama/patches/0015-ggml-Export-GPU-UUIDs.patch +++ b/llama/patches/0015-ggml-Export-GPU-UUIDs.patch @@ -6,28 +6,28 @@ Subject: [PATCH] ggml: Export GPU UUIDs This enables matching up devices and information reported by the backend with tools (e.g. nvidia-smi) and system management libraries (e.g. nvml). --- - ggml/include/ggml-backend.h | 1 + - ggml/src/ggml-cuda/ggml-cuda.cu | 67 +++++++++++++++++++++++++++++--- - ggml/src/ggml-metal/ggml-metal.m | 1 + + ggml/include/ggml-backend.h | 1 + + ggml/src/ggml-cuda/ggml-cuda.cu | 67 +++++++++++++++++++++++++++--- + ggml/src/ggml-metal/ggml-metal.cpp | 1 + 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index 8a91b381..9424394e 100644 +index c54ff98b..229bf387 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h -@@ -152,6 +152,7 @@ extern "C" { - struct ggml_backend_dev_props { - const char * name; +@@ -158,6 +158,7 @@ extern "C" { const char * description; -+ const char * id; + // device free memory in bytes size_t memory_free; ++ const char * id; + // device total memory in bytes size_t memory_total; - enum ggml_backend_dev_type type; + // device type diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 37ee2a6d..57eae461 100644 +index c0b1e4c1..5b852f69 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -179,6 +179,51 @@ static int ggml_cuda_parse_id(char devName[]) { +@@ -183,6 +183,51 @@ static int ggml_cuda_parse_id(char devName[]) { } #endif // defined(GGML_USE_HIP) @@ -77,9 +77,9 @@ index 37ee2a6d..57eae461 100644 +} + static ggml_cuda_device_info ggml_cuda_init() { - #if defined(GGML_USE_HIP) - // Workaround for a rocBLAS bug when using multiple graphics cards: -@@ -267,22 +312,24 @@ static ggml_cuda_device_info ggml_cuda_init() { + ggml_cuda_device_info info = {}; + +@@ -249,22 +294,24 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc += prop.minor * 0x10; } } @@ -107,18 +107,18 @@ index 37ee2a6d..57eae461 100644 + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", + ggml_cuda_parse_uuid(prop, id).c_str()); - #endif // defined(GGML_USE_HIP) - } - -@@ -3144,6 +3191,7 @@ struct ggml_backend_cuda_device_context { - int device; + std::string device_name(prop.name); + if (device_name == "NVIDIA GeForce MX450") { + turing_devices_without_mma.push_back({ id, device_name }); +@@ -3276,6 +3323,7 @@ struct ggml_backend_cuda_device_context { std::string name; std::string description; + std::string pci_bus_id; + std::string id; }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { -@@ -3156,6 +3204,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t +@@ -3288,6 +3336,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t return ctx->description.c_str(); } @@ -130,31 +130,31 @@ index 37ee2a6d..57eae461 100644 static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_cuda_set_device(ctx->device); -@@ -3170,6 +3223,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend - static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { +@@ -3304,6 +3357,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back + props->name = ggml_backend_cuda_device_get_name(dev); props->description = ggml_backend_cuda_device_get_description(dev); + props->id = ggml_backend_cuda_device_get_id(dev); props->type = ggml_backend_cuda_device_get_type(dev); + props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); - -@@ -3767,6 +3821,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { +@@ -3873,6 +3927,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; + dev_ctx->id = ggml_cuda_parse_uuid(prop, i); - ggml_backend_dev_t dev = new ggml_backend_device { - /* .iface = */ ggml_backend_cuda_device_interface, -diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m -index 7bccc7bf..fe7b2f0a 100644 ---- a/ggml/src/ggml-metal/ggml-metal.m -+++ b/ggml/src/ggml-metal/ggml-metal.m -@@ -6522,6 +6522,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen - static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + char pci_bus_id[16] = {}; + snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); +diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp +index bf096227..f2ff9f32 100644 +--- a/ggml/src/ggml-metal/ggml-metal.cpp ++++ b/ggml/src/ggml-metal/ggml-metal.cpp +@@ -538,6 +538,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen + static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { props->name = ggml_backend_metal_device_get_name(dev); props->description = ggml_backend_metal_device_get_description(dev); + props->id = "0"; props->type = ggml_backend_metal_device_get_type(dev); + ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); - props->caps = (struct ggml_backend_dev_caps) { diff --git a/llama/patches/0016-add-C-API-for-mtmd_input_text.patch b/llama/patches/0016-add-C-API-for-mtmd_input_text.patch index 2c19ae6d..422d633b 100644 --- a/llama/patches/0016-add-C-API-for-mtmd_input_text.patch +++ b/llama/patches/0016-add-C-API-for-mtmd_input_text.patch @@ -10,11 +10,11 @@ Signed-off-by: Gabe Goodhart 2 files changed, 13 insertions(+) diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp -index a05373d5..6f70f7f4 100644 +index 4d487581..35a0d25e 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -79,6 +79,16 @@ enum mtmd_slice_tmpl { - // TODO @ngxson : add support for idefics (SmolVLM) + MTMD_SLICE_TMPL_IDEFICS3, }; +mtmd_input_text* mtmd_input_text_init(const char * text, bool add_special, bool parse_special) { diff --git a/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch b/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch index 14c50ca0..279e42c3 100644 --- a/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch +++ b/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch @@ -8,10 +8,10 @@ Subject: [PATCH] no power throttling win32 with gnuc 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c -index a5689c18..85af19a3 100644 +index 99509b0c..b13a491d 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c -@@ -2412,7 +2412,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { +@@ -2437,7 +2437,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { // Newer Windows 11 versions aggresively park (offline) CPU cores and often place // all our threads onto the first 4 cores which results in terrible performance with // n_threads > 4 diff --git a/llama/patches/0018-BF16-macos-version-guard.patch b/llama/patches/0018-BF16-macos-version-guard.patch index 6ebc3376..313d51be 100644 --- a/llama/patches/0018-BF16-macos-version-guard.patch +++ b/llama/patches/0018-BF16-macos-version-guard.patch @@ -5,23 +5,24 @@ Subject: [PATCH] BF16 macos version guard Only enable BF16 on supported MacOS versions (v14+) --- - ggml/src/ggml-metal/ggml-metal.m | 6 +++++- - 1 file changed, 5 insertions(+), 1 deletion(-) + ggml/src/ggml-metal/ggml-metal-context.m | 7 ++++++- + 1 file changed, 6 insertions(+), 1 deletion(-) -diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m -index fe7b2f0a..e4c31268 100644 ---- a/ggml/src/ggml-metal/ggml-metal.m -+++ b/ggml/src/ggml-metal/ggml-metal.m -@@ -106,7 +106,11 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev - ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6]; +diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m +index 052efb7a..b47dc787 100644 +--- a/ggml/src/ggml-metal/ggml-metal-context.m ++++ b/ggml/src/ggml-metal/ggml-metal-context.m +@@ -125,7 +125,12 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { + + res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); + +- res->use_bfloat = props_dev->has_bfloat; ++ if (@available(macOS 14.0, *)) { ++ res->use_bfloat = props_dev->has_bfloat; ++ } else { ++ res->use_bfloat = false; ++ } ++ + res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil; + res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil; - #if defined(GGML_METAL_USE_BF16) -- ctx->use_bfloat = ctx->has_bfloat; -+ if (@available(macOS 14.0, *)) { -+ ctx->use_bfloat = ctx->has_bfloat; -+ } else { -+ ctx->use_bfloat = false; -+ } - #else - ctx->use_bfloat = false; - #endif diff --git a/llama/patches/0019-Enable-CUDA-Graphs-for-gemma3n.patch b/llama/patches/0019-Enable-CUDA-Graphs-for-gemma3n.patch index db1303b3..85cba5b3 100644 --- a/llama/patches/0019-Enable-CUDA-Graphs-for-gemma3n.patch +++ b/llama/patches/0019-Enable-CUDA-Graphs-for-gemma3n.patch @@ -13,10 +13,10 @@ checks. 1 file changed, 18 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 57eae461..c7f9dc3a 100644 +index 5b852f69..827e3205 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -2671,12 +2671,24 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud +@@ -2689,14 +2689,26 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud // Loop over nodes in GGML graph to obtain info needed for CUDA graph cuda_ctx->cuda_graph->cpy_dest_ptrs.clear(); @@ -36,12 +36,14 @@ index 57eae461..c7f9dc3a 100644 const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased"; const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased"; const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; + const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out"; + const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d"; + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; -@@ -2700,6 +2712,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud +@@ -2720,6 +2732,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && diff --git a/llama/patches/0020-Disable-ggml-blas-on-macos-v13-and-older.patch b/llama/patches/0020-Disable-ggml-blas-on-macos-v13-and-older.patch index fbcbfd4f..e724663d 100644 --- a/llama/patches/0020-Disable-ggml-blas-on-macos-v13-and-older.patch +++ b/llama/patches/0020-Disable-ggml-blas-on-macos-v13-and-older.patch @@ -8,10 +8,10 @@ Subject: [PATCH] Disable ggml-blas on macos v13 and older 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp -index aeac2e57..40738d5b 100644 +index 5b888cdd..2a9ff7f6 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp -@@ -505,6 +505,11 @@ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = { +@@ -506,6 +506,11 @@ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = { }; ggml_backend_reg_t ggml_backend_blas_reg(void) { diff --git a/llama/patches/0022-ggml-No-alloc-mode.patch b/llama/patches/0022-ggml-No-alloc-mode.patch index fa738452..019cb886 100644 --- a/llama/patches/0022-ggml-No-alloc-mode.patch +++ b/llama/patches/0022-ggml-No-alloc-mode.patch @@ -3,35 +3,45 @@ From: Jesse Gross Date: Wed, 23 Jul 2025 11:58:49 -0700 Subject: [PATCH] ggml: No-alloc mode -Callers can set a backend buffer type to be no-alloc, meaning that +Callers can set a scheduler to be no-alloc, meaning that it does not allocate memory for tensors or operations. This can be used for calculating memory requirements. Tensors and graphs must be recreated with no-alloc set to false before loading data. - -Defaults to false for newly created backend buffer types. --- - ggml/include/ggml-backend.h | 1 + - ggml/src/ggml-backend-impl.h | 2 ++ - ggml/src/ggml-backend.cpp | 19 ++++++++++++++++++- - 3 files changed, 21 insertions(+), 1 deletion(-) + ggml/include/ggml-backend.h | 1 + + ggml/src/ggml-backend-impl.h | 16 +++ + ggml/src/ggml-backend.cpp | 72 ++++++++++- + ggml/src/ggml-cuda/common.cuh | 48 ++++++- + ggml/src/ggml-cuda/ggml-cuda.cu | 217 ++++++++++++++++++++++++++------ + 5 files changed, 310 insertions(+), 44 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index 9424394e..b602a7c7 100644 +index 229bf387..1ff53ed0 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h -@@ -35,6 +35,7 @@ extern "C" { - // +@@ -305,6 +305,7 @@ extern "C" { - GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft); -+ GGML_API void ggml_backend_buft_set_alloc (ggml_backend_buffer_type_t buft, bool alloc); - GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size); - GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft); - GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft); + // Initialize a backend scheduler, backends with low index are given priority over backends with high index + GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload); ++ GGML_API ggml_backend_sched_t ggml_backend_sched_new_ext(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload, bool alloc_buffers); + GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); + + // Initialize backend buffers from a measure graph diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h -index c36c12d6..81749a5a 100644 +index 6792ba98..3c3f22fc 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h -@@ -32,6 +32,7 @@ extern "C" { +@@ -26,12 +26,17 @@ extern "C" { + size_t (*get_alloc_size)(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); + // (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false) + bool (*is_host) (ggml_backend_buffer_type_t buft); ++ ++ // (optional) returns a dummy buffer that is equivalent to one created by alloc_buffer but without actually being backed ++ // by memory ++ ggml_backend_buffer_t (*noalloc_buffer)(ggml_backend_buffer_type_t buft, size_t size); + }; + + struct ggml_backend_buffer_type { struct ggml_backend_buffer_type_i iface; ggml_backend_dev_t device; void * context; @@ -39,7 +49,7 @@ index c36c12d6..81749a5a 100644 }; // -@@ -63,6 +64,7 @@ extern "C" { +@@ -63,6 +68,7 @@ extern "C" { void * context; size_t size; enum ggml_backend_buffer_usage usage; @@ -47,34 +57,48 @@ index c36c12d6..81749a5a 100644 }; GGML_API ggml_backend_buffer_t ggml_backend_buffer_init( +@@ -117,6 +123,16 @@ extern "C" { + + // (optional) sort/optimize the nodes in the graph + void (*graph_optimize) (ggml_backend_t backend, struct ggml_cgraph * cgraph); ++ ++ // (optional) reserves intermediate buffers needed for the compution ++ // if alloc is true, memory is actually allocated, otherwise the required amount is just returned by buffer_size ++ enum ggml_status (*graph_reserve) (ggml_backend_t backend, struct ggml_cgraph * cgraph, bool alloc); ++ ++ // (optional) returns the memory needed after calling graph_reserve ++ size_t (*buffer_size) (ggml_backend_t backend); ++ ++ // (optional) frees memory from intermediate buffers that was allocated either by graph_compute or graph_reserve ++ void (*reset) (ggml_backend_t backend); + }; + + struct ggml_backend { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index eded0291..05a842ed 100644 +index cb2b9956..6ef5eeaf 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp -@@ -35,12 +35,22 @@ const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) { - return buft->iface.get_name(buft); - } - -+void ggml_backend_buft_set_alloc(ggml_backend_buffer_type_t buft, bool alloc) { -+ buft->no_alloc = !alloc; -+} -+ - ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - if (size == 0) { - // return a dummy buffer for zero-sized allocations +@@ -41,6 +41,19 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t return ggml_backend_buffer_init(buft, {}, NULL, 0); } + if (buft->no_alloc) { -+ ggml_backend_buffer_t buf = ggml_backend_buffer_init(buft, {}, NULL, size); ++ ggml_backend_buffer_t buf; ++ ++ if (buft->iface.noalloc_buffer != NULL) { ++ buf = buft->iface.noalloc_buffer(buft, size); ++ } else { ++ buf = ggml_backend_buffer_init(buft, {}, NULL, size); ++ } ++ + buf->no_alloc = true; + return buf; + } + + GGML_ASSERT(buft); return buft->iface.alloc_buffer(buft, size); } - -@@ -89,7 +99,8 @@ ggml_backend_buffer_t ggml_backend_buffer_init( +@@ -95,7 +108,8 @@ ggml_backend_buffer_t ggml_backend_buffer_init( /* .buft = */ buft, /* .context = */ context, /* .size = */ size, @@ -84,7 +108,7 @@ index eded0291..05a842ed 100644 }; return buffer; -@@ -119,6 +130,12 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { +@@ -127,6 +141,12 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { return NULL; } @@ -97,3 +121,532 @@ index eded0291..05a842ed 100644 void * base = buffer->iface.get_base(buffer); GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL"); +@@ -723,6 +743,12 @@ struct ggml_backend_sched { + bool op_offload; + + int debug; ++ ++ // allocate buffers on attached ggml_backend_buffer_type_t's and during reservation ++ // if false, dummy buffers are used for faster memory sizing calculations ++ // the scheduler needs to be recreated with allocated buffers before it can be used ++ // for computation ++ bool alloc_buffers; + }; + + #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) +@@ -1606,6 +1632,17 @@ ggml_backend_sched_t ggml_backend_sched_new( + size_t graph_size, + bool parallel, + bool op_offload) { ++ return ggml_backend_sched_new_ext(backends, bufts, n_backends, graph_size, parallel, op_offload, true); ++ } ++ ++ggml_backend_sched_t ggml_backend_sched_new_ext( ++ ggml_backend_t * backends, ++ ggml_backend_buffer_type_t * bufts, ++ int n_backends, ++ size_t graph_size, ++ bool parallel, ++ bool op_offload, ++ bool alloc_buffers) { + GGML_ASSERT(n_backends > 0); + GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); + GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU); +@@ -1647,10 +1684,13 @@ ggml_backend_sched_t ggml_backend_sched_new( + sched->events[b][c] = ggml_backend_event_new(backends[b]->device); + } + } ++ ++ sched->bufts[b]->no_alloc = !alloc_buffers; + } + + sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); + sched->op_offload = op_offload; ++ sched->alloc_buffers = alloc_buffers; + + ggml_backend_sched_reset(sched); + +@@ -1665,6 +1705,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { + for (int c = 0; c < sched->n_copies; c++) { + ggml_backend_event_free(sched->events[b][c]); + } ++ ++ if (sched->backends[b]->iface.reset != NULL) { ++ sched->backends[b]->iface.reset(sched->backends[b]); ++ } + } + ggml_gallocr_free(sched->galloc); + ggml_free(sched->ctx); +@@ -1708,6 +1752,24 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * + return false; + } + ++ if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { ++ return false; ++ } ++ ++ struct ggml_backend_sched_split * splits = sched->splits; ++ for (int i = 0; i < sched->n_splits; i++) { ++ struct ggml_backend_sched_split * split = &splits[i]; ++ int split_backend_id = split->backend_id; ++ ggml_backend_t split_backend = sched->backends[split_backend_id]; ++ ++ if (split_backend->iface.graph_reserve != NULL) { ++ enum ggml_status ec = split_backend->iface.graph_reserve(split_backend, &split->graph, sched->alloc_buffers); ++ if (ec != GGML_STATUS_SUCCESS) { ++ return false; ++ } ++ } ++ } ++ + ggml_backend_sched_reset(sched); + + return true; +@@ -1813,7 +1875,13 @@ size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, + int backend_index = ggml_backend_sched_backend_id(sched, backend); + GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); + +- return ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); ++ size_t size = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); ++ ++ if (backend->iface.buffer_size != NULL) { ++ size += backend->iface.buffer_size(backend); ++ } ++ ++ return size; + } + + void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { +diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh +index e0abde54..28d6bcd7 100644 +--- a/ggml/src/ggml-cuda/common.cuh ++++ b/ggml/src/ggml-cuda/common.cuh +@@ -35,6 +35,31 @@ + #include "vendors/cuda.h" + #endif // defined(GGML_USE_HIP) + ++extern bool reserving_graph; ++ ++// If we are reserving the graph, pointers might be invalid and will fail if cudaMemcpyAsync tries to validate them. ++// However, since we don't actually expect a result, we don't need to actually do the memcpy. ++static cudaError_t cudaMemcpyAsyncReserve ( void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream = 0 ) { ++ if (!reserving_graph) { ++ return cudaMemcpyAsync(dst, src, count, kind, stream); ++ } else { ++ return cudaSuccess; ++ } ++} ++ ++static cudaError_t cudaMemcpy2DAsyncReserve ( void* dst, size_t dpitch, const void* src, size_t spitch, size_t width, size_t height, cudaMemcpyKind kind, cudaStream_t stream = 0 ) { ++ if (!reserving_graph) { ++ return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, kind, stream); ++ } else { ++ return cudaSuccess; ++ } ++} ++ ++#undef cudaMemcpyAsync ++#define cudaMemcpyAsync cudaMemcpyAsyncReserve ++#undef cudaMemcpy2DAsync ++#define cudaMemcpy2DAsync cudaMemcpy2DAsyncReserve ++ + #define STRINGIZE_IMPL(...) #__VA_ARGS__ + #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) + +@@ -856,6 +881,9 @@ struct ggml_cuda_pool { + + virtual void * alloc(size_t size, size_t * actual_size) = 0; + virtual void free(void * ptr, size_t size) = 0; ++ ++ virtual bool alloc_memory() = 0; ++ virtual size_t alloc_size() = 0; + }; + + template +@@ -999,11 +1027,11 @@ struct ggml_backend_cuda_context { + // pool + std::unique_ptr pools[GGML_CUDA_MAX_DEVICES]; + +- static std::unique_ptr new_pool_for_device(int device); ++ static std::unique_ptr new_pool_for_device(int device, bool alloc); + + ggml_cuda_pool & pool(int device) { + if (pools[device] == nullptr) { +- pools[device] = new_pool_for_device(device); ++ pools[device] = new_pool_for_device(device, true); + } + return *pools[device]; + } +@@ -1011,4 +1039,20 @@ struct ggml_backend_cuda_context { + ggml_cuda_pool & pool() { + return pool(device); + } ++ ++ void pool_set_alloc(bool alloc) { ++ GGML_ASSERT(pools[device] == nullptr || pools[device]->alloc_memory() == alloc); ++ ++ if (pools[device] == nullptr) { ++ pools[device] = new_pool_for_device(device, alloc); ++ } ++ } ++ ++ size_t pool_get_alloc_size() { ++ if (pools[device] == nullptr) { ++ return 0; ++ } ++ ++ return pools[device]->alloc_size(); ++ } + }; +diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu +index 827e3205..811462c7 100644 +--- a/ggml/src/ggml-cuda/ggml-cuda.cu ++++ b/ggml/src/ggml-cuda/ggml-cuda.cu +@@ -350,6 +350,8 @@ const ggml_cuda_device_info & ggml_cuda_info() { + + // #define DEBUG_CUDA_MALLOC + ++#define CUDA_ALIGNMENT 128 ++ + // buffer pool for cuda (legacy) + struct ggml_cuda_pool_leg : public ggml_cuda_pool { + static const int MAX_BUFFERS = 256; +@@ -362,9 +364,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { + + ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {}; + size_t pool_size = 0; ++ bool allocate = true; ++ size_t last_alloc = 0; + +- explicit ggml_cuda_pool_leg(int device) : +- device(device) { ++ explicit ggml_cuda_pool_leg(int device, bool alloc) : ++ device(device), ++ allocate(alloc) { + } + + ~ggml_cuda_pool_leg() { +@@ -372,7 +377,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { + for (int i = 0; i < MAX_BUFFERS; ++i) { + ggml_cuda_buffer & b = buffer_pool[i]; + if (b.ptr != nullptr) { +- CUDA_CHECK(cudaFree(b.ptr)); ++ if (allocate) { ++ CUDA_CHECK(cudaFree(b.ptr)); ++ } + pool_size -= b.size; + } + } +@@ -420,8 +427,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { + void * ptr; + size_t look_ahead_size = (size_t) (1.05 * size); + look_ahead_size = 256 * ((look_ahead_size + 255)/256); +- ggml_cuda_set_device(device); +- CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); ++ if (allocate) { ++ ggml_cuda_set_device(device); ++ if (ggml_cuda_device_malloc(&ptr, look_ahead_size, device) != cudaSuccess) { ++ last_alloc = look_ahead_size; ++ throw std::bad_alloc(); ++ } ++ } else { ++ ptr = (void *)CUDA_ALIGNMENT; ++ } + *actual_size = look_ahead_size; + pool_size += look_ahead_size; + #ifdef DEBUG_CUDA_MALLOC +@@ -441,10 +455,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { + } + } + GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n"); +- ggml_cuda_set_device(device); +- CUDA_CHECK(cudaFree(ptr)); ++ if (allocate) { ++ ggml_cuda_set_device(device); ++ CUDA_CHECK(cudaFree(ptr)); ++ } + pool_size -= size; + } ++ ++ bool alloc_memory() override { ++ return allocate; ++ } ++ ++ size_t alloc_size() override { ++ return pool_size + last_alloc; ++ } + }; + + // pool with virtual memory +@@ -456,18 +480,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { + CUdeviceptr pool_addr = 0; + size_t pool_used = 0; + size_t pool_size = 0; ++ bool allocate = true; ++ size_t last_alloc = 0; + size_t granularity; + #if defined(GGML_USE_HIP) + std::vector> mappings; + #endif + +- explicit ggml_cuda_pool_vmm(int device) : ++ explicit ggml_cuda_pool_vmm(int device, bool alloc) : + device(device), +- granularity(ggml_cuda_info().devices[device].vmm_granularity) { ++ granularity(ggml_cuda_info().devices[device].vmm_granularity), ++ allocate(alloc) { ++ if (!allocate) { ++ pool_addr = (CUdeviceptr)CUDA_ALIGNMENT; ++ } + } + + ~ggml_cuda_pool_vmm() { +- if (pool_addr != 0) { ++ if (pool_addr != 0 && allocate) { + #if defined(GGML_USE_HIP) + // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285 + for (std::pair & mapping : mappings) { +@@ -494,35 +524,49 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { + + GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE); + +- // allocate more physical memory +- CUmemAllocationProp prop = {}; +- prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; +- prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; +- prop.location.id = device; +- CUmemGenericAllocationHandle handle; +- CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0)); +- +- // reserve virtual address space (if not already reserved) +- if (pool_addr == 0) { +- CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0)); +- } ++ if (allocate) { ++ // allocate more physical memory ++ CUmemAllocationProp prop = {}; ++ prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; ++ prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; ++ prop.location.id = device; ++ CUmemGenericAllocationHandle handle; ++ if (cuMemCreate(&handle, reserve_size, &prop, 0) != CUDA_SUCCESS) { ++ last_alloc = reserve_size; ++ throw std::bad_alloc(); ++ } + +- // map at the end of the pool +- CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size); +- CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0)); +-#if defined(GGML_USE_HIP) +- mappings.push_back({start_ptr, reserve_size}); +-#endif ++ // reserve virtual address space (if not already reserved) ++ if (pool_addr == 0) { ++ CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0)); ++ } + +- // the memory allocation handle is no longer needed after mapping +- CU_CHECK(cuMemRelease(handle)); ++ // map at the end of the pool ++ CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size); ++ if (cuMemMap(start_ptr, reserve_size, 0, handle, 0) != CUDA_SUCCESS) { ++ last_alloc = reserve_size; ++ CU_CHECK(cuMemRelease(handle)); ++ throw std::bad_alloc(); ++ } ++ ++ // the memory allocation handle is no longer needed after mapping ++ CU_CHECK(cuMemRelease(handle)); ++ ++ // set access ++ CUmemAccessDesc access = {}; ++ access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; ++ access.location.id = device; ++ access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; ++ if (cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1) != CUDA_SUCCESS) { ++ CU_CHECK(cuMemUnmap(start_ptr, reserve_size)); ++ last_alloc = reserve_size; ++ throw std::bad_alloc(); ++ } + +- // set access +- CUmemAccessDesc access = {}; +- access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; +- access.location.id = device; +- access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; +- CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1)); ++ #if defined(GGML_USE_HIP) ++ mappings.push_back({start_ptr, reserve_size}); ++ #endif ++ } + + // add to the pool + pool_size += reserve_size; +@@ -555,16 +599,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { + // all deallocations must be in reverse order of the allocations + GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); + } ++ ++ bool alloc_memory() override { ++ return allocate; ++ } ++ ++ size_t alloc_size() override { ++ return pool_size + last_alloc; ++ } + }; + #endif // defined(GGML_USE_VMM) + +-std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { ++std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device, bool alloc) { + #if defined(GGML_USE_VMM) + if (ggml_cuda_info().devices[device].vmm) { +- return std::unique_ptr(new ggml_cuda_pool_vmm(device)); ++ return std::unique_ptr(new ggml_cuda_pool_vmm(device, alloc)); + } + #endif // defined(GGML_USE_VMM) +- return std::unique_ptr(new ggml_cuda_pool_leg(device)); ++ return std::unique_ptr(new ggml_cuda_pool_leg(device, alloc)); + } + + // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error +@@ -748,11 +800,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac + } + + static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { +- return 128; ++ return CUDA_ALIGNMENT; + + GGML_UNUSED(buft); + } + ++static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_noalloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ++ ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; ++ ++ void * dev_ptr = (void *)ggml_backend_cuda_buffer_type_get_alignment(buft); ++ ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr); ++ ++ return ggml_backend_buffer_init(buft, {}, ctx, size); ++} ++ + static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + size_t size = ggml_nbytes(tensor); + int64_t ne0 = tensor->ne[0]; +@@ -776,6 +837,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, + /* .is_host = */ NULL, ++ /* .noalloc_buffer = */ ggml_backend_cuda_buffer_type_noalloc_buffer, + }; + + ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { +@@ -3011,6 +3073,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, + + static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, + bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { ++ + // flag used to determine whether it is an integrated_gpu + const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; + +@@ -3026,6 +3089,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx + continue; + } + ++ // When reserving, we are forcing CUDA graphs but this operation is not graph-safe so we need to skip it ++ if (reserving_graph && node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) { ++ continue; ++ } ++ + static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); + if (!disable_fusion) { + +@@ -3152,6 +3220,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx + + static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ++ cuda_ctx->pool_set_alloc(true); + + ggml_cuda_set_device(cuda_ctx->device); + +@@ -3231,6 +3300,71 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, + return GGML_STATUS_SUCCESS; + } + ++// This is used to skip operations that are not graph safe during the reservation process. ++bool reserving_graph = false; ++ ++static enum ggml_status ggml_backend_cuda_graph_reserve(ggml_backend_t backend, ggml_cgraph * cgraph, bool alloc) { ++ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ++ cuda_ctx->pool_set_alloc(alloc); ++ ++ #ifdef USE_CUDA_GRAPH ++ if (cuda_ctx->cuda_graph == nullptr) { ++ cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); ++ } ++ #endif ++ ++ ggml_cuda_set_device(cuda_ctx->device); ++ ++ { ++ std::lock_guard lock(ggml_cuda_lock); ++ ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed); ++ } ++ ++ reserving_graph = true; ++ ++ // Create CuBLAS handles early to avoid synchronous allocations during graph capture. ++ cuda_ctx->cublas_handle(); ++ ++ CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); ++ ++ enum ggml_status result = GGML_STATUS_SUCCESS; ++ ++ try { ++ bool use_cuda_graph = false; ++ bool cuda_graph_update_required = false; ++ bool graph_evaluated_or_captured = false; ++ ++ evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); ++ } catch (const std::exception &e) { ++ result = GGML_STATUS_FAILED; ++ } ++ ++ cudaGraph_t graph; ++ CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph)); ++ CUDA_CHECK(cudaGraphDestroy(graph)); ++ ++ reserving_graph = false; ++ ++ { ++ std::lock_guard lock(ggml_cuda_lock); ++ if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) { ++ ggml_cuda_lock_cv.notify_all(); ++ } ++ } ++ ++ return result; ++} ++ ++static size_t ggml_backend_cuda_buffer_size(ggml_backend_t backend) { ++ ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context; ++ return ctx->pool_get_alloc_size(); ++} ++ ++static void ggml_backend_cuda_reset(ggml_backend_t backend) { ++ ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context; ++ ctx->pools[ctx->device] = NULL; ++} ++ + static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + +@@ -3271,6 +3405,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = { + /* .event_record = */ ggml_backend_cuda_event_record, + /* .event_wait = */ ggml_backend_cuda_event_wait, + /* .graph_optimize = */ NULL, ++ /* .graph_reserve = */ ggml_backend_cuda_graph_reserve, ++ /* .buffer_size = */ ggml_backend_cuda_buffer_size, ++ /* .reset = */ ggml_backend_cuda_reset, + }; + + static ggml_guid_t ggml_backend_cuda_guid() { diff --git a/llama/patches/0023-decode-disable-output_all.patch b/llama/patches/0023-decode-disable-output_all.patch index dc326ae6..ddf281bb 100644 --- a/llama/patches/0023-decode-disable-output_all.patch +++ b/llama/patches/0023-decode-disable-output_all.patch @@ -8,10 +8,10 @@ Subject: [PATCH] decode: disable output_all 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp -index 26a5cf9c..6ece5263 100644 +index e7526e7d..53a5e3a9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp -@@ -962,8 +962,7 @@ int llama_context::decode(const llama_batch & batch_inp) { +@@ -974,8 +974,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const int64_t n_vocab = vocab.n_tokens(); const int64_t n_embd = hparams.n_embd; diff --git a/llama/patches/0024-ggml-Enable-resetting-backend-devices.patch b/llama/patches/0024-ggml-Enable-resetting-backend-devices.patch index 84aefd1d..1cb10d93 100644 --- a/llama/patches/0024-ggml-Enable-resetting-backend-devices.patch +++ b/llama/patches/0024-ggml-Enable-resetting-backend-devices.patch @@ -10,15 +10,16 @@ unused then it can be reset to free these data structures. ggml/include/ggml-backend.h | 1 + ggml/src/ggml-backend-impl.h | 4 ++++ ggml/src/ggml-backend.cpp | 8 ++++++++ - ggml/src/ggml-cuda/ggml-cuda.cu | 17 +++++++++++++++-- + ggml/src/ggml-cuda/ggml-cuda.cu | 16 +++++++++++++++- ggml/src/ggml-cuda/vendors/hip.h | 1 + - 5 files changed, 29 insertions(+), 2 deletions(-) + src/llama.cpp | 4 +++- + 6 files changed, 32 insertions(+), 2 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index b602a7c78..fda5ceb24 100644 +index 1ff53ed03..ba181d09d 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h -@@ -167,6 +167,7 @@ extern "C" { +@@ -178,6 +178,7 @@ extern "C" { GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props); GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device); GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params); @@ -27,10 +28,10 @@ index b602a7c78..fda5ceb24 100644 GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device); GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size); diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h -index 81749a5a3..6f10c353b 100644 +index 3c3f22fc0..43c91d9f2 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h -@@ -178,6 +178,10 @@ extern "C" { +@@ -195,6 +195,10 @@ extern "C" { ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev); void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event); void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event); @@ -42,10 +43,10 @@ index 81749a5a3..6f10c353b 100644 struct ggml_backend_device { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 05a842ed5..6556943b0 100644 +index 6ef5eeafa..0b757af59 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp -@@ -477,6 +477,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par +@@ -526,6 +526,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par return device->iface.init_backend(device, params); } @@ -58,13 +59,13 @@ index 05a842ed5..6556943b0 100644 +} + ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) { + GGML_ASSERT(device); return device->iface.get_buffer_type(device); - } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index c7f9dc3a5..e43fde523 100644 +index 811462c79..87c6c34a4 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -103,6 +103,11 @@ int ggml_cuda_get_device() { +@@ -107,6 +107,11 @@ int ggml_cuda_get_device() { return id; } @@ -76,10 +77,10 @@ index c7f9dc3a5..e43fde523 100644 static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { ggml_cuda_set_device(device); cudaError_t err; -@@ -3243,7 +3248,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back - props->description = ggml_backend_cuda_device_get_description(dev); +@@ -3515,7 +3520,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back props->id = ggml_backend_cuda_device_get_id(dev); props->type = ggml_backend_cuda_device_get_type(dev); + props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); - ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); + + // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device). @@ -88,7 +89,7 @@ index c7f9dc3a5..e43fde523 100644 bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; #ifdef GGML_CUDA_NO_PEER_COPY -@@ -3700,6 +3708,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g +@@ -3948,6 +3956,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context)); } @@ -100,7 +101,7 @@ index c7f9dc3a5..e43fde523 100644 static const ggml_backend_device_i ggml_backend_cuda_device_interface = { /* .get_name = */ ggml_backend_cuda_device_get_name, /* .get_description = */ ggml_backend_cuda_device_get_description, -@@ -3716,6 +3729,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { +@@ -3964,6 +3977,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { /* .event_new = */ ggml_backend_cuda_device_event_new, /* .event_free = */ ggml_backend_cuda_device_event_free, /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize, @@ -108,19 +109,11 @@ index c7f9dc3a5..e43fde523 100644 }; // backend reg -@@ -3835,7 +3849,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { - dev_ctx->device = i; - dev_ctx->name = GGML_CUDA_NAME + std::to_string(i); - -- ggml_cuda_set_device(i); - cudaDeviceProp prop; - CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); - dev_ctx->description = prop.name; diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h -index c31f31923..cf22e60d2 100644 +index 890c10364..1f06be80e 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h -@@ -40,6 +40,7 @@ +@@ -45,6 +45,7 @@ #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceProp hipDeviceProp_t @@ -128,3 +121,21 @@ index c31f31923..cf22e60d2 100644 #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled +diff --git a/src/llama.cpp b/src/llama.cpp +index fe5a7a835..d821a96a0 100644 +--- a/src/llama.cpp ++++ b/src/llama.cpp +@@ -267,10 +267,12 @@ static struct llama_model * llama_model_load_from_file_impl( + for (auto * dev : model->devices) { + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); ++ size_t memory_free, memory_total; ++ ggml_backend_dev_memory(dev, &memory_free, &memory_total); + LLAMA_LOG_INFO("%s: using device %s (%s) (%s) - %zu MiB free\n", __func__, + ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + props.device_id ? props.device_id : "unknown id", +- props.memory_free/1024/1024); ++ memory_free/1024/1024); + } + + const int status = llama_model_load(path_model, splits, *model, params); diff --git a/llama/patches/0026-GPU-discovery-enhancements.patch b/llama/patches/0026-GPU-discovery-enhancements.patch new file mode 100644 index 00000000..b505f900 --- /dev/null +++ b/llama/patches/0026-GPU-discovery-enhancements.patch @@ -0,0 +1,919 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Daniel Hiltgen +Date: Tue, 26 Aug 2025 12:48:29 -0700 +Subject: [PATCH] GPU discovery enhancements + +Expose more information about the devices through backend props, and leverage +management libraries for more accurate VRAM usage reporting if available. +--- + ggml/include/ggml-backend.h | 9 + + ggml/src/CMakeLists.txt | 2 + + ggml/src/ggml-cuda/ggml-cuda.cu | 72 +++++ + ggml/src/ggml-cuda/vendors/hip.h | 3 + + ggml/src/ggml-impl.h | 8 + + ggml/src/ggml-metal/ggml-metal.cpp | 2 + + ggml/src/mem_hip.cpp | 449 +++++++++++++++++++++++++++++ + ggml/src/mem_nvml.cpp | 209 ++++++++++++++ + 8 files changed, 754 insertions(+) + create mode 100644 ggml/src/mem_hip.cpp + create mode 100644 ggml/src/mem_nvml.cpp + +diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h +index ba181d09..09ff75f9 100644 +--- a/ggml/include/ggml-backend.h ++++ b/ggml/include/ggml-backend.h +@@ -169,6 +169,17 @@ extern "C" { + const char * device_id; + // device capabilities + struct ggml_backend_dev_caps caps; ++ int driver_major; ++ int driver_minor; ++ int compute_major; ++ int compute_minor; ++ int integrated; ++ int pci_bus_id; ++ int pci_device_id; ++ int pci_domain_id; ++ const char *library; ++ // number with which the devices are accessed (Vulkan) ++ const char *numeric_id; + }; + + GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device); +diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt +index 0609c650..aefe43bd 100644 +--- a/ggml/src/CMakeLists.txt ++++ b/ggml/src/CMakeLists.txt +@@ -209,6 +209,8 @@ add_library(ggml-base + ggml-threading.h + ggml-quants.c + ggml-quants.h ++ mem_hip.cpp ++ mem_nvml.cpp + gguf.cpp) + + target_include_directories(ggml-base PRIVATE .) +diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu +index 87c6c34a..6a278b5e 100644 +--- a/ggml/src/ggml-cuda/ggml-cuda.cu ++++ b/ggml/src/ggml-cuda/ggml-cuda.cu +@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() { + for (int id = 0; id < info.device_count; ++id) { + int device_vmm = 0; + ++#if defined(GGML_USE_HIP) ++ if (std::getenv("GGML_CUDA_INIT") != NULL) { ++ GGML_LOG_INFO("%s: initializing rocBLAS on device %d\n", __func__, id); ++ CUDA_CHECK(cudaSetDevice(id)); ++ // rocblas_initialize will SIGABRT if the GPU isn't supported ++ rocblas_initialize(); ++ GGML_LOG_INFO("%s: rocBLAS initialized on device %d\n", __func__, id); ++ } ++#endif ++ + #if defined(GGML_USE_VMM) + CUdevice device; + CU_CHECK(cuDeviceGet(&device, id)); +@@ -314,6 +324,11 @@ static ggml_cuda_device_info ggml_cuda_init() { + #else + info.devices[id].smpbo = prop.sharedMemPerBlockOptin; + info.devices[id].cc = 100*prop.major + 10*prop.minor; ++#ifdef __CUDA_ARCH_LIST__ ++ if (std::getenv("GGML_CUDA_INIT") != NULL) { ++ GGML_ASSERT(ggml_cuda_has_arch(info.devices[id].cc) && "ggml was not compiled with support for this arch"); ++ } ++#endif // defined(__CUDA_ARCH_LIST__) + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", + ggml_cuda_parse_uuid(prop, id).c_str()); +@@ -3484,6 +3499,14 @@ struct ggml_backend_cuda_device_context { + std::string description; + std::string pci_bus_id; + std::string id; ++ int major; ++ int minor; ++ int driver_major; ++ int driver_minor; ++ int integrated; ++ int pciBusID; ++ int pciDeviceID; ++ int pciDomainID; + }; + + static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { +@@ -3504,6 +3527,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { + static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + ggml_cuda_set_device(ctx->device); ++ ++#if defined(GGML_USE_HIP) ++ if (ggml_hip_mgmt_init() == 0) { ++ int status = ggml_hip_get_device_memory(ctx->pciBusID, ctx->pciDeviceID, free, total); ++ if (status == 0) { ++ GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ ggml_hip_mgmt_release(); ++ return; ++ } ++ ggml_hip_mgmt_release(); ++ } ++#else ++ if (ggml_nvml_init() == 0) { ++ int status = ggml_nvml_get_device_memory(ctx->id.c_str(), free, total); ++ if (status == 0) { ++ GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ ggml_nvml_release(); ++ return; ++ } ++ ggml_nvml_release(); ++ } ++#endif + CUDA_CHECK(cudaMemGetInfo(free, total)); + } + +@@ -3512,6 +3557,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend + return GGML_BACKEND_DEVICE_TYPE_GPU; + } + ++#define GGML_HIP_NAME "HIP" + static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + +@@ -3525,6 +3571,22 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back + // If you need the memory data, call ggml_backend_dev_memory() explicitly. + props->memory_total = props->memory_free = 0; + ++#if defined(GGML_USE_HIP) ++ int cc = ggml_cuda_info().devices[ctx->device].cc - GGML_CUDA_CC_OFFSET_AMD; ++ props->compute_major = cc / 0x100; ++ props->compute_minor = cc - (props->compute_major * 0x100); ++#else ++ props->compute_major = ctx->major; ++ props->compute_minor = ctx->minor; ++#endif ++ props->driver_major = ctx->driver_major; ++ props->driver_minor = ctx->driver_minor; ++ props->integrated = ctx->integrated; ++ props->pci_bus_id = ctx->pciBusID; ++ props->pci_device_id = ctx->pciDeviceID; ++ props->pci_domain_id = ctx->pciDomainID; ++ props->library = GGML_CUDA_NAME; ++ + bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; + #ifdef GGML_CUDA_NO_PEER_COPY + bool events = false; +@@ -4087,6 +4149,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { + std::lock_guard lock(mutex); + if (!initialized) { + ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; ++ int driverVersion = 0; ++ CUDA_CHECK(cudaDriverGetVersion(&driverVersion)); + + for (int i = 0; i < ggml_cuda_info().device_count; i++) { + ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; +@@ -4102,6 +4166,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { + snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); + dev_ctx->pci_bus_id = pci_bus_id; + ++ dev_ctx->major = prop.major; ++ dev_ctx->minor = prop.minor; ++ dev_ctx->driver_major = driverVersion / 1000; ++ dev_ctx->driver_minor = (driverVersion - (dev_ctx->driver_major * 1000)) / 10; ++ dev_ctx->integrated = prop.integrated; ++ dev_ctx->pciBusID = prop.pciBusID; ++ dev_ctx->pciDeviceID = prop.pciDeviceID; ++ dev_ctx->pciDomainID = prop.pciDomainID; + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_cuda_device_interface, + /* .reg = */ ®, +diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h +index 1f06be80..2f9ef2dc 100644 +--- a/ggml/src/ggml-cuda/vendors/hip.h ++++ b/ggml/src/ggml-cuda/vendors/hip.h +@@ -5,6 +5,8 @@ + #include + #include + #include ++// for rocblas_initialize() ++#include "rocblas/rocblas.h" + + #if defined(GGML_HIP_ROCWMMA_FATTN) + #include +@@ -47,6 +49,7 @@ + #define cudaDeviceProp hipDeviceProp_t + #define cudaDeviceReset hipDeviceReset + #define cudaDeviceSynchronize hipDeviceSynchronize ++#define cudaDriverGetVersion hipDriverGetVersion + #define cudaError_t hipError_t + #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled + #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled +diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h +index d0fb3bcc..80597b6e 100644 +--- a/ggml/src/ggml-impl.h ++++ b/ggml/src/ggml-impl.h +@@ -638,6 +638,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx + return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops); + } + ++// Management libraries for fetching more accurate free VRAM data ++GGML_API int ggml_nvml_init(); ++GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total); ++GGML_API void ggml_nvml_release(); ++GGML_API int ggml_hip_mgmt_init(); ++GGML_API int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total); ++GGML_API void ggml_hip_mgmt_release(); ++ + #ifdef __cplusplus + } + #endif +diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp +index f2ff9f32..f356e4a0 100644 +--- a/ggml/src/ggml-metal/ggml-metal.cpp ++++ b/ggml/src/ggml-metal/ggml-metal.cpp +@@ -535,6 +535,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen + GGML_UNUSED(dev); + } + ++#define GGML_METAL_NAME "Metal" + static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_metal_device_get_name(dev); + props->description = ggml_backend_metal_device_get_description(dev); +@@ -543,6 +544,7 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac + + ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); + ++ props->library = GGML_METAL_NAME; + props->caps = { + /* .async = */ true, + /* .host_buffer = */ false, +diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp +new file mode 100644 +index 00000000..8ef19b8c +--- /dev/null ++++ b/ggml/src/mem_hip.cpp +@@ -0,0 +1,449 @@ ++#include "ggml.h" ++ ++#ifdef _WIN32 ++// AMD Device Library eXtra (ADLX) ++// ++// https://github.com/GPUOpen-LibrariesAndSDKs/ADLX ++// ++// This Windows-only library provides accurate VRAM reporting for AMD GPUs. ++// The runtime DLL is installed with every AMD Driver on Windows, however ++// the SDK isn't a part of the HIP SDK packaging. As such, we avoid including ++// the headers from the SDK to simplify building from source. ++// ++// ADLX relies heavily on function pointer tables. ++// Only the minimal set of types are defined below to facilitate ++// finding the target AMD GPU(s) and querying their current VRAM usage ++// Unused function parameters are commented out to avoid unnecessary type ++// definitions. ++ ++#include "ggml-impl.h" ++#include ++#include ++ ++#define WIN32_LEAN_AND_MEAN ++#ifndef NOMINMAX ++# define NOMINMAX ++#endif ++#include ++ ++namespace fs = std::filesystem; ++ ++#include ++#include ++ ++// Begin minimal ADLX definitions - derived from tag v1.0 (Dec 2022) ++typedef uint64_t adlx_uint64; ++typedef uint32_t adlx_uint32; ++typedef int32_t adlx_int32; ++typedef adlx_int32 adlx_int; ++typedef adlx_uint32 adlx_uint; ++typedef long adlx_long; ++typedef uint8_t adlx_uint8; ++typedef enum ++{ ++ ADLX_OK = 0, /**< @ENG_START_DOX This result indicates success. @ENG_END_DOX */ ++ ADLX_ALREADY_ENABLED, /**< @ENG_START_DOX This result indicates that the asked action is already enabled. @ENG_END_DOX */ ++ ADLX_ALREADY_INITIALIZED, /**< @ENG_START_DOX This result indicates that ADLX has a unspecified type of initialization. @ENG_END_DOX */ ++ ADLX_FAIL, /**< @ENG_START_DOX This result indicates an unspecified failure. @ENG_END_DOX */ ++ ADLX_INVALID_ARGS, /**< @ENG_START_DOX This result indicates that the arguments are invalid. @ENG_END_DOX */ ++ ADLX_BAD_VER, /**< @ENG_START_DOX This result indicates that the asked version is incompatible with the current version. @ENG_END_DOX */ ++ ADLX_UNKNOWN_INTERFACE, /**< @ENG_START_DOX This result indicates that an unknown interface was asked. @ENG_END_DOX */ ++ ADLX_TERMINATED, /**< @ENG_START_DOX This result indicates that the calls were made in an interface after ADLX was terminated. @ENG_END_DOX */ ++ ADLX_ADL_INIT_ERROR, /**< @ENG_START_DOX This result indicates that the ADL initialization failed. @ENG_END_DOX */ ++ ADLX_NOT_FOUND, /**< @ENG_START_DOX This result indicates that the item is not found. @ENG_END_DOX */ ++ ADLX_INVALID_OBJECT, /**< @ENG_START_DOX This result indicates that the method was called into an invalid object. @ENG_END_DOX */ ++ ADLX_ORPHAN_OBJECTS, /**< @ENG_START_DOX This result indicates that ADLX was terminated with outstanding ADLX objects. Any interface obtained from ADLX points to invalid memory and calls in their methods will result in unexpected behavior. @ENG_END_DOX */ ++ ADLX_NOT_SUPPORTED, /**< @ENG_START_DOX This result indicates that the asked feature is not supported. @ENG_END_DOX */ ++ ADLX_PENDING_OPERATION, /**< @ENG_START_DOX This result indicates a failure due to an operation currently in progress. @ENG_END_DOX */ ++ ADLX_GPU_INACTIVE /**< @ENG_START_DOX This result indicates that the GPU is inactive. @ENG_END_DOX */ ++} ADLX_RESULT; ++#define ADLX_SUCCEEDED(x) (ADLX_OK == (x) || ADLX_ALREADY_ENABLED == (x) || ADLX_ALREADY_INITIALIZED == (x)) ++#define ADLX_FAILED(x) (ADLX_OK != (x) && ADLX_ALREADY_ENABLED != (x) && ADLX_ALREADY_INITIALIZED != (x)) ++#define ADLX_VER_MAJOR 1 ++#define ADLX_VER_MINOR 0 ++#define ADLX_VER_RELEASE 5 ++#define ADLX_VER_BUILD_NUM 30 ++#define ADLX_MAKE_FULL_VER(VERSION_MAJOR, VERSION_MINOR, VERSION_RELEASE, VERSION_BUILD_NUM) ( ((adlx_uint64)(VERSION_MAJOR) << 48ull) | ((adlx_uint64)(VERSION_MINOR) << 32ull) | ((adlx_uint64)(VERSION_RELEASE) << 16ull) | (adlx_uint64)(VERSION_BUILD_NUM)) ++#define ADLX_FULL_VERSION ADLX_MAKE_FULL_VER(ADLX_VER_MAJOR, ADLX_VER_MINOR, ADLX_VER_RELEASE, ADLX_VER_BUILD_NUM) ++#define ADLX_CORE_LINK __declspec(dllexport) ++#define ADLX_STD_CALL __stdcall ++#define ADLX_CDECL_CALL __cdecl ++#define ADLX_FAST_CALL __fastcall ++#define ADLX_INLINE __inline ++#define ADLX_FORCEINLINE __forceinline ++#define ADLX_NO_VTABLE __declspec(novtable) ++ ++#if defined(__cplusplus) ++typedef bool adlx_bool; ++#else ++typedef adlx_uint8 adlx_bool; ++#define true 1 ++#define false 0 ++#endif ++ ++typedef struct IADLXSystem IADLXSystem; ++typedef struct IADLXGPUList IADLXGPUList; ++typedef struct IADLXGPU IADLXGPU; ++typedef struct IADLXInterface IADLXInterface; ++typedef struct IADLXPerformanceMonitoringServices IADLXPerformanceMonitoringServices; ++typedef struct IADLXGPUMetrics IADLXGPUMetrics; ++typedef struct IADLXGPUMetricsSupport IADLXGPUMetricsSupport; ++ ++typedef struct IADLXSystemVtbl ++{ ++ // IADLXSystem interface ++ ADLX_RESULT (ADLX_STD_CALL *GetHybridGraphicsType)(/* IADLXSystem* pThis, ADLX_HG_TYPE* hgType */); ++ ADLX_RESULT (ADLX_STD_CALL *GetGPUs)(IADLXSystem* pThis, IADLXGPUList** ppGPUs); // Used ++ ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXSystem* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ADLX_RESULT (ADLX_STD_CALL *GetDisplaysServices)(/* IADLXSystem* pThis, IADLXDisplayServices** ppDispServices */); ++ ADLX_RESULT (ADLX_STD_CALL *GetDesktopsServices)(/* IADLXSystem* pThis, IADLXDesktopServices** ppDeskServices */); ++ ADLX_RESULT (ADLX_STD_CALL *GetGPUsChangedHandling)(/* IADLXSystem* pThis, IADLXGPUsChangedHandling** ppGPUsChangedHandling */); ++ ADLX_RESULT (ADLX_STD_CALL *EnableLog)(/* IADLXSystem* pThis, ADLX_LOG_DESTINATION mode, ADLX_LOG_SEVERITY severity, IADLXLog* pLogger, const wchar_t* fileName */); ++ ADLX_RESULT (ADLX_STD_CALL *Get3DSettingsServices)(/* IADLXSystem* pThis, IADLX3DSettingsServices** pp3DSettingsServices */); ++ ADLX_RESULT (ADLX_STD_CALL *GetGPUTuningServices)(/* IADLXSystem* pThis, IADLXGPUTuningServices** ppGPUTuningServices */); ++ ADLX_RESULT (ADLX_STD_CALL *GetPerformanceMonitoringServices)(IADLXSystem* pThis, IADLXPerformanceMonitoringServices** ppPerformanceMonitoringServices); // Used ++ ADLX_RESULT (ADLX_STD_CALL *TotalSystemRAM)(/* IADLXSystem* pThis, adlx_uint* ramMB */); ++ ADLX_RESULT (ADLX_STD_CALL *GetI2C)(/* IADLXSystem* pThis, IADLXGPU* pGPU, IADLXI2C** ppI2C */); ++} IADLXSystemVtbl; ++struct IADLXSystem { const IADLXSystemVtbl *pVtbl; }; ++ ++typedef struct IADLXGPUVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXGPU* pThis */); ++ adlx_long (ADLX_STD_CALL *Release)(IADLXGPU* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXGPU* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXGPU ++ ADLX_RESULT (ADLX_STD_CALL *VendorId)(/* IADLXGPU* pThis, const char** vendorId */); ++ ADLX_RESULT (ADLX_STD_CALL *ASICFamilyType)(/* IADLXGPU* pThis, ADLX_ASIC_FAMILY_TYPE* asicFamilyType */); ++ ADLX_RESULT (ADLX_STD_CALL *Type)(/* IADLXGPU* pThis, ADLX_GPU_TYPE* gpuType */); ++ ADLX_RESULT (ADLX_STD_CALL *IsExternal)(/* IADLXGPU* pThis, adlx_bool* isExternal */); ++ ADLX_RESULT (ADLX_STD_CALL *Name)(/* IADLXGPU* pThis, const char** gpuName */); ++ ADLX_RESULT (ADLX_STD_CALL *DriverPath)(/* IADLXGPU* pThis, const char** driverPath */); ++ ADLX_RESULT (ADLX_STD_CALL *PNPString)(/* IADLXGPU* pThis, const char** pnpString */); ++ ADLX_RESULT (ADLX_STD_CALL *HasDesktops)(/* IADLXGPU* pThis, adlx_bool* hasDesktops */); ++ ADLX_RESULT (ADLX_STD_CALL *TotalVRAM)(IADLXGPU* pThis, adlx_uint* vramMB); // Used ++ ADLX_RESULT (ADLX_STD_CALL *VRAMType)(/* IADLXGPU* pThis, const char** type */); ++ ADLX_RESULT (ADLX_STD_CALL *BIOSInfo)(/* IADLXGPU* pThis, const char** partNumber, const char** version, const char** date */); ++ ADLX_RESULT (ADLX_STD_CALL *DeviceId)(/* IADLXGPU* pThis, const char** deviceId */); ++ ADLX_RESULT (ADLX_STD_CALL *RevisionId)(/* IADLXGPU* pThis, const char** revisionId */); ++ ADLX_RESULT (ADLX_STD_CALL *SubSystemId)(/* IADLXGPU* pThis, const char** subSystemId */); ++ ADLX_RESULT (ADLX_STD_CALL *SubSystemVendorId)(/* IADLXGPU* pThis, const char** subSystemVendorId */); ++ ADLX_RESULT (ADLX_STD_CALL *UniqueId)(IADLXGPU* pThis, adlx_int* uniqueId); // Used ++} IADLXGPUVtbl; ++struct IADLXGPU { const IADLXGPUVtbl *pVtbl; }; ++ ++typedef struct IADLXGPUListVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXGPUList* pThis */); ++ adlx_long (ADLX_STD_CALL *Release)(IADLXGPUList* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXGPUList* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXList ++ adlx_uint (ADLX_STD_CALL *Size)(/* IADLXGPUList* pThis */); ++ adlx_uint8 (ADLX_STD_CALL *Empty)(/* IADLXGPUList* pThis */); ++ adlx_uint (ADLX_STD_CALL *Begin)(IADLXGPUList* pThis); // Used ++ adlx_uint (ADLX_STD_CALL *End)(IADLXGPUList* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL *At)(/* IADLXGPUList* pThis, const adlx_uint location, IADLXInterface** ppItem */); ++ ADLX_RESULT (ADLX_STD_CALL *Clear)(/* IADLXGPUList* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *Remove_Back)(/* IADLXGPUList* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *Add_Back)(/* IADLXGPUList* pThis, IADLXInterface* pItem */); ++ ++ //IADLXGPUList ++ ADLX_RESULT (ADLX_STD_CALL *At_GPUList)(IADLXGPUList* pThis, const adlx_uint location, IADLXGPU** ppItem); // Used ++ ADLX_RESULT (ADLX_STD_CALL *Add_Back_GPUList)(/* IADLXGPUList* pThis, IADLXGPU* pItem */); ++ ++} IADLXGPUListVtbl; ++struct IADLXGPUList { const IADLXGPUListVtbl *pVtbl; }; ++ ++typedef struct IADLXPerformanceMonitoringServicesVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXPerformanceMonitoringServices* pThis */); ++ adlx_long (ADLX_STD_CALL *Release)(IADLXPerformanceMonitoringServices* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXPerformanceMonitoringServices* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXPerformanceMonitoringServices ++ ADLX_RESULT (ADLX_STD_CALL *GetSamplingIntervalRange)(/* IADLXPerformanceMonitoringServices* pThis, ADLX_IntRange* range */); ++ ADLX_RESULT (ADLX_STD_CALL *SetSamplingInterval)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int intervalMs */); ++ ADLX_RESULT (ADLX_STD_CALL *GetSamplingInterval)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* intervalMs */); ++ ADLX_RESULT (ADLX_STD_CALL *GetMaxPerformanceMetricsHistorySizeRange)(/* IADLXPerformanceMonitoringServices* pThis, ADLX_IntRange* range */); ++ ADLX_RESULT (ADLX_STD_CALL *SetMaxPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int sizeSec */); ++ ADLX_RESULT (ADLX_STD_CALL *GetMaxPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* sizeSec */); ++ ADLX_RESULT (ADLX_STD_CALL *ClearPerformanceMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* sizeSec */); ++ ADLX_RESULT (ADLX_STD_CALL *StartPerformanceMetricsTracking)(/* IADLXPerformanceMonitoringServices* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *StopPerformanceMetricsTracking)(/* IADLXPerformanceMonitoringServices* pThis */); ++ ADLX_RESULT (ADLX_STD_CALL *GetAllMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXAllMetricsList** ppMetricsList */); ++ ADLX_RESULT (ADLX_STD_CALL *GetGPUMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, adlx_int startMs, adlx_int stopMs, IADLXGPUMetricsList** ppMetricsList */); ++ ADLX_RESULT (ADLX_STD_CALL *GetSystemMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXSystemMetricsList** ppMetricsList */); ++ ADLX_RESULT (ADLX_STD_CALL *GetFPSHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXFPSList** ppMetricsList */); ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentAllMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXAllMetrics** ppMetrics */); ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentGPUMetrics)(IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, IADLXGPUMetrics** ppMetrics); // Used ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentSystemMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXSystemMetrics** ppMetrics */); ++ ADLX_RESULT (ADLX_STD_CALL *GetCurrentFPS)(/* IADLXPerformanceMonitoringServices* pThis, IADLXFPS** ppMetrics */); ++ ADLX_RESULT (ADLX_STD_CALL *GetSupportedGPUMetrics)(IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, IADLXGPUMetricsSupport** ppMetricsSupported); // Used ++ ADLX_RESULT (ADLX_STD_CALL *GetSupportedSystemMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXSystemMetricsSupport** ppMetricsSupported */); ++}IADLXPerformanceMonitoringServicesVtbl; ++struct IADLXPerformanceMonitoringServices { const IADLXPerformanceMonitoringServicesVtbl *pVtbl; }; ++ ++typedef struct IADLXGPUMetricsSupportVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL* Acquire)(/* IADLXGPUMetricsSupport* pThis */); ++ adlx_long (ADLX_STD_CALL* Release)(IADLXGPUMetricsSupport* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL* QueryInterface)(/* IADLXGPUMetricsSupport* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXGPUMetricsSupport ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUUsage)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUClockSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVRAMClockSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUTemperature)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUHotspotTemperature)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUPower)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUTotalBoardPower)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUFanSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVRAM)(IADLXGPUMetricsSupport* pThis, adlx_bool* supported); // Used ++ ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVoltage)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); ++ ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUUsageRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUClockSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUVRAMClockSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUTemperatureRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUHotspotTemperatureRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUPowerRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUFanSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUVRAMRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUVoltageRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++ ADLX_RESULT (ADLX_STD_CALL* GetGPUTotalBoardPowerRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); ++} IADLXGPUMetricsSupportVtbl; ++struct IADLXGPUMetricsSupport { const IADLXGPUMetricsSupportVtbl *pVtbl; }; ++ ++typedef struct IADLXGPUMetricsVtbl ++{ ++ //IADLXInterface ++ adlx_long (ADLX_STD_CALL* Acquire)(/* IADLXGPUMetrics* pThis */); ++ adlx_long (ADLX_STD_CALL* Release)(IADLXGPUMetrics* pThis); // Used ++ ADLX_RESULT (ADLX_STD_CALL* QueryInterface)(/* IADLXGPUMetrics* pThis, const wchar_t* interfaceId, void** ppInterface */); ++ ++ //IADLXGPUMetrics ++ ADLX_RESULT (ADLX_STD_CALL* TimeStamp)(/* IADLXGPUMetrics* pThis, adlx_int64* ms */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUUsage)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUClockSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUVRAMClockSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUTemperature)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUHotspotTemperature)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUPower)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUTotalBoardPower)(/* IADLXGPUMetrics* pThis, adlx_double* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUFanSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); ++ ADLX_RESULT (ADLX_STD_CALL* GPUVRAM)(IADLXGPUMetrics* pThis, adlx_int* data); // Used ++ ADLX_RESULT (ADLX_STD_CALL* GPUVoltage)(/* IADLXGPUMetrics* pThis, adlx_int* data */); ++} IADLXGPUMetricsVtbl; ++struct IADLXGPUMetrics { const IADLXGPUMetricsVtbl *pVtbl; }; ++ ++struct { ++ void *handle; ++ ADLX_RESULT (*ADLXInitialize)(adlx_uint64 version, IADLXSystem** ppSystem); ++ ADLX_RESULT (*ADLXInitializeWithIncompatibleDriver)(adlx_uint64 version, IADLXSystem** ppSystem); ++ ADLX_RESULT (*ADLXQueryVersion)(const char** version); ++ ADLX_RESULT (*ADLXTerminate)(); ++ IADLXSystem *sys; ++} adlx { NULL, NULL, NULL, NULL, NULL, NULL }; ++static std::mutex ggml_adlx_lock; ++ ++extern "C" { ++ ++int ggml_hip_mgmt_init() { ++ std::lock_guard lock(ggml_adlx_lock); ++ if (adlx.handle != NULL) { ++ // Already initialized ++ return 0; ++ } ++ DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); ++ SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); ++ fs::path libPath = fs::path("\\Windows") / fs::path("System32") / fs::path("amdadlx64.dll"); ++ ++ adlx.handle = (void*)LoadLibraryW(libPath.wstring().c_str()); ++ if (adlx.handle == NULL) { ++ return ADLX_NOT_FOUND; ++ } ++ ++ adlx.ADLXInitialize = (ADLX_RESULT (*)(adlx_uint64 version, IADLXSystem **ppSystem)) GetProcAddress((HMODULE)(adlx.handle), "ADLXInitialize"); ++ adlx.ADLXInitializeWithIncompatibleDriver = (ADLX_RESULT (*)(adlx_uint64 version, IADLXSystem **ppSystem)) GetProcAddress((HMODULE)(adlx.handle), "ADLXInitializeWithIncompatibleDriver"); ++ adlx.ADLXTerminate = (ADLX_RESULT (*)()) GetProcAddress((HMODULE)(adlx.handle), "ADLXTerminate"); ++ adlx.ADLXQueryVersion = (ADLX_RESULT (*)(const char **version)) GetProcAddress((HMODULE)(adlx.handle), "ADLXQueryVersion"); ++ if (adlx.ADLXInitialize == NULL || adlx.ADLXInitializeWithIncompatibleDriver == NULL || adlx.ADLXTerminate == NULL) { ++ GGML_LOG_INFO("%s unable to locate required symbols in amdadlx64.dll, falling back to hip free memory reporting", __func__); ++ FreeLibrary((HMODULE)(adlx.handle)); ++ adlx.handle = NULL; ++ return ADLX_NOT_FOUND; ++ } ++ ++ SetErrorMode(old_mode); ++ ++ // Aid in troubleshooting... ++ if (adlx.ADLXQueryVersion != NULL) { ++ const char *version = NULL; ++ ADLX_RESULT status = adlx.ADLXQueryVersion(&version); ++ if (ADLX_SUCCEEDED(status)) { ++ GGML_LOG_DEBUG("%s located ADLX version %s\n", __func__, version); ++ } ++ } ++ ++ ADLX_RESULT status = adlx.ADLXInitialize(ADLX_FULL_VERSION, &adlx.sys); ++ if (ADLX_FAILED(status)) { ++ // GGML_LOG_DEBUG("%s failed to initialize ADLX error=%d - attempting with incompatible driver...\n", __func__, status); ++ // Try with the incompatible driver ++ status = adlx.ADLXInitializeWithIncompatibleDriver(ADLX_FULL_VERSION, &adlx.sys); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s failed to initialize ADLX error=%d\n", __func__, status); ++ FreeLibrary((HMODULE)(adlx.handle)); ++ adlx.handle = NULL; ++ adlx.sys = NULL; ++ return status; ++ } ++ // GGML_LOG_DEBUG("%s initialized ADLX with incpomatible driver\n", __func__); ++ } ++ return ADLX_OK; ++} ++ ++void ggml_hip_mgmt_release() { ++ std::lock_guard lock(ggml_adlx_lock); ++ if (adlx.handle == NULL) { ++ // Already free ++ return; ++ } ++ ADLX_RESULT status = adlx.ADLXTerminate(); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s failed to terminate Adlx %d\n", __func__, status); ++ // Unload anyway... ++ } ++ FreeLibrary((HMODULE)(adlx.handle)); ++ adlx.handle = NULL; ++} ++ ++#define adlx_gdm_cleanup \ ++ if (gpuMetricsSupport != NULL) gpuMetricsSupport->pVtbl->Release(gpuMetricsSupport); \ ++ if (gpuMetrics != NULL) gpuMetrics->pVtbl->Release(gpuMetrics); \ ++ if (perfMonitoringServices != NULL) perfMonitoringServices->pVtbl->Release(perfMonitoringServices); \ ++ if (gpus != NULL) gpus->pVtbl->Release(gpus); \ ++ if (gpu != NULL) gpu->pVtbl->Release(gpu) ++ ++int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total) { ++ std::lock_guard lock(ggml_adlx_lock); ++ if (adlx.handle == NULL) { ++ GGML_LOG_INFO("%s ADLX was not initialized\n", __func__); ++ return ADLX_ADL_INIT_ERROR; ++ } ++ IADLXGPUMetricsSupport *gpuMetricsSupport = NULL; ++ IADLXPerformanceMonitoringServices *perfMonitoringServices = NULL; ++ IADLXGPUList* gpus = NULL; ++ IADLXGPU* gpu = NULL; ++ IADLXGPUMetrics *gpuMetrics = NULL; ++ ADLX_RESULT status; ++ // The "UniqueID" exposed in ADLX is the PCI Bus and Device IDs ++ adlx_int target = (pci_bus_id << 8) | (pci_device_id & 0xff); ++ ++ status = adlx.sys->pVtbl->GetPerformanceMonitoringServices(adlx.sys, &perfMonitoringServices); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GetPerformanceMonitoringServices failed %d\n", __func__, status); ++ return status; ++ } ++ ++ status = adlx.sys->pVtbl->GetGPUs(adlx.sys, &gpus); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GetGPUs failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ ++ // Get GPU list ++ for (adlx_uint crt = gpus->pVtbl->Begin(gpus); crt != gpus->pVtbl->End(gpus); ++crt) ++ { ++ status = gpus->pVtbl->At_GPUList(gpus, crt, &gpu); ++ if (ADLX_FAILED(status)) ++ { ++ GGML_LOG_INFO("%s %d] At_GPUList failed %d\n", __func__, crt, status); ++ continue; ++ } ++ adlx_int id; ++ status = gpu->pVtbl->UniqueId(gpu, &id); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s %d] UniqueId lookup failed %d\n", __func__, crt, status); ++ gpu->pVtbl->Release(gpu); ++ gpu = NULL; ++ continue; ++ } ++ if (id != target) { ++ GGML_LOG_DEBUG("%s %d] GPU UniqueId: %x does not match target %02x %02x\n", __func__, crt, id, pci_bus_id, pci_device_id); ++ gpu->pVtbl->Release(gpu); ++ gpu = NULL; ++ continue; ++ } ++ // Any failures at this point should cause a fall-back to other APIs ++ status = perfMonitoringServices->pVtbl->GetSupportedGPUMetrics(perfMonitoringServices, gpu, &gpuMetricsSupport); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GetSupportedGPUMetrics failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ status = perfMonitoringServices->pVtbl->GetCurrentGPUMetrics(perfMonitoringServices, gpu, &gpuMetrics); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GetCurrentGPUMetrics failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ ++ adlx_bool supported = false; ++ status = gpuMetricsSupport->pVtbl->IsSupportedGPUVRAM(gpuMetricsSupport, &supported); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s IsSupportedGPUVRAM failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ ++ adlx_uint totalVRAM = 0; ++ status = gpu->pVtbl->TotalVRAM(gpu, &totalVRAM); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s TotalVRAM failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ ++ adlx_int usedVRAM = 0; ++ status = gpuMetrics->pVtbl->GPUVRAM(gpuMetrics, &usedVRAM); ++ if (ADLX_FAILED(status)) { ++ GGML_LOG_INFO("%s GPUVRAM failed %d\n", __func__, status); ++ adlx_gdm_cleanup; ++ return status; ++ } ++ *total = size_t(totalVRAM) * 1024 * 1024; ++ *free = size_t(totalVRAM-usedVRAM) * 1024 * 1024; ++ ++ adlx_gdm_cleanup; ++ return ADLX_OK; ++ } ++ adlx_gdm_cleanup; ++ return ADLX_NOT_FOUND; ++} ++ ++} // extern "C" ++ ++#else // #ifdef _WIN32 ++ ++extern "C" { ++ ++// TODO Linux implementation of accurate VRAM reporting ++int ggml_hip_mgmt_init() { ++ return -1; ++} ++void ggml_hip_mgmt_release() {} ++int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total) { ++ return -1; ++} ++ ++} // extern "C" ++ ++#endif // #ifdef _WIN32 +\ No newline at end of file +diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp +new file mode 100644 +index 00000000..c9073cef +--- /dev/null ++++ b/ggml/src/mem_nvml.cpp +@@ -0,0 +1,209 @@ ++// NVIDIA Management Library (NVML) ++// ++// https://developer.nvidia.com/management-library-nvml ++// ++// This library provides accurate VRAM reporting for NVIDIA GPUs, particularly ++// on Windows, where the cuda library provides inaccurate VRAM usage metrics. The ++// runtime DLL is installed with every driver on Windows, and most Linux ++// systems, and the headers are included in the standard CUDA SDK install. As ++// such, we can include the header here to simplify the code. ++ ++ ++#include "ggml-impl.h" ++#include ++#include ++#include ++ ++#ifdef _WIN32 ++# define WIN32_LEAN_AND_MEAN ++# ifndef NOMINMAX ++# define NOMINMAX ++# endif ++# include ++#else ++# include ++# include ++#endif ++ ++namespace fs = std::filesystem; ++ ++// Minimal definitions to avoid including the nvml.h header ++typedef enum nvmlReturn_enum ++{ ++ // cppcheck-suppress * ++ NVML_SUCCESS = 0, //!< The operation was successful ++ NVML_ERROR_UNINITIALIZED = 1, //!< NVML was not first initialized with nvmlInit() ++ NVML_ERROR_INVALID_ARGUMENT = 2, //!< A supplied argument is invalid ++ NVML_ERROR_NOT_SUPPORTED = 3, //!< The requested operation is not available on target device ++ NVML_ERROR_NO_PERMISSION = 4, //!< The current user does not have permission for operation ++ NVML_ERROR_ALREADY_INITIALIZED = 5, //!< Deprecated: Multiple initializations are now allowed through ref counting ++ NVML_ERROR_NOT_FOUND = 6, //!< A query to find an object was unsuccessful ++ NVML_ERROR_INSUFFICIENT_SIZE = 7, //!< An input argument is not large enough ++ NVML_ERROR_INSUFFICIENT_POWER = 8, //!< A device's external power cables are not properly attached ++ NVML_ERROR_DRIVER_NOT_LOADED = 9, //!< NVIDIA driver is not loaded ++ NVML_ERROR_TIMEOUT = 10, //!< User provided timeout passed ++ NVML_ERROR_IRQ_ISSUE = 11, //!< NVIDIA Kernel detected an interrupt issue with a GPU ++ NVML_ERROR_LIBRARY_NOT_FOUND = 12, //!< NVML Shared Library couldn't be found or loaded ++ NVML_ERROR_FUNCTION_NOT_FOUND = 13, //!< Local version of NVML doesn't implement this function ++ NVML_ERROR_CORRUPTED_INFOROM = 14, //!< infoROM is corrupted ++ NVML_ERROR_GPU_IS_LOST = 15, //!< The GPU has fallen off the bus or has otherwise become inaccessible ++ NVML_ERROR_RESET_REQUIRED = 16, //!< The GPU requires a reset before it can be used again ++ NVML_ERROR_OPERATING_SYSTEM = 17, //!< The GPU control device has been blocked by the operating system/cgroups ++ NVML_ERROR_LIB_RM_VERSION_MISMATCH = 18, //!< RM detects a driver/library version mismatch ++ NVML_ERROR_IN_USE = 19, //!< An operation cannot be performed because the GPU is currently in use ++ NVML_ERROR_MEMORY = 20, //!< Insufficient memory ++ NVML_ERROR_NO_DATA = 21, //!< No data ++ NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22, //!< The requested vgpu operation is not available on target device, becasue ECC is enabled ++ NVML_ERROR_INSUFFICIENT_RESOURCES = 23, //!< Ran out of critical resources, other than memory ++ NVML_ERROR_FREQ_NOT_SUPPORTED = 24, //!< Ran out of critical resources, other than memory ++ NVML_ERROR_ARGUMENT_VERSION_MISMATCH = 25, //!< The provided version is invalid/unsupported ++ NVML_ERROR_DEPRECATED = 26, //!< The requested functionality has been deprecated ++ NVML_ERROR_NOT_READY = 27, //!< The system is not ready for the request ++ NVML_ERROR_GPU_NOT_FOUND = 28, //!< No GPUs were found ++ NVML_ERROR_INVALID_STATE = 29, //!< Resource not in correct state to perform requested operation ++ NVML_ERROR_UNKNOWN = 999 //!< An internal driver error occurred ++} nvmlReturn_t; ++typedef struct nvmlDevice_st* nvmlDevice_t; ++typedef struct nvmlMemory_st ++{ ++ unsigned long long total; //!< Total physical device memory (in bytes) ++ unsigned long long free; //!< Unallocated device memory (in bytes) ++ unsigned long long used; //!< Sum of Reserved and Allocated device memory (in bytes). ++ //!< Note that the driver/GPU always sets aside a small amount of memory for bookkeeping ++} nvmlMemory_t; ++// end nvml.h definitions ++ ++struct { ++ void *handle; ++ nvmlReturn_t (*nvmlInit_v2)(void); ++ nvmlReturn_t (*nvmlShutdown)(void); ++ nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *); ++ nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *); ++ const char * (*nvmlErrorString)(nvmlReturn_t result); ++} nvml { NULL, NULL, NULL, NULL, NULL }; ++static std::mutex ggml_nvml_lock; ++ ++extern "C" { ++ ++int ggml_nvml_init() { ++ std::lock_guard lock(ggml_nvml_lock); ++ if (nvml.handle != NULL) { ++ // Already initialized ++ return 0; ++ } ++#ifdef _WIN32 ++ DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); ++ SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); ++ fs::path libPath[2]; ++ const char * programDir = std::getenv("ProgramW6432"); ++ if (programDir == NULL) { ++ libPath[0] = fs::path("Program Files") / fs::path("NVIDIA Corporation") / fs::path("NVSMI") / fs::path("NVML.dll"); ++ } else { ++ libPath[0] = fs::path(programDir) / fs::path("NVIDIA Corporation") / fs::path("NVSMI") / fs::path("NVML.dll"); ++ } ++ libPath[1] = fs::path("\\Windows") / fs::path("System32") / fs::path("NVML.dll"); ++ ++ for (int i = 0; i < 2; i++) { ++ nvml.handle = (void*)LoadLibraryW(libPath[i].wstring().c_str()); ++ if (nvml.handle != NULL) { ++ break; ++ } ++ } ++ if (nvml.handle == NULL) { ++ return NVML_ERROR_NOT_FOUND; ++ } ++ ++ nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlInit_v2"); ++ nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown"); ++ nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID"); ++ nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo"); ++ nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) GetProcAddress((HMODULE)(nvml.handle), "nvmlErrorString"); ++ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlErrorString == NULL) { ++ GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__); ++ FreeLibrary((HMODULE)(nvml.handle)); ++ nvml.handle = NULL; ++ return NVML_ERROR_NOT_FOUND; ++ } ++ ++ SetErrorMode(old_mode); ++ ++ nvmlReturn_t status = nvml.nvmlInit_v2(); ++ if (status != NVML_SUCCESS) { ++ GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status)); ++ FreeLibrary((HMODULE)(nvml.handle)); ++ nvml.handle = NULL; ++ return status; ++ } ++#else ++ constexpr std::array libPaths = { ++ "/usr/lib/wsl/lib/libnvidia-ml.so.1", // Favor WSL2 path if present ++ "libnvidia-ml.so.1" // On a non-WSL2 system, it should be in the path ++ }; ++ for (const char* path : libPaths) { ++ nvml.handle = dlopen(path, RTLD_LAZY); ++ if (nvml.handle) break; ++ } ++ if (nvml.handle == NULL) { ++ GGML_LOG_INFO("%s unable to load libnvidia-ml: %s\n", __func__, dlerror()); ++ return NVML_ERROR_NOT_FOUND; ++ } ++ nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlInit_v2"); ++ nvml.nvmlShutdown = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlShutdown"); ++ nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) dlsym(nvml.handle, "nvmlDeviceGetHandleByUUID"); ++ nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) dlsym(nvml.handle, "nvmlDeviceGetMemoryInfo"); ++ nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) dlsym(nvml.handle, "nvmlErrorString"); ++ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) { ++ GGML_LOG_INFO("%s unable to locate required symbols in libnvidia-ml.so", __func__); ++ dlclose(nvml.handle); ++ nvml.handle = NULL; ++ return NVML_ERROR_NOT_FOUND; ++ } ++ nvmlReturn_t status = nvml.nvmlInit_v2(); ++ if (status != NVML_SUCCESS) { ++ GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status)); ++ dlclose(nvml.handle); ++ nvml.handle = NULL; ++ return status; ++ } ++#endif ++ return NVML_SUCCESS; ++} ++ ++void ggml_nvml_release() { ++ std::lock_guard lock(ggml_nvml_lock); ++ if (nvml.handle == NULL) { ++ // Already free ++ return; ++ } ++ nvmlReturn_enum status = nvml.nvmlShutdown(); ++ if (status != NVML_SUCCESS) { ++ GGML_LOG_INFO("%s failed to shutdown NVML: %s\n", __func__, nvml.nvmlErrorString(status)); ++ } ++#ifdef _WIN32 ++ FreeLibrary((HMODULE)(nvml.handle)); ++#else ++ dlclose(nvml.handle); ++#endif ++ nvml.handle = NULL; ++} ++ ++int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) { ++ std::lock_guard lock(ggml_nvml_lock); ++ if (nvml.handle == NULL) { ++ return NVML_ERROR_UNINITIALIZED; ++ } ++ nvmlDevice_t device; ++ auto status = nvml.nvmlDeviceGetHandleByUUID(uuid, &device); ++ if (status != NVML_SUCCESS) { ++ return status; ++ } ++ nvmlMemory_t memInfo = {0}; ++ status = nvml.nvmlDeviceGetMemoryInfo(device, &memInfo); ++ if (status == NVML_SUCCESS) { ++ *free = memInfo.free; ++ *total = memInfo.total; ++ } ++ return status; ++} ++ ++} +\ No newline at end of file diff --git a/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch new file mode 100644 index 00000000..997dd386 --- /dev/null +++ b/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch @@ -0,0 +1,95 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Xiaodong Ye +Date: Mon, 18 Aug 2025 12:48:07 +0800 +Subject: [PATCH] vulkan: get GPU ID (ollama v0.11.5) + +Signed-off-by: Xiaodong Ye +--- + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 37 ++++++++++++++++++++++++++++ + 1 file changed, 37 insertions(+) + +diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +index 061cd078..adea7783 100644 +--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp ++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +@@ -11588,6 +11588,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ + snprintf(description, description_size, "%s", props.deviceName.data()); + } + ++static std::string ggml_vk_get_device_id(int device) { ++ ggml_vk_instance_init(); ++ ++ std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); ++ ++ vk::PhysicalDeviceProperties2 props; ++ vk::PhysicalDeviceIDProperties deviceIDProps; ++ props.pNext = &deviceIDProps; ++ devices[device].getProperties2(&props); ++ ++ const auto& uuid = deviceIDProps.deviceUUID; ++ char id[64]; ++ snprintf(id, sizeof(id), ++ "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", ++ uuid[0], uuid[1], uuid[2], uuid[3], ++ uuid[4], uuid[5], ++ uuid[6], uuid[7], ++ uuid[8], uuid[9], ++ uuid[10], uuid[11], uuid[12], uuid[13], uuid[14], uuid[15] ++ ); ++ return std::string(id); ++} ++ + // backend interface + + #define UNUSED GGML_UNUSED +@@ -12394,6 +12417,12 @@ void ggml_backend_vk_get_device_description(int device, char * description, size + ggml_vk_get_device_description(dev_idx, description, description_size); + } + ++std::string ggml_backend_vk_get_device_id(int device) { ++ GGML_ASSERT(device < (int) vk_instance.device_indices.size()); ++ int dev_idx = vk_instance.device_indices[device]; ++ return ggml_vk_get_device_id(dev_idx); ++} ++ + void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); +@@ -12481,6 +12510,7 @@ struct ggml_backend_vk_device_context { + std::string description; + bool is_integrated_gpu; + std::string pci_bus_id; ++ std::string id; + }; + + static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { +@@ -12493,6 +12523,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de + return ctx->description.c_str(); + } + ++static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { ++ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ++ return ctx->id.c_str(); ++} ++ + static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; + ggml_backend_vk_get_device_memory(ctx->device, free, total); +@@ -12519,6 +12554,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml + + props->name = ggml_backend_vk_device_get_name(dev); + props->description = ggml_backend_vk_device_get_description(dev); ++ props->id = ggml_backend_vk_device_get_id(dev); + props->type = ggml_backend_vk_device_get_type(dev); + props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); + ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); +@@ -12965,6 +13001,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, + ctx->description = desc; + ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; + ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); ++ ctx->id = ggml_backend_vk_get_device_id(i); + devices.push_back(new ggml_backend_device { + /* .iface = */ ggml_backend_vk_device_i, + /* .reg = */ reg, +-- +2.51.0 \ No newline at end of file diff --git a/llama/patches/0028-vulkan-pci-and-memory.patch b/llama/patches/0028-vulkan-pci-and-memory.patch new file mode 100644 index 00000000..c20ccf5c --- /dev/null +++ b/llama/patches/0028-vulkan-pci-and-memory.patch @@ -0,0 +1,254 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Daniel Hiltgen +Date: Fri Sep 5 08:25:03 2025 -0700 +Subject: [PATCH] Vulkan PCI and Memory + +--- + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 176 ++++++++++++++++++++++----- + 1 file changed, 145 insertions(+), 31 deletions(-) + +diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +index adea7783..fb7204ce 100644 +--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp ++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +@@ -12423,31 +12423,99 @@ std::string ggml_backend_vk_get_device_id(int device) { + return ggml_vk_get_device_id(dev_idx); + } + +-void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { +- GGML_ASSERT(device < (int) vk_instance.device_indices.size()); +- GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); ++////////////////////////// ++ ++struct ggml_backend_vk_device_context { ++ size_t device; ++ std::string name; ++ std::string description; ++ bool is_integrated_gpu; ++ // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function) ++ std::string pci_id; ++ std::string id; ++ std::string uuid; ++ int major; ++ int minor; ++ int driver_major; ++ int driver_minor; ++ int pci_bus_id; ++ int pci_device_id; ++ int pci_domain_id; ++}; ++ ++void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) { ++ GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size()); ++ GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size()); ++ ++ vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; + +- vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; +- vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops; +- vk::PhysicalDeviceMemoryProperties2 memprops = {}; +- bool membudget_supported = vk_instance.device_supports_membudget[device]; ++ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); ++ vk::PhysicalDeviceProperties2 props2; ++ vkdev.getProperties2(&props2); + +- if (membudget_supported) { +- memprops.pNext = &budgetprops; ++ if (!ctx->is_integrated_gpu) ++ { ++ // Use vendor specific management libraries for best VRAM reporting if available ++ switch (props2.properties.vendorID) { ++ case VK_VENDOR_ID_AMD: ++ if (ggml_hip_mgmt_init() == 0) { ++ int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); ++ if (status == 0) { ++ GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ ggml_hip_mgmt_release(); ++ return; ++ } ++ ggml_hip_mgmt_release(); ++ } ++ break; ++ case VK_VENDOR_ID_NVIDIA: ++ if (ggml_nvml_init() == 0) { ++ int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total); ++ if (status == 0) { ++ GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ ggml_nvml_release(); ++ return; ++ } ++ ggml_nvml_release(); ++ } ++ break; ++ } + } +- vkdev.getMemoryProperties2(&memprops); ++ // else fallback to memory budget if supported + +- for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) { +- const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i]; ++ *total = 0; ++ *free = 0; ++ vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props; ++ vk::PhysicalDeviceMemoryProperties2 memprops2; ++ memprops2.pNext = &mem_budget_props; ++ vkdev.getMemoryProperties2(&memprops2); ++ for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { ++ if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { ++ *total += memprops2.memoryProperties.memoryHeaps[i].size; ++ } else if (ctx->is_integrated_gpu) { ++ // Include shared memory on iGPUs ++ *total += memprops2.memoryProperties.memoryHeaps[i].size; ++ } ++ } ++ for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { ++ if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { ++ *free += mem_budget_props.heapBudget[i]; ++ } else if (ctx->is_integrated_gpu) { ++ *free += mem_budget_props.heapBudget[i]; ++ } ++ } ++ if (*total > 0 && *free > 0) { ++ return; ++ } else if (*total > 0) { ++ *free = *total; ++ return; ++ } + ++ // else just report the physical memory ++ for (const vk::MemoryHeap& heap : memprops2.memoryProperties.memoryHeaps) { + if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total = heap.size; +- +- if (membudget_supported && i < budgetprops.heapUsage.size()) { +- *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i]; +- } else { +- *free = heap.size; +- } ++ *free = heap.size; + break; + } + } +@@ -12502,16 +12570,17 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { + return std::string(pci_bus_id); + } + +-////////////////////////// +- +-struct ggml_backend_vk_device_context { +- size_t device; +- std::string name; +- std::string description; +- bool is_integrated_gpu; +- std::string pci_bus_id; +- std::string id; +-}; ++static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) { ++ if (id.empty()) return false; ++ unsigned int d = 0, b = 0, dev = 0, func = 0; ++ // Expected format: dddd:bb:dd.f (all hex) ++ int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func); ++ if (n < 4) return false; ++ if (domain) *domain = (int) d; ++ if (bus) *bus = (int) b; ++ if (device) *device = (int) dev; ++ return true; ++} + + static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; +@@ -12530,7 +12599,7 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { + + static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; +- ggml_backend_vk_get_device_memory(ctx->device, free, total); ++ ggml_backend_vk_get_device_memory(ctx, free, total); + } + + static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { +@@ -12556,7 +12625,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml + props->description = ggml_backend_vk_device_get_description(dev); + props->id = ggml_backend_vk_device_get_id(dev); + props->type = ggml_backend_vk_device_get_type(dev); +- props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ++ props->device_id = ctx->pci_id.empty() ? nullptr : ctx->pci_id.c_str(); + ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, +@@ -12564,6 +12633,17 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; ++ ++ props->compute_major = ctx->major; ++ props->compute_minor = ctx->minor; ++ props->driver_major = ctx->driver_major; ++ props->driver_minor = ctx->driver_minor; ++ props->integrated = ctx->is_integrated_gpu; ++ props->pci_bus_id = ctx->pci_bus_id; ++ props->pci_device_id = ctx->pci_device_id; ++ props->pci_domain_id = ctx->pci_domain_id; ++ props->library = GGML_VK_NAME; ++ props->numeric_id = ctx->id.empty() ? nullptr : ctx->id.c_str(); + } + + static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { +@@ -12992,6 +13071,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { ++ std::vector vk_devices = vk_instance.instance.enumeratePhysicalDevices(); ++ + for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { + ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; + char desc[256]; +@@ -13000,13 +13081,46 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, + ctx->name = GGML_VK_NAME + std::to_string(i); + ctx->description = desc; + ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; +- ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); ++ ctx->pci_id = ggml_backend_vk_get_device_pci_id(i); + ctx->id = ggml_backend_vk_get_device_id(i); + devices.push_back(new ggml_backend_device { + /* .iface = */ ggml_backend_vk_device_i, + /* .reg = */ reg, + /* .context = */ ctx, + }); ++ ++ // Gather additional information about the device ++ int dev_idx = vk_instance.device_indices[i]; ++ vk::PhysicalDeviceProperties props1; ++ vk_devices[dev_idx].getProperties(&props1); ++ vk::PhysicalDeviceProperties2 props2; ++ vk::PhysicalDeviceIDProperties device_id_props; ++ vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_props; ++ vk::PhysicalDeviceDriverProperties driver_props; ++ props2.pNext = &device_id_props; ++ device_id_props.pNext = &pci_bus_props; ++ pci_bus_props.pNext = &driver_props; ++ vk_devices[dev_idx].getProperties2(&props2); ++ std::ostringstream oss; ++ oss << std::hex << std::setfill('0'); ++ oss << "GPU-"; ++ int byteIdx = 0; ++ for (int i = 0; i < 16; ++i, ++byteIdx) { ++ oss << std::setw(2) << static_cast(device_id_props.deviceUUID[i]); ++ if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) { ++ oss << '-'; ++ } ++ } ++ ctx->uuid = oss.str(); ++ ctx->pci_bus_id = pci_bus_props.pciBus; ++ ctx->pci_device_id = pci_bus_props.pciDevice; ++ ctx->pci_domain_id = pci_bus_props.pciDomain; ++ ctx->id = std::to_string(i); ++ ctx->major = 0; ++ ctx->minor = 0; ++ // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string ++ ctx->driver_major = 0; ++ ctx->driver_minor = 0; + } + initialized = true; + } +-- +2.51.0 \ No newline at end of file diff --git a/llama/patches/0029-NVML-fallback-for-unified-memory-GPUs.patch b/llama/patches/0029-NVML-fallback-for-unified-memory-GPUs.patch new file mode 100644 index 00000000..9ba11168 --- /dev/null +++ b/llama/patches/0029-NVML-fallback-for-unified-memory-GPUs.patch @@ -0,0 +1,137 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Santosh Bhavani +Date: Wed, 15 Oct 2025 09:29:51 -0700 +Subject: [PATCH] NVML fallback for unified memory GPUs + +--- + ggml/src/mem_nvml.cpp | 71 +++++++++++++++++++++++++++++++++++++++++-- + 1 file changed, 68 insertions(+), 3 deletions(-) + +diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp +index c9073cef..f473a2a2 100644 +--- a/ggml/src/mem_nvml.cpp ++++ b/ggml/src/mem_nvml.cpp +@@ -13,6 +13,7 @@ + #include + #include + #include ++#include + + #ifdef _WIN32 + # define WIN32_LEAN_AND_MEAN +@@ -23,6 +24,8 @@ + #else + # include + # include ++# include ++# include + #endif + + namespace fs = std::filesystem; +@@ -79,12 +82,36 @@ struct { + nvmlReturn_t (*nvmlShutdown)(void); + nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *); + nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *); ++ nvmlReturn_t (*nvmlDeviceGetName)(nvmlDevice_t, char *, unsigned int); + const char * (*nvmlErrorString)(nvmlReturn_t result); +-} nvml { NULL, NULL, NULL, NULL, NULL }; ++} nvml { NULL, NULL, NULL, NULL, NULL, NULL, NULL }; + static std::mutex ggml_nvml_lock; + + extern "C" { + ++#ifndef _WIN32 ++// Helper function to get available memory from /proc/meminfo on Linux ++// Returns MemAvailable as calculated by the kernel ++static size_t get_mem_available() { ++ std::ifstream meminfo("/proc/meminfo"); ++ if (!meminfo.is_open()) { ++ return 0; ++ } ++ ++ std::string line; ++ while (std::getline(meminfo, line)) { ++ if (line.find("MemAvailable:") == 0) { ++ size_t available_kb; ++ sscanf(line.c_str(), "MemAvailable: %zu kB", &available_kb); ++ // Convert from kB to bytes ++ return available_kb * 1024; ++ } ++ } ++ ++ return 0; ++} ++#endif ++ + int ggml_nvml_init() { + std::lock_guard lock(ggml_nvml_lock); + if (nvml.handle != NULL) { +@@ -117,8 +144,9 @@ int ggml_nvml_init() { + nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown"); + nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID"); + nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo"); ++ nvml.nvmlDeviceGetName = (nvmlReturn_t (*)(nvmlDevice_t, char *, unsigned int)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetName"); + nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) GetProcAddress((HMODULE)(nvml.handle), "nvmlErrorString"); +- if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlErrorString == NULL) { ++ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlDeviceGetName == NULL || nvml.nvmlErrorString == NULL) { + GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__); + FreeLibrary((HMODULE)(nvml.handle)); + nvml.handle = NULL; +@@ -151,8 +179,9 @@ int ggml_nvml_init() { + nvml.nvmlShutdown = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlShutdown"); + nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) dlsym(nvml.handle, "nvmlDeviceGetHandleByUUID"); + nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) dlsym(nvml.handle, "nvmlDeviceGetMemoryInfo"); ++ nvml.nvmlDeviceGetName = (nvmlReturn_t (*)(nvmlDevice_t, char *, unsigned int)) dlsym(nvml.handle, "nvmlDeviceGetName"); + nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) dlsym(nvml.handle, "nvmlErrorString"); +- if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) { ++ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlDeviceGetName == NULL) { + GGML_LOG_INFO("%s unable to locate required symbols in libnvidia-ml.so", __func__); + dlclose(nvml.handle); + nvml.handle = NULL; +@@ -199,10 +228,46 @@ int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) { + } + nvmlMemory_t memInfo = {0}; + status = nvml.nvmlDeviceGetMemoryInfo(device, &memInfo); ++ + if (status == NVML_SUCCESS) { ++ // NVML working correctly, use its values + *free = memInfo.free; + *total = memInfo.total; ++ return NVML_SUCCESS; + } ++ ++#ifndef _WIN32 ++ // Handle NVML_ERROR_NOT_SUPPORTED - this indicates NVML doesn't support ++ // reporting framebuffer memory (e.g., unified memory GPUs where FB memory is 0) ++ if (status == NVML_ERROR_NOT_SUPPORTED) { ++ // Use system memory from /proc/meminfo ++ size_t mem_available = get_mem_available(); ++ size_t mem_total = 0; ++ ++ // Read MemTotal ++ std::ifstream meminfo("/proc/meminfo"); ++ if (meminfo.is_open()) { ++ std::string line; ++ while (std::getline(meminfo, line)) { ++ if (line.find("MemTotal:") == 0) { ++ size_t total_kb; ++ sscanf(line.c_str(), "MemTotal: %zu kB", &total_kb); ++ mem_total = total_kb * 1024; ++ break; ++ } ++ } ++ } ++ ++ if (mem_total > 0) { ++ *total = mem_total; ++ *free = mem_available; ++ GGML_LOG_INFO("%s NVML not supported for memory query, using system memory (total=%zu, available=%zu)\n", ++ __func__, mem_total, mem_available); ++ return NVML_SUCCESS; ++ } ++ } ++#endif ++ + return status; + } + diff --git a/llama/patches/0030-CUDA-Changing-the-CUDA-scheduling-strategy-to-spin-1.patch b/llama/patches/0030-CUDA-Changing-the-CUDA-scheduling-strategy-to-spin-1.patch new file mode 100644 index 00000000..c3c7fedf --- /dev/null +++ b/llama/patches/0030-CUDA-Changing-the-CUDA-scheduling-strategy-to-spin-1.patch @@ -0,0 +1,49 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Julius Tischbein +Date: Wed, 15 Oct 2025 13:54:15 +0200 +Subject: [PATCH] CUDA: Changing the CUDA scheduling strategy to spin (#16585) +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 +Content-Transfer-Encoding: 8bit + +* CUDA set scheduling strategy to spinning for cc121 + +* Using prop.major and prop.minor, include HIP and MUSA + +* Exclude HIP and MUSA + +* Remove trailing whitespace + +Co-authored-by: Johannes Gäßler + +* Remove empty line + +Co-authored-by: Johannes Gäßler + +--------- + +Co-authored-by: Johannes Gäßler +--- + ggml/src/ggml-cuda/ggml-cuda.cu | 9 +++++++++ + 1 file changed, 9 insertions(+) + +diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu +index 6a278b5e9..87941f872 100644 +--- a/ggml/src/ggml-cuda/ggml-cuda.cu ++++ b/ggml/src/ggml-cuda/ggml-cuda.cu +@@ -340,6 +340,15 @@ static ggml_cuda_device_info ggml_cuda_init() { + } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") { + turing_devices_without_mma.push_back({ id, device_name }); + } ++ ++ // Temporary performance fix: ++ // Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls. ++ // TODO: Check for future drivers the default scheduling strategy and ++ // remove this call again when cudaDeviceScheduleSpin is default. ++ if (prop.major == 12 && prop.minor == 1) { ++ CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin)); ++ } ++ + #endif // defined(GGML_USE_HIP) + } + diff --git a/llama/patches/0031-report-LoadLibrary-failures.patch b/llama/patches/0031-report-LoadLibrary-failures.patch new file mode 100644 index 00000000..f537f6e2 --- /dev/null +++ b/llama/patches/0031-report-LoadLibrary-failures.patch @@ -0,0 +1,32 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Daniel Hiltgen +Date: Fri, 17 Oct 2025 14:17:00 -0700 +Subject: [PATCH] report LoadLibrary failures + +--- + ggml/src/ggml-backend-reg.cpp | 12 ++++++++++++ + 1 file changed, 12 insertions(+) + +diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp +index f794d9cfa..3a855ab2e 100644 +--- a/ggml/src/ggml-backend-reg.cpp ++++ b/ggml/src/ggml-backend-reg.cpp +@@ -118,6 +118,18 @@ static dl_handle * dl_load_library(const fs::path & path) { + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + HMODULE handle = LoadLibraryW(path.wstring().c_str()); ++ if (!handle) { ++ DWORD error_code = GetLastError(); ++ std::string msg; ++ LPSTR lpMsgBuf = NULL; ++ DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, ++ NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL); ++ if (bufLen) { ++ msg = lpMsgBuf; ++ LocalFree(lpMsgBuf); ++ GGML_LOG_INFO("%s unable to load library %s: %s\n", __func__, path_str(path).c_str(), msg.c_str()); ++ } ++ } + + SetErrorMode(old_mode); + diff --git a/llm/memory.go b/llm/memory.go index 7a87b28f..aa4927f1 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -195,8 +195,8 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin slog.Warn("model missing blk.0 layer size") } - useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) && - discover.GetGPUInfo().FlashAttentionSupported() && + useFlashAttention := envconfig.FlashAttention(f.FlashAttention()) && + (discover.GpuInfoList)(gpus).FlashAttentionSupported() && f.SupportsFlashAttention() var kvct string @@ -231,7 +231,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin } // on metal there's no partial offload overhead - if gpus[0].Library == "metal" { + if gpus[0].Library == "Metal" { graphPartialOffload = graphFullOffload } else if len(gpus) > 1 { // multigpu should always use the partial graph size @@ -266,11 +266,18 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin } // Only include GPUs that can fit the graph, gpu minimum, the layer buffer and at least more layer if gpus[i].FreeMemory < overhead+gzo+max(graphPartialOffload, graphFullOffload)+gpus[i].MinimumMemory+2*layerSize { + var compute string + if gpus[i].Library == "ROCm" { + compute = fmt.Sprintf("gfx%x%02x", gpus[i].ComputeMajor, gpus[i].ComputeMinor) + } else { + compute = fmt.Sprintf("%d.%d", gpus[i].ComputeMajor, gpus[i].ComputeMinor) + } + slog.Debug("gpu has too little memory to allocate any layers", "id", gpus[i].ID, "library", gpus[i].Library, "variant", gpus[i].Variant, - "compute", gpus[i].Compute, + "compute", compute, "driver", fmt.Sprintf("%d.%d", gpus[i].DriverMajor, gpus[i].DriverMinor), "name", gpus[i].Name, "total", format.HumanBytes2(gpus[i].TotalMemory), diff --git a/llm/memory_test.go b/llm/memory_test.go index 49851006..553214b9 100644 --- a/llm/memory_test.go +++ b/llm/memory_test.go @@ -12,6 +12,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/discover" "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/ml" ) func TestEstimateGPULayers(t *testing.T) { @@ -55,7 +56,9 @@ func TestEstimateGPULayers(t *testing.T) { // Simple CPU scenario gpus := []discover.GpuInfo{ { - Library: "cpu", + DeviceID: ml.DeviceID{ + Library: "cpu", + }, }, } projectors := []string{} @@ -77,11 +80,15 @@ func TestEstimateGPULayers(t *testing.T) { gpuMinimumMemory := uint64(2048) gpus = []discover.GpuInfo{ { - Library: "cuda", + DeviceID: ml.DeviceID{ + Library: "cuda", + }, MinimumMemory: gpuMinimumMemory, }, { - Library: "cuda", + DeviceID: ml.DeviceID{ + Library: "cuda", + }, MinimumMemory: gpuMinimumMemory, }, } diff --git a/llm/server.go b/llm/server.go index 75f049bc..6ba8f8d2 100644 --- a/llm/server.go +++ b/llm/server.go @@ -66,7 +66,7 @@ func (e filteredEnv) LogValue() slog.Value { type LlamaServer interface { ModelPath() string - Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error + Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) ([]ml.DeviceID, error) Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error @@ -76,8 +76,11 @@ type LlamaServer interface { Close() error VRAMSize() uint64 // Total VRAM across all GPUs TotalSize() uint64 - VRAMByGPU(gpuID string) uint64 + VRAMByGPU(id ml.DeviceID) uint64 Pid() int + GetPort() int + GetDeviceInfos(ctx context.Context) []ml.DeviceInfo + HasExited() bool } // llmServer is an instance of a runner hosting a single model @@ -193,14 +196,10 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a loadRequest.ProjectorPath = projectors[0] } + fa := envconfig.FlashAttention(f.FlashAttention()) + // This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset // that can handle it. - fa := envconfig.FlashAttention() - if f.FlashAttention() { - slog.Info("model wants flash attention") - fa = true - } - if fa && !gpus.FlashAttentionSupported() { slog.Warn("flash attention enabled but not supported by gpu") fa = false @@ -331,6 +330,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a if gpu.DependencyPath != nil { slog.Debug("adding gpu dependency paths", "paths", gpu.DependencyPath) libraryPaths = append(gpu.DependencyPath, libraryPaths...) + ggmlPaths = append(ggmlPaths, gpu.DependencyPath...) } } @@ -359,24 +359,22 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a s.cmd.Stderr = s.status s.cmd.SysProcAttr = LlamaServerSysProcAttr - s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator))) - - envWorkarounds := []string{} - for _, gpu := range gpus { - envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...) - } // Always filter down the set of GPUs in case there are any unsupported devices that might crash - envWorkarounds = append(envWorkarounds, gpus.GetVisibleDevicesEnv()...) + envWorkarounds := gpus.GetVisibleDevicesEnv() pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) // Update or add the path variable with our adjusted version pathNeeded := true + ollamaPathNeeded := true envWorkaroundDone := make([]bool, len(envWorkarounds)) for i := range s.cmd.Env { cmp := strings.SplitN(s.cmd.Env[i], "=", 2) if strings.EqualFold(cmp[0], pathEnv) { s.cmd.Env[i] = pathEnv + "=" + pathEnvVal pathNeeded = false + } else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") { + s.cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(ggmlPaths, string(filepath.ListSeparator)) + ollamaPathNeeded = false } else if len(envWorkarounds) != 0 { for j, kv := range envWorkarounds { tmp := strings.SplitN(kv, "=", 2) @@ -390,6 +388,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a if pathNeeded { s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal) } + if ollamaPathNeeded { + s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator))) + } for i, done := range envWorkaroundDone { if !done { s.cmd.Env = append(s.cmd.Env, envWorkarounds[i]) @@ -496,7 +497,7 @@ type LoadResponse struct { var ErrLoadRequiredFull = errors.New("unable to load full model on GPU") -func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error { +func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) ([]ml.DeviceID, error) { systemInfo := discover.GetSystemInfo() systemTotalMemory := systemInfo.System.TotalMemory systemFreeMemory := systemInfo.System.FreeMemory @@ -509,7 +510,7 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel) } else { slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate) - return ErrLoadRequiredFull + return nil, ErrLoadRequiredFull } } @@ -518,13 +519,13 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi if len(gpus) > 1 || gpus[0].Library != "cpu" { switch { - case gpus[0].Library == "metal" && s.estimate.VRAMSize > systemInfo.System.TotalMemory: + case gpus[0].Library == "Metal" && s.estimate.VRAMSize > systemInfo.System.TotalMemory: // disable partial offloading when model is greater than total system memory as this // can lead to locking up the system s.options.NumGPU = 0 - case gpus[0].Library != "metal" && s.estimate.Layers == 0: + case gpus[0].Library != "Metal" && s.estimate.Layers == 0: // Don't bother loading into the GPU if no layers can fit - gpus = discover.GetCPUInfo() + gpus = discover.GpuInfoList{discover.GetCPUInfo()} case s.options.NumGPU < 0 && s.estimate.Layers > 0 && gpus[0].Library != "cpu": s.options.NumGPU = s.estimate.Layers } @@ -537,7 +538,7 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi available := systemInfo.System.FreeMemory + systemInfo.System.FreeSwap if systemMemoryRequired > available { slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.System.TotalMemory), "free", format.HumanBytes2(systemInfo.System.FreeMemory), "swap", format.HumanBytes2(systemInfo.System.FreeSwap)) - return fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available)) + return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available)) } } @@ -552,7 +553,7 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi // mmap has issues with partial offloading on metal for _, g := range gpus { - if g.Library == "metal" && + if g.Library == "Metal" && uint64(s.options.NumGPU) > 0 && uint64(s.options.NumGPU) < s.ggml.KV().BlockCount()+1 { s.options.UseMMap = new(bool) @@ -563,21 +564,22 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi // Windows CUDA should not use mmap for best performance // Linux with a model larger than free space, mmap leads to thrashing // For CPU loads we want the memory to be allocated, not FS cache - if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && s.options.UseMMap == nil) || + if (runtime.GOOS == "windows" && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) || (runtime.GOOS == "linux" && systemInfo.System.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) || (gpus[0].Library == "cpu" && s.options.UseMMap == nil) || + (gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) || (s.options.UseMMap != nil && !*s.options.UseMMap) { s.loadRequest.UseMmap = false } } if err := s.waitUntilRunnerLaunched(ctx); err != nil { - return err + return nil, err } resp, err := s.initModel(ctx, s.loadRequest, LoadOperationCommit) if err != nil { - return err + return nil, err } // On the Ollama engine, we can print out a summary of the memory allocations. @@ -588,16 +590,16 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi if !resp.Success { slog.Warn("failed to allocate memory for model", "memory", resp.Memory) - return errors.New("failed to allocate memory for model") + return nil, errors.New("failed to allocate memory for model") } // The llama engine does its memory allocations together with model loading, so we // need to wait until it is done to ensure that we have accurate memory data before // loading the next model if s.textProcessor == nil { - return s.WaitUntilRunning(ctx) + return uniqueDeviceIDs(s.loadRequest.GPULayers), s.WaitUntilRunning(ctx) } else { - return nil + return uniqueDeviceIDs(s.loadRequest.GPULayers), nil } } @@ -610,7 +612,7 @@ func createGPULayers(estimate MemoryEstimate, ggml *ggml.GGML, gpus discover.Gpu gpuLayers := make(ml.GPULayersList, len(gpus)) for i := range gpuLayers { - gpuLayers[i].ID = gpus[i].ID + gpuLayers[i].DeviceID = gpus[i].DeviceID } var sum float32 @@ -658,7 +660,9 @@ func createGPULayers(estimate MemoryEstimate, ggml *ggml.GGML, gpus discover.Gpu // // This process is repeated for higher levels of loading the model (fit, allocate, commit). The earlier levels are quicker, // allowing for faster iteration, but may return less information. -func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error { +// +// Returns the list of GPU IDs that were used in the final allocation on success +func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) ([]ml.DeviceID, error) { var success bool defer func() { if !success { @@ -683,7 +687,7 @@ func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requ if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory { available = 0 } - slog.Info("gpu memory", "id", gpu.ID, + slog.Info("gpu memory", "id", gpu.ID, "library", gpu.Library, "available", format.HumanBytes2(available), "free", format.HumanBytes2(gpu.FreeMemory), "minimum", format.HumanBytes2(gpu.MinimumMemory), @@ -696,11 +700,11 @@ func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requ gpuLayers, err := s.createLayout(systemInfo, gpus, s.mem, requireFull, backoff) if err != nil { - return err + return nil, err } if err := s.waitUntilRunnerLaunched(ctx); err != nil { - return err + return nil, err } nextOperation: @@ -710,7 +714,7 @@ nextOperation: s.loadRequest.GPULayers = gpuLayers resp, err := s.initModel(ctx, s.loadRequest, operation) if err != nil { - return err + return nil, err } resp.Memory.Log(slog.LevelDebug) @@ -722,7 +726,7 @@ nextOperation: for { newGPULayers, err := s.createLayout(systemInfo, gpus, s.mem, requireFull, backoff) if err != nil { - return err + return nil, err } slog.Debug("new layout created", "layers", newGPULayers) @@ -756,7 +760,7 @@ nextOperation: newGPULayers, err = s.createLayout(systemInfo, gpus, s.mem, requireFull, backoff) s.options.NumGPU = -1 if err != nil { - return err + return nil, err } slog.Debug("new layout created", "layers", newGPULayers) @@ -764,7 +768,7 @@ nextOperation: s.loadRequest.GPULayers = newGPULayers resp, err = s.initModel(ctx, s.loadRequest, operation) if err != nil { - return err + return nil, err } resp.Memory.Log(slog.LevelDebug) @@ -773,7 +777,7 @@ nextOperation: if resp.Success { verifyGPULayers, err := s.createLayout(systemInfo, gpus, &resp.Memory, requireFull, backoff) if err != nil { - return err + return nil, err } slog.Debug("verifying layout", "layers", verifyGPULayers) @@ -798,7 +802,7 @@ nextOperation: } if s.options.NumGPU >= 0 { - return fmt.Errorf("memory layout cannot be allocated with num_gpu = %v", s.options.NumGPU) + return nil, fmt.Errorf("memory layout cannot be allocated with num_gpu = %v", s.options.NumGPU) } // Memory allocation failed even though we created a layout that we thought should @@ -808,7 +812,7 @@ nextOperation: // space. if backoff > 1 { slog.Warn("memory layout cannot be allocated", "memory", resp.Memory) - return errors.New("memory layout cannot be allocated") + return nil, errors.New("memory layout cannot be allocated") } else if backoff == 0 { backoff = 0.01 } else { @@ -823,7 +827,7 @@ nextOperation: s.loadRequest.GPULayers = gpuLayers resp, err := s.initModel(ctx, s.loadRequest, LoadOperationCommit) if err != nil { - return err + return nil, err } success = resp.Success @@ -831,10 +835,27 @@ nextOperation: if !success { slog.Warn("failed to commit memory for model", "memory", resp.Memory) - return errors.New("failed to commit memory for model") + return nil, errors.New("failed to commit memory for model") } - return nil + return uniqueDeviceIDs(gpuLayers), nil +} + +func uniqueDeviceIDs(gpuLayers ml.GPULayersList) []ml.DeviceID { + devices := []ml.DeviceID{} + for _, layer := range gpuLayers { + new := true + for _, ID := range devices { + if layer.DeviceID == ID { + new = false + break + } + } + if new { + devices = append(devices, layer.DeviceID) + } + } + return devices } // createLayout uses the current best view of memory requirements and creates a layout of model layers on GPUs. @@ -853,19 +874,19 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d if memory == nil { memory = &ml.BackendMemory{CPU: ml.DeviceMemory{ - Weights: make([]ml.Memory, s.totalLayers), - Cache: make([]ml.Memory, s.totalLayers), + Weights: make([]uint64, s.totalLayers), + Cache: make([]uint64, s.totalLayers), }} } layers := make([]uint64, len(memory.CPU.Weights)) for i := range layers { for j := range memory.GPUs { - layers[i] += memory.GPUs[j].Weights[i].Size - layers[i] += memory.GPUs[j].Cache[i].Size + layers[i] += memory.GPUs[j].Weights[i] + layers[i] += memory.GPUs[j].Cache[i] } - layers[i] += memory.CPU.Weights[i].Size - layers[i] += memory.CPU.Cache[i].Size + layers[i] += memory.CPU.Weights[i] + layers[i] += memory.CPU.Cache[i] logutil.Trace("layer to assign", "layer", i, "size", format.HumanBytes2(layers[i])) } @@ -879,23 +900,23 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d for i := range gl { found := false for j := range memory.GPUs { - if gl[i].ID == memory.GPUs[j].ID { - if memory.GPUs[j].Graph.Size != 0 { + if gl[i].DeviceID == memory.GPUs[j].DeviceID { + if memory.GPUs[j].Graph != 0 { lastUsedGPU = i } - reserved := uint64(float32(gl[i].FreeMemory)*backoff) + gl[i].MinimumMemory + envconfig.GpuOverhead() + memory.GPUs[j].Graph.Size + reserved := uint64(float32(gl[i].FreeMemory)*backoff) + gl[i].MinimumMemory + envconfig.GpuOverhead() + memory.GPUs[j].Graph if gl[i].FreeMemory > reserved { gl[i].FreeMemory -= reserved } else { gl[i].FreeMemory = 0 } - slog.Debug("available gpu", "id", gl[i].ID, + slog.Debug("available gpu", "id", gl[i].ID, "library", gl[i].Library, "available layer vram", format.HumanBytes2(gl[i].FreeMemory), "backoff", fmt.Sprintf("%.2f", backoff), "minimum", format.HumanBytes2(gl[i].MinimumMemory), "overhead", format.HumanBytes2(envconfig.GpuOverhead()), - "graph", format.HumanBytes2(memory.GPUs[j].Graph.Size)) + "graph", format.HumanBytes2(memory.GPUs[j].Graph)) found = true break @@ -907,19 +928,19 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d } } - libraryGpuLayers := assignLayers(layers, gl, s.options.NumGPU, lastUsedGPU) + libraryGpuLayers := assignLayers(layers, gl, requireFull, s.options.NumGPU, lastUsedGPU) if libraryGpuLayers.Sum() > gpuLayers.Sum() { gpuLayers = libraryGpuLayers } } // These sizes will only increase as we go through additional iterations and get additional information. - cpuSize := memory.InputWeights.Size + memory.CPU.Graph.Size + cpuSize := memory.InputWeights + memory.CPU.Graph var vramSize uint64 for _, gl := range gpuLayers { for _, gpu := range memory.GPUs { - if gl.ID == gpu.ID { - vramSize += gpu.Graph.Size + if gl.DeviceID == gpu.DeviceID { + vramSize += gpu.Graph break } } @@ -973,7 +994,7 @@ nextLayer: } // assignLayers packs the maximum number of layers onto the smallest set of GPUs and comes up with a layer assignment -func assignLayers(layers []uint64, gpus discover.GpuInfoList, requestedLayers int, lastUsedGPU int) (gpuLayers ml.GPULayersList) { +func assignLayers(layers []uint64, gpus discover.GpuInfoList, requireFull bool, requestedLayers int, lastUsedGPU int) (gpuLayers ml.GPULayersList) { // If we can't fit everything then prefer offloading layers other than the output layer for range 2 { // requestedLayers may be -1 if nothing was requested @@ -982,14 +1003,14 @@ func assignLayers(layers []uint64, gpus discover.GpuInfoList, requestedLayers in if !envconfig.SchedSpread() { for i := lastUsedGPU; i < len(gpus); i++ { // Try to pack things into as few GPUs as possible - forceRequest := i == len(gpus)-1 + forceRequest := i == len(gpus)-1 && !requireFull gpuLayers = findBestFit(layers, gpus[:i+1], requestedLayers, forceRequest) if gpuLayers.Sum() == len(layers) || gpuLayers.Sum() == requestedLayers { break } } } else { - gpuLayers = findBestFit(layers, gpus, requestedLayers, true) + gpuLayers = findBestFit(layers, gpus, requestedLayers, !requireFull) } // We only stop if we've gotten all of the layers - even if we got requestedLayers, we still @@ -1039,7 +1060,7 @@ func findBestFit(layers []uint64, gpus discover.GpuInfoList, requestedLayers int // greedyFit assigns layers incrementally to GPUs, spilling over as each runs out of free space func greedyFit(layers []uint64, gpus discover.GpuInfoList, capacity float32, requestedLayers int) (gpuLayers ml.GPULayersList) { device := len(gpus) - 1 - gpuLayers = ml.GPULayersList{{ID: gpus[device].ID}} + gpuLayers = ml.GPULayersList{{DeviceID: gpus[device].DeviceID}} freeSpace := uint64(float32(gpus[device].FreeMemory) * capacity) for i := len(layers) - 1; i >= 0; i-- { if requestedLayers >= 0 && len(layers)-1-i >= requestedLayers { @@ -1057,7 +1078,7 @@ func greedyFit(layers []uint64, gpus discover.GpuInfoList, capacity float32, req if device < 0 { return gpuLayers } - gpuLayers = append(ml.GPULayersList{{ID: gpus[device].ID}}, gpuLayers...) + gpuLayers = append(ml.GPULayersList{{DeviceID: gpus[device].DeviceID}}, gpuLayers...) freeSpace = uint64(float32(gpus[device].FreeMemory) * capacity) } } @@ -1312,6 +1333,17 @@ func (s *llmServer) Pid() int { return -1 } +func (s *llmServer) GetPort() int { + return s.port +} + +func (s *llmServer) HasExited() bool { + if s.cmd != nil && s.cmd.ProcessState != nil && s.cmd.ProcessState.ExitCode() >= 0 { + return true + } + return false +} + var grammarJSON = ` root ::= object value ::= object | array | string | number | ("true" | "false" | "null") ws @@ -1348,7 +1380,9 @@ type CompletionRequest struct { Images []ImageData Options *api.Options - Grammar string // set before sending the request to the subprocess + Grammar string // set before sending the request to the subprocess + Shift bool + Truncate bool } // DoneReason represents the reason why a completion response is done @@ -1386,7 +1420,7 @@ type CompletionResponse struct { func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { slog.Debug("completion request", "images", len(req.Images), "prompt", len(req.Prompt), "format", string(req.Format)) - slog.Log(ctx, logutil.LevelTrace, "completion request", "prompt", req.Prompt) + logutil.Trace("completion request", "prompt", req.Prompt) if len(req.Format) > 0 { switch string(req.Format) { @@ -1455,7 +1489,10 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu serverReq.Header.Set("Content-Type", "application/json") res, err := http.DefaultClient.Do(serverReq) - if err != nil { + if err != nil && errors.Is(err, context.Canceled) { + // client closed connection + return err + } else if err != nil { slog.Error("post predict", "error", err) return errors.New("model runner has unexpectedly stopped, this may be due to resource limitations or an internal error, check ollama server logs for details") } @@ -1467,7 +1504,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return fmt.Errorf("failed reading llm error response: %w", err) } log.Printf("llm predict error: %s", bodyBytes) - return fmt.Errorf("%s", bodyBytes) + return api.StatusError{StatusCode: res.StatusCode, ErrorMessage: strings.TrimSpace(string(bodyBytes))} } scanner := bufio.NewScanner(res.Body) @@ -1552,7 +1589,7 @@ type EmbeddingResponse struct { } func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) { - slog.Log(ctx, logutil.LevelTrace, "embedding request", "input", input) + logutil.Trace("embedding request", "input", input) if err := s.sem.Acquire(ctx, 1); err != nil { if errors.Is(err, context.Canceled) { @@ -1704,9 +1741,9 @@ func (s *llamaServer) TotalSize() uint64 { return s.estimate.TotalSize } -func (s *llamaServer) VRAMByGPU(gpuID string) uint64 { +func (s *llamaServer) VRAMByGPU(id ml.DeviceID) uint64 { for i, gpu := range s.gpus { - if gpu.ID == gpuID { + if gpu.DeviceID == id { if i < len(s.estimate.GPUSizes) { return s.estimate.GPUSizes[i] } @@ -1715,6 +1752,11 @@ func (s *llamaServer) VRAMByGPU(gpuID string) uint64 { return 0 } +func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { + slog.Debug("llamarunner free vram reporting not supported") + return nil +} + func (s *ollamaServer) VRAMSize() uint64 { if s.mem == nil { return 0 @@ -1723,21 +1765,21 @@ func (s *ollamaServer) VRAMSize() uint64 { var mem uint64 for _, g := range s.mem.GPUs { - mem += g.Allocated() + mem += g.Size() } // Some elements are always on CPU. However, if we have allocated all layers // on the GPU then include the CPU components as well, to represent complete offloading. noCPULayers := true for i := range s.mem.CPU.Weights { - if s.mem.CPU.Weights[i].Size != 0 || s.mem.CPU.Cache[i].Size != 0 { + if s.mem.CPU.Weights[i] != 0 || s.mem.CPU.Cache[i] != 0 { noCPULayers = false break } } if noCPULayers { - mem += s.mem.InputWeights.Size - mem += s.mem.CPU.Graph.Size + mem += s.mem.InputWeights + mem += s.mem.CPU.Graph } return mem @@ -1748,25 +1790,37 @@ func (s *ollamaServer) TotalSize() uint64 { return 0 } - mem := s.mem.InputWeights.Size - mem += s.mem.CPU.Allocated() + mem := s.mem.InputWeights + mem += s.mem.CPU.Size() for _, g := range s.mem.GPUs { - mem += g.Allocated() + mem += g.Size() } return mem } -func (s *ollamaServer) VRAMByGPU(gpuID string) uint64 { +func (s *ollamaServer) VRAMByGPU(id ml.DeviceID) uint64 { if s.mem == nil { return 0 } for _, g := range s.mem.GPUs { - if g.ID == gpuID { - return g.Allocated() + if g.DeviceID == id { + return g.Size() } } return 0 } + +func (s *ollamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { + devices, err := discover.GetDevicesFromRunner(ctx, s) + if err != nil { + if s.cmd != nil && s.cmd.ProcessState == nil { + // Still running but hit an error, log + slog.Debug("failure refreshing GPU information", "error", err) + } + // else no longer running so suppress logging as a failure is expected + } + return devices +} diff --git a/llm/server_test.go b/llm/server_test.go index 4eed82bc..bdedc960 100644 --- a/llm/server_test.go +++ b/llm/server_test.go @@ -16,8 +16,8 @@ import ( func TestLLMServerFitGPU(t *testing.T) { type gpu struct { - library string - free int + id ml.DeviceID + free int } tests := []struct { @@ -37,96 +37,104 @@ func TestLLMServerFitGPU(t *testing.T) { }, { name: "Full single GPU", - gpus: []gpu{{free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}}, layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, numGPU: -1, - expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{0, 1, 2}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2}}}, }, { name: "Partial single GPU", - gpus: []gpu{{free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}}, layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte}, numGPU: -1, - expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{1, 2}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}}, }, { name: "Single GPU with numGPU 1", - gpus: []gpu{{free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}}, layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, numGPU: 1, - expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{1}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}}, }, { name: "Single GPU with numGPU 0", - gpus: []gpu{{free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}}, layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, numGPU: 0, expected: ml.GPULayersList{}, }, { name: "Single GPU with numGPU 999", - gpus: []gpu{{free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}}, layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte}, numGPU: 999, - expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{0, 1, 2, 3}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2, 3}}}, }, { name: "Multi GPU fits on one", - gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}}, layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, numGPU: -1, - expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0, 1, 2}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1, 2}}}, }, { name: "Multi GPU split", - gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}}, layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, numGPU: -1, - expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0}}, {ID: "gpu0", Layers: []int{1, 2}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}}, }, { name: "Multi GPU partial", - gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}}, layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte}, numGPU: -1, - expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{1}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}}, }, { name: "Multi GPU numGPU 1", - gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}}, layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, numGPU: 1, - expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{1}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}}, }, { name: "Multi GPU numGPU 2", - gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}}, layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, numGPU: 2, - expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0}}, {ID: "gpu0", Layers: []int{1}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}}, }, { name: "Multi GPU numGPU 999", - gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}}, layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte}, numGPU: 999, - expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0, 1}}, {ID: "gpu0", Layers: []int{2}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}}, }, { name: "Multi GPU different libraries", - gpus: []gpu{{library: "cuda", free: 128 * format.MebiByte}, {library: "rocm", free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{Library: "CUDA", ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{Library: "ROCm", ID: "gpu1"}, free: 256 * format.MebiByte}}, layers: []int{128 * format.MebiByte, 128 * format.MebiByte, 50 * format.MebiByte}, numGPU: -1, - expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0, 1}}}, + expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1", Library: "ROCm"}, Layers: []int{0, 1}}}, }, { name: "requireFull", - gpus: []gpu{{free: 256 * format.MebiByte}}, + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}}, layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte}, numGPU: -1, requireFull: true, expectedErr: ErrLoadRequiredFull, }, + { + name: "requireFull numGPU", + gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}}, + layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte}, + numGPU: 4, + requireFull: true, + expectedErr: ErrLoadRequiredFull, + }, } for _, tt := range tests { @@ -138,8 +146,7 @@ func TestLLMServerFitGPU(t *testing.T) { gpus := make(discover.GpuInfoList, len(tt.gpus)) for i := range tt.gpus { - gpus[i].ID = fmt.Sprintf("gpu%d", i) - gpus[i].Library = tt.gpus[i].library + gpus[i].DeviceID = tt.gpus[i].id gpus[i].FreeMemory = uint64(tt.gpus[i].free) } @@ -155,18 +162,18 @@ func TestLLMServerFitGPU(t *testing.T) { } s.mem = &ml.BackendMemory{CPU: ml.DeviceMemory{ - Weights: make([]ml.Memory, s.totalLayers), - Cache: make([]ml.Memory, s.totalLayers), + Weights: make([]uint64, s.totalLayers), + Cache: make([]uint64, s.totalLayers), }, GPUs: make([]ml.DeviceMemory, len(gpus))} for i := range tt.layers { - s.mem.CPU.Weights[i].Size = uint64(tt.layers[i]) + s.mem.CPU.Weights[i] = uint64(tt.layers[i]) } for i := range s.mem.GPUs { - s.mem.GPUs[i].ID = fmt.Sprintf("gpu%d", i) - s.mem.GPUs[i].Weights = make([]ml.Memory, s.totalLayers) - s.mem.GPUs[i].Cache = make([]ml.Memory, s.totalLayers) + s.mem.GPUs[i].DeviceID = gpus[i].DeviceID + s.mem.GPUs[i].Weights = make([]uint64, s.totalLayers) + s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers) } gpuLayers, err := s.createLayout(systemInfo, gpus, s.mem, tt.requireFull, 0) diff --git a/middleware/openai.go b/middleware/openai.go new file mode 100644 index 00000000..826a2111 --- /dev/null +++ b/middleware/openai.go @@ -0,0 +1,424 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "math/rand" + "net/http" + + "github.com/gin-gonic/gin" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/openai" +) + +type BaseWriter struct { + gin.ResponseWriter +} + +type ChatWriter struct { + stream bool + streamOptions *openai.StreamOptions + id string + toolCallSent bool + BaseWriter +} + +type CompleteWriter struct { + stream bool + streamOptions *openai.StreamOptions + id string + BaseWriter +} + +type ListWriter struct { + BaseWriter +} + +type RetrieveWriter struct { + BaseWriter + model string +} + +type EmbedWriter struct { + BaseWriter + model string +} + +func (w *BaseWriter) writeError(data []byte) (int, error) { + var serr api.StatusError + err := json.Unmarshal(data, &serr) + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(openai.NewError(http.StatusInternalServerError, serr.Error())) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *ChatWriter) writeResponse(data []byte) (int, error) { + var chatResponse api.ChatResponse + err := json.Unmarshal(data, &chatResponse) + if err != nil { + return 0, err + } + + // chat chunk + if w.stream { + c := openai.ToChunk(w.id, chatResponse, w.toolCallSent) + d, err := json.Marshal(c) + if err != nil { + return 0, err + } + if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 { + w.toolCallSent = true + } + + w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + + if chatResponse.Done { + if w.streamOptions != nil && w.streamOptions.IncludeUsage { + u := openai.ToUsage(chatResponse) + c.Usage = &u + c.Choices = []openai.ChunkChoice{} + d, err := json.Marshal(c) + if err != nil { + return 0, err + } + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + } + _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) + if err != nil { + return 0, err + } + } + + return len(data), nil + } + + // chat completion + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToChatCompletion(w.id, chatResponse)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *ChatWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(data) + } + + return w.writeResponse(data) +} + +func (w *CompleteWriter) writeResponse(data []byte) (int, error) { + var generateResponse api.GenerateResponse + err := json.Unmarshal(data, &generateResponse) + if err != nil { + return 0, err + } + + // completion chunk + if w.stream { + c := openai.ToCompleteChunk(w.id, generateResponse) + if w.streamOptions != nil && w.streamOptions.IncludeUsage { + c.Usage = &openai.Usage{} + } + d, err := json.Marshal(c) + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + + if generateResponse.Done { + if w.streamOptions != nil && w.streamOptions.IncludeUsage { + u := openai.ToUsageGenerate(generateResponse) + c.Usage = &u + c.Choices = []openai.CompleteChunkChoice{} + d, err := json.Marshal(c) + if err != nil { + return 0, err + } + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + } + _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) + if err != nil { + return 0, err + } + } + + return len(data), nil + } + + // completion + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToCompletion(w.id, generateResponse)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *CompleteWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(data) + } + + return w.writeResponse(data) +} + +func (w *ListWriter) writeResponse(data []byte) (int, error) { + var listResponse api.ListResponse + err := json.Unmarshal(data, &listResponse) + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToListCompletion(listResponse)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *ListWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(data) + } + + return w.writeResponse(data) +} + +func (w *RetrieveWriter) writeResponse(data []byte) (int, error) { + var showResponse api.ShowResponse + err := json.Unmarshal(data, &showResponse) + if err != nil { + return 0, err + } + + // retrieve completion + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToModel(showResponse, w.model)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *RetrieveWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(data) + } + + return w.writeResponse(data) +} + +func (w *EmbedWriter) writeResponse(data []byte) (int, error) { + var embedResponse api.EmbedResponse + err := json.Unmarshal(data, &embedResponse) + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToEmbeddingList(w.model, embedResponse)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *EmbedWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(data) + } + + return w.writeResponse(data) +} + +func ListMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + w := &ListWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + } + + c.Writer = w + + c.Next() + } +} + +func RetrieveMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &RetrieveWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + model: c.Param("model"), + } + + c.Writer = w + + c.Next() + } +} + +func CompletionsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req openai.CompletionRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + var b bytes.Buffer + genReq, err := openai.FromCompleteRequest(req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + if err := json.NewEncoder(&b).Encode(genReq); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &CompleteWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + stream: req.Stream, + id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), + streamOptions: req.StreamOptions, + } + + c.Writer = w + c.Next() + } +} + +func EmbeddingsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req openai.EmbedRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + if req.Input == "" { + req.Input = []string{""} + } + + if req.Input == nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input")) + return + } + + if v, ok := req.Input.([]any); ok && len(v) == 0 { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input")) + return + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &EmbedWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + model: req.Model, + } + + c.Writer = w + + c.Next() + } +} + +func ChatMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req openai.ChatCompletionRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + if len(req.Messages) == 0 { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "[] is too short - 'messages'")) + return + } + + var b bytes.Buffer + + chatReq, err := openai.FromChatRequest(req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + if err := json.NewEncoder(&b).Encode(chatReq); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &ChatWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + stream: req.Stream, + id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), + streamOptions: req.StreamOptions, + } + + c.Writer = w + + c.Next() + } +} diff --git a/middleware/openai_test.go b/middleware/openai_test.go new file mode 100644 index 00000000..a78ee8b9 --- /dev/null +++ b/middleware/openai_test.go @@ -0,0 +1,928 @@ +package middleware + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/openai" +) + +const ( + prefix = `data:image/jpeg;base64,` + image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` +) + +var ( + False = false + True = true +) + +func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { + return func(c *gin.Context) { + bodyBytes, _ := io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + err := json.Unmarshal(bodyBytes, capturedRequest) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request") + } + c.Next() + } +} + +func TestChatMiddleware(t *testing.T) { + type testCase struct { + name string + body string + req api.ChatRequest + err openai.ErrorResponse + } + + var capturedRequest *api.ChatRequest + + testCases := []testCase{ + { + name: "chat handler", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with options", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "stream": true, + "max_tokens": 999, + "seed": 123, + "stop": ["\n", "stop"], + "temperature": 3.0, + "frequency_penalty": 4.0, + "presence_penalty": 5.0, + "top_p": 6.0, + "response_format": {"type": "json_object"} + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + }, + Options: map[string]any{ + "num_predict": 999.0, // float because JSON doesn't distinguish between float and int + "seed": 123.0, + "stop": []any{"\n", "stop"}, + "temperature": 3.0, + "frequency_penalty": 4.0, + "presence_penalty": 5.0, + "top_p": 6.0, + }, + Format: json.RawMessage(`"json"`), + Stream: &True, + }, + }, + { + name: "chat handler with streaming usage", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "stream": true, + "stream_options": {"include_usage": true}, + "max_tokens": 999, + "seed": 123, + "stop": ["\n", "stop"], + "temperature": 3.0, + "frequency_penalty": 4.0, + "presence_penalty": 5.0, + "top_p": 6.0, + "response_format": {"type": "json_object"} + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + }, + Options: map[string]any{ + "num_predict": 999.0, // float because JSON doesn't distinguish between float and int + "seed": 123.0, + "stop": []any{"\n", "stop"}, + "temperature": 3.0, + "frequency_penalty": 4.0, + "presence_penalty": 5.0, + "top_p": 6.0, + }, + Format: json.RawMessage(`"json"`), + Stream: &True, + }, + }, + { + name: "chat handler with image content", + body: `{ + "model": "test-model", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Hello" + }, + { + "type": "image_url", + "image_url": { + "url": "` + prefix + image + `" + } + } + ] + } + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + { + Role: "user", + Images: []api.ImageData{ + func() []byte { + img, _ := base64.StdEncoding.DecodeString(image) + return img + }(), + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with tools", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris Today?"}, + {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris Today?", + }, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]any{ + "location": "Paris, France", + "format": "celsius", + }, + }, + }, + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with tools and content", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris Today?"}, + {"role": "assistant", "content": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris Today?", + }, + { + Role: "assistant", + Content: "Let's see what the weather is like in Paris", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]any{ + "location": "Paris, France", + "format": "celsius", + }, + }, + }, + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with tools and empty content", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris Today?"}, + {"role": "assistant", "content": "", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris Today?", + }, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]any{ + "location": "Paris, France", + "format": "celsius", + }, + }, + }, + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with tools and thinking content", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris Today?"}, + {"role": "assistant", "reasoning": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris Today?", + }, + { + Role: "assistant", + Thinking: "Let's see what the weather is like in Paris", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]any{ + "location": "Paris, France", + "format": "celsius", + }, + }, + }, + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "tool response with call ID", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris Today?"}, + {"role": "assistant", "tool_calls": [{"id": "id_abc", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}, + {"role": "tool", "tool_call_id": "id_abc", "content": "The weather in Paris is 20 degrees Celsius"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris Today?", + }, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]any{ + "location": "Paris, France", + "format": "celsius", + }, + }, + }, + }, + }, + { + Role: "tool", + Content: "The weather in Paris is 20 degrees Celsius", + ToolName: "get_current_weather", + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "tool response with name", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris Today?"}, + {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}, + {"role": "tool", "name": "get_current_weather", "content": "The weather in Paris is 20 degrees Celsius"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris Today?", + }, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]any{ + "location": "Paris, France", + "format": "celsius", + }, + }, + }, + }, + }, + { + Role: "tool", + Content: "The weather in Paris is 20 degrees Celsius", + ToolName: "get_current_weather", + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with streaming tools", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris?"} + ], + "stream": true, + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "required": ["location"], + "properties": { + "location": { + "type": "string", + "description": "The city and state" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + } + } + } + }] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris?", + }, + }, + Tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: struct { + Type string `json:"type"` + Defs any `json:"$defs,omitempty"` + Items any `json:"items,omitempty"` + Required []string `json:"required"` + Properties map[string]api.ToolProperty `json:"properties"` + }{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "The city and state", + }, + "unit": { + Type: api.PropertyType{"string"}, + Enum: []any{"celsius", "fahrenheit"}, + }, + }, + }, + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &True, + }, + }, + { + name: "chat handler error forwarding", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": 2} + ] + }`, + err: openai.ErrorResponse{ + Error: openai.Error{ + Message: "invalid message content type: float64", + Type: "invalid_request_error", + }, + }, + }, + } + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest)) + router.Handle(http.MethodPost, "/api/chat", endpoint) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") + + defer func() { capturedRequest = nil }() + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + var errResp openai.ErrorResponse + if resp.Code != http.StatusOK { + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatal(err) + } + return + } + if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" { + t.Fatalf("requests did not match: %+v", diff) + } + if diff := cmp.Diff(tc.err, errResp); diff != "" { + t.Fatalf("errors did not match for %s:\n%s", tc.name, diff) + } + }) + } +} + +func TestCompletionsMiddleware(t *testing.T) { + type testCase struct { + name string + body string + req api.GenerateRequest + err openai.ErrorResponse + } + + var capturedRequest *api.GenerateRequest + + testCases := []testCase{ + { + name: "completions handler", + body: `{ + "model": "test-model", + "prompt": "Hello", + "temperature": 0.8, + "stop": ["\n", "stop"], + "suffix": "suffix" + }`, + req: api.GenerateRequest{ + Model: "test-model", + Prompt: "Hello", + Options: map[string]any{ + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "temperature": 0.8, + "top_p": 1.0, + "stop": []any{"\n", "stop"}, + }, + Suffix: "suffix", + Stream: &False, + }, + }, + { + name: "completions handler stream", + body: `{ + "model": "test-model", + "prompt": "Hello", + "stream": true, + "temperature": 0.8, + "stop": ["\n", "stop"], + "suffix": "suffix" + }`, + req: api.GenerateRequest{ + Model: "test-model", + Prompt: "Hello", + Options: map[string]any{ + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "temperature": 0.8, + "top_p": 1.0, + "stop": []any{"\n", "stop"}, + }, + Suffix: "suffix", + Stream: &True, + }, + }, + { + name: "completions handler stream with usage", + body: `{ + "model": "test-model", + "prompt": "Hello", + "stream": true, + "stream_options": {"include_usage": true}, + "temperature": 0.8, + "stop": ["\n", "stop"], + "suffix": "suffix" + }`, + req: api.GenerateRequest{ + Model: "test-model", + Prompt: "Hello", + Options: map[string]any{ + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "temperature": 0.8, + "top_p": 1.0, + "stop": []any{"\n", "stop"}, + }, + Suffix: "suffix", + Stream: &True, + }, + }, + { + name: "completions handler error forwarding", + body: `{ + "model": "test-model", + "prompt": "Hello", + "temperature": null, + "stop": [1, 2], + "suffix": "suffix" + }`, + err: openai.ErrorResponse{ + Error: openai.Error{ + Message: "invalid type for 'stop' field: float64", + Type: "invalid_request_error", + }, + }, + }, + } + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest)) + router.Handle(http.MethodPost, "/api/generate", endpoint) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + var errResp openai.ErrorResponse + if resp.Code != http.StatusOK { + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatal(err) + } + } + + if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { + t.Fatal("requests did not match") + } + + if !reflect.DeepEqual(tc.err, errResp) { + t.Fatal("errors did not match") + } + + capturedRequest = nil + }) + } +} + +func TestEmbeddingsMiddleware(t *testing.T) { + type testCase struct { + name string + body string + req api.EmbedRequest + err openai.ErrorResponse + } + + var capturedRequest *api.EmbedRequest + + testCases := []testCase{ + { + name: "embed handler single input", + body: `{ + "input": "Hello", + "model": "test-model" + }`, + req: api.EmbedRequest{ + Input: "Hello", + Model: "test-model", + }, + }, + { + name: "embed handler batch input", + body: `{ + "input": ["Hello", "World"], + "model": "test-model" + }`, + req: api.EmbedRequest{ + Input: []any{"Hello", "World"}, + Model: "test-model", + }, + }, + { + name: "embed handler error forwarding", + body: `{ + "model": "test-model" + }`, + err: openai.ErrorResponse{ + Error: openai.Error{ + Message: "invalid input", + Type: "invalid_request_error", + }, + }, + }, + } + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest)) + router.Handle(http.MethodPost, "/api/embed", endpoint) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + var errResp openai.ErrorResponse + if resp.Code != http.StatusOK { + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatal(err) + } + } + + if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { + t.Fatal("requests did not match") + } + + if !reflect.DeepEqual(tc.err, errResp) { + t.Fatal("errors did not match") + } + + capturedRequest = nil + }) + } +} + +func TestListMiddleware(t *testing.T) { + type testCase struct { + name string + endpoint func(c *gin.Context) + resp string + } + + testCases := []testCase{ + { + name: "list handler", + endpoint: func(c *gin.Context) { + c.JSON(http.StatusOK, api.ListResponse{ + Models: []api.ListModelResponse{ + { + Name: "test-model", + ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), + }, + }, + }) + }, + resp: `{ + "object": "list", + "data": [ + { + "id": "test-model", + "object": "model", + "created": 1686935002, + "owned_by": "library" + } + ] + }`, + }, + { + name: "list handler empty output", + endpoint: func(c *gin.Context) { + c.JSON(http.StatusOK, api.ListResponse{}) + }, + resp: `{ + "object": "list", + "data": null + }`, + }, + } + + gin.SetMode(gin.TestMode) + + for _, tc := range testCases { + router := gin.New() + router.Use(ListMiddleware()) + router.Handle(http.MethodGet, "/api/tags", tc.endpoint) + req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil) + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + var expected, actual map[string]any + err := json.Unmarshal([]byte(tc.resp), &expected) + if err != nil { + t.Fatalf("failed to unmarshal expected response: %v", err) + } + + err = json.Unmarshal(resp.Body.Bytes(), &actual) + if err != nil { + t.Fatalf("failed to unmarshal actual response: %v", err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) + } + } +} + +func TestRetrieveMiddleware(t *testing.T) { + type testCase struct { + name string + endpoint func(c *gin.Context) + resp string + } + + testCases := []testCase{ + { + name: "retrieve handler", + endpoint: func(c *gin.Context) { + c.JSON(http.StatusOK, api.ShowResponse{ + ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), + }) + }, + resp: `{ + "id":"test-model", + "object":"model", + "created":1686935002, + "owned_by":"library"} + `, + }, + { + name: "retrieve handler error forwarding", + endpoint: func(c *gin.Context) { + c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"}) + }, + resp: `{ + "error": { + "code": null, + "message": "model not found", + "param": null, + "type": "api_error" + } + }`, + }, + } + + gin.SetMode(gin.TestMode) + + for _, tc := range testCases { + router := gin.New() + router.Use(RetrieveMiddleware()) + router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint) + req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil) + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + var expected, actual map[string]any + err := json.Unmarshal([]byte(tc.resp), &expected) + if err != nil { + t.Fatalf("failed to unmarshal expected response: %v", err) + } + + err = json.Unmarshal(resp.Body.Bytes(), &actual) + if err != nil { + t.Fatalf("failed to unmarshal actual response: %v", err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) + } + } +} diff --git a/ml/backend.go b/ml/backend.go index 455715b0..351942d5 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -5,14 +5,11 @@ import ( "context" "encoding/binary" "fmt" - "hash/maphash" - "log/slog" "math" "slices" "strconv" "strings" - "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs" ) @@ -29,6 +26,9 @@ type Backend interface { Get(name string) Tensor NewContext() Context NewContextSize(size int) Context + + // Enumerate the devices available for inference via this backend + BackendDevices() []DeviceInfo } // BackendCacheConfig should be implemented by backends that need special output @@ -60,77 +60,6 @@ type CacheConfig struct { MaskBatchPadding int } -// GPULayers is a set of layers to be allocated on a single GPU -type GPULayers struct { - // ID is the identifier of the GPU, as reported in DeviceMemory - ID string - - // Layers is a set of layer indicies to load - Layers []int -} - -func (g GPULayers) String() string { - if len(g.Layers) == 0 { - return "" - } - - slices.Sort(g.Layers) - - contiguous := true - base := g.Layers[0] - for i := range g.Layers { - if g.Layers[i] != base+i { - contiguous = false - break - } - } - - if contiguous { - return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1]) - } else { - return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers) - } -} - -// GPULayersList is a set of layer allocations across multiple GPUs -type GPULayersList []GPULayers - -func (l GPULayersList) String() string { - if l.Sum() > 0 { - return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l)) - } else { - return fmt.Sprintf("%v", []GPULayers(l)) - } -} - -// Sum is the total number of layers assigned across all GPUs -func (l GPULayersList) Sum() int { - var sum int - - for _, g := range l { - sum += len(g.Layers) - } - - return sum -} - -var h maphash.Hash - -// Hash is an identifier of this layer assignment -func (l GPULayersList) Hash() uint64 { - h.Reset() - for _, g := range l { - if len(g.Layers) > 0 { - h.WriteString(g.ID) - for _, l := range g.Layers { - binary.Write(&h, binary.NativeEndian, int64(l)) - } - } - } - - return h.Sum64() -} - // BackendParams controls how the backend loads and executes models type BackendParams struct { // AllocMemory causes the backend to allocate memory for the model. If @@ -148,201 +77,6 @@ type BackendParams struct { FlashAttention bool } -// ErrNoMem is returned when panicing due to insufficient memory. It includes -// the attempted memory allocation. -type ErrNoMem struct { - BackendMemory -} - -func (e ErrNoMem) Error() string { - return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory) -} - -type AllocationStatus int - -const ( - // Unallocated memory - have not yet attempted to allocate - Unallocated AllocationStatus = iota - - // Failed memory - tried to allocate the memory and did not succeed - Failed - - // Allocated memory = tried and succeeded to allocate memory - Allocated -) - -// Memory is the size of an allocation and whether it was successful. -type Memory struct { - Size uint64 - Status AllocationStatus -} - -func (m Memory) String() string { - s := fmt.Sprint(m.Size) - - switch m.Status { - case Unallocated: - s += "U" - case Failed: - s += "F" - case Allocated: - s += "A" - } - - return s -} - -// DeviceMemory provides a breakdown of the memory needed -// per device, such as a CPU or GPU. -type DeviceMemory struct { - // Name is the name of the device as labeled by the backend. It - // may not be persistent across instances of the runner. - Name string - - // ID is an identifier for the device for matching with system - // management libraries. - ID string - - // Weights is the per-layer memory needed for the model weights. - Weights []Memory - - // Cache is the per-layer memory needed for the KV cache. - Cache []Memory - - // Graph is the size of the compute graph. It is not per-layer. - Graph Memory -} - -// Allocated returns the total size of the memory that has been successfully -// allocated on this device -func (m DeviceMemory) Allocated() uint64 { - var mem uint64 - - for _, w := range m.Weights { - if w.Status == Allocated { - mem += w.Size - } - } - for _, c := range m.Cache { - if c.Status == Allocated { - mem += c.Size - } - } - if m.Graph.Status == Allocated { - mem += m.Graph.Size - } - - return mem -} - -func memoryPresent(mem []Memory) bool { - return slices.ContainsFunc(mem, func(m Memory) bool { return m.Size != 0 }) -} - -func (m DeviceMemory) LogValue() slog.Value { - var attrs []slog.Attr - if memoryPresent(m.Weights) { - attrs = append(attrs, slog.Any("Weights", m.Weights)) - } - - if memoryPresent(m.Cache) { - attrs = append(attrs, slog.Any("Cache", m.Cache)) - } - - if m.Graph.Size != 0 { - attrs = append(attrs, slog.Any("Graph", m.Graph)) - } - - if len(attrs) > 0 && m.ID != "" { - attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...) - } - - return slog.GroupValue(attrs...) -} - -// BackendMemory provides the amount of memory required to load the model -// per device based on the BackendParams. In some cases, not all required -// allocations will be known at this point. However, the size of the most recent -// allocation is guaranteed to be provided so that if it failed, the caller can -// accommodate that to make forward progress. -type BackendMemory struct { - // InputWeights are always located on the CPU and cannot be moved - InputWeights Memory - - // CPU model components are located in system memory. This does not - // include unified memory allocated through the GPU. - CPU DeviceMemory - - // GPU model components are located on one or more GPUs. - GPUs []DeviceMemory -} - -func (m BackendMemory) LogValue() slog.Value { - var attrs []slog.Attr - if m.InputWeights.Size != 0 { - attrs = append(attrs, slog.Any("InputWeights", m.InputWeights)) - } - - attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU)) - for _, g := range m.GPUs { - attrs = append(attrs, slog.Any(g.Name, g)) - } - - return slog.GroupValue(attrs...) -} - -func sumMemory(mem []Memory) uint64 { - var sum uint64 - - for _, m := range mem { - sum += m.Size - } - - return sum -} - -// Log prints a high level summary of the memory (allocated or not) -func (m BackendMemory) Log(level slog.Level) { - var total uint64 - - for _, gpu := range m.GPUs { - if sum := sumMemory(gpu.Weights); sum > 0 { - slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum)) - total += sum - } - } - if sum := m.InputWeights.Size + sumMemory(m.CPU.Weights); sum > 0 { - slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) - total += sum - } - - for _, gpu := range m.GPUs { - if sum := sumMemory(gpu.Cache); sum > 0 { - slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum)) - total += sum - } - } - if sum := sumMemory(m.CPU.Cache); sum > 0 { - slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) - total += sum - } - - for _, gpu := range m.GPUs { - if sum := gpu.Graph.Size; sum > 0 { - slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum)) - total += sum - } - } - if sum := m.CPU.Graph.Size; sum > 0 { - slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) - total += sum - } - - if total > 0 { - slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total)) - } -} - var backends = make(map[string]func(string, BackendParams) (Backend, error)) func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) { diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 49dc3e1a..88078d77 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1,5 +1,7 @@ package ggml +// #cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm +// #cgo windows LDFLAGS: -lpthread // #cgo CPPFLAGS: -I${SRCDIR}/ggml/include // #include // #include @@ -55,7 +57,8 @@ var initDevices = sync.OnceFunc(func() { } case C.GGML_BACKEND_DEVICE_TYPE_ACCEL: accels = append(accels, d) - case C.GGML_BACKEND_DEVICE_TYPE_GPU: + case C.GGML_BACKEND_DEVICE_TYPE_GPU, + C.GGML_BACKEND_DEVICE_TYPE_IGPU: gpus = append(gpus, d) } @@ -159,7 +162,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { C.GGML_BACKEND_DEVICE_TYPE_ACCEL: bt := C.ggml_backend_dev_buffer_type(d) cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, bt) - C.ggml_backend_buft_set_alloc(bt, C.bool(params.AllocMemory)) btDeviceMemory[C.ggml_backend_dev_buffer_type(d)] = &requiredMemory.CPU } @@ -169,8 +171,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { var props C.struct_ggml_backend_dev_props C.ggml_backend_dev_get_props(cpuDeviceBufferType.d, &props) requiredMemory.CPU.ID = C.GoString(props.id) - requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1) - requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1) + requiredMemory.CPU.Library = C.GoString(props.library) + requiredMemory.CPU.Weights = make([]uint64, blocks+1) + requiredMemory.CPU.Cache = make([]uint64, blocks+1) // create list of buffer types for each gpu var gpuDeviceBufferTypes []deviceBufferType @@ -181,15 +184,15 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { d: d, bts: append([]C.ggml_backend_buffer_type_t{bt}, cpuDeviceBufferType.bts...), }) - C.ggml_backend_buft_set_alloc(bt, C.bool(params.AllocMemory)) btDeviceMemory[bt] = &requiredMemory.GPUs[i] requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d)) var props C.struct_ggml_backend_dev_props C.ggml_backend_dev_get_props(d, &props) requiredMemory.GPUs[i].ID = C.GoString(props.id) - requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1) - requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1) + requiredMemory.GPUs[i].Library = C.GoString(props.library) + requiredMemory.GPUs[i].Weights = make([]uint64, blocks+1) + requiredMemory.GPUs[i].Cache = make([]uint64, blocks+1) } // inputs always use cpu @@ -200,7 +203,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { for _, l := range p.Layers { if l == layer { for i := range requiredMemory.GPUs { - if requiredMemory.GPUs[i].ID == p.ID { + if requiredMemory.GPUs[i].DeviceID == p.DeviceID { return gpuDeviceBufferTypes[i] } } @@ -275,13 +278,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt)) if layer == -1 { - // Assume that InputWeights can be allocated - they're always in system memory and can't be moved in any case - if params.AllocMemory { - requiredMemory.InputWeights.Status = ml.Allocated - } - requiredMemory.InputWeights.Size += uint64(size) + requiredMemory.InputWeights += uint64(size) } else { - btDeviceMemory[bt].Weights[layer].Size += uint64(size) + btDeviceMemory[bt].Weights[layer] += uint64(size) } //nolint:staticcheck // TODO: check if buffer type supports this tensor @@ -341,47 +340,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } } - // allocate buffers for each context - bbs := make(map[*C.struct_ggml_context]C.ggml_backend_buffer_t, len(ctxs)) - for bt, c := range ctxs { - if C.ggml_get_first_tensor(c) == nil { - continue - } - - b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt) - if params.AllocMemory { - for i := range btDeviceMemory[bt].Weights { - if btDeviceMemory[bt].Weights[i].Size != 0 { - if b != nil { - btDeviceMemory[bt].Weights[i].Status = ml.Allocated - } else { - btDeviceMemory[bt].Weights[i].Status = ml.Failed - } - } - } - } - - if b == nil { - for _, b := range bbs { - C.ggml_backend_buffer_free(b) - } - - for _, ctx := range ctxs { - C.ggml_free(ctx) - } - - panic(ml.ErrNoMem{BackendMemory: requiredMemory}) - } - - C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS) - bbs[c] = b - } - - for bs := range maps.Values(bbs) { - logutil.Trace("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), - "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs)))) - } - // map tensor names to tensors for easy lookup later tensors := make(map[string]*C.struct_ggml_tensor) for _, c := range ctxs { @@ -419,6 +377,46 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } maxGraphNodes := max(8192, len(meta.Tensors().Items())*5) + + sched := C.ggml_backend_sched_new_ext( + (*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])), + (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])), + C.int(len(schedBackends)), + C.size_t(maxGraphNodes), + C._Bool(false), + C._Bool(false), + C._Bool(params.AllocMemory), + ) + + // allocate buffers for each context + bbs := make(map[*C.struct_ggml_context]C.ggml_backend_buffer_t, len(ctxs)) + for bt, c := range ctxs { + if C.ggml_get_first_tensor(c) == nil { + continue + } + + b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt) + if b == nil { + for _, b := range bbs { + C.ggml_backend_buffer_free(b) + } + + for _, ctx := range ctxs { + C.ggml_free(ctx) + } + + panic(ml.ErrNoMem{BackendMemory: requiredMemory}) + } + + C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS) + bbs[c] = b + } + + for bs := range maps.Values(bbs) { + logutil.Trace("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), + "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs)))) + } + return &Backend{ modelPath: modelPath, allocMemory: params.AllocMemory, @@ -426,18 +424,11 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { meta: meta, tensorLoadTargets: targets, tensors: tensors, - sched: C.ggml_backend_sched_new( - (*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])), - (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])), - C.int(len(schedBackends)), - C.size_t(maxGraphNodes), - C._Bool(false), - C._Bool(false), - ), - schedBackends: schedBackends, - schedBufts: schedBufts, - input: deviceBufferTypes[input.d], - output: output.d, + sched: sched, + schedBackends: schedBackends, + schedBufts: schedBufts, + input: deviceBufferTypes[input.d], + output: output.d, layers: func() map[int]layerDevice { m := make(map[int]layerDevice) for i, layer := range layers { @@ -480,7 +471,9 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { // Mimic llama runner logs summarizing layers and memory gpuLayers := 0 for layer := range maps.Values(b.layers) { - if C.ggml_backend_dev_type(layer.d) == C.GGML_BACKEND_DEVICE_TYPE_GPU { + switch C.ggml_backend_dev_type(layer.d) { + case C.GGML_BACKEND_DEVICE_TYPE_GPU, + C.GGML_BACKEND_DEVICE_TYPE_IGPU: gpuLayers++ } } @@ -489,7 +482,8 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { switch C.ggml_backend_dev_type(b.output) { case C.GGML_BACKEND_DEVICE_TYPE_CPU: slog.Info("offloading output layer to CPU") - case C.GGML_BACKEND_DEVICE_TYPE_GPU: + case C.GGML_BACKEND_DEVICE_TYPE_GPU, + C.GGML_BACKEND_DEVICE_TYPE_IGPU: slog.Info("offloading output layer to GPU") gpuLayers++ case C.GGML_BACKEND_DEVICE_TYPE_ACCEL: @@ -696,6 +690,55 @@ func (b *Backend) CacheConfig() ml.CacheConfig { } } +func (b *Backend) BackendDevices() []ml.DeviceInfo { + deviceInfos := []ml.DeviceInfo{} + for _, dev := range gpus { + // If we have a model loaded, and it's only loaded on a subset of the devices + // skip idle/unused devices to avoid initializing them and causing VRAM allocations + if b.allocMemory { + idleDev := true + for _, backend := range b.schedBackends { + if dev == C.ggml_backend_get_device(backend) { + idleDev = false + break + } + } + if idleDev { + slog.Debug("skipping unused backend device", "description", C.GoString(C.ggml_backend_dev_description(dev))) + continue + } + } + + info := ml.DeviceInfo{} + props := C.struct_ggml_backend_dev_props{} + C.ggml_backend_dev_get_props(dev, &props) + info.Name = C.GoString(props.name) + info.Description = C.GoString(props.description) + info.ID = C.GoString(props.id) + info.Library = C.GoString(props.library) + info.ComputeMajor = (int)(props.compute_major) + info.ComputeMinor = (int)(props.compute_minor) + info.DriverMajor = (int)(props.driver_major) + info.DriverMinor = (int)(props.driver_minor) + info.Integrated = props.integrated != 0 + if props.library != nil { + info.Library = C.GoString(props.library) + } + info.PCIID = fmt.Sprintf("%02x:%02x.%x", props.pci_bus_id, props.pci_device_id, props.pci_domain_id) + info.LibraryPath = ggml.LibPaths() + if props.numeric_id != nil { + info.FilteredID = C.GoString(props.numeric_id) + } + + C.ggml_backend_dev_memory(dev, &props.memory_free, &props.memory_total) + info.TotalMemory = (uint64)(props.memory_total) + info.FreeMemory = (uint64)(props.memory_free) + + deviceInfos = append(deviceInfos, info) + } + return deviceInfos +} + type Context struct { b *Backend @@ -795,24 +838,15 @@ func (c *Context) Reserve() { // Reserve may get called multiple times for different graphs - we just want the last run, which will contain the max allocations for _, bt := range c.b.schedBufts { - c.b.btDeviceMemory[bt].Graph = ml.Memory{} + c.b.btDeviceMemory[bt].Graph = 0 } for i := range c.b.schedBackends { - bufferStatus := C.ggml_backend_sched_get_attempted_buffer_size(c.b.sched, c.b.schedBackends[i]) - - graph := &c.b.btDeviceMemory[c.b.schedBufts[i]].Graph - graph.Size += uint64(bufferStatus.size) - if c.b.allocMemory { - if bufferStatus.allocated && graph.Status != ml.Failed { - graph.Status = ml.Allocated - } else { - graph.Status = ml.Failed - } - } + bufferSize := C.ggml_backend_sched_get_attempted_buffer_size(c.b.sched, c.b.schedBackends[i]) + c.b.btDeviceMemory[c.b.schedBufts[i]].Graph += uint64(bufferSize) logutil.Trace("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), - "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferStatus.size))) + "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferSize))) } if !reserved { @@ -862,16 +896,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor { b := C.ggml_backend_buft_alloc_buffer(c.buft, size) if c.layer >= 0 { - cache := &c.b.btDeviceMemory[c.buft].Cache[c.layer] - - cache.Size += uint64(size) - if c.b.allocMemory { - if b != nil { - cache.Status = ml.Allocated - } else { - cache.Status = ml.Failed - } - } + c.b.btDeviceMemory[c.buft].Cache[c.layer] += uint64(size) } if b == nil { diff --git a/ml/backend/ggml/ggml/.rsync-filter b/ml/backend/ggml/ggml/.rsync-filter index a2b1b7d9..449ec9e5 100644 --- a/ml/backend/ggml/ggml/.rsync-filter +++ b/ml/backend/ggml/ggml/.rsync-filter @@ -20,10 +20,14 @@ include /src/ggml-cuda/vendors/ include /src/ggml-cuda/template-instances/ include /src/ggml-hip/ include /src/ggml-metal/ +include src/ggml-vulkan/ +include src/ggml-vulkan/vulkan-shaders include CMakeLists.txt include *.[chm] include *.cpp include *.cu include *.cuh include *.metal +include *.comp +include *.glsl hide * diff --git a/ml/backend/ggml/ggml/include/ggml-alloc.h b/ml/backend/ggml/ggml/include/ggml-alloc.h index 781b1e10..7ab3f019 100644 --- a/ml/backend/ggml/ggml/include/ggml-alloc.h +++ b/ml/backend/ggml/ggml/include/ggml-alloc.h @@ -65,12 +65,7 @@ GGML_API bool ggml_gallocr_reserve_n( GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph); GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id); - -struct ggml_allocr_buffer_status { - size_t size; - bool allocated; -}; -GGML_API struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id); +GGML_API size_t ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id); // Utils // Create a buffer and allocate all the tensors in a ggml_context diff --git a/ml/backend/ggml/ggml/include/ggml-backend.h b/ml/backend/ggml/ggml/include/ggml-backend.h index fda5ceb2..094fc3c8 100644 --- a/ml/backend/ggml/ggml/include/ggml-backend.h +++ b/ml/backend/ggml/ggml/include/ggml-backend.h @@ -35,7 +35,6 @@ extern "C" { // GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft); - GGML_API void ggml_backend_buft_set_alloc (ggml_backend_buffer_type_t buft, bool alloc); GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size); GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft); GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft); @@ -133,6 +132,8 @@ extern "C" { GGML_BACKEND_DEVICE_TYPE_CPU, // GPU device using dedicated memory GGML_BACKEND_DEVICE_TYPE_GPU, + // integrated GPU device using host memory + GGML_BACKEND_DEVICE_TYPE_IGPU, // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX) GGML_BACKEND_DEVICE_TYPE_ACCEL }; @@ -151,13 +152,34 @@ extern "C" { // all the device properties struct ggml_backend_dev_props { + // device name const char * name; + // device description const char * description; - const char * id; + // device free memory in bytes size_t memory_free; + const char * id; + // device total memory in bytes size_t memory_total; + // device type enum ggml_backend_dev_type type; + // device id + // for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0") + // if the id is unknown, this should be NULL + const char * device_id; + // device capabilities struct ggml_backend_dev_caps caps; + int driver_major; + int driver_minor; + int compute_major; + int compute_minor; + int integrated; + int pci_bus_id; + int pci_device_id; + int pci_domain_id; + const char *library; + // number with which the devices are accessed (Vulkan) + const char *numeric_id; }; GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device); @@ -206,6 +228,8 @@ extern "C" { // Backend registry // + GGML_API void ggml_backend_register(ggml_backend_reg_t reg); + GGML_API void ggml_backend_device_register(ggml_backend_dev_t device); // Backend (reg) enumeration @@ -293,6 +317,7 @@ extern "C" { // Initialize a backend scheduler, backends with low index are given priority over backends with high index GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload); + GGML_API ggml_backend_sched_t ggml_backend_sched_new_ext(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload, bool alloc_buffers); GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); // Initialize backend buffers from a measure graph @@ -305,17 +330,16 @@ extern "C" { GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched); GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched); - GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); - - struct ggml_backend_buffer_status { - size_t size; - bool allocated; - }; - GGML_API struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); + GGML_API ggml_backend_buffer_type_t ggml_backend_sched_get_buffer_type(ggml_backend_sched_t sched, ggml_backend_t backend); + GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); + GGML_API size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); + // Split graph without allocating it + GGML_API void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); + // Allocate and compute graph on the backend scheduler GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph); diff --git a/ml/backend/ggml/ggml/include/ggml-cpu.h b/ml/backend/ggml/ggml/include/ggml-cpu.h index be40b100..9edd4851 100644 --- a/ml/backend/ggml/ggml/include/ggml-cpu.h +++ b/ml/backend/ggml/ggml/include/ggml-cpu.h @@ -101,7 +101,6 @@ extern "C" { GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); GGML_BACKEND_API int ggml_cpu_has_vsx (void); GGML_BACKEND_API int ggml_cpu_has_vxe (void); - GGML_BACKEND_API int ggml_cpu_has_nnpa (void); GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void); GGML_BACKEND_API int ggml_cpu_has_llamafile (void); @@ -135,6 +134,7 @@ extern "C" { GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t); + GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t); GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t); GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t); GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t); diff --git a/ml/backend/ggml/ggml/include/ggml-metal.h b/ml/backend/ggml/ggml/include/ggml-metal.h index a6106944..433838f0 100644 --- a/ml/backend/ggml/ggml/include/ggml-metal.h +++ b/ml/backend/ggml/ggml/include/ggml-metal.h @@ -39,18 +39,13 @@ extern "C" { // user-code should use only these functions // +// TODO: remove in the future GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void); GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend); -GGML_DEPRECATED( - GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size), - "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713"); - GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); -GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); - // helper to check if the device supports a specific family // ideally, the user code should be doing these checks // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf diff --git a/ml/backend/ggml/ggml/include/ggml-opt.h b/ml/backend/ggml/ggml/include/ggml-opt.h index 74ec080a..4703a05a 100644 --- a/ml/backend/ggml/ggml/include/ggml-opt.h +++ b/ml/backend/ggml/ggml/include/ggml-opt.h @@ -74,16 +74,26 @@ extern "C" { GGML_OPT_BUILD_TYPE_OPT = 30, }; + enum ggml_opt_optimizer_type { + GGML_OPT_OPTIMIZER_TYPE_ADAMW, + GGML_OPT_OPTIMIZER_TYPE_SGD, + + GGML_OPT_OPTIMIZER_TYPE_COUNT + }; + // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss struct ggml_opt_optimizer_params { - // AdamW optimizer parameters struct { float alpha; // learning rate - float beta1; - float beta2; + float beta1; // first AdamW momentum + float beta2; // second AdamW momentum float eps; // epsilon for numerical stability - float wd; // weight decay for AdamW, use 0.0f to disable + float wd; // weight decay - 0.0f to disable } adamw; + struct { + float alpha; // learning rate + float wd; // weight decay + } sgd; }; // callback to calculate optimizer parameters prior to a backward pass @@ -112,8 +122,11 @@ extern "C" { int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done - ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters - void * get_opt_pars_ud; // userdata for calculating optimizer parameters + ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters + void * get_opt_pars_ud; // userdata for calculating optimizer parameters + + // only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor + enum ggml_opt_optimizer_type optimizer; }; // get parameters for an optimization context with defaults set where possible @@ -142,6 +155,10 @@ extern "C" { // get the gradient accumulator for a node from the forward graph GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node); + GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme + + GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type); + // ====== Optimization Result ====== GGML_API ggml_opt_result_t ggml_opt_result_init(void); @@ -226,12 +243,14 @@ extern "C" { struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used ggml_opt_dataset_t dataset, // dataset with data and optionally also labels enum ggml_opt_loss_type loss_type, // loss to minimize + enum ggml_opt_optimizer_type optimizer, // sgd or adamw ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t) int64_t nepoch, // how many times the dataset should be iterated over int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f) bool silent); // whether or not info prints to stderr should be suppressed + #ifdef __cplusplus } #endif diff --git a/ml/backend/ggml/ggml/include/ggml-rpc.h b/ml/backend/ggml/ggml/include/ggml-rpc.h index 1e674112..72eff002 100644 --- a/ml/backend/ggml/ggml/include/ggml-rpc.h +++ b/ml/backend/ggml/ggml/include/ggml-rpc.h @@ -7,26 +7,25 @@ extern "C" { #endif -#define RPC_PROTO_MAJOR_VERSION 2 +#define RPC_PROTO_MAJOR_VERSION 3 #define RPC_PROTO_MINOR_VERSION 0 #define RPC_PROTO_PATCH_VERSION 0 #define GGML_RPC_MAX_SERVERS 16 // backend API -GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint); +GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device); GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend); -GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device); -GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total); +GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total); -GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, - const char * cache_dir, - size_t free_mem, size_t total_mem); +GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, + size_t n_threads, size_t n_devices, + ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); - -GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint); +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint); #ifdef __cplusplus } diff --git a/ml/backend/ggml/ggml/include/ggml-zdnn.h b/ml/backend/ggml/ggml/include/ggml-zdnn.h new file mode 100644 index 00000000..fbf45b6e --- /dev/null +++ b/ml/backend/ggml/ggml/include/ggml-zdnn.h @@ -0,0 +1,17 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// device buffer +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_zdnn_buffer_type(void); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zdnn_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/ml/backend/ggml/ggml/include/ggml.h b/ml/backend/ggml/ggml/include/ggml.h index 2f06e1e3..60c6b63d 100644 --- a/ml/backend/ggml/ggml/include/ggml.h +++ b/ml/backend/ggml/ggml/include/ggml.h @@ -237,11 +237,22 @@ #define GGML_EXIT_SUCCESS 0 #define GGML_EXIT_ABORTED 1 +// TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726 +#define GGML_ROPE_TYPE_NORMAL 0 #define GGML_ROPE_TYPE_NEOX 2 #define GGML_ROPE_TYPE_MROPE 8 #define GGML_ROPE_TYPE_VISION 24 +#define GGML_MROPE_SECTIONS 4 + #define GGML_UNUSED(x) (void)(x) +#ifdef __CUDACC__ +template +__host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexcept {} +#define GGML_UNUSED_VARS(...) ggml_unused_vars_impl(__VA_ARGS__) +#else +#define GGML_UNUSED_VARS(...) do { (void)sizeof((__VA_ARGS__, 0)); } while(0) +#endif // __CUDACC__ #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) @@ -275,19 +286,19 @@ // GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); // #define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \ - const type prefix##0 = (pointer)->array[0]; \ + const type prefix##0 = (pointer) ? (pointer)->array[0] : 0; \ GGML_UNUSED(prefix##0); #define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \ GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \ - const type prefix##1 = (pointer)->array[1]; \ + const type prefix##1 = (pointer) ? (pointer)->array[1] : 0; \ GGML_UNUSED(prefix##1); #define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \ GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \ - const type prefix##2 = (pointer)->array[2]; \ + const type prefix##2 = (pointer) ? (pointer)->array[2] : 0; \ GGML_UNUSED(prefix##2); #define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \ GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \ - const type prefix##3 = (pointer)->array[3]; \ + const type prefix##3 = (pointer) ? (pointer)->array[3] : 0; \ GGML_UNUSED(prefix##3); #define GGML_TENSOR_UNARY_OP_LOCALS \ @@ -502,7 +513,9 @@ extern "C" { GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, + GGML_OP_IM2COL_3D, GGML_OP_CONV_2D, + GGML_OP_CONV_3D, GGML_OP_CONV_2D_DW, GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, @@ -540,6 +553,7 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS, GGML_OP_CROSS_ENTROPY_LOSS_BACK, GGML_OP_OPT_STEP_ADAMW, + GGML_OP_OPT_STEP_SGD, GGML_OP_GLU, @@ -562,6 +576,7 @@ extern "C" { GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_EXP, GGML_UNARY_OP_GELU_ERF, + GGML_UNARY_OP_XIELU, GGML_UNARY_OP_COUNT, }; @@ -1136,6 +1151,18 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // xIELU activation function + // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0) + // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions + // that constrain the positive and negative source alpha values respectively + GGML_API struct ggml_tensor * ggml_xielu( + struct ggml_context * ctx, + struct ggml_tensor * a, + float alpha_n, + float alpha_p, + float beta, + float eps); + // gated linear unit ops // A: n columns, r rows, // result is n / 2 columns, r rows, @@ -1392,6 +1419,7 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // note: casting from f32 to i32 will discard the fractional part GGML_API struct ggml_tensor * ggml_cast( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1516,7 +1544,11 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // supports 3D: a->ne[2] == b->ne[1] + // supports 4D a: + // a [n_embd, ne1, ne2, ne3] + // b I32 [n_rows, ne2, ne3, 1] + // + // return [n_embd, n_rows, ne2, ne3] GGML_API struct ggml_tensor * ggml_get_rows( struct ggml_context * ctx, struct ggml_tensor * a, // data @@ -1598,6 +1630,13 @@ extern "C" { float scale, float max_bias); + GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias); + GGML_API void ggml_soft_max_add_sinks( struct ggml_tensor * a, struct ggml_tensor * sinks); @@ -1660,7 +1699,7 @@ extern "C" { struct ggml_tensor * b, struct ggml_tensor * c, int n_dims, - int sections[4], + int sections[GGML_MROPE_SECTIONS], int mode, int n_ctx_orig, float freq_base, @@ -1686,6 +1725,22 @@ extern "C" { float beta_fast, float beta_slow); + GGML_API struct ggml_tensor * ggml_rope_multi_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[GGML_MROPE_SECTIONS], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1843,6 +1898,41 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 + GGML_API struct ggml_tensor * ggml_im2col_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t IC, + int s0, // stride width + int s1, // stride height + int s2, // stride depth + int p0, // padding width + int p1, // padding height + int p2, // padding depth + int d0, // dilation width + int d1, // dilation height + int d2, // dilation depth + enum ggml_type dst_type); + + // a: [OC*IC, KD, KH, KW] + // b: [N*IC, ID, IH, IW] + // result: [N*OC, OD, OH, OW] + GGML_API struct ggml_tensor * ggml_conv_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t IC, + int s0, // stride width + int s1, // stride height + int s2, // stride depth + int p0, // padding width + int p1, // padding height + int p2, // padding depth + int d0, // dilation width + int d1, // dilation height + int d2 // dilation depth + ); + // kernel size is a->ne[0] x a->ne[1] // stride is equal to kernel size // padding is zero @@ -1914,6 +2004,23 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 + GGML_API struct ggml_tensor * ggml_conv_3d_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC] + struct ggml_tensor * b, // input [W, H, D, C * N] + int s0, // stride + int s1, + int s2, + int p0, // padding + int p1, + int p2, + int d0, // dilation + int d1, + int d2, + int n_channels, + int n_batch, + int n_channels_out); + enum ggml_op_pool { GGML_OP_POOL_MAX, GGML_OP_POOL_AVG, @@ -2004,6 +2111,19 @@ extern "C" { int p2, int p3); + GGML_API struct ggml_tensor * ggml_pad_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + int lp0, + int rp0, + int lp1, + int rp1, + int lp2, + int rp2, + int lp3, + int rp3 + ); + // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c] GGML_API struct ggml_tensor * ggml_pad_reflect_1d( struct ggml_context * ctx, @@ -2293,7 +2413,14 @@ extern "C" { struct ggml_tensor * grad, struct ggml_tensor * m, struct ggml_tensor * v, - struct ggml_tensor * adamw_params); // parameters such a the learning rate + struct ggml_tensor * adamw_params); // parameters such as the learning rate + + // stochastic gradient descent step (with weight decay) + GGML_API struct ggml_tensor * ggml_opt_step_sgd( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * grad, + struct ggml_tensor * sgd_params); // alpha, weight decay // // automatic differentiation diff --git a/ml/backend/ggml/ggml/src/CMakeLists.txt b/ml/backend/ggml/ggml/src/CMakeLists.txt index 5158acd6..aefe43bd 100644 --- a/ml/backend/ggml/ggml/src/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/CMakeLists.txt @@ -114,6 +114,9 @@ message(STATUS "GGML_SYSTEM_ARCH: ${GGML_SYSTEM_ARCH}") if (NOT MSVC) if (GGML_STATIC) + if (UNIX AND NOT APPLE) + set(CMAKE_FIND_LIBRARY_SUFFIXES ".a;.so") + endif() add_link_options(-static) if (MINGW) add_link_options(-static-libgcc -static-libstdc++) @@ -142,6 +145,9 @@ endif() # which was introduced in POSIX.1-2008, forcing us to go higher if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") add_compile_definitions(_XOPEN_SOURCE=700) +elseif (CMAKE_SYSTEM_NAME MATCHES "AIX") + # Don't define _XOPEN_SOURCE. We need _ALL_SOURCE, which is the default, + # in order to define _SC_PHYS_PAGES. else() add_compile_definitions(_XOPEN_SOURCE=600) endif() @@ -203,6 +209,8 @@ add_library(ggml-base ggml-threading.h ggml-quants.c ggml-quants.h + mem_hip.cpp + mem_nvml.cpp gguf.cpp) target_include_directories(ggml-base PRIVATE .) @@ -380,6 +388,7 @@ ggml_add_backend(RPC) ggml_add_backend(SYCL) ggml_add_backend(Vulkan) ggml_add_backend(WebGPU) +ggml_add_backend(zDNN) ggml_add_backend(OpenCL) foreach (target ggml-base ggml) diff --git a/ml/backend/ggml/ggml/src/ggml-alloc.c b/ml/backend/ggml/ggml/src/ggml-alloc.c index 41c8c4a2..eee9d3b1 100644 --- a/ml/backend/ggml/ggml/src/ggml-alloc.c +++ b/ml/backend/ggml/ggml/src/ggml-alloc.c @@ -23,7 +23,7 @@ static bool ggml_is_view(const struct ggml_tensor * t) { } // ops that return true for this function must not use restrict pointers for their backend implementations -static bool ggml_op_can_inplace(enum ggml_op op) { +bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { case GGML_OP_SCALE: case GGML_OP_DIAG_MASK_ZERO: @@ -95,39 +95,104 @@ enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_te // dynamic tensor allocator +#define GGML_VBUFFER_MAX_CHUNKS 16 + +// relative memory address within an allocation that can be split into multiple buffers (chunks) +struct buffer_address { + int chunk; // index of a backend buffer + size_t offset; // local memory offset within the buffer +}; + +static const struct buffer_address GGML_BUFFER_ADDRESS_INVALID = { -1, SIZE_MAX }; + +static bool ggml_buffer_address_less(struct buffer_address a, struct buffer_address b) { + return a.chunk != b.chunk ? a.chunk < b.chunk : a.offset < b.offset; +} + struct free_block { size_t offset; size_t size; }; +struct tallocr_chunk { + struct free_block free_blocks[MAX_FREE_BLOCKS]; + int n_free_blocks; + size_t max_size; +}; + struct ggml_dyn_tallocr { size_t alignment; - int n_free_blocks; - struct free_block free_blocks[MAX_FREE_BLOCKS]; - size_t max_size; + size_t max_chunk_size; + struct tallocr_chunk * chunks[GGML_VBUFFER_MAX_CHUNKS]; + int n_chunks; #ifdef GGML_ALLOCATOR_DEBUG struct { const struct ggml_tensor * tensor; - size_t offset; + struct buffer_address addr; } allocated_tensors[1024]; #endif }; +static void ggml_dyn_tallocr_insert_block(struct tallocr_chunk * chunk, size_t offset, size_t size) { + GGML_ASSERT(chunk->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks"); + // insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster) + int insert_pos = 0; + while (insert_pos < chunk->n_free_blocks && chunk->free_blocks[insert_pos].offset < offset) { + insert_pos++; + } + // shift all blocks from insert_pos onward to make room for the new block + for (int i = chunk->n_free_blocks; i > insert_pos; i--) { + chunk->free_blocks[i] = chunk->free_blocks[i-1]; + } + // insert the new block + chunk->free_blocks[insert_pos].offset = offset; + chunk->free_blocks[insert_pos].size = size; + chunk->n_free_blocks++; +} + +static void ggml_dyn_tallocr_remove_block(struct tallocr_chunk * chunk, int idx) { + // shift all elements after idx by 1 to the left, overwriting the element at idx + for (int i = idx; i < chunk->n_free_blocks; i++) { + chunk->free_blocks[i] = chunk->free_blocks[i+1]; + } + chunk->n_free_blocks--; +} + +static int ggml_dyn_tallocr_new_chunk(struct ggml_dyn_tallocr * alloc, size_t min_size) { + if (alloc->n_chunks >= GGML_VBUFFER_MAX_CHUNKS) { + return -1; + } + struct tallocr_chunk * chunk = calloc(1, sizeof(struct tallocr_chunk)); + chunk->n_free_blocks = 1; + chunk->free_blocks[0].offset = 0; + // available space in a chunk is limited to max_chunk_size, but can be higher if: + // 1. a single tensor exceeds the maximum, and cannot fit any other way + // 2. we are running out of chunks + // backends will either manage to allocate the larger size, or report an error. + chunk->free_blocks[0].size = MAX(min_size, alloc->max_chunk_size); + if (alloc->n_chunks == GGML_VBUFFER_MAX_CHUNKS - 1) { + chunk->free_blocks[0].size = SIZE_MAX/2; + } + alloc->chunks[alloc->n_chunks] = chunk; + alloc->n_chunks++; + return alloc->n_chunks - 1; +} + #ifdef GGML_ALLOCATOR_DEBUG -static void add_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, const struct ggml_tensor * tensor) { +static void add_allocated_tensor(struct ggml_dyn_tallocr * alloc, struct buffer_address addr, const struct ggml_tensor * tensor) { for (int i = 0; i < 1024; i++) { if (alloc->allocated_tensors[i].tensor == NULL) { alloc->allocated_tensors[i].tensor = tensor; - alloc->allocated_tensors[i].offset = offset; + alloc->allocated_tensors[i].addr = addr; return; } } GGML_ABORT("out of allocated_tensors"); } -static void remove_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, const struct ggml_tensor * tensor) { +static void remove_allocated_tensor(struct ggml_dyn_tallocr * alloc, struct buffer_address addr, const struct ggml_tensor * tensor) { for (int i = 0; i < 1024; i++) { - if (alloc->allocated_tensors[i].offset == offset) { + if (alloc->allocated_tensors[i].addr.chunk == addr.chunk && alloc->allocated_tensors[i].addr.offset == addr.offset) { alloc->allocated_tensors[i].tensor = NULL; return; } @@ -136,76 +201,94 @@ static void remove_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offs } #endif -static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t size, const struct ggml_tensor * tensor) { +static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t size, const struct ggml_tensor * tensor) { size = aligned_offset(NULL, size, alloc->alignment); AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size); + int best_fit_chunk = -1; + int best_fit_block = -1; size_t max_avail = 0; - // find the best fitting free block besides the last block - int best_fit_block = -1; - size_t best_fit_size = SIZE_MAX; - for (int i = 0; i < alloc->n_free_blocks - 1; i++) { - struct free_block * block = &alloc->free_blocks[i]; - max_avail = MAX(max_avail, block->size); - if (block->size >= size && block->size <= best_fit_size) { - best_fit_block = i; - best_fit_size = block->size; + // find the best fitting free block besides the last block, within any chunk + for (int c = 0; c < alloc->n_chunks; ++c) { + struct tallocr_chunk * chunk = alloc->chunks[c]; + size_t best_fit_size = SIZE_MAX; + for (int i = 0; i < chunk->n_free_blocks - 1; i++) { + struct free_block * block = &chunk->free_blocks[i]; + max_avail = MAX(max_avail, block->size); + if (block->size >= size && block->size <= best_fit_size) { + best_fit_chunk = c; + best_fit_block = i; + best_fit_size = block->size; + } } } if (best_fit_block == -1) { - // the last block is our last resort - struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1]; - max_avail = MAX(max_avail, block->size); - if (block->size >= size) { - best_fit_block = alloc->n_free_blocks - 1; - } else { - // this should never happen - GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n", - __func__, size, max_avail); - GGML_ABORT("not enough space in the buffer"); - } - } - - struct free_block * block = &alloc->free_blocks[best_fit_block]; - size_t offset = block->offset; - block->offset = offset + size; - block->size -= size; - if (block->size == 0) { - // remove block if empty - alloc->n_free_blocks--; - for (int j = best_fit_block; j < alloc->n_free_blocks; j++) { - alloc->free_blocks[j] = alloc->free_blocks[j+1]; - } - } - - AT_PRINTF("block %d, offset %zu\n", best_fit_block, offset); - -#ifdef GGML_ALLOCATOR_DEBUG - add_allocated_tensor(alloc, offset, tensor); - size_t cur_max = offset + size; - if (cur_max > alloc->max_size) { - // sort allocated_tensors by offset - for (int i = 0; i < 1024; i++) { - for (int j = i + 1; j < 1024; j++) { - if (alloc->allocated_tensors[i].offset > alloc->allocated_tensors[j].offset) { - const struct ggml_tensor * tmp_tensor = alloc->allocated_tensors[i].tensor; - size_t tmp_offset = alloc->allocated_tensors[i].offset; - alloc->allocated_tensors[i].tensor = alloc->allocated_tensors[j].tensor; - alloc->allocated_tensors[i].offset = alloc->allocated_tensors[j].offset; - alloc->allocated_tensors[j].tensor = tmp_tensor; - alloc->allocated_tensors[j].offset = tmp_offset; + // no suitable block found, try the last block (this will grow a chunks size) + for (int c = 0; c < alloc->n_chunks; ++c) { + struct tallocr_chunk * chunk = alloc->chunks[c]; + if (chunk->n_free_blocks > 0) { + struct free_block * block = &chunk->free_blocks[chunk->n_free_blocks - 1]; + max_avail = MAX(max_avail, block->size); + if (block->size >= size) { + best_fit_chunk = c; + best_fit_block = chunk->n_free_blocks - 1; + break; } } } - GGML_LOG_DEBUG("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0); + } + + if (best_fit_block == -1) { + // none of the existing chunks have enough space left + best_fit_chunk = ggml_dyn_tallocr_new_chunk(alloc, size); + best_fit_block = 0; + } + if (best_fit_chunk == -1) { + // since the last chunk always has virtually endless memory, this should never happen + GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n", + __func__, size, max_avail); + GGML_ABORT("graph allocation: failed to reserve memory"); + } + + struct tallocr_chunk * chunk = alloc->chunks[best_fit_chunk]; + struct free_block * block = &chunk->free_blocks[best_fit_block]; + struct buffer_address addr = {.chunk = best_fit_chunk, .offset = block->offset }; + block->offset += size; + block->size -= size; + if (block->size == 0) { + // remove block if empty + ggml_dyn_tallocr_remove_block(chunk, best_fit_block); + } + + AT_PRINTF("block %d, offset %zu, chunk %d\n", best_fit_block, addr.offset, addr.chunk); + +#ifdef GGML_ALLOCATOR_DEBUG + add_allocated_tensor(alloc, addr, tensor); + size_t cur_max = addr.offset + size; + if (cur_max > alloc->max_size[addr.chunk]) { + // sort allocated_tensors by chunk/offset + for (int i = 0; i < 1024; i++) { + for (int j = i + 1; j < 1024; j++) { + if (ggml_buffer_address_less(alloc->allocated_tensors[j].addr, alloc->allocated_tensors[i].addr)) { + const struct ggml_tensor * tmp_tensor = alloc->allocated_tensors[i].tensor; + struct buffer_address tmp_addr = alloc->allocated_tensors[i].addr; + alloc->allocated_tensors[i].tensor = alloc->allocated_tensors[j].tensor; + alloc->allocated_tensors[i].addr = alloc->allocated_tensors[j].addr; + alloc->allocated_tensors[j].tensor = tmp_tensor; + alloc->allocated_tensors[j].addr = tmp_addr; + } + } + } + GGML_LOG_DEBUG("max_size[%d] = %.2f MB: tensors: ", addr.chunk, cur_max / 1024.0 / 1024.0); for (int i = 0; i < 1024; i++) { if (alloc->allocated_tensors[i].tensor) { - GGML_LOG_DEBUG("%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name, - alloc->allocated_tensors[i].offset, - alloc->allocated_tensors[i].offset + ggml_nbytes(alloc->allocated_tensors[i].tensor), + GGML_LOG_DEBUG("%s [%d: %zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name, + alloc->allocated_tensors[i].addr.chunk, + alloc->allocated_tensors[i].addr.offset, + alloc->allocated_tensors[i].addr.offset + ggml_nbytes(alloc->allocated_tensors[i].tensor), ggml_nbytes(alloc->allocated_tensors[i].tensor) / 1024.0 / 1024.0); } } @@ -213,78 +296,69 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz } #endif - alloc->max_size = MAX(alloc->max_size, offset + size); + chunk->max_size = MAX(chunk->max_size, addr.offset + size); - return offset; + return addr; GGML_UNUSED(tensor); } // this is a very naive implementation, but for our case the number of free blocks should be very small -static void ggml_dyn_tallocr_free_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, size_t size, const struct ggml_tensor * tensor) { +static void ggml_dyn_tallocr_free_tensor(struct ggml_dyn_tallocr * alloc, struct buffer_address addr, size_t size, const struct ggml_tensor * tensor) { size = aligned_offset(NULL, size, alloc->alignment); - AT_PRINTF("%s: freeing %s at %zu (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, offset, size, alloc->n_free_blocks); + AT_PRINTF("%s: freeing %s at {chunk=%d, offset=%zu} (%zu bytes) - n_free_blocks = %d\n", + __func__, tensor->name, addr.chunk, addr.offset, size, alloc->chunks[addr.chunk]->n_free_blocks); #ifdef GGML_ALLOCATOR_DEBUG - remove_allocated_tensor(alloc, offset, tensor); + remove_allocated_tensor(alloc, addr, tensor); #endif + struct tallocr_chunk * chunk = alloc->chunks[addr.chunk]; + // see if we can merge with an existing block - for (int i = 0; i < alloc->n_free_blocks; i++) { - struct free_block * block = &alloc->free_blocks[i]; + for (int i = 0; i < chunk->n_free_blocks; i++) { + struct free_block * block = &chunk->free_blocks[i]; // check if ptr is at the end of the block - if (block->offset + block->size == offset) { + if (block->offset + block->size == addr.offset) { block->size += size; // check if we can merge with the next block - if (i < alloc->n_free_blocks - 1 && block->offset + block->size == alloc->free_blocks[i+1].offset) { - block->size += alloc->free_blocks[i+1].size; - alloc->n_free_blocks--; - for (int j = i+1; j < alloc->n_free_blocks; j++) { - alloc->free_blocks[j] = alloc->free_blocks[j+1]; + if (i < chunk->n_free_blocks - 1) { + struct free_block * next = &chunk->free_blocks[i+1]; + if (block->offset + block->size == next->offset) { + block->size += next->size; + ggml_dyn_tallocr_remove_block(chunk, i+1); } } return; } // check if ptr is at the beginning of the block - if (offset + size == block->offset) { - block->offset = offset; + if (addr.offset + size == block->offset) { + block->offset = addr.offset; block->size += size; // check if we can merge with the previous block - if (i > 0 && alloc->free_blocks[i-1].offset + alloc->free_blocks[i-1].size == block->offset) { - alloc->free_blocks[i-1].size += block->size; - alloc->n_free_blocks--; - for (int j = i; j < alloc->n_free_blocks; j++) { - alloc->free_blocks[j] = alloc->free_blocks[j+1]; + if (i > 0) { + struct free_block * prev = &chunk->free_blocks[i-1]; + if (prev->offset + prev->size == block->offset) { + prev->size += block->size; + ggml_dyn_tallocr_remove_block(chunk, i); } } return; } } // otherwise, add a new block - GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks"); - // insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster) - int insert_pos = 0; - while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].offset < offset) { - insert_pos++; - } - // shift all blocks from insert_pos onward to make room for the new block - for (int i = alloc->n_free_blocks; i > insert_pos; i--) { - alloc->free_blocks[i] = alloc->free_blocks[i-1]; - } - // insert the new block - alloc->free_blocks[insert_pos].offset = offset; - alloc->free_blocks[insert_pos].size = size; - alloc->n_free_blocks++; + ggml_dyn_tallocr_insert_block(chunk, addr.offset, size); GGML_UNUSED(tensor); } static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) { - alloc->n_free_blocks = 1; - alloc->free_blocks[0].offset = 0; - alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows - alloc->max_size = 0; + for (int i = 0; i < GGML_VBUFFER_MAX_CHUNKS; i++) { + free(alloc->chunks[i]); + alloc->chunks[i] = NULL; + } + alloc->n_chunks = 0; #ifdef GGML_ALLOCATOR_DEBUG for (int i = 0; i < 1024; i++) { @@ -293,14 +367,14 @@ static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) { #endif } -static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment) { +static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment, size_t max_buffer_size) { struct ggml_dyn_tallocr * alloc = (struct ggml_dyn_tallocr *)malloc(sizeof(struct ggml_dyn_tallocr)); *alloc = (struct ggml_dyn_tallocr) { - /*.alignment = */ alignment, - /*.n_free_blocks = */ 0, - /*.free_blocks = */ {{0}}, - /*.max_size = */ 0, + /*.alignment = */ alignment, + /*.max_chunk_size = */ MIN(max_buffer_size, SIZE_MAX/2), // clamp to avoid overflows + /*.chunks = */ {NULL}, + /*.n_chunks = */ 0, #ifdef GGML_ALLOCATOR_DEBUG /*.allocated_tensors = */ {{0}}, #endif @@ -312,11 +386,73 @@ static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment) { } static void ggml_dyn_tallocr_free(struct ggml_dyn_tallocr * alloc) { + for (int i = 0; i < alloc->n_chunks; ++i) { + free(alloc->chunks[i]); + } free(alloc); } -static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc) { - return alloc->max_size; +static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc, int chunk) { + return chunk < alloc->n_chunks ? alloc->chunks[chunk]->max_size : 0; +} + + +// virtual buffer with contiguous memory range, split into multiple backend buffers (chunks) + +struct vbuffer { + ggml_backend_buffer_t chunks[GGML_VBUFFER_MAX_CHUNKS]; +}; + +static void ggml_vbuffer_free(struct vbuffer * buf) { + if (buf == NULL) { + return; + } + for (int i = 0; i < GGML_VBUFFER_MAX_CHUNKS; ++i) { + ggml_backend_buffer_free(buf->chunks[i]); + } + free(buf); +} + +static size_t ggml_vbuffer_chunk_size(struct vbuffer * buf, int chunk) { + return buf->chunks[chunk] ? ggml_backend_buffer_get_size(buf->chunks[chunk]) : 0; +} + +static size_t ggml_vbuffer_size(struct vbuffer * buf) { + size_t size = 0; + for (int i = 0; i < GGML_VBUFFER_MAX_CHUNKS && buf->chunks[i]; ++i) { + size += ggml_backend_buffer_get_size(buf->chunks[i]); + } + return size; +} + +static struct vbuffer * ggml_vbuffer_alloc(ggml_backend_buffer_type_t buft, const struct ggml_dyn_tallocr * talloc, enum ggml_backend_buffer_usage usage) { + struct vbuffer * buf = (struct vbuffer *)calloc(1, sizeof(struct vbuffer)); + if (buf == NULL) { + return NULL; + } + + for (int n = 0; n < talloc->n_chunks; n++) { + size_t chunk_size = talloc->chunks[n]->max_size; + buf->chunks[n] = ggml_backend_buft_alloc_buffer(buft, chunk_size); + if (buf->chunks[n] == NULL) { + ggml_vbuffer_free(buf); + return NULL; + } + ggml_backend_buffer_set_usage(buf->chunks[n], usage); + } + return buf; +} + +static void ggml_vbuffer_tensor_alloc(struct vbuffer * buf, struct ggml_tensor * tensor, struct buffer_address buf_addr) { + void * base = ggml_backend_buffer_get_base(buf->chunks[buf_addr.chunk]); + void * addr = (char *)base + buf_addr.offset; + ggml_backend_tensor_alloc(buf->chunks[buf_addr.chunk], tensor, addr); +} + +static void ggml_vbuffer_reset(struct vbuffer * buf) { + for (int i = 0; i < GGML_VBUFFER_MAX_CHUNKS && buf->chunks[i]; ++i) { + ggml_backend_buffer_reset(buf->chunks[i]); + } } @@ -328,13 +464,13 @@ struct hash_node { int n_children; int n_views; int buffer_id; - size_t offset; // offset within the buffer + struct buffer_address addr; bool allocated; }; struct tensor_alloc { int buffer_id; - size_t offset; + struct buffer_address addr; size_t size_max; // 0 = pre-allocated, unused, or view }; @@ -349,7 +485,7 @@ struct node_alloc { struct ggml_gallocr { ggml_backend_buffer_type_t * bufts; // [n_buffers] - ggml_backend_buffer_t * buffers; // [n_buffers] + struct vbuffer ** buffers; // [n_buffers] size_t *buffer_sizes; // [n_buffers] struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers] int n_buffers; @@ -371,7 +507,7 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs galloc->bufts = calloc(n_bufs, sizeof(ggml_backend_buffer_type_t)); GGML_ASSERT(galloc->bufts != NULL); - galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t)); + galloc->buffers = calloc(n_bufs, sizeof(struct vbuffer *)); GGML_ASSERT(galloc->buffers != NULL); galloc->buffer_sizes = calloc(n_bufs, sizeof(size_t)); @@ -394,7 +530,8 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs if (galloc->buf_tallocs[i] == NULL) { size_t alignment = ggml_backend_buft_get_alignment(bufts[i]); - galloc->buf_tallocs[i] = ggml_dyn_tallocr_new(alignment); + size_t max_size = ggml_backend_buft_get_max_size(bufts[i]); + galloc->buf_tallocs[i] = ggml_dyn_tallocr_new(alignment, max_size); } } galloc->n_buffers = n_bufs; @@ -422,7 +559,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { } } if (!freed) { - ggml_backend_buffer_free(galloc->buffers[i]); + ggml_vbuffer_free(galloc->buffers[i]); } } if (galloc->buf_tallocs != NULL) { @@ -472,7 +609,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) { hn->allocated = true; - assert(hn->offset == 0); + assert(hn->addr.offset == 0); // try to reuse a parent's buffer (inplace) if (ggml_op_can_inplace(node->op)) { @@ -506,9 +643,9 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src); if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) { AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name); - assert(view_src_hn->offset == p_hn->offset); + assert(view_src_hn->addr.chunk == p_hn->addr.chunk && view_src_hn->addr.offset == p_hn->addr.offset); hn->buffer_id = p_hn->buffer_id; - hn->offset = p_hn->offset; + hn->addr = p_hn->addr; p_hn->allocated = false; // avoid freeing the parent view_src_hn->allocated = false; return; @@ -516,7 +653,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor } else { AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name); hn->buffer_id = p_hn->buffer_id; - hn->offset = p_hn->offset; + hn->addr = p_hn->addr; p_hn->allocated = false; // avoid freeing the parent return; } @@ -527,9 +664,8 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id]; ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id]; size_t size = ggml_backend_buft_get_alloc_size(buft, node); - size_t offset = ggml_dyn_tallocr_alloc(alloc, size, node); hn->buffer_id = buffer_id; - hn->offset = offset; + hn->addr = ggml_dyn_tallocr_alloc(alloc, size, node); } } @@ -541,12 +677,11 @@ static void ggml_gallocr_free_node(ggml_gallocr_t galloc, struct ggml_tensor * n } struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); - size_t offset = hn->offset; int buffer_id = hn->buffer_id; struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id]; ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id]; size_t size = ggml_backend_buft_get_alloc_size(buft, node); - ggml_dyn_tallocr_free_tensor(alloc, offset, size, node); + ggml_dyn_tallocr_free_tensor(alloc, hn->addr, size, node); hn->allocated = false; } @@ -697,24 +832,24 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c struct node_alloc * node_alloc = &galloc->node_allocs[i]; if (node->view_src || node->data) { node_alloc->dst.buffer_id = -1; - node_alloc->dst.offset = SIZE_MAX; + node_alloc->dst.addr = GGML_BUFFER_ADDRESS_INVALID; node_alloc->dst.size_max = 0; } else { struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); node_alloc->dst.buffer_id = hn->buffer_id; - node_alloc->dst.offset = hn->offset; + node_alloc->dst.addr = hn->addr; node_alloc->dst.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node); } for (int j = 0; j < GGML_MAX_SRC; j++) { struct ggml_tensor * src = node->src[j]; if (!src || src->view_src || src->data) { node_alloc->src[j].buffer_id = -1; - node_alloc->src[j].offset = SIZE_MAX; + node_alloc->src[j].addr = GGML_BUFFER_ADDRESS_INVALID; node_alloc->src[j].size_max = 0; } else { struct hash_node * hn = ggml_gallocr_hash_get(galloc, src); node_alloc->src[j].buffer_id = hn->buffer_id; - node_alloc->src[j].offset = hn->offset; + node_alloc->src[j].addr = hn->addr; node_alloc->src[j].size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], src); } } @@ -730,11 +865,11 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf); if (leaf->view_src || leaf->data) { galloc->leaf_allocs[i].leaf.buffer_id = -1; - galloc->leaf_allocs[i].leaf.offset = SIZE_MAX; + galloc->leaf_allocs[i].leaf.addr = GGML_BUFFER_ADDRESS_INVALID; galloc->leaf_allocs[i].leaf.size_max = 0; } else { galloc->leaf_allocs[i].leaf.buffer_id = hn->buffer_id; - galloc->leaf_allocs[i].leaf.offset = hn->offset; + galloc->leaf_allocs[i].leaf.addr = hn->addr; galloc->leaf_allocs[i].leaf.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], leaf); } } @@ -751,27 +886,34 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c } } - size_t cur_size = galloc->buffers[i] ? ggml_backend_buffer_get_size(galloc->buffers[i]) : 0; - size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]); - // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views - if (new_size > cur_size || galloc->buffers[i] == NULL) { + bool realloc = galloc->buffers[i] == NULL; + size_t new_size = 0; + for (int c = 0; c < galloc->buf_tallocs[i]->n_chunks; c++) { + size_t cur_chunk_size = galloc->buffers[i] ? ggml_vbuffer_chunk_size(galloc->buffers[i], c) : 0; + size_t new_chunk_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i], c); + new_size += new_chunk_size; + if (new_chunk_size > cur_chunk_size) { + realloc = true; + } + } + if (realloc) { #ifndef NDEBUG + size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif - ggml_backend_buffer_free(galloc->buffers[i]); - galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size); + ggml_vbuffer_free(galloc->buffers[i]); + galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); if (galloc->buffers[i]) { - galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]); - ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); + galloc->buffer_sizes[i] = ggml_vbuffer_size(galloc->buffers[i]); } else { GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size); galloc->buffer_sizes[i] = new_size; success = false; } } else { - galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]); + galloc->buffer_sizes[i] = ggml_vbuffer_size(galloc->buffers[i]); } } @@ -784,11 +926,11 @@ bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) { static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * tensor, struct tensor_alloc * tensor_alloc) { int buffer_id = tensor_alloc->buffer_id; - assert(tensor->data || tensor->view_src || ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max); + assert(tensor->data || tensor->view_src || ggml_backend_buft_get_alloc_size(galloc->bufts[buffer_id], tensor) <= tensor_alloc->size_max); if (tensor->view_src != NULL) { if (tensor->buffer == NULL) { - assert(tensor_alloc->offset == SIZE_MAX); + assert(tensor_alloc->addr.offset == SIZE_MAX); if (tensor->view_src->buffer == NULL) { // this tensor was allocated without ggml-backend return; @@ -797,11 +939,9 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * } } else { if (tensor->data == NULL) { - assert(tensor_alloc->offset != SIZE_MAX); - assert(ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max); - void * base = ggml_backend_buffer_get_base(galloc->buffers[buffer_id]); - void * addr = (char *)base + tensor_alloc->offset; - ggml_backend_tensor_alloc(galloc->buffers[buffer_id], tensor, addr); + assert(tensor_alloc->addr.offset != SIZE_MAX); + assert(ggml_backend_buft_get_alloc_size(galloc->bufts[buffer_id], tensor) <= tensor_alloc->size_max); + ggml_vbuffer_tensor_alloc(galloc->buffers[buffer_id], tensor, tensor_alloc->addr); } else { if (tensor->buffer == NULL) { // this tensor was allocated without ggml-backend @@ -886,7 +1026,7 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph) // reset buffers for (int i = 0; i < galloc->n_buffers; i++) { if (galloc->buffers[i] != NULL) { - ggml_backend_buffer_reset(galloc->buffers[i]); + ggml_vbuffer_reset(galloc->buffers[i]); } } @@ -929,10 +1069,10 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { } } - return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]); + return ggml_vbuffer_size(galloc->buffers[buffer_id]); } -struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) { +size_t ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) { GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers); for (int i = 0; i < buffer_id; i++) { @@ -941,13 +1081,11 @@ struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gal // (See above.) However, we need a different check because multiple buffers might be NULL in our // case and we still want to know the attempted size. - struct ggml_allocr_buffer_status status = {0, true}; - return status; + return 0; } } - struct ggml_allocr_buffer_status status = {galloc->buffer_sizes[buffer_id], galloc->buffers[buffer_id] != NULL}; - return status; + return galloc->buffer_sizes[buffer_id]; } // utils diff --git a/ml/backend/ggml/ggml/src/ggml-backend-impl.h b/ml/backend/ggml/ggml/src/ggml-backend-impl.h index 6f10c353..43c91d9f 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-backend-impl.h @@ -8,7 +8,7 @@ extern "C" { #endif - #define GGML_BACKEND_API_VERSION 1 + #define GGML_BACKEND_API_VERSION 2 // // Backend buffer type @@ -26,6 +26,10 @@ extern "C" { size_t (*get_alloc_size)(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); // (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false) bool (*is_host) (ggml_backend_buffer_type_t buft); + + // (optional) returns a dummy buffer that is equivalent to one created by alloc_buffer but without actually being backed + // by memory + ggml_backend_buffer_t (*noalloc_buffer)(ggml_backend_buffer_type_t buft, size_t size); }; struct ggml_backend_buffer_type { @@ -116,6 +120,19 @@ extern "C" { void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event); // wait for an event on on a different stream void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event); + + // (optional) sort/optimize the nodes in the graph + void (*graph_optimize) (ggml_backend_t backend, struct ggml_cgraph * cgraph); + + // (optional) reserves intermediate buffers needed for the compution + // if alloc is true, memory is actually allocated, otherwise the required amount is just returned by buffer_size + enum ggml_status (*graph_reserve) (ggml_backend_t backend, struct ggml_cgraph * cgraph, bool alloc); + + // (optional) returns the memory needed after calling graph_reserve + size_t (*buffer_size) (ggml_backend_t backend); + + // (optional) frees memory from intermediate buffers that was allocated either by graph_compute or graph_reserve + void (*reset) (ggml_backend_t backend); }; struct ggml_backend { @@ -212,9 +229,6 @@ extern "C" { void * context; }; - // Internal backend registry API - GGML_API void ggml_backend_register(ggml_backend_reg_t reg); - // Add backend dynamic loading support to the backend // Initialize the backend diff --git a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp index 3040b2aa..3a855ab2 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp @@ -49,6 +49,10 @@ #include "ggml-webgpu.h" #endif +#ifdef GGML_USE_ZDNN +#include "ggml-zdnn.h" +#endif + #ifdef GGML_USE_OPENCL #include "ggml-opencl.h" #endif @@ -114,6 +118,18 @@ static dl_handle * dl_load_library(const fs::path & path) { SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); HMODULE handle = LoadLibraryW(path.wstring().c_str()); + if (!handle) { + DWORD error_code = GetLastError(); + std::string msg; + LPSTR lpMsgBuf = NULL; + DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL); + if (bufLen) { + msg = lpMsgBuf; + LocalFree(lpMsgBuf); + GGML_LOG_INFO("%s unable to load library %s: %s\n", __func__, path_str(path).c_str(), msg.c_str()); + } + } SetErrorMode(old_mode); @@ -131,6 +147,10 @@ static void * dl_get_sym(dl_handle * handle, const char * name) { return p; } +static const char * dl_error() { + return ""; +} + #else using dl_handle = void; @@ -151,6 +171,11 @@ static void * dl_get_sym(dl_handle * handle, const char * name) { return dlsym(handle, name); } +static const char * dl_error() { + const char *rslt = dlerror(); + return rslt != nullptr ? rslt : ""; +} + #endif using dl_handle_ptr = std::unique_ptr; @@ -180,6 +205,9 @@ struct ggml_backend_registry { #ifdef GGML_USE_WEBGPU register_backend(ggml_backend_webgpu_reg()); #endif +#ifdef GGML_USE_ZDNN + register_backend(ggml_backend_zdnn_reg()); +#endif #ifdef GGML_USE_OPENCL register_backend(ggml_backend_opencl_reg()); #endif @@ -238,7 +266,7 @@ struct ggml_backend_registry { dl_handle_ptr handle { dl_load_library(path) }; if (!handle) { if (!silent) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(path).c_str()); + GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_str(path).c_str(), dl_error()); } return nullptr; } @@ -398,9 +426,8 @@ ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const ggml_backend_t ggml_backend_init_best(void) { ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU); - if (!dev) { - dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - } + dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU); + dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); if (!dev) { return nullptr; } @@ -529,7 +556,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, if (filename.native().find(file_prefix) == 0 && ext == file_extension) { dl_handle_ptr handle { dl_load_library(entry) }; if (!handle && !silent) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(entry.path()).c_str()); + GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_str(entry.path()).c_str(), dl_error()); } if (handle) { auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); diff --git a/ml/backend/ggml/ggml/src/ggml-backend.cpp b/ml/backend/ggml/ggml/src/ggml-backend.cpp index 6556943b..0b757af5 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend.cpp @@ -19,9 +19,8 @@ #include #include #include -#include -#include #include +#include #ifdef __APPLE__ #include @@ -32,13 +31,10 @@ // backend buffer type const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(buft); return buft->iface.get_name(buft); } -void ggml_backend_buft_set_alloc(ggml_backend_buffer_type_t buft, bool alloc) { - buft->no_alloc = !alloc; -} - ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { if (size == 0) { // return a dummy buffer for zero-sized allocations @@ -46,19 +42,29 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t } if (buft->no_alloc) { - ggml_backend_buffer_t buf = ggml_backend_buffer_init(buft, {}, NULL, size); + ggml_backend_buffer_t buf; + + if (buft->iface.noalloc_buffer != NULL) { + buf = buft->iface.noalloc_buffer(buft, size); + } else { + buf = ggml_backend_buffer_init(buft, {}, NULL, size); + } + buf->no_alloc = true; return buf; } + GGML_ASSERT(buft); return buft->iface.alloc_buffer(buft, size); } size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(buft); return buft->iface.get_alignment(buft); } size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(buft); // get_max_size is optional, defaults to SIZE_MAX if (buft->iface.get_max_size) { return buft->iface.get_max_size(buft); @@ -67,6 +73,7 @@ size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) { } size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { + GGML_ASSERT(buft); // get_alloc_size is optional, defaults to ggml_nbytes if (buft->iface.get_alloc_size) { size_t size = buft->iface.get_alloc_size(buft, tensor); @@ -77,6 +84,7 @@ size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const s } bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(buft); if (buft->iface.is_host) { return buft->iface.is_host(buft); } @@ -84,6 +92,7 @@ bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) { } ggml_backend_dev_t ggml_backend_buft_get_device(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(buft); return buft->device; } @@ -121,10 +130,12 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { } size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); return buffer->size; } void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); // get_base is optional if the buffer is zero-sized if (buffer->size == 0) { return NULL; @@ -144,6 +155,7 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { } enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + GGML_ASSERT(buffer); // init_tensor is optional if (buffer->iface.init_tensor) { return buffer->iface.init_tensor(buffer, tensor); @@ -152,6 +164,7 @@ enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, s } void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_ASSERT(buffer); // clear is optional if the buffer is zero-sized if (buffer->size == 0) { return; @@ -177,6 +190,7 @@ bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) { } void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) { + GGML_ASSERT(buffer); buffer->usage = usage; // FIXME: add a generic callback to the buffer interface @@ -186,14 +200,17 @@ void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backe } enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); return buffer->usage; } ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); return buffer->buft; } void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); if (buffer->iface.reset) { buffer->iface.reset(buffer); } @@ -232,6 +249,7 @@ void ggml_backend_free(ggml_backend_t backend) { } ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) { + GGML_ASSERT(backend); return ggml_backend_dev_buffer_type(backend->device); } @@ -248,6 +266,8 @@ size_t ggml_backend_get_max_size(ggml_backend_t backend) { } void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(backend); + GGML_ASSERT(tensor); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); @@ -259,6 +279,8 @@ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * } void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(backend); + GGML_ASSERT(tensor); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); @@ -300,6 +322,7 @@ void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, siz } void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; if (size == 0) { @@ -315,6 +338,7 @@ void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size } void ggml_backend_synchronize(ggml_backend_t backend) { + GGML_ASSERT(backend); if (backend->iface.synchronize == NULL) { return; } @@ -323,18 +347,21 @@ void ggml_backend_synchronize(ggml_backend_t backend) { } ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(backend); GGML_ASSERT(backend->iface.graph_plan_create != NULL); return backend->iface.graph_plan_create(backend, cgraph); } void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + GGML_ASSERT(backend); GGML_ASSERT(backend->iface.graph_plan_free != NULL); backend->iface.graph_plan_free(backend, plan); } enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + GGML_ASSERT(backend); GGML_ASSERT(backend->iface.graph_plan_compute != NULL); return backend->iface.graph_plan_compute(backend, plan); @@ -347,22 +374,27 @@ enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_ } enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(backend); return backend->iface.graph_compute(backend, cgraph); } bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { + GGML_ASSERT(backend); return ggml_backend_dev_supports_op(backend->device, op); } bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { + GGML_ASSERT(backend); return ggml_backend_dev_supports_buft(backend->device, buft); } bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) { + GGML_ASSERT(backend); return ggml_backend_dev_offload_op(backend->device, op); } ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) { + GGML_ASSERT(backend); return backend->device; } @@ -398,6 +430,7 @@ void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t b return; } + GGML_ASSERT(backend_dst); if (backend_dst->iface.cpy_tensor_async != NULL) { if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) { return; @@ -429,38 +462,52 @@ void ggml_backend_event_free(ggml_backend_event_t event) { } void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend) { + GGML_ASSERT(backend); GGML_ASSERT(backend->iface.event_record != NULL); backend->iface.event_record(backend, event); } void ggml_backend_event_synchronize(ggml_backend_event_t event) { + GGML_ASSERT(event); GGML_ASSERT(event->device->iface.event_synchronize); event->device->iface.event_synchronize(event->device, event); } void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + GGML_ASSERT(backend); GGML_ASSERT(backend->iface.event_wait != NULL); backend->iface.event_wait(backend, event); } +static void ggml_backend_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(backend); + if (backend->iface.graph_optimize != NULL) { + backend->iface.graph_optimize(backend, cgraph); + } +} + // Backend device const char * ggml_backend_dev_name(ggml_backend_dev_t device) { + GGML_ASSERT(device); return device->iface.get_name(device); } const char * ggml_backend_dev_description(ggml_backend_dev_t device) { + GGML_ASSERT(device); return device->iface.get_description(device); } void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + GGML_ASSERT(device); device->iface.get_memory(device, free, total); } enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) { + GGML_ASSERT(device); return device->iface.get_type(device); } @@ -470,10 +517,12 @@ void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_d } ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device) { + GGML_ASSERT(device); return device->reg; } ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params) { + GGML_ASSERT(device); return device->iface.init_backend(device, params); } @@ -486,10 +535,12 @@ void ggml_backend_dev_reset(ggml_backend_dev_t device) { } ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) { + GGML_ASSERT(device); return device->iface.get_buffer_type(device); } ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) { + GGML_ASSERT(device); if (device->iface.get_host_buffer_type == NULL) { return NULL; } @@ -498,18 +549,22 @@ ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t } ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size) { + GGML_ASSERT(device); return device->iface.buffer_from_host_ptr(device, ptr, size, max_tensor_size); } bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op) { + GGML_ASSERT(device); return device->iface.supports_op(device, op); } bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft) { + GGML_ASSERT(device); return device->iface.supports_buft(device, buft); } bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op) { + GGML_ASSERT(device); if (device->iface.offload_op != NULL) { return device->iface.offload_op(device, op); } @@ -520,18 +575,22 @@ bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_te // Backend (reg) const char * ggml_backend_reg_name(ggml_backend_reg_t reg) { + GGML_ASSERT(reg); return reg->iface.get_name(reg); } size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg) { + GGML_ASSERT(reg); return reg->iface.get_device_count(reg); } ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(reg); return reg->iface.get_device(reg, index); } void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { + GGML_ASSERT(reg); if (!reg->iface.get_proc_address) { return NULL; } @@ -546,6 +605,7 @@ struct ggml_backend_multi_buffer_context { }; static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; for (size_t i = 0; i < ctx->n_buffers; i++) { ggml_backend_buffer_free(ctx->buffers[i]); @@ -557,6 +617,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) } static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_ASSERT(buffer); ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; for (size_t i = 0; i < ctx->n_buffers; i++) { ggml_backend_buffer_clear(ctx->buffers[i], value); @@ -592,10 +653,12 @@ ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer } bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); return buffer->iface.free_buffer == ggml_backend_multi_buffer_free_buffer; } void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) { + GGML_ASSERT(buffer); GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer)); ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; for (size_t i = 0; i < ctx->n_buffers; i++) { @@ -623,7 +686,7 @@ static bool ggml_is_view_op(enum ggml_op op) { #endif #ifndef GGML_SCHED_MAX_SPLIT_INPUTS -#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC +#define GGML_SCHED_MAX_SPLIT_INPUTS 30 #endif #ifndef GGML_SCHED_MAX_COPIES @@ -688,6 +751,12 @@ struct ggml_backend_sched { bool op_offload; int debug; + + // allocate buffers on attached ggml_backend_buffer_type_t's and during reservation + // if false, dummy buffers are used for faster memory sizing calculations + // the scheduler needs to be recreated with allocated buffers before it can be used + // for computation + bool alloc_buffers; }; #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) @@ -874,7 +943,7 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru } // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend -static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { +void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { // reset splits sched->n_splits = 0; sched->n_graph_inputs = 0; @@ -1270,6 +1339,10 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg struct ggml_backend_sched_split * split = &sched->splits[i]; split->graph = ggml_graph_view(graph, split->i_start, split->i_end); + // Optimize this split of the graph. This needs to happen before we make graph_copy, + // so they are in sync. + ggml_backend_graph_optimize(sched->backends[split->backend_id], &split->graph); + // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split for (int j = 0; j < split->n_inputs; j++) { assert(graph_copy->size > (graph_copy->n_nodes + 1)); @@ -1375,17 +1448,22 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { } static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) { + GGML_ASSERT(sched); struct ggml_backend_sched_split * splits = sched->splits; - for (int i = 0; i < sched->n_splits; i++) { - struct ggml_backend_sched_split * split = &splits[i]; + ggml_tensor * prev_ids_tensor = nullptr; + std::vector ids; + std::vector used_ids; + + for (int split_id = 0; split_id < sched->n_splits; split_id++) { + struct ggml_backend_sched_split * split = &splits[split_id]; int split_backend_id = split->backend_id; ggml_backend_t split_backend = sched->backends[split_backend_id]; // copy the input tensors to the split backend - for (int j = 0; j < split->n_inputs; j++) { - ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]); - struct ggml_tensor * input = split->inputs[j]; + for (int input_id = 0; input_id < split->n_inputs; input_id++) { + ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]); + struct ggml_tensor * input = split->inputs[input_id]; struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy); if (input->flags & GGML_TENSOR_FLAG_INPUT) { @@ -1403,16 +1481,104 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } else { ggml_backend_synchronize(split_backend); } - // try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events - // TODO: add public function to facilitate this, since applications do not have direct access to the backend interface - if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) { + + // when offloading MoE weights, we can reduce the amount of data copied by copying only the experts that are used + ggml_tensor * node = split->graph.nodes[0]; + if (split->graph.n_nodes > 0 && + ggml_backend_buffer_get_usage(input->buffer) == GGML_BACKEND_BUFFER_USAGE_WEIGHTS && + ggml_backend_buffer_is_host(input->buffer) && ( + (node->src[0] == input_cpy && node->op == GGML_OP_MUL_MAT_ID) + //|| (node->src[1] == input_cpy && node->op == GGML_OP_ADD_ID) /* GGML_OP_ADD_ID weights are small and not worth splitting */ + )) { + + const int64_t n_expert = node->op == GGML_OP_MUL_MAT_ID ? input->ne[2] : input->ne[1]; + const size_t expert_size = node->op == GGML_OP_MUL_MAT_ID ? input->nb[2] : input->nb[1]; + ggml_backend_synchronize(input_backend); - if (sched->events[split_backend_id][sched->cur_copy] != NULL) { - ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); - } else { - ggml_backend_synchronize(split_backend); + + // get the ids + ggml_tensor * ids_tensor = node->src[2]; + ggml_backend_t ids_backend = split_backend; + + // if the ids tensor is also an input of the split, it may not have been copied yet to the split backend + // in that case, we use the original ids tensor + for (int i = input_id + 1; i < split->n_inputs; i++) { + if (ids_tensor == tensor_copy(split->inputs[i], split_backend_id, sched->cur_copy)) { + ids_tensor = split->inputs[i]; + ids_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[i]); + break; + } + } + + if (ids_tensor != prev_ids_tensor) { + ids.resize(ggml_nbytes(ids_tensor) / sizeof(int32_t)); + ggml_backend_tensor_get_async(ids_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor)); + ggml_backend_synchronize(ids_backend); + + // find the used experts + used_ids.clear(); + used_ids.resize(ggml_bitset_size(n_expert)); + for (int64_t i1 = 0; i1 < ids_tensor->ne[1]; i1++) { + for (int64_t i0 = 0; i0 < ids_tensor->ne[0]; i0++) { + int32_t id = ids[i1 * ids_tensor->nb[1]/sizeof(int32_t) + i0 * ids_tensor->nb[0]/sizeof(int32_t)]; + GGML_ASSERT(id >= 0 && id < n_expert); + ggml_bitset_set(used_ids.data(), id); + } + } + + prev_ids_tensor = ids_tensor; + } + + // group consecutive experts and copy them together + auto copy_experts = [&](int32_t first_id, int32_t last_id) { + const size_t expert_offset = first_id * expert_size; + const size_t expert_size_copy = (last_id - first_id + 1) * expert_size; + const size_t padding = std::min(expert_size, 512); + const size_t padding_end = last_id < n_expert - 1 ? padding : 0; + + ggml_backend_tensor_set_async(split_backend, + input_cpy, + (const uint8_t *)input->data + expert_offset, expert_offset, + // copy a bit extra at the to ensure there are no NaNs in the padding of the last expert + // this is necessary for MMQ in the CUDA backend + expert_size_copy + padding_end); + }; + + int id = 0; + while (!ggml_bitset_get(used_ids.data(), id)) { + id++; + } + int32_t first_id = id; + int32_t last_id = first_id; + + for (++id; id < n_expert; ++id) { + if (!ggml_bitset_get(used_ids.data(), id)) { + continue; + } + + if (id == last_id + 1) { + last_id = id; + continue; + } + + copy_experts(first_id, last_id); + + first_id = id; + last_id = id; + } + copy_experts(first_id, last_id); + } else { + // try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events + // TODO: add public function to facilitate this, since applications do not have direct access to the backend interface + if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) { + ggml_backend_synchronize(input_backend); + if (sched->events[split_backend_id][sched->cur_copy] != NULL) { + ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); + } else { + ggml_backend_synchronize(split_backend); + } + ggml_backend_tensor_copy(input, input_cpy); } - ggml_backend_tensor_copy(input, input_cpy); } } } @@ -1474,6 +1640,17 @@ ggml_backend_sched_t ggml_backend_sched_new( size_t graph_size, bool parallel, bool op_offload) { + return ggml_backend_sched_new_ext(backends, bufts, n_backends, graph_size, parallel, op_offload, true); + } + +ggml_backend_sched_t ggml_backend_sched_new_ext( + ggml_backend_t * backends, + ggml_backend_buffer_type_t * bufts, + int n_backends, + size_t graph_size, + bool parallel, + bool op_offload, + bool alloc_buffers) { GGML_ASSERT(n_backends > 0); GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU); @@ -1515,10 +1692,13 @@ ggml_backend_sched_t ggml_backend_sched_new( sched->events[b][c] = ggml_backend_event_new(backends[b]->device); } } + + sched->bufts[b]->no_alloc = !alloc_buffers; } sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); sched->op_offload = op_offload; + sched->alloc_buffers = alloc_buffers; ggml_backend_sched_reset(sched); @@ -1533,6 +1713,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { for (int c = 0; c < sched->n_copies; c++) { ggml_backend_event_free(sched->events[b][c]); } + + if (sched->backends[b]->iface.reset != NULL) { + sched->backends[b]->iface.reset(sched->backends[b]); + } } ggml_gallocr_free(sched->galloc); ggml_free(sched->ctx); @@ -1551,6 +1735,7 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { } void ggml_backend_sched_reset(ggml_backend_sched_t sched) { + GGML_ASSERT(sched); // reset state for the next run if (!sched->is_reset) { ggml_hash_set_reset(&sched->hash_set); @@ -1562,8 +1747,11 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) { } bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { + GGML_ASSERT(sched); GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs); + ggml_backend_sched_reset(sched); + ggml_backend_sched_synchronize(sched); ggml_backend_sched_split_graph(sched, measure_graph); @@ -1572,12 +1760,31 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * return false; } + if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { + return false; + } + + struct ggml_backend_sched_split * splits = sched->splits; + for (int i = 0; i < sched->n_splits; i++) { + struct ggml_backend_sched_split * split = &splits[i]; + int split_backend_id = split->backend_id; + ggml_backend_t split_backend = sched->backends[split_backend_id]; + + if (split_backend->iface.graph_reserve != NULL) { + enum ggml_status ec = split_backend->iface.graph_reserve(split_backend, &split->graph, sched->alloc_buffers); + if (ec != GGML_STATUS_SUCCESS) { + return false; + } + } + } + ggml_backend_sched_reset(sched); return true; } bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + GGML_ASSERT(sched); GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs); GGML_ASSERT(!sched->is_alloc); @@ -1602,6 +1809,7 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st } enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + GGML_ASSERT(sched); if (!sched->is_reset && !sched->is_alloc) { ggml_backend_sched_reset(sched); } @@ -1616,6 +1824,7 @@ enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sch } void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { + GGML_ASSERT(sched); for (int i = 0; i < sched->n_backends; i++) { ggml_backend_synchronize(sched->backends[i]); } @@ -1628,45 +1837,63 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { } void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) { + GGML_ASSERT(sched); sched->callback_eval = callback; sched->callback_eval_user_data = user_data; } int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) { + GGML_ASSERT(sched); return sched->n_splits; } int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) { + GGML_ASSERT(sched); return sched->n_copies; } int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) { + GGML_ASSERT(sched); return sched->n_backends; } ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) { + GGML_ASSERT(sched); GGML_ASSERT(i >= 0 && i < sched->n_backends); return sched->backends[i]; } +ggml_backend_buffer_type_t ggml_backend_sched_get_buffer_type(ggml_backend_sched_t sched, ggml_backend_t backend) { + GGML_ASSERT(sched); + int backend_index = ggml_backend_sched_backend_id(sched, backend); + GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); + + return sched->bufts[backend_index]; +} + size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { + GGML_ASSERT(sched); int backend_index = ggml_backend_sched_backend_id(sched, backend); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); return ggml_gallocr_get_buffer_size(sched->galloc, backend_index); } -struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { +size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { int backend_index = ggml_backend_sched_backend_id(sched, backend); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); - struct ggml_allocr_buffer_status allocr_status = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); - struct ggml_backend_buffer_status status = {allocr_status.size, allocr_status.allocated}; + size_t size = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index); - return status; + if (backend->iface.buffer_size != NULL) { + size += backend->iface.buffer_size(backend); + } + + return size; } void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { + GGML_ASSERT(sched); int backend_index = ggml_backend_sched_backend_id(sched, backend); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); tensor_backend_id(node) = backend_index; @@ -1675,6 +1902,7 @@ void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct gg } ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) { + GGML_ASSERT(sched); int backend_index = tensor_backend_id(node); if (backend_index == -1) { return NULL; @@ -1685,6 +1913,7 @@ ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, // utils enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) { + GGML_ASSERT(tensor); GGML_ASSERT(tensor->buffer == NULL); GGML_ASSERT(tensor->view_src != NULL); GGML_ASSERT(tensor->view_src->buffer != NULL); @@ -1696,6 +1925,7 @@ enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) { } enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) { + GGML_ASSERT(tensor); GGML_ASSERT(tensor->buffer == NULL); GGML_ASSERT(tensor->data == NULL); GGML_ASSERT(tensor->view_src == NULL); @@ -1769,6 +1999,7 @@ static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_ } struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) { + GGML_ASSERT(graph); struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size); struct ggml_tensor ** node_copies = (ggml_tensor **) calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT bool * node_init = (bool *) calloc(hash_set.size, sizeof(node_init[0])); @@ -1913,6 +2144,7 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t // CPU backend - buffer static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); uintptr_t data = (uintptr_t)buffer->context; // align the buffer @@ -1924,6 +2156,7 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { } static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); ggml_aligned_free(buffer->context, buffer->size); delete buffer; } @@ -1933,24 +2166,28 @@ static void ggml_backend_cpu_ptr_buffer_free_buffer(ggml_backend_buffer_t buffer } static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + GGML_ASSERT(tensor); memset((char *)tensor->data + offset, value, size); GGML_UNUSED(buffer); } static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor); memcpy((char *)tensor->data + offset, data, size); GGML_UNUSED(buffer); } static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor); memcpy(data, (const char *)tensor->data + offset, size); GGML_UNUSED(buffer); } static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + GGML_ASSERT(src); if (ggml_backend_buffer_is_host(src->buffer)) { memcpy(dst->data, src->data, ggml_nbytes(src)); return true; @@ -1961,6 +2198,7 @@ static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con } static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_ASSERT(buffer); memset(buffer->context, value, buffer->size); } diff --git a/ml/backend/ggml/ggml/src/ggml-blas/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-blas/CMakeLists.txt index 76064c3f..60ce4b1e 100644 --- a/ml/backend/ggml/ggml/src/ggml-blas/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-blas/CMakeLists.txt @@ -74,7 +74,7 @@ if (BLAS_FOUND) target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS}) - if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel")) + if ("${BLAS_INCLUDE_DIRS}" MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel")) add_compile_definitions(GGML_BLAS_USE_MKL) endif() diff --git a/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp b/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp index 40738d5b..2a9ff7f6 100644 --- a/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp @@ -270,6 +270,7 @@ static struct ggml_backend_i blas_backend_i = { /* .graph_compute = */ ggml_backend_blas_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, }; static ggml_guid_t ggml_backend_blas_guid(void) { diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt index f188d163..42041b71 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt @@ -224,7 +224,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos) if (NOT ${feature_pos} EQUAL -1) - message(STATUS "ARM feature ${feature} enabled") + # Special handling for MATMUL_INT8 when machine doesn't support i8mm + if ("${feature}" STREQUAL "MATMUL_INT8" AND GGML_MACHINE_SUPPORTS_noi8mm) + message(STATUS "ARM feature ${feature} detected but unsetting due to machine not supporting i8mm") + list(APPEND ARCH_FLAGS -U__ARM_FEATURE_MATMUL_INT8) + else() + message(STATUS "ARM feature ${feature} enabled") + endif() endif() endforeach() endif() @@ -433,15 +439,31 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/arch/riscv/quants.c ggml-cpu/arch/riscv/repack.cpp ) - if (GGML_RVV) - if (GGML_XTHEADVECTOR) - list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d) - elseif (GGML_RV_ZFH) - list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d) - else() - list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d) + if (GGML_CPU_RISCV64_SPACEMIT) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC}) + list(APPEND GGML_CPU_SOURCES + ggml-cpu/spacemit/ime.cpp + ggml-cpu/spacemit/ime.h + ggml-cpu/spacemit/ime1_kernels.cpp + ggml-cpu/spacemit/ime_kernels.h + ) + endif() + set(MARCH_STR "rv64gc") + if (GGML_RV_ZFH) + string(APPEND MARCH_STR "_zfh") + endif() + if (GGML_XTHEADVECTOR) + string(APPEND MARCH_STR "_xtheadvector") + elseif (GGML_RVV) + string(APPEND MARCH_STR "_v") + if (GGML_RV_ZVFH) + string(APPEND MARCH_STR "_zvfh") endif() endif() + if (GGML_RV_ZICBOP) + string(APPEND MARCH_STR "_zicbop") + endif() + list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) elseif (GGML_SYSTEM_ARCH STREQUAL "s390x") message(STATUS "s390x detected") list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c) @@ -450,7 +472,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # TODO: Separation to determine activation of VX/VXE/VXE2 if (${S390X_M} MATCHES "8561|8562") - set(GGML_NNPA OFF) message(STATUS "z15 target") list(APPEND ARCH_FLAGS -march=z15) elseif (${S390X_M} MATCHES "3931") @@ -460,7 +481,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version. # binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15. message(STATUS "z17 target") - list(APPEND ARCH_FLAGS -march=z17) + list(APPEND ARCH_FLAGS -march=arch15) else() message(STATUS "Unknown target") message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.") @@ -472,11 +493,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND ARCH_FLAGS -mvx -mzvector) list(APPEND ARCH_DEFINITIONS GGML_VXE) endif() - - if (GGML_NNPA) - message(STATUS "NNPA enabled") - list(APPEND ARCH_DEFINITIONS GGML_NNPA) - endif() elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm") message(STATUS "Wasm detected") list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c) @@ -497,9 +513,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.11.0") + set(KLEIDIAI_COMMIT_TAG "v1.14.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2") + set(KLEIDIAI_ARCHIVE_MD5 "45e110675d93f99f82c23a1afcca76bc") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) @@ -555,6 +571,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) @@ -575,8 +592,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c - ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c) + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c + ${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S) set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2") endif() diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp index 258857b0..895a5713 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp @@ -7,7 +7,7 @@ #include "ggml-cpu.h" #include "traits.h" -#if defined(__gnu_linux__) +#if defined(__linux__) #include #include #endif @@ -149,6 +149,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous is_contiguous_2d(op->src[1]) && // src1 must be contiguous op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() && + op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315) op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) { // src1 must be host buffer @@ -186,7 +187,7 @@ static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_ty #define XFEATURE_XTILEDATA 18 static bool ggml_amx_init() { -#if defined(__gnu_linux__) +#if defined(__linux__) if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) { fprintf(stderr, "AMX is not ready to be used!\n"); return false; @@ -194,6 +195,8 @@ static bool ggml_amx_init() { return true; #elif defined(_WIN32) return true; +#else + return false; #endif } diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch-fallback.h b/ml/backend/ggml/ggml/src/ggml-cpu/arch-fallback.h index b62e3158..edfd7913 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/arch-fallback.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch-fallback.h @@ -40,18 +40,22 @@ #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) // repack.cpp #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // repack.cpp @@ -69,7 +73,6 @@ #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K -#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -80,12 +83,14 @@ #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #elif defined(__loongarch64) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K @@ -103,12 +108,14 @@ #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #elif defined(__riscv) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K @@ -133,16 +140,16 @@ #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #elif defined(__s390x__) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K -#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0 -#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -153,7 +160,6 @@ #define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K #define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K -#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -164,12 +170,14 @@ #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #elif defined(__wasm__) // quants.c #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 @@ -195,10 +203,12 @@ #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #endif diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/arch/x86/repack.cpp index 37933a4b..fe18225c 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -511,38 +511,34 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR #endif } -void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +// +// GEMV/GEMM templates +// + +#if defined(__AVX2__) || defined(__AVX512F__) + +// GEMV for 8x blocks of 32 4-bit quants with a single scale factor per block +template +static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) { + static_assert( + std::is_same_v || + std::is_same_v, + "Unsupported block type"); + const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); -#if defined(__AVX2__) - // Lookup table to convert signed nibbles to signed bytes - __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); - signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); // Permute mask used for easier vector processing at later stages const __m256i m4b = _mm256_set1_epi8(0x0F); - int64_t b_nb = n / QK4_0; + int64_t b_nb = n / 32; - const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; + const block_tx8 * b_ptr_start = (const block_tx8 *)vx; const block_q8_0 * a_ptr_start = (const block_q8_0 *)vy; // Process Q8_0 blocks one by one @@ -551,17 +547,17 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo // Pointers to LHS blocks of block_q8_0 format const block_q8_0 * a_ptr = a_ptr_start + (y * nb); - // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation + // Take group of eight blocks at each pass of the loop and perform dot product operation for (int64_t x = 0; x < nc / 8; x++) { // Pointers to RHS blocks - const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); + const block_tx8 * b_ptr = b_ptr_start + (x * b_nb); // Master FP accumulator __m256 acc_row = _mm256_setzero_ps(); for (int64_t b = 0; b < nb; b++) { - // Load 8 blocks of Q4_0 interleaved as 8 bytes (B0 - B7) + // Load 8 blocks of 32 interleaved as 8 bytes (B0 - B7) const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 1); const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 2); @@ -578,8 +574,13 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m256i rhs_vec_0123_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b)); // B0(24-31) B1(24-31) B2(24-31) B3(24-31) const __m256i rhs_vec_4567_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b)); // B4(24-31) B5(24-31) B6(24-31) B7(24-31) - // Load the scale values for the 8 blocks interleaved in block_q4_0x8 - const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); + // Load the scale values for the 8 blocks interleaved in block_tx8 + __m256 col_scale_f32; + if constexpr ( + std::is_same_v || + std::is_same_v) { + col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); + } // Load and convert to FP32 scale from block_q8_0 const __m256 row_scale_f32 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(a_ptr[b].d)); @@ -620,9 +621,771 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo _mm256_storeu_ps(s + (y * nr + x * 8), acc_row); } } - return; +} +// GEMM for 8x blocks of 32 4-bit quants with a single scale factor per block +template +static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) { + static_assert( + std::is_same_v || + std::is_same_v, + "Unsupported block type"); + + const int qk = QK8_0; + const int nb = n / qk; + + const block_tx8 * b_ptr_start = (const block_tx8 *)vx; + const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy; + + int64_t b_nb = n / 32; + int64_t y = 0; + // Mask to mask out nibbles from packed bytes + const __m256i m4b = _mm256_set1_epi8(0x0F); + const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3); + // Permute mask used for easier vector processing at later stages + __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4); + int64_t xstart = 0; + int anr = nr - nr%16; // Used to align nr with boundary of 16 +#ifdef __AVX512F__ + int anc = nc - nc%16; // Used to align nc with boundary of 16 + // Mask to mask out nibbles from packed bytes expanded to 512 bit length + const __m512i m4bexpanded = _mm512_set1_epi8(0x0F); + // Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length + __m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1); + + // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation + for (; y < anr / 4; y += 4) { + + const block_q8_0x4 * a_ptrs[4]; + + a_ptrs[0] = a_ptr_start + (y * nb); + for (int i = 0; i < 3; ++i) { + a_ptrs[i + 1] = a_ptrs[i] + nb; + } + + // Take group of two block_tx8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_tx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_tx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[16]; + for (int i = 0; i < 16; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the sixteen blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + // 4-bit -> 8-bit - Sign is maintained + const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7) + const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7) + + const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15) + const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15) + + const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23) + const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23) + + const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31) + const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) + + // Shuffle pattern one - right side input + const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) + const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) + + const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) + const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) + + const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) + const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) + + const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) + const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) + + // Shuffle pattern two - right side input + + const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) + const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) + + const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) + const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) + + const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) + const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) + + const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) + const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) + + // Scale values - Load the weight scale values of two block_tx8 + __m512 col_scale_f32; + if constexpr ( + std::is_same_v || + std::is_same_v) { + col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + } + + // Process LHS in pairs of rows + for (int rp = 0; rp < 4; rp++) { + + // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector + __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); + __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0); + __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17); + __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); + __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0); + __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17); + __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); + __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0); + __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17); + __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); + __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0); + __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17); + + __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1); + __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1); + __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1); + __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1); + __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1); + __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1); + __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1); + __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1); + + // Shuffle pattern one - left side input + + const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + + const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + const __m512i zero = _mm512_setzero_epi32(); + __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2); + __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + + // Straighten out to make 4 row vectors + __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01); + __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68); + const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16); + + // Multiply with appropiate scales and accumulate + acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); + acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); + acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); + acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + } + } + + // Store the accumulated values + for (int i = 0; i < 16; i++) { + _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + + // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation + for (; y < nr / 4; y ++) { + const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb); + + // Take group of two block_tx8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_tx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_tx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the sixteen blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + // 4-bit -> 8-bit - Sign is maintained + const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7) + const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7) + + const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15) + const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15) + + const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23) + const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23) + + const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31) + const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) + + // Shuffle pattern one - right side input + const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) + const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) + + const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) + const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) + + const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) + const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) + + const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) + const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) + + // Shuffle pattern two - right side input + + const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) + const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) + + const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) + const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) + + const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) + const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) + + const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) + const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) + + + // Scale values - Load the weight scale values of two block_tx8 + __m512 col_scale_f32; + if constexpr ( + std::is_same_v || + std::is_same_v) { + col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + } + + // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector + __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); + __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0); + __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17); + __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); + __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0); + __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17); + __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); + __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0); + __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17); + __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); + __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0); + __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17); + + __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1); + __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1); + __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1); + __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1); + __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1); + __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1); + __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1); + __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1); + + // Shuffle pattern one - left side input + + const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + + const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + const __m512i zero = _mm512_setzero_epi32(); + __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2); + __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + + // Straighten out to make 4 row vectors + __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01); + __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68); + const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16); + + // Multiply with appropiate scales and accumulate + acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); + acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); + acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); + acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); + } + + // Store the accumulated values + for (int i = 0; i < 4; i++) { + _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + if (anc != nc) { + xstart = anc/8; + y = 0; + } +#endif // __AVX512F__ + + // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation + + for (; y < anr / 4; y += 4) { + const block_q8_0x4 * a_ptrs[4]; + + a_ptrs[0] = a_ptr_start + (y * nb); + for (int i = 0; i < 3; ++i) { + a_ptrs[i + 1] = a_ptrs[i] + nb; + } + + // Take group of eight block_tx8 structures at each pass of the loop and perform dot product operation + for (int64_t x = xstart; x < nc / 8; x++) { + + const block_tx8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulators + __m256 acc_rows[16]; + for (int i = 0; i < 16; i++) { + acc_rows[i] = _mm256_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the eight blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + // 4-bit -> 8-bit - Sign is maintained + const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) + const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) + + const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) + const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) + + const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) + const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) + + const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) + const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) + + // Shuffle pattern one - right side input + const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) + const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) + + const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) + const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) + + const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) + const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) + + const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) + const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) + + // Shuffle pattern two - right side input + + const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) + const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) + + const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) + const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) + + const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) + const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) + + const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) + const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) + + // Scale values - Load the wight scale values of block_tx8 + __m256 col_scale_f32; + if constexpr ( + std::is_same_v || + std::is_same_v) { + col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); + } + + // Process LHS in groups of four + for (int rp = 0; rp < 4; rp++) { + // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); + __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); + __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); + __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); + __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); + __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); + __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); + __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); + __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); + __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); + __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); + __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); + + // Shuffle pattern one - left side input + const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + const __m256i zero = _mm256_setzero_si256(); + __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2); + __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + // Straighten out to make 4 row vectors + __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); + __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); + __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); + __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask); + + // Multiply with appropiate scales and accumulate + acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); + acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); + acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); + acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + } + } + + // Store the accumulated values + for (int i = 0; i < 16; i++) { + _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + + // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation + for (; y < nr / 4; y ++) { + const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb); + + // Load the eight blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + for (int64_t x = xstart; x < nc / 8; x++) { + const block_tx8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulators + __m256 acc_rows[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = _mm256_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + // 4-bit -> 8-bit - Sign is maintained + const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) + const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) + + const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) + const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) + + const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) + const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) + + const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) + const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) + + // Shuffle pattern one - right side input + const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) + const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) + + const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) + const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) + + const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) + const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) + + const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) + const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) + + // Shuffle pattern two - right side input + + const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) + const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) + + const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) + const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) + + const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) + const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) + + const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) + const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) + + // Scale values - Load the wight scale values of block_tx8 + __m256 col_scale_f32; + if constexpr ( + std::is_same_v || + std::is_same_v) { + col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); + } + + // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); + __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); + __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); + __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); + __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); + __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); + __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); + __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); + __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); + __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); + __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); + __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); + + // Shuffle pattern one - left side input + + const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + + const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + const __m256i zero = _mm256_setzero_si256(); + __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2); + __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + + // Straighten out to make 4 row vectors + __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); + __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); + __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); + __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask); + + // Multiply with appropiate scales and accumulate + acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); + acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); + acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); + acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); + } + + // Store the accumulated values + for (int i = 0; i < 4; i++) { + _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } +} + +#endif // defined(__AVX2__) || defined(__AVX512F__) + +void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined(__AVX2__) || defined(__AVX512F__) + { + // Lookup table to convert signed nibbles to signed bytes + __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + + gemv_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut); + + return; + } #endif + ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } @@ -849,6 +1612,19 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined(__AVX2__) + __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_iq4nl)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + + gemv_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut); + + return; +#endif + + ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -1163,750 +1939,18 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - #if defined(__AVX2__) || defined(__AVX512F__) { - const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; - const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy; - int64_t b_nb = n / QK4_0; - int64_t y = 0; - // Mask to mask out nibbles from packed bytes - const __m256i m4b = _mm256_set1_epi8(0x0F); - const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3); // Lookup table to convert signed nibbles to signed bytes __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); - // Permute mask used for easier vector processing at later stages - __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4); - int64_t xstart = 0; - int anr = nr - nr%16; // Used to align nr with boundary of 16 - #ifdef __AVX512F__ - int anc = nc - nc%16; // Used to align nc with boundary of 16 - // Mask to mask out nibbles from packed bytes expanded to 512 bit length - const __m512i m4bexpanded = _mm512_set1_epi8(0x0F); - // Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length - __m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1); - // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation - for (; y < anr / 4; y += 4) { + gemm_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut); - const block_q8_0x4 * a_ptrs[4]; - - a_ptrs[0] = a_ptr_start + (y * nb); - for (int i = 0; i < 3; ++i) { - a_ptrs[i + 1] = a_ptrs[i] + nb; - } - - // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation - for (int64_t x = 0; x < anc / 8; x += 2) { - - const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); - const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); - - // Master FP accumulators - __m512 acc_rows[16]; - for (int i = 0; i < 16; i++) { - acc_rows[i] = _mm512_setzero_ps(); - } - - for (int64_t b = 0; b < nb; b++) { - // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF - const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs)); - const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32)); - const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64)); - const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96)); - - const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs)); - const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32)); - const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64)); - const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96)); - - // Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values - const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); - const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); - - const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); - const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); - const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); - const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); - - const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); - const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); - const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); - const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); - - // 4-bit -> 8-bit - Sign is maintained - const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7) - const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7) - - const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15) - const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15) - - const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23) - const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23) - - const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31) - const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) - - // Shuffle pattern one - right side input - const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) - const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) - - const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) - const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) - - const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) - const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) - - const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) - const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) - - // Shuffle pattern two - right side input - - const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) - const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) - - const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) - const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) - - const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) - const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) - - const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) - const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) - - // Scale values - Load the weight scale values of two block_q4_0x8 - const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); - - // Process LHS in pairs of rows - for (int rp = 0; rp < 4; rp++) { - - // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 - // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector - __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); - __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0); - __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17); - __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); - __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0); - __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17); - __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); - __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0); - __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17); - __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); - __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0); - __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17); - - __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1); - __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1); - __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1); - __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1); - __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1); - __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1); - __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1); - __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1); - - // Shuffle pattern one - left side input - - const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) - const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) - - const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) - const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) - - const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) - const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) - - const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) - const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) - - // Shuffle pattern two - left side input - - const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) - const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) - - const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) - const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) - - const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) - const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) - - const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) - const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) - - // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane - // Resembles MMLAs into 2x2 matrices in ARM Version - const __m512i zero = _mm512_setzero_epi32(); - __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1); - __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1); - __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1); - __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1); - __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2); - __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2); - __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2); - __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2); - - // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block - __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); - __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); - __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); - __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); - - - // Straighten out to make 4 row vectors - __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78)); - __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01); - __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78)); - __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11); - - // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes - const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68); - const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16); - - // Multiply with appropiate scales and accumulate - acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); - acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); - acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); - acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); - } - } - - // Store the accumulated values - for (int i = 0; i < 16; i++) { - _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); - } - } - } - // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation - for (; y < nr / 4; y ++) { - - const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb); - - // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation - for (int64_t x = 0; x < anc / 8; x += 2) { - - const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); - const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); - - // Master FP accumulators - __m512 acc_rows[4]; - for (int i = 0; i < 4; i++) { - acc_rows[i] = _mm512_setzero_ps(); - } - - for (int64_t b = 0; b < nb; b++) { - // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF - const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs)); - const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32)); - const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64)); - const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96)); - - const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs)); - const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32)); - const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64)); - const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96)); - - // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess - const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); - const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); - - const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); - const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); - const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); - const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); - - const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); - const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); - const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); - const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); - - // 4-bit -> 8-bit - Sign is maintained - const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7) - const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7) - - const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15) - const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15) - - const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23) - const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23) - - const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31) - const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) - - // Shuffle pattern one - right side input - const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) - const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) - - const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) - const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) - - const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) - const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) - - const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) - const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) - - // Shuffle pattern two - right side input - - const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) - const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) - - const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) - const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) - - const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) - const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) - - const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) - const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) - - - // Scale values - Load the weight scale values of two block_q4_0x8 - const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); - - // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 - // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector - __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); - __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0); - __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17); - __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); - __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0); - __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17); - __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); - __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0); - __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17); - __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); - __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0); - __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17); - - __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1); - __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1); - __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1); - __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1); - __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1); - __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1); - __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1); - __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1); - - // Shuffle pattern one - left side input - - const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) - const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) - - const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) - const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) - - const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) - const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) - - const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) - const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) - - // Shuffle pattern two - left side input - - const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) - const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) - - const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) - const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) - - const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) - const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) - - const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) - const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) - - // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane - // Resembles MMLAs into 2x2 matrices in ARM Version - const __m512i zero = _mm512_setzero_epi32(); - __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1); - __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1); - __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1); - __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1); - __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2); - __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2); - __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2); - __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2); - - // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block - __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); - __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); - __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); - __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); - - - // Straighten out to make 4 row vectors - __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78)); - __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01); - __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78)); - __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11); - - // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes - const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68); - const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16); - - // Multiply with appropiate scales and accumulate - acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); - acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); - acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); - acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); - } - - // Store the accumulated values - for (int i = 0; i < 4; i++) { - _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); - } - } - } - if (anc != nc) { - xstart = anc/8; - y = 0; - } - #endif // __AVX512F__ - - // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation - - for (; y < anr / 4; y += 4) { - const block_q8_0x4 * a_ptrs[4]; - - a_ptrs[0] = a_ptr_start + (y * nb); - for (int i = 0; i < 3; ++i) { - a_ptrs[i + 1] = a_ptrs[i] + nb; - } - - // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation - for (int64_t x = xstart; x < nc / 8; x++) { - - const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); - - // Master FP accumulators - __m256 acc_rows[16]; - for (int i = 0; i < 16; i++) { - acc_rows[i] = _mm256_setzero_ps(); - } - - for (int64_t b = 0; b < nb; b++) { - // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 - const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); - const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); - const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); - const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); - - // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values - const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); - const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); - - // 4-bit -> 8-bit - Sign is maintained - const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) - const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) - - const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) - const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) - - const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) - const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) - - const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) - const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) - - // Shuffle pattern one - right side input - const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) - const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) - - const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) - const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) - - const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) - const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) - - const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) - const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) - - // Shuffle pattern two - right side input - - const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) - const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) - - const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) - const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) - - const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) - const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) - - const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) - const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) - - // Scale values - Load the wight scale values of block_q4_0x8 - const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); - - // Process LHS in groups of four - for (int rp = 0; rp < 4; rp++) { - // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 - // Loaded as set of 128 bit vectors and repeated into a 256 bit vector - __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); - __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); - __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); - __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); - __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); - __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); - __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); - __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); - __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); - __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); - __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); - __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); - - // Shuffle pattern one - left side input - const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) - const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) - - const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) - const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) - - const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) - const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) - - const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) - const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) - - // Shuffle pattern two - left side input - const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) - const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) - - const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) - const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) - - const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) - const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) - - const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) - const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) - - // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane - // Resembles MMLAs into 2x2 matrices in ARM Version - const __m256i zero = _mm256_setzero_si256(); - __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1); - __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1); - __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1); - __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1); - __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2); - __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2); - __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2); - __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2); - - // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block - __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); - __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); - __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); - __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); - - // Straighten out to make 4 row vectors - __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); - __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); - __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); - __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); - - // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes - const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask); - - // Multiply with appropiate scales and accumulate - acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); - acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); - acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); - acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); - } - } - - // Store the accumulated values - for (int i = 0; i < 16; i++) { - _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); - } - } - } - - // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation - for (; y < nr / 4; y ++) { - - const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb); - - // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 - for (int64_t x = xstart; x < nc / 8; x++) { - - const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); - - // Master FP accumulators - __m256 acc_rows[4]; - for (int i = 0; i < 4; i++) { - acc_rows[i] = _mm256_setzero_ps(); - } - - for (int64_t b = 0; b < nb; b++) { - // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 - const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); - const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); - const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); - const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); - - // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess - const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); - const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); - - // 4-bit -> 8-bit - Sign is maintained - const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) - const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) - - const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) - const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) - - const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) - const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) - - const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) - const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) - - // Shuffle pattern one - right side input - const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) - const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) - - const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) - const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) - - const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) - const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) - - const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) - const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) - - // Shuffle pattern two - right side input - - const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) - const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) - - const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) - const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) - - const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) - const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) - - const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) - const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) - - // Scale values - Load the wight scale values of block_q4_0x8 - const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); - - // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 - // Loaded as set of 128 bit vectors and repeated into a 256 bit vector - __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); - __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); - __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); - __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); - __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); - __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); - __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); - __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); - __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); - __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); - __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); - __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); - - // Shuffle pattern one - left side input - - const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) - const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) - - const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) - const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) - - const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) - const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) - - const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) - const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) - - // Shuffle pattern two - left side input - - const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) - const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) - - const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) - const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) - - const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) - const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) - - const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) - const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) - - // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane - // Resembles MMLAs into 2x2 matrices in ARM Version - const __m256i zero = _mm256_setzero_si256(); - __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1); - __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1); - __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1); - __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1); - __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2); - __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2); - __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2); - __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2); - - // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block - __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); - __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); - __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); - __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); - - - // Straighten out to make 4 row vectors - __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); - __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); - __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); - __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); - - // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes - const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask); - - // Multiply with appropiate scales and accumulate - acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); - acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); - acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); - acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); - } - - // Store the accumulated values - for (int i = 0; i < 4; i++) { - _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); - } - } - } return; } +#endif // defined(__AVX2__) || defined(__AVX512F__) -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } @@ -3364,6 +3408,21 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined(__AVX2__) || defined(__AVX512F__) + { + __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_iq4nl)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + + gemm_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut); + + return; + } +#endif // defined(__AVX2__) || defined(__AVX512F__) + + ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/common.h b/ml/backend/ggml/ggml/src/ggml-cpu/common.h index 353563dc..6adca543 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/common.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/common.h @@ -28,6 +28,14 @@ static inline float bf16_to_f32(ggml_bf16_t x) { return GGML_BF16_TO_FP32(x); } +static inline float i32_to_f32(int32_t x) { + return x; +} + +static inline int32_t f32_to_i32(float x) { + return x; +} + static inline float f32_to_f32(float x) { return x; } @@ -54,6 +62,12 @@ struct type_conversion_table { static constexpr ggml_bf16_t (*from_f32)(float) = f32_to_bf16; }; +template <> +struct type_conversion_table { + static constexpr float (*to_f32)(int32_t) = i32_to_f32; + static constexpr int32_t (*from_f32)(float) = f32_to_i32; +}; + static std::pair get_thread_range(const struct ggml_compute_params * params, const struct ggml_tensor * src0) { const int64_t ith = params->ith; const int64_t nth = params->nth; diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h index d839cf5c..713bf85e 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -68,13 +68,7 @@ struct ggml_compute_params { #endif // __VXE2__ #endif // __s390x__ && __VEC__ -#if defined(__s390x__) && defined(GGML_NNPA) -#ifndef __NNPA__ -#define __NNPA__ -#endif // __NNPA__ -#endif // __s390x__ && GGML_NNPA - -#if defined(__ARM_FEATURE_SVE) +#if defined(__ARM_FEATURE_SVE) && defined(__linux__) #include #endif @@ -486,6 +480,19 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) { return v_abo + v_abe; } +/** + * @see https://github.com/ggml-org/llama.cpp/pull/14037 + */ +inline static float vec_hsum_f32x4(float32x4_t v) { + float32x4_t v_temp = v + vec_reve(v); + return v_temp[0] + v_temp[1]; +} + +inline static int32_t vec_hsum_i32x4(int32x4_t v) { + int32x4_t v_temp = v + vec_reve(v); + return v_temp[0] + v_temp[1]; +} + inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) { const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b); return acc + (vec_unpackh(p) + vec_unpackl(p)); diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c index 85af19a3..b13a491d 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c @@ -375,6 +375,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_I32] = { + .from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32, + }, }; const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { @@ -472,10 +475,10 @@ struct ggml_threadpool { struct ggml_compute_state { #ifndef GGML_USE_OPENMP ggml_thread_t thrd; - bool cpumask[GGML_MAX_N_THREADS]; int last_graph; bool pending; #endif + bool cpumask[GGML_MAX_N_THREADS]; struct ggml_threadpool * threadpool; int ith; }; @@ -688,8 +691,13 @@ bool ggml_is_numa(void) { #endif static void ggml_init_arm_arch_features(void) { -#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE) +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) +#if defined(__linux__) ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); +#else + // TODO: add support of SVE for non-linux systems +#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here." +#endif #endif } @@ -1878,10 +1886,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_im2col_back_f32(params, tensor); } break; + case GGML_OP_IM2COL_3D: + { + ggml_compute_forward_im2col_3d(params, tensor); + } break; case GGML_OP_CONV_2D: { ggml_compute_forward_conv_2d(params, tensor); } break; + case GGML_OP_CONV_3D: + { + ggml_compute_forward_conv_3d(params, tensor); + } break; case GGML_OP_CONV_2D_DW: { ggml_compute_forward_conv_2d_dw(params, tensor); @@ -2024,6 +2040,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm ggml_compute_forward_opt_step_adamw(params, tensor); } break; + case GGML_OP_OPT_STEP_SGD: + { + ggml_compute_forward_opt_step_sgd(params, tensor); + } + break; case GGML_OP_NONE: { // nop @@ -2173,6 +2194,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_XIELU: { n_tasks = n_threads; } break; @@ -2248,7 +2270,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_IM2COL: case GGML_OP_IM2COL_BACK: + case GGML_OP_IM2COL_3D: case GGML_OP_CONV_2D: + case GGML_OP_CONV_3D: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_2D: @@ -2327,6 +2351,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: { n_tasks = n_threads; } break; @@ -2682,7 +2707,10 @@ struct ggml_cplan ggml_graph_plan( if (ggml_is_quantized(node->type) || // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32 (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) || - (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) { + (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16) || + // conversion between F32 and I32 + (node->src[0]->type == GGML_TYPE_F32 && node->src[1] && node->src[1]->type == GGML_TYPE_I32) || + (node->src[0]->type == GGML_TYPE_I32 && node->src[1] && node->src[1]->type == GGML_TYPE_F32)) { cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; } } break; @@ -2769,6 +2797,7 @@ struct ggml_cplan ggml_graph_plan( } } break; case GGML_OP_CONV_2D: + case GGML_OP_CONV_3D: { cur = GGML_IM2COL_WORK_SIZE; } break; @@ -3064,7 +3093,14 @@ static struct ggml_threadpool * ggml_threadpool_new_impl( threadpool->workers = workers; -#ifndef GGML_USE_OPENMP +#ifdef GGML_USE_OPENMP + int32_t cpumask_iter = 0; + + // Compute CPU masks for each thread + for (int j = 0; j < tpp->n_threads; j++) { + ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter); + } +#else // GGML_USE_OPENMP ggml_mutex_init(&threadpool->mutex); ggml_cond_init(&threadpool->cond); @@ -3137,7 +3173,14 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed); } - ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]); + // Apply thread CPU mask and priority + int ith = omp_get_thread_num(); + + ggml_thread_apply_priority(threadpool->prio); + if (ggml_thread_cpumask_is_valid(threadpool->workers[ith].cpumask)) { + ggml_thread_apply_affinity(threadpool->workers[ith].cpumask); + } + ggml_graph_compute_thread(&threadpool->workers[ith]); } } else { atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed); @@ -3200,20 +3243,12 @@ void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) { __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); _mm_storel_epi64((__m128i *)(y + i), y_vec); } -#elif defined(__NNPA__) - for (; i + 7 < n; i += 8) { - float32x4_t v_xh = vec_xl(0, (const float *)(x + i + 0)); - float32x4_t v_xl = vec_xl(0, (const float *)(x + i + 4)); - uint16x8_t v_yd = vec_round_from_fp32(v_xh, v_xl, 0); - uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0); - vec_xst(v_y, 0, (ggml_fp16_t *)(y + i)); - } - for (; i + 3 < n; i += 4) { - float32x4_t v_x = vec_xl(0, (const float *)(x + i)); - float32x4_t v_zero = vec_splats(0.0f); - uint16x8_t v_yd = vec_round_from_fp32(v_x, v_zero, 0); - uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0); - vec_xst(v_y, 0, (ggml_fp16_t *)(y + i)); +#elif defined(__riscv_zvfh) + for (int vl; i < n; i += vl) { + vl = __riscv_vsetvl_e32m2(n - i); + vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl); + vfloat16m1_t vy = __riscv_vfncvt_f_f_w_f16m1(vx, vl); + __riscv_vse16_v_f16m1((_Float16 *)&y[i], vy, vl); } #endif for (; i < n; ++i) { @@ -3241,21 +3276,6 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) { __m128 y_vec = _mm_cvtph_ps(x_vec); _mm_storeu_ps(y + i, y_vec); } -#elif defined(__NNPA__) - for (; i + 7 < n; i += 8) { - uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)(x + i)); - uint16x8_t v_yd = vec_convert_from_fp16(v_x, 0); - float32x4_t v_yh = vec_extend_to_fp32_hi(v_yd, 0); - float32x4_t v_yl = vec_extend_to_fp32_lo(v_yd, 0); - vec_xst(v_yh, 0, (float *)(y + i + 0)); - vec_xst(v_yl, 0, (float *)(y + i + 4)); - } - for (; i + 3 < n; i += 4) { - uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)(x + i)); - uint16x8_t v_yd = vec_convert_from_fp16(v_x, 0); - float32x4_t v_yh = vec_extend_to_fp32_hi(v_yd, 0); - vec_xst(v_yh, 0, (float *)(y + i)); - } #endif for (; i < n; ++i) { @@ -3270,6 +3290,13 @@ void ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) { } } +void ggml_cpu_fp32_to_i32(const float * x, int32_t * y, int64_t n) { + int64_t i = 0; + for (; i < n; ++i) { + y[i] = x[i]; + } +} + void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) { int64_t i = 0; #if defined(__AVX2__) @@ -3459,14 +3486,6 @@ int ggml_cpu_has_vxe(void) { #endif } -int ggml_cpu_has_nnpa(void) { -#if defined(GGML_NNPA) - return 1; -#else - return 0; -#endif -} - int ggml_cpu_has_neon(void) { #if defined(__ARM_ARCH) && defined(__ARM_NEON) return 1; diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp index 8dacd367..3191faaa 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -18,6 +18,10 @@ # include "kleidiai/kleidiai.h" #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT +# include "spacemit/ime.h" +#endif + #if defined(_WIN32) # define WIN32_LEAN_AND_MEAN # ifndef NOMINMAX @@ -45,6 +49,12 @@ std::vector & ggml_backend_cpu_get_extra_buffer_type } #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + if (ggml_backend_cpu_riscv64_spacemit_buffer_type()) { + bufts.push_back(ggml_backend_cpu_riscv64_spacemit_buffer_type()); + } +#endif + #ifdef GGML_USE_CPU_KLEIDIAI if (ggml_backend_cpu_kleidiai_buffer_type()) { bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type()); @@ -190,6 +200,7 @@ static const struct ggml_backend_i ggml_backend_cpu_i = { /* .graph_compute = */ ggml_backend_cpu_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, }; static ggml_guid_t ggml_backend_cpu_guid(void) { @@ -348,8 +359,10 @@ static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * long pages = sysconf(_SC_PHYS_PAGES); long page_size = sysconf(_SC_PAGE_SIZE); *total = pages * page_size; + + // "free" system memory is ill-defined, for practical purposes assume that all of it is free: *free = *total; -#endif +#endif // _WIN32 GGML_UNUSED(dev); } @@ -576,9 +589,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r if (ggml_cpu_has_vxe()) { features.push_back({ "VXE", "1" }); } - if (ggml_cpu_has_nnpa()) { - features.push_back({ "NNPA", "1" }); - } if (ggml_cpu_has_wasm_simd()) { features.push_back({ "WASM_SIMD", "1" }); } diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 2be54c31..2c4ad9d5 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -2169,94 +2169,117 @@ class tinyBLAS_Q0_PPC { class tinyBLAS_PPC { public: tinyBLAS_PPC(int64_t k, - const float *A, int64_t lda, - const float *B, int64_t ldb, - float *C, int64_t ldc, + const float * A, int64_t lda, + const float * B, int64_t ldb, + float * C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } void matmul(int64_t m, int64_t n) { - mnpack(0, m, 0, n); + int64_t mc = 256; int64_t nc = 256; int64_t kc = 256; + if (m % mc == 0 && n % nc == 0 && k % kc == 0) { + matmul_tiled(m, n, mc, nc, kc); + } else { + mnpack(0, m, 0, n); + } } private: - void (tinyBLAS_PPC::*kernel)(int64_t, int64_t); - - inline void vector_permute_store_4(vector float *src, float *vecOffset) { - vector float t1, t2, t3, t4, t5, t6, t7, t8; - t1 = vec_mergeh(src[0], src[1]); - t2 = vec_mergeh(src[2], src[3]); - t3 = vec_mergel(src[0], src[1]); - t4 = vec_mergel(src[2], src[3]); - - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t1, t2, 3); - t7 = vec_xxpermdi(t3, t4, 0); - t8 = vec_xxpermdi(t3, t4, 3); - - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset + 4); - vec_xst(t7, 0, vecOffset + 8); - vec_xst(t8, 0, vecOffset + 12); - } - - inline void vector_permute_store_8(vector float *src, float *vecOffset) { - vector float t1, t2, t3, t4, t5, t6, t7, t8; - t1 = vec_mergeh(src[0], src[1]); - t2 = vec_mergeh(src[2], src[3]); - t3 = vec_mergeh(src[4], src[5]); - t4 = vec_mergeh(src[6], src[7]); - - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset + 4); - vec_xst(t7, 0, vecOffset + 8); - vec_xst(t8, 0, vecOffset + 12); - - t1 = vec_mergel(src[0], src[1]); - t2 = vec_mergel(src[2], src[3]); - t3 = vec_mergel(src[4], src[5]); - t4 = vec_mergel(src[6], src[7]); - - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - - vec_xst(t5, 0, vecOffset + 16); - vec_xst(t6, 0, vecOffset + 20); - vec_xst(t7, 0, vecOffset + 24); - vec_xst(t8, 0, vecOffset + 28); + inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) { + vec_t vec_C[4]; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int I = 0; I < 4; I++) { + for (int J = 0; J < 4; J++) { + *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J); + } + } } - void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) { + inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) { + vec_t vec_C[4]; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int I = 0; I < 4; I++) { + for (int J = 0; J < 4; J++) { + float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I); + *c_ptr += *((float *)&vec_C[I]+J); + } + } + } + + inline void vector_permute_store_4(vector float * src, float * vecOffset) { + vector float t1, t2, t3, t4, t5, t6, t7, t8; + t1 = vec_mergeh(src[0], src[1]); + t2 = vec_mergeh(src[2], src[3]); + t3 = vec_mergel(src[0], src[1]); + t4 = vec_mergel(src[2], src[3]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t1, t2, 3); + t7 = vec_xxpermdi(t3, t4, 0); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 4); + vec_xst(t7, 0, vecOffset + 8); + vec_xst(t8, 0, vecOffset + 12); + } + + inline void vector_permute_store_8(vector float * src, float * vecOffset) { + vector float t1, t2, t3, t4, t5, t6, t7, t8; + t1 = vec_mergeh(src[0], src[1]); + t2 = vec_mergeh(src[2], src[3]); + t3 = vec_mergeh(src[4], src[5]); + t4 = vec_mergeh(src[6], src[7]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 4); + vec_xst(t7, 0, vecOffset + 8); + vec_xst(t8, 0, vecOffset + 12); + + t1 = vec_mergel(src[0], src[1]); + t2 = vec_mergel(src[2], src[3]); + t3 = vec_mergel(src[4], src[5]); + t4 = vec_mergel(src[6], src[7]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset + 16); + vec_xst(t6, 0, vecOffset + 20); + vec_xst(t7, 0, vecOffset + 24); + vec_xst(t8, 0, vecOffset + 28); + } + + void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) { int64_t i, j; float * aoffsets[8]; - float *aoffset = NULL, *boffset = NULL; + float * aoffset = NULL, * boffset = NULL; __vector_pair arr[8]; vector float c[8][2] = {0}; vector float c1[8] = {0}; vector float c2[8] = {0}; - aoffset = const_cast(a); + aoffset = const_cast(a); boffset = vec; j = (rows >> 3); if (j > 0) { - do { aoffsets[0] = aoffset; - for (int it = 1; it< 8; it++) + for (int it = 1; it < 8; it++) aoffsets[it] = aoffsets[it-1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { - for (int it = 0; it< 8; it++) { + for (int it = 0; it < 8; it++) { arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]); __builtin_vsx_disassemble_pair(c[it], &arr[it]); c1[it] = c[it][0]; @@ -2264,11 +2287,14 @@ class tinyBLAS_PPC { } vector_permute_store_8(c1, boffset); - vector_permute_store_8(c2, boffset+32); - for (int it = 0; it < 4; it++) - aoffsets[it] = aoffsets[it] + 8*lda; + vector_permute_store_8(c2, boffset + 32); boffset += 64; i--; + if (i > 0) { + for (int it = 0; it < 8; it++) { + aoffsets[it] = aoffsets[it] + 8; + } + } } while(i > 0); } if (cols & 4) { @@ -2295,9 +2321,9 @@ class tinyBLAS_PPC { c2[it] = c[it][1]; } vector_permute_store_4(c1, boffset); - vector_permute_store_4(c2, boffset+16); + vector_permute_store_4(c2, boffset + 16); for (int it = 0; it < 4; it++) - aoffsets[it] += 8*lda; + aoffsets[it] += 8 * lda; boffset += 32; i--; } while(i > 0); @@ -2325,15 +2351,15 @@ class tinyBLAS_PPC { vec_t vec_A[4], vec_B[4], vec_C[4]; acc_t acc_0; __builtin_mma_xxsetaccz(&acc_0); - for (int l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); + for (int l = 0; l < k; l += 4) { + packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A); + packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B); __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]); } - SAVE_ACC(&acc_0, ii, jj); + save_acc(&acc_0, ii, jj); } void KERNEL_4x8(int64_t ii, int64_t jj) { @@ -2341,9 +2367,9 @@ class tinyBLAS_PPC { acc_t acc_0, acc_1; __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); - for (int64_t l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B); + for (int64_t l = 0; l < k; l += 4) { + packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A); + packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B); __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]); @@ -2353,8 +2379,8 @@ class tinyBLAS_PPC { __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]); __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]); } - SAVE_ACC(&acc_0, ii, jj); - SAVE_ACC(&acc_1, ii, jj+4); + save_acc(&acc_0, ii, jj); + save_acc(&acc_1, ii, jj + 4); } void KERNEL_8x4(int64_t ii, int64_t jj) { @@ -2362,9 +2388,9 @@ class tinyBLAS_PPC { acc_t acc_0, acc_1; __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); - for (int64_t l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); + for (int64_t l = 0; l < k; l += 4) { + packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A); + packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]); @@ -2374,8 +2400,8 @@ class tinyBLAS_PPC { __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]); } - SAVE_ACC(&acc_0, ii, jj); - SAVE_ACC(&acc_1, ii+4, jj); + save_acc(&acc_0, ii, jj); + save_acc(&acc_1, ii + 4, jj); } void KERNEL_8x8(int64_t ii, int64_t jj) { @@ -2386,19 +2412,96 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_2); __builtin_mma_xxsetaccz(&acc_3); for (int l = 0; l < k; l+=8) { - packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B); + packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A); + packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B); for(int x = 0; x < 16; x+=2) { __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]); - __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]); - __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]); - __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]); + __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]); + __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]); + } + } + save_acc(&acc_0, ii, jj); + save_acc(&acc_1, ii, jj + 4); + save_acc(&acc_2, ii + 4, jj); + save_acc(&acc_3, ii + 4, jj + 4); + } + + inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) { + for (int x = 0; x < 16; x += 2) { + __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]); + __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]); + __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]); + __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]); + } + } + + void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) { + for (int64_t i = 0; i < mc; i += 16) { + int A_base_addr = (mc / 8) * (i / 8) * 16; + for (int64_t j = 0; j < nc; j += 8) { + int B_base_addr = (nc / 8) * (j / 8) * 16; + acc_t acc[8]; + vec_t A0_block[16]; vec_t A1_block[16]; + for (int x = 0; x < 8; x++) + __builtin_mma_xxsetaccz(&acc[x]); + for (int64_t l = 0; l < kc; l += 8) { + int A0_block_idx = A_base_addr + (l / 8) * 16; + int A1_block_idx = A0_block_idx + (mc / 8) * 16; + int B_block_idx = B_base_addr + (l / 8) * 16; + vec_t* A0_block = &vec_A[A0_block_idx]; + vec_t* A1_block = &vec_A[A1_block_idx]; + vec_t* B_block = &vec_B[B_block_idx]; + MMA_16x8(A0_block, A1_block, B_block, acc); + } + if (kk == 0) { + save_acc(&acc[0], ii + i, jj + j); + save_acc(&acc[1], ii + i, jj + j + 4); + save_acc(&acc[2], ii + i + 4, jj + j); + save_acc(&acc[3], ii + i + 4, jj + j + 4); + save_acc(&acc[4], ii + i + 8, jj + j); + save_acc(&acc[5], ii + i + 8, jj + j + 4); + save_acc(&acc[6], ii + i + 12, jj + j); + save_acc(&acc[7], ii + i + 12, jj + j + 4); + } else { + add_save_acc(&acc[0], ii + i, jj + j); + add_save_acc(&acc[1], ii + i, jj + j + 4); + add_save_acc(&acc[2], ii + i + 4, jj + j); + add_save_acc(&acc[3], ii + i + 4, jj + j + 4); + add_save_acc(&acc[4], ii + i + 8, jj + j); + add_save_acc(&acc[5], ii + i + 8, jj + j + 4); + add_save_acc(&acc[6], ii + i + 12, jj + j); + add_save_acc(&acc[7], ii + i + 12, jj + j + 4); + } + } + } + } + + void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) { + int64_t ytiles = m / mc; + int64_t xtiles = n / nc; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) { + end = tiles; + } + for (int64_t job = start; job < end; ++job) { + int64_t ii = (job / xtiles) * mc; + int64_t jj = (job % xtiles) * nc; + for (int64_t kk = 0; kk < k; kk += kc) { + vec_t A_pack[kc * mc / 4]; + vec_t B_pack[kc * nc / 4]; + packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack); + packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack); + KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk); } } - SAVE_ACC(&acc_0, ii, jj); - SAVE_ACC(&acc_1, ii, jj+4); - SAVE_ACC(&acc_2, ii+4, jj); - SAVE_ACC(&acc_3, ii+4, jj+4); } void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { @@ -2406,35 +2509,35 @@ class tinyBLAS_PPC { int n_rem = MIN(n - n0, 8); int mc = 0, nc = 0; if (m_rem >= 8 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8, 8>(m0, m, n0, n); + mc = 8; + nc = 8; + gemm<8, 8>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 8) { - mc = 4; - nc = 8; - gemm<4, 8>(m0, m, n0, n); + mc = 4; + nc = 8; + gemm<4, 8>(m0, m, n0, n); } else if (m_rem >= 8 && n_rem >= 4) { - mc = 8; - nc = 4; - gemm<8, 4>(m0, m, n0, n); + mc = 8; + nc = 4; + gemm<8, 4>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 4) { - mc = 4; - nc = 4; - gemm<4, 4>(m0, m, n0, n); + mc = 4; + nc = 4; + gemm<4, 4>(m0, m, n0, n); } else { mc = (m_rem >= 4) ? 4 : m_rem; nc = (n_rem >= 4) ? 4 : n_rem; if (mc == 0 || nc == 0) - return; + return; gemm_small(m0, m, n0, n, mc, nc); } int64_t mp = m0 + ((m - m0) / mc) * mc; int64_t np = n0 + ((n - n0) / nc) * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); - } + } - void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2449,30 +2552,30 @@ class tinyBLAS_PPC { vec_t vec_C[4]; acc_t acc_0; __builtin_mma_xxsetaccz(&acc_0); - vec_t vec_A[4] {0}, vec_B[4] = {0}; - for (int l=0; l(A+(ii)*lda+l); - packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); + float * a = const_cast(A + (ii) * lda + l); + packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B); vec_A[0] = (vec_t)vec_xl(0,a); - vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1)); - vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2)); - vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3)); + vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1)); + vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2)); + vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3)); } else if (RN == 1) { - packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); - float* b = const_cast(B+(jj)*ldb+l); + packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A); + float * b = const_cast(B + (jj) * ldb + l); vec_B[0] = (vec_t)vec_xl(0,b); - vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1)); - vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2)); - vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3)); + vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1)); + vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2)); + vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3)); } else { - packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); + packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A); + packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B); } __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); @@ -2482,12 +2585,27 @@ class tinyBLAS_PPC { __builtin_mma_disassemble_acc(vec_C, &acc_0); for (int I = 0; I < RM; I++) { for (int J = 0; J < RN; J++) { - *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); + *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J); } } } } + template + inline void kernel(int64_t ii, int64_t jj) { + if constexpr(RM == 4 && RN == 4) { + KERNEL_4x4(ii, jj); + } else if constexpr(RM == 4 && RN == 8) { + KERNEL_4x8(ii, jj); + } else if constexpr(RM == 8 && RN == 4) { + KERNEL_8x4(ii, jj); + } else if constexpr(RM == 8 && RN == 8) { + KERNEL_8x8(ii, jj); + } else { + static_assert(false, "RN/RM values not supported"); + } + } + template NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t ytiles = (m - m0) / RM; @@ -2496,27 +2614,18 @@ class tinyBLAS_PPC { int64_t duty = (tiles + nth - 1) / nth; int64_t start = duty * ith; int64_t end = start + duty; - if (RM == 4 && RN == 4) { - kernel = &tinyBLAS_PPC::KERNEL_4x4; - } else if (RM == 4 && RN == 8) { - kernel = &tinyBLAS_PPC::KERNEL_4x8; - } else if (RM == 8 && RN == 4) { - kernel = &tinyBLAS_PPC::KERNEL_8x4; - } else if (RM == 8 && RN == 8) { - kernel = &tinyBLAS_PPC::KERNEL_8x8; - } if (end > tiles) end = tiles; for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; - (this->*kernel)(ii, jj); + kernel(ii, jj); } } - const float *const A; - const float *const B; - float *C; + const float * const A; + const float * const B; + float * C; const int64_t k; const int64_t lda; const int64_t ldb; diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp index a2924757..31478dd8 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp @@ -41,13 +41,15 @@ static void ggml_compute_forward_dup_same_cont( } } -static void ggml_compute_forward_dup_f16( +template +static void ggml_compute_forward_dup_flt( const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type)); GGML_TENSOR_UNARY_OP_LOCALS @@ -62,6 +64,7 @@ static void ggml_compute_forward_dup_f16( const int ir0 = dr * ith; const int ir1 = MIN(ir0 + dr, nr); + // case: type & row size equal if (src0->type == dst->type && ne00 == ne0 && nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { @@ -80,11 +83,11 @@ static void ggml_compute_forward_dup_f16( return; } - // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy - + // case: dst tensor is contiguous if (ggml_is_contiguous(dst)) { - if (nb00 == sizeof(ggml_fp16_t)) { - if (dst->type == GGML_TYPE_F16) { + if (nb00 == sizeof(src_t)) { + if constexpr (std::is_same_v) { + // same type size_t id = 0; const size_t rs = ne00 * nb00; char * dst_ptr = (char *) dst->data; @@ -100,91 +103,46 @@ static void ggml_compute_forward_dup_f16( id += rs * (ne01 - ir1); } } - } else if (dst->type == GGML_TYPE_F32) { + } else { + // casting between non-quantized types size_t id = 0; - float * dst_ptr = (float *) dst->data; + dst_t * dst_ptr = (dst_t *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { for (int i02 = 0; i02 < ne02; i02++) { id += ne00 * ir0; for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); for (int i00 = 0; i00 < ne00; i00++) { - dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]); + float tmp = type_conversion_table::to_f32(src0_ptr[i00]); + dst_ptr[id] = type_conversion_table::from_f32(tmp); id++; } } id += ne00 * (ne01 - ir1); } } - } else if (ggml_get_type_traits_cpu(dst->type)->from_float) { - ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float; - float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; - - size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - for (int i00 = 0; i00 < ne00; i00++) { - src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]); - } - - quantize_row_q(src0_f32, dst_ptr + id, ne00); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement } } else { //printf("%s: this is not optimal - fix me\n", __func__); - if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; + size_t id = 0; + dst_t * dst_ptr = (dst_t *) dst->data; - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr); - id++; - } + float tmp = type_conversion_table::to_f32(*src0_ptr); + dst_ptr[id] = type_conversion_table::from_f32(tmp); + id++; } - id += ne00 * (ne01 - ir1); } + id += ne00 * (ne01 - ir1); } - } else if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement } } return; @@ -196,7 +154,7 @@ static void ggml_compute_forward_dup_f16( int64_t i12 = 0; int64_t i13 = 0; - if (dst->type == GGML_TYPE_F16) { + if constexpr (std::is_same_v) { for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { i10 += ne00 * ir0; @@ -217,7 +175,7 @@ static void ggml_compute_forward_dup_f16( const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t)); + memcpy(dst_ptr, src0_ptr, sizeof(dst_t)); if (++i10 == ne00) { i10 = 0; @@ -248,7 +206,8 @@ static void ggml_compute_forward_dup_f16( } } } - } else if (dst->type == GGML_TYPE_F32) { + + } else { for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { i10 += ne00 * ir0; @@ -269,7 +228,8 @@ static void ggml_compute_forward_dup_f16( const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - *(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr); + float tmp = type_conversion_table::to_f32(*(const src_t *) src0_ptr); + *(dst_t *) dst_ptr = type_conversion_table::from_f32(tmp); if (++i10 == ne0) { i10 = 0; @@ -300,18 +260,19 @@ static void ggml_compute_forward_dup_f16( } } } - } else { - GGML_ABORT("fatal error"); // TODO: implement } } -static void ggml_compute_forward_dup_bf16( + +template +static void ggml_compute_forward_dup_to_q( const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(!ggml_is_quantized(src0->type)); GGML_TENSOR_UNARY_OP_LOCALS @@ -326,629 +287,36 @@ static void ggml_compute_forward_dup_bf16( const int ir0 = dr * ith; const int ir1 = MIN(ir0 + dr, nr); - if (src0->type == dst->type && - ne00 == ne0 && - nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { - // copy by rows - const size_t rs = ne00*nb00; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ir0; i01 < ir1; i01++) { - memcpy( - ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - rs); - } - } - } - return; - } + if (ggml_is_contiguous(dst) && + nb00 == sizeof(src_t) && + ggml_get_type_traits_cpu(dst->type)->from_float) { + // casting non-quantized types --> intermediate f32 --> quantized + ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float; + float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; - // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy + size_t id = 0; + size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + char * dst_ptr = (char *) dst->data; - if (ggml_is_contiguous(dst)) { - if (nb00 == sizeof(ggml_bf16_t)) { - if (dst->type == GGML_TYPE_BF16) { - size_t id = 0; - const size_t rs = ne00 * nb00; - char * dst_ptr = (char *) dst->data; + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, rs); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - for (int i00 = 0; i00 < ne00; i00++) { - dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00])); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - for (int i00 = 0; i00 < ne00; i00++) { - dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (ggml_get_type_traits_cpu(dst->type)->from_float) { - ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float; - float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; - - size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - for (int i00 = 0; i00 < ne00; i00++) { - src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]); - } - - quantize_row_q(src0_f32, dst_ptr + id, ne00); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_BF16) { - size_t id = 0; - ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr)); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } - return; - } - - // dst counters - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - if (dst->type == GGML_TYPE_BF16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t)); - - if (++i10 == ne00) { - i10 = 0; - if (++i11 == ne01) { - i11 = 0; - if (++i12 == ne02) { - i12 = 0; - if (++i13 == ne03) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == GGML_TYPE_F16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr)); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == GGML_TYPE_F32) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } + for (int i00 = 0; i00 < ne00; i00++) { + src0_f32[i00] = type_conversion_table::to_f32(src0_ptr[i00]); } + + quantize_row_q(src0_f32, dst_ptr + id, ne00); + id += rs; } + id += rs * (ne01 - ir1); } } } else { - GGML_ABORT("fatal error"); // TODO: implement - } -} - -static void ggml_compute_forward_dup_f32( - const ggml_compute_params * params, - ggml_tensor * dst) { - - const ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - - GGML_TENSOR_UNARY_OP_LOCALS - - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - // parallelize by rows - const int nr = ne01; - // number of rows per thread - const int dr = (nr + nth - 1) / nth; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (src0->type == dst->type && - ne00 == ne0 && - nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { - // copy by rows - const size_t rs = ne00*nb00; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ir0; i01 < ir1; i01++) { - memcpy( - ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - rs); - } - } - } - return; - } - - if (ggml_is_contiguous(dst)) { - // TODO: simplify - if (nb00 == sizeof(float)) { - if (ggml_get_type_traits_cpu(dst->type)->from_float) { - ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float; - - size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - from_float(src0_ptr, dst_ptr + id, ne00); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_BF16) { - size_t id = 0; - ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } - - return; - } - - // dst counters - - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - if (dst->type == GGML_TYPE_F32) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, sizeof(float)); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == GGML_TYPE_F16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == GGML_TYPE_BF16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement + // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type)); + GGML_ABORT("not implemented"); } } @@ -1102,7 +470,7 @@ static void ggml_compute_forward_dup_bytes( } } -static void ggml_compute_forward_dup_q( +static void ggml_compute_forward_dup_from_q( const ggml_compute_params * params, ggml_tensor * dst) { @@ -1167,20 +535,35 @@ void ggml_compute_forward_dup( switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_dup_f16(params, dst); + /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt(params, dst); + else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt(params, dst); + else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt(params, dst); + else ggml_compute_forward_dup_to_q(params, dst); } break; case GGML_TYPE_BF16: { - ggml_compute_forward_dup_bf16(params, dst); + /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt(params, dst); + else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt(params, dst); + else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt(params, dst); + else ggml_compute_forward_dup_to_q(params, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_dup_f32(params, dst); + /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt(params, dst); + else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt(params, dst); + else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt(params, dst); + else if (dst->type == GGML_TYPE_I32) ggml_compute_forward_dup_flt(params, dst); + else ggml_compute_forward_dup_to_q(params, dst); + } break; + case GGML_TYPE_I32: + { + if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt(params, dst); + else GGML_ABORT("not implemented"); } break; default: { if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) { - ggml_compute_forward_dup_q(params, dst); + ggml_compute_forward_dup_from_q(params, dst); break; } GGML_ABORT("fatal error"); @@ -4084,31 +3467,27 @@ static void ggml_compute_forward_norm_f32( GGML_ASSERT(eps >= 0.0f); - // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ith; i01 < ne01; i01 += nth) { const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - ggml_float sum = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)x[i00]; - } - + float sum = 0.0; + ggml_vec_sum_f32(ne00, &sum, x); float mean = sum/ne00; float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + float variance = 0; - ggml_float sum2 = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - float v = x[i00] - mean; - y[i00] = v; - sum2 += (ggml_float)(v*v); - } +#ifdef GGML_USE_ACCELERATE + mean = -mean; + vDSP_vsadd(x, 1, &mean, y, 1, ne00); + vDSP_measqv(y, 1, &variance, ne00); +#else + variance = ggml_vec_cvar_f32(ne00, y, x, mean); +#endif //GGML_USE_ACCELERATE - float variance = sum2/ne00; const float scale = 1.0f/sqrtf(variance + eps); - ggml_vec_scale_f32(ne00, y, scale); } } @@ -5356,6 +4735,7 @@ void ggml_compute_forward_get_rows( //} } +template static void ggml_compute_forward_set_rows_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -5394,7 +4774,7 @@ static void ggml_compute_forward_set_rows_f32( const int64_t i11 = i02%ne11; const int64_t i10 = i; - const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); GGML_ASSERT(i1 >= 0 && i1 < ne1); @@ -5411,11 +4791,18 @@ void ggml_compute_forward_set_rows( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_set_rows_f32(params, dst); + if (src1->type == GGML_TYPE_I64) { + ggml_compute_forward_set_rows_f32(params, dst); + } else if (src1->type == GGML_TYPE_I32) { + ggml_compute_forward_set_rows_f32(params, dst); + } else { + GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type)); + } } break; default: { @@ -7027,6 +6414,209 @@ void ggml_compute_forward_im2col_back_f32( } } + +// ggml_compute_forward_im2col_3d_f16 +// src0: kernel [OC*IC, KD, KH, KW] +// src1: image [N*IC, ID, IH, IW] +// dst: result [N*OD, OH, OW, IC * KD * KH * KW] +static void ggml_compute_forward_im2col_3d_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t OC = ne03 / IC; + GGML_UNUSED(OC); + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + const int64_t OH_OW = OH*OW; + const int64_t KD_KH_KW = KD*KH*KW; + const int64_t KH_KW = KH*KW; + const int64_t IC_KD_KH_KW = IC*KD*KH*KW; + + GGML_ASSERT(nb10 == sizeof(float)); + + // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t iod = 0; iod < OD; iod++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + + // micro kernel + ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW] + const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW] + + for (int64_t ikd = 0; ikd < KD; ikd++) { + for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + const int64_t iid = iod*s2 + ikd*d2 - p2; + + if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0; + } else { + const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW] + dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s); + } + } + } + } + } + } + } + } + } + } +} + +// ggml_compute_forward_im2col_3d_f32 +// src0: kernel [OC*IC, KD, KH, KW] +// src1: image [N*IC, ID, IH, IW] +// dst: result [N*OD, OH, OW, IC * KD * KH * KW] +static void ggml_compute_forward_im2col_3d_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t OC = ne03 / IC; + GGML_UNUSED(OC); + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + + const int64_t OH_OW = OH*OW; + const int64_t KD_KH_KW = KD*KH*KW; + const int64_t KH_KW = KH*KW; + const int64_t IC_KD_KH_KW = IC*KD*KH*KW; + + GGML_ASSERT(nb10 == sizeof(float)); + + // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] + { + float * const wdata = (float *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t iod = 0; iod < OD; iod++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + + // micro kernel + float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW] + const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW] + + for (int64_t ikd = 0; ikd < KD; ikd++) { + for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + const int64_t iid = iod*s2 + ikd*d2 - p2; + + if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0; + } else { + const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW] + dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s; + } + } + } + } + } + } + } + } + } + } +} + + +void ggml_compute_forward_im2col_3d( + const ggml_compute_params * params, + ggml_tensor * dst) { + switch (dst->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_im2col_3d_f16(params, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_im2col_3d_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k, void * a, void * b, float * c) { const ggml_type_traits * traits = ggml_get_type_traits(type); @@ -7207,6 +6797,148 @@ void ggml_compute_forward_conv_2d( ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type); } +// ggml_compute_forward_conv_3d + +static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params, + const ggml_tensor * kernel, + const ggml_tensor * src, + ggml_tensor * dst, + ggml_type kernel_type) { + + GGML_ASSERT(ggml_is_contiguous(kernel)); + GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32); + GGML_ASSERT(kernel->type == kernel_type); + + const ggml_type_traits * traits = ggml_get_type_traits(kernel_type); + + const int32_t s0 = dst->op_params[0]; + const int32_t s1 = dst->op_params[1]; + const int32_t s2 = dst->op_params[2]; + const int32_t p0 = dst->op_params[3]; + const int32_t p1 = dst->op_params[4]; + const int32_t p2 = dst->op_params[5]; + const int32_t d0 = dst->op_params[6]; + const int32_t d1 = dst->op_params[7]; + const int32_t d2 = dst->op_params[8]; + const int32_t c = dst->op_params[9]; + const int32_t n = dst->op_params[10]; + const int32_t oc = dst->op_params[11]; + + const int64_t src_w = src->ne[0]; + const int64_t src_h = src->ne[1]; + const int64_t src_d = src->ne[2]; + const int64_t knl_w = kernel->ne[0]; + const int64_t knl_h = kernel->ne[1]; + const int64_t knl_d = kernel->ne[2]; + const int64_t dst_w = dst->ne[0]; + const int64_t dst_h = dst->ne[1]; + const int64_t dst_d = dst->ne[2]; + + const float * src_data = (float *) src->data; + void * knl_data = kernel->data; + float * dst_data = (float *) dst->data; + + const int64_t knl_n_per_channel = knl_w * knl_h * knl_d; + const int64_t knl_n_total = knl_n_per_channel * c; + const int64_t patch_total = n * dst_w * dst_h * dst_d; + + const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float); + const int64_t batch_size = params->wsize / space_per_patch; + const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size; + const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch; + + GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1); + + void * tmp = params->wdata; + + for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) { + const int64_t patch_start_batch = batch_i * patches_per_batch; + const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total); + const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch; + + const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth; + const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread; + const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch); + + for (int64_t p = patch_start; p < patch_end; ++p) { + const int64_t p_in_batch = p % (dst_w * dst_h * dst_d); + const int64_t p_in_depth = p_in_batch % (dst_w * dst_h); + const int64_t batch_idx = p / (dst_w * dst_h * dst_d); + const int64_t dst_z = p_in_batch / (dst_w * dst_h); + const int64_t dst_y = p_in_depth / dst_w; + const int64_t dst_x = p_in_depth % dst_w; + + char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size; + + for (int64_t ic = 0; ic < c; ++ic) { + for (int64_t kz = 0; kz < knl_d; ++kz) { + for (int64_t ky = 0; ky < knl_h; ++ky) { + for (int64_t kx = 0; kx < knl_w; ++kx) { + const int64_t sz = dst_z * s2 + kz * d2 - p2; + const int64_t sy = dst_y * s1 + ky * d1 - p1; + const int64_t sx = dst_x * s0 + kx * d0 - p0; + + int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx; + + float src_val; + if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) { + src_val = 0.0f; + } else { + const int64_t cn_idx = batch_idx * c + ic; + const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]); + src_val = *src_ptr; + } + + char * element_ptr = dst_row + dst_idx * traits->type_size; + if (kernel_type == GGML_TYPE_F32) { + *(float *)element_ptr = src_val; + } else if (kernel_type == GGML_TYPE_F16) { + *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val); + } + } + } + } + } + } + + ggml_barrier(params->threadpool); + + float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size); + ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output); + + ggml_barrier(params->threadpool); + + const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth; + const int64_t permute_start = params->ith * permute_per_thread; + const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch); + + for (int64_t i = permute_start; i < permute_end; ++i) { + const int64_t p = patch_start_batch + i; + const int64_t p_in_batch = p % (dst_w * dst_h * dst_d); + const int64_t p_in_depth = p_in_batch % (dst_w * dst_h); + const int64_t batch_idx = p / (dst_w * dst_h * dst_d); + const int64_t dst_z = p_in_batch / (dst_w * dst_h); + const int64_t dst_y = p_in_depth / dst_w; + const int64_t dst_x = p_in_depth % dst_w; + + for (int64_t ioc = 0; ioc < oc; ++ioc) { + const float value = gemm_output[i * oc + ioc]; + const int64_t ocn_idx = batch_idx * oc + ioc; + float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]); + *dst_ptr = value; + } + } + } +} + +void ggml_compute_forward_conv_3d( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type); +} + // ggml_compute_forward_conv_transpose_2d void ggml_compute_forward_conv_transpose_2d( @@ -7872,6 +7604,15 @@ static void ggml_compute_forward_pad_f32( GGML_TENSOR_UNARY_OP_LOCALS float * dst_ptr = (float *) dst->data; + const int32_t lp0 = ggml_get_op_params_i32(dst, 0); + const int32_t rp0 = ggml_get_op_params_i32(dst, 1); + const int32_t lp1 = ggml_get_op_params_i32(dst, 2); + const int32_t rp1 = ggml_get_op_params_i32(dst, 3); + const int32_t lp2 = ggml_get_op_params_i32(dst, 4); + const int32_t rp2 = ggml_get_op_params_i32(dst, 5); + const int32_t lp3 = ggml_get_op_params_i32(dst, 6); + const int32_t rp3 = ggml_get_op_params_i32(dst, 7); + // TODO: optimize @@ -7880,10 +7621,12 @@ static void ggml_compute_forward_pad_f32( for (int64_t i0 = 0; i0 < ne0; ++i0) { for (int64_t i3 = 0; i3 < ne3; ++i3) { const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; - - const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + if ((i0 >= lp0 && i0 < ne0 - rp0) \ + && (i1 >= lp1 && i1 < ne1 - rp1) \ + && (i2 >= lp2 && i2 < ne2 - rp2) \ + && (i3 >= lp3 && i3 < ne3 - rp3)) { + const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00; + const float * src_ptr = (const float *)((char *) src0->data + src_idx); dst_ptr[dst_idx] = *src_ptr; } else { dst_ptr[dst_idx] = 0; @@ -8082,7 +7825,7 @@ static void ggml_compute_forward_timestep_embedding_f32( embed_data[j + half] = sinf(arg); } if (dim % 2 != 0 && ith == 0) { - embed_data[dim] = 0.f; + embed_data[2 * half] = 0.f; } } } @@ -8431,7 +8174,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( } // V /= S - const float S_inv = 1.0f/S; + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; ggml_vec_scale_f32(DV, VKQ32, S_inv); // dst indices @@ -8904,8 +8647,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); - // allows optimizing the modulo since n_group should be a power of 2 - GGML_ASSERT((ng & -ng) == ng); + GGML_ASSERT(nh % ng == 0); // heads per thread const int dh = (nh + nth - 1)/nth; @@ -8934,8 +8676,9 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dt_soft_plus = ggml_softplus(dt[h]); const float dA = expf(dt_soft_plus * A[h]); + const int g = h / (nh / ng); // repeat_interleave // dim for (int i1 = 0; i1 < nr; ++i1) { @@ -8958,8 +8701,8 @@ static void ggml_compute_forward_ssm_scan_f32( // TODO: maybe unroll more? for (int j = 0; j < 1; j++) { GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc); - GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc); - GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc); + GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc); t0 = GGML_F32_VEC_MUL(t0, adA); t1 = GGML_F32_VEC_MUL(t1, axdt); @@ -8973,6 +8716,9 @@ static void ggml_compute_forward_ssm_scan_f32( } sumf = GGML_F32xt_REDUCE_ONE(sum); + #elif defined(__riscv_v_intrinsic) + // todo: RVV implementation + const int np = 0; #else const int np = (nc & ~(GGML_F32_STEP - 1)); @@ -8988,8 +8734,8 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i = 0; i < np; i += GGML_F32_STEP) { for (int j = 0; j < GGML_F32_ARR; j++) { ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc); - ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); - az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc); + az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc); ax[j] = GGML_F32_VEC_MUL(ax[j], adA); ay[j] = GGML_F32_VEC_MUL(ay[j], axdt); @@ -9011,7 +8757,7 @@ static void ggml_compute_forward_ssm_scan_f32( // d_state for (int i0 = np; i0 < nc; ++i0) { const int i = i0 + ii*nc; - const int ig = i0 + (h & (ng - 1))*nc; + const int ig = i0 + g*nc; // state = prev_state * dA + dB * x const float state = (s0[i] * dA) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) @@ -9027,7 +8773,8 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dt_soft_plus = ggml_softplus(dt[h]); + const int g = h / (nh / ng); // repeat_interleave // dim for (int i1 = 0; i1 < nr; ++i1) { @@ -9042,8 +8789,8 @@ static void ggml_compute_forward_ssm_scan_f32( // TODO: what happens when (d_state % svcntw()) != 0? for (int64_t k = 0; k < nc; k += svcntw()) { svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]); - svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]); - svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]); + svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]); + svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]); svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]); svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); @@ -9063,7 +8810,7 @@ static void ggml_compute_forward_ssm_scan_f32( // d_state for (int i0 = 0; i0 < nc; ++i0) { const int i = i0 + ii*nc; - const int ig = i0 + (h & (ng - 1))*nc; + const int ig = i0 + g*nc; // state = prev_state * dA + dB * x const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) @@ -9289,6 +9036,10 @@ void ggml_compute_forward_unary( { ggml_compute_forward_exp(params, dst); } break; + case GGML_UNARY_OP_XIELU: + { + ggml_compute_forward_xielu(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -9924,8 +9675,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32( int64_t h_stride_2d = head_size * head_size; #if defined(GGML_SIMD) - #if defined(__ARM_FEATURE_SVE) - // scalar Route to scalar implementation //TODO: Write SVE code + #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic) + // scalar Route to scalar implementation //TODO: Write SVE code and RVV code for (int64_t t = 0; t < T; t++) { int64_t t_offset = t * t_stride; int64_t state_offset = head_size * C * (t / (T / n_seqs)); @@ -10373,6 +10124,7 @@ static void ggml_compute_forward_opt_step_adamw_f32( const int ir1 = MIN(ir0 + dr, nr); const float * adamw_params_ptr = ggml_get_data_f32(adamw_params); + const float alpha = adamw_params_ptr[0]; const float beta1 = adamw_params_ptr[1]; const float beta2 = adamw_params_ptr[2]; @@ -10380,7 +10132,7 @@ static void ggml_compute_forward_opt_step_adamw_f32( const float wd = adamw_params_ptr[4]; const float beta1h = adamw_params_ptr[5]; const float beta2h = adamw_params_ptr[6]; - + const float keep = 1.f - alpha * wd; for (int ir = ir0; ir < ir1; ++ir) { const int64_t i03 = ir/(ne02*ne01); const int64_t i02 = (ir - i03*ne02*ne01)/ne01; @@ -10403,7 +10155,7 @@ static void ggml_compute_forward_opt_step_adamw_f32( // The weight decay is applied independently of the Adam momenta m and v. // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss. // See: https://arxiv.org/pdf/1711.05101v3.pdf - w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh; + w[i00] = w[i00] * keep - alpha * mh / vh; } } } @@ -10425,3 +10177,63 @@ void ggml_compute_forward_opt_step_adamw( } } } + +static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src0_grad = dst->src[1]; + const ggml_tensor * sgd_params = dst->src[2]; + + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad)); + GGML_ASSERT(ggml_nelements(sgd_params) == 2); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1) / nth; + + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + // using adamw param subset we care about - alpha, wd - could have a separate struct + const float * sgd_params_ptr = ggml_get_data_f32(sgd_params); + const float alpha = sgd_params_ptr[0]; + const float keep = 1.f - alpha * sgd_params_ptr[1]; + + for (int ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02 * ne01); + const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; + const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + + const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01; + + float * w = (float *) ((char *) src0->data + offset); // weight + const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad + + for (int i00 = 0; i00 < ne00; ++i00) { + w[i00] = w[i00] * keep - alpha * g[i00]; + } + } +} + +void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_opt_step_sgd_f32(params, dst); + } + break; + default: + { + GGML_ABORT("fatal error - sgd is F32 only"); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ops.h b/ml/backend/ggml/ggml/src/ggml-cpu/ops.h index f154afb4..9824a03b 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ops.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ops.h @@ -69,7 +69,9 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); @@ -107,7 +109,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst); - +void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus } #endif diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp index 2583aefa..f531d21e 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp @@ -206,8 +206,9 @@ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const int ncols_interleaved = 4; const int blocklen = 4; - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); UNUSED(s); UNUSED(bs); @@ -307,30 +308,28 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, UNUSED(ncols_interleaved); UNUSED(blocklen); - { - float sumf[8]; - int sumi; + float sumf[8]; + int sumi; - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); } } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } } @@ -494,43 +493,73 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs const int ncols_interleaved = 4; const int blocklen = 4; - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); - UNUSED(s); UNUSED(bs); - UNUSED(vx); - UNUSED(vy); UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - { - float sumf[4]; - int sumi; + float sumf[4]; + int sumi; - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); } } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } } @@ -934,6 +963,50 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][8]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + } // extern "C" static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) { @@ -1285,15 +1358,16 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); - //GGML_ASSERT(interleave_block == 4 || interleave_block == 8); GGML_ASSERT(interleave_block == 4); - block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data; - const block_iq4_nl * src = (const block_iq4_nl *)data; + const block_iq4_nl * src = (const block_iq4_nl *)data; + block_iq4_nlx4 * dst = ( block_iq4_nlx4 *)t->data; + block_iq4_nl dst_tmp[4]; + int nrow = ggml_nrows(t); int nrows_interleaved = 4; - int nblocks = t->ne[0] / QK4_0; + int nblocks = t->ne[0] / QK4_NL; GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); @@ -1315,6 +1389,63 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_b GGML_UNUSED(data_size); } +static block_iq4_nlx8 make_block_iq4_nlx8(block_iq4_nl * in, unsigned int blck_size_interleave) { + block_iq4_nlx8 out; + + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_NL * 4 / blck_size_interleave; + + if (blck_size_interleave == 8) { + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); + GGML_ASSERT(interleave_block == 8); + + const block_iq4_nl * src = (const block_iq4_nl *)data; + block_iq4_nlx8 * dst = ( block_iq4_nlx8 *)t->data; + + block_iq4_nl dst_tmp[8]; + + int nrow = ggml_nrows(t); + int nrows_interleaved = 8; + int nblocks = t->ne[0] / QK4_NL; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); + + if (t->ne[1] % nrows_interleaved != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_iq4_nlx8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + namespace ggml::cpu::repack { // repack template @@ -1350,6 +1481,10 @@ template <> int repack(struct ggml_tensor * t, const void * // return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size); //} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size); +} + // gemv template void gemv(int, float *, size_t, const void *, const void *, int, int); @@ -1378,6 +1513,10 @@ template <> void gemv(int n, float * s, size ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + // gemm template void gemm(int, float *, size_t, const void *, const void *, int, int); @@ -1406,6 +1545,10 @@ template <> void gemm(int n, float * s, size ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + class tensor_traits_base : public ggml::cpu::tensor_traits { public: virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; @@ -1680,6 +1823,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons // instance for IQ4 static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0; if (cur->type == GGML_TYPE_Q4_0) { if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { @@ -1710,6 +1854,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons } } } else if (cur->type == GGML_TYPE_IQ4_NL) { + if (ggml_cpu_has_avx2()) { + if (cur->ne[1] % 8 == 0) { + return &iq4_nl_8x8_q8_0; + } + } if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { if (cur->ne[1] % 4 == 0) { return &iq4_nl_4x4_q8_0; diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/repack.h b/ml/backend/ggml/ggml/src/ggml-cpu/repack.h index cd322e74..cb32b503 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/repack.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/repack.h @@ -67,6 +67,13 @@ struct block_iq4_nlx4 { static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding"); +struct block_iq4_nlx8 { + ggml_half d[8]; // deltas for 8 iq4_nl blocks + uint8_t qs[QK4_NL * 4]; // nibbles / quants for 8 iq4_nl blocks +}; + +static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding"); + #if defined(__cplusplus) extern "C" { #endif @@ -80,12 +87,14 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); // Native implementations void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); @@ -97,12 +106,14 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined(__cplusplus) } // extern "C" diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h b/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h index b4ad68c9..8daec663 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h @@ -18,6 +18,10 @@ #include #endif +#if defined(__riscv_v_intrinsic) +#include +#endif + #ifdef __cplusplus extern "C" { #endif @@ -94,24 +98,15 @@ extern "C" { } #elif defined(__riscv) && defined(__riscv_zfhmin) static inline float riscv_compute_fp16_to_fp32(ggml_fp16_t h) { - float f; - __asm__( - "fmv.h.x %[f], %[h]\n\t" - "fcvt.s.h %[f], %[f]" - : [f] "=&f" (f) - : [h] "r" (h) - ); - return f; + _Float16 hf; + memcpy(&hf, &h, sizeof(ggml_fp16_t)); + return hf; } static inline ggml_fp16_t riscv_compute_fp32_to_fp16(float f) { ggml_fp16_t res; - __asm__( - "fcvt.h.s %[f], %[f]\n\t" - "fmv.x.h %[h], %[f]" - : [h] "=&r" (res) - : [f] "f" (f) - ); + _Float16 hf = (_Float16)f; + memcpy(&res, &hf, sizeof(ggml_fp16_t)); return res; } @@ -119,26 +114,6 @@ extern "C" { #define GGML_CPU_COMPUTE_FP32_TO_FP16(x) riscv_compute_fp32_to_fp16(x) #define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x) #define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x) -#elif defined(__NNPA__) - #define GGML_CPU_COMPUTE_FP16_TO_FP32(x) nnpa_compute_fp16_to_fp32(x) - #define GGML_CPU_COMPUTE_FP32_TO_FP16(x) nnpa_compute_fp32_to_fp16(x) - - #define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x) - #define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x) - - static inline float nnpa_compute_fp16_to_fp32(ggml_fp16_t h) { - uint16x8_t v_h = vec_splats(h); - uint16x8_t v_hd = vec_convert_from_fp16(v_h, 0); - return vec_extend_to_fp32_hi(v_hd, 0)[0]; - } - - static inline ggml_fp16_t nnpa_compute_fp32_to_fp16(float f) { - float32x4_t v_f = vec_splats(f); - float32x4_t v_zero = vec_splats(0.0f); - uint16x8_t v_hd = vec_round_from_fp32(v_f, v_zero, 0); - uint16x8_t v_h = vec_convert_to_fp16(v_hd, 0); - return vec_extract(v_h, 0); - } #endif // precomputed f32 table for f16 (256 KB) @@ -220,6 +195,47 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { #define GGML_F32_VEC_MUL GGML_F32xt_MUL #define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE +// F16 SVE +#define DEFAULT_PG32 svptrue_b32() +#define DEFAULT_PG16 svptrue_b16() + +#define GGML_F32Cxt svfloat16_t +#define GGML_F32Cxt_ZERO svdup_n_f16(0.0f) +#define GGML_F32Cxt_SET1(x) svdup_n_f16(x) +#define GGML_F32Cxt_LOAD(p) svld1_f16(DEFAULT_PG16, (const __fp16 *)(p)) +#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec)) + +#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a) +#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b) +#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b) +#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED + +#define GGML_F16x_VEC GGML_F32Cxt +#define GGML_F16x_VEC_ZERO GGML_F32Cxt_ZERO +#define GGML_F16x_VEC_SET1 GGML_F32Cxt_SET1 +#define GGML_F16x_VEC_LOAD(p, i) GGML_F32Cxt_LOAD(p) +#define GGML_F16x_VEC_STORE(p, r, i) GGML_F32Cxt_STORE((__fp16 *)(p), r) +#define GGML_F16x_VEC_FMA GGML_F32Cxt_FMA +#define GGML_F16x_VEC_ADD GGML_F32Cxt_ADD +#define GGML_F16x_VEC_MUL GGML_F32Cxt_MUL +#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE + +#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a) +#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__) + +#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \ +{ \ + sum1 = svadd_f16_x(pg16, sum1, sum2); \ + sum3 = svadd_f16_x(pg16, sum3, sum4); \ + sum1 = svadd_f16_x(pg16, sum1, sum3); \ + __fp16 sum_f16 = svaddv_f16(pg16, sum1); \ + (res) = (ggml_float) sum_f16; \ +} +#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__) + // F16 NEON #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) @@ -982,9 +998,9 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) { #define GGML_F32_EPR 4 #define GGML_F32x4 __m128 -#define GGML_F32x4_ZERO __lsx_vldi(0) -#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) -#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0) +#define GGML_F32x4_ZERO (__m128)__lsx_vldi(0) +#define GGML_F32x4_SET1(x) (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) +#define GGML_F32x4_LOAD(x) (__m128)__lsx_vld((x), 0) #define GGML_F32x4_STORE(x, y) __lsx_vst(y, x, 0) #define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a) #define GGML_F32x4_ADD __lsx_vfadd_s @@ -1006,7 +1022,7 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) { __m128i tmp = __lsx_vsrli_d((__m128i) x[0], 32); \ tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]); \ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ - const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \ + const __m128 t0 = (__m128)__lsx_vshuf4i_w(tmp, 0x88); \ tmp = __lsx_vsrli_d((__m128i) t0, 32); \ tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, t0); \ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ @@ -1036,7 +1052,7 @@ static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) { tmp[2] = GGML_CPU_FP16_TO_FP32(x[2]); tmp[3] = GGML_CPU_FP16_TO_FP32(x[3]); - return __lsx_vld(tmp, 0); + return (__m128)__lsx_vld(tmp, 0); } static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { @@ -1051,9 +1067,9 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { } #define GGML_F32Cx4 __m128 -#define GGML_F32Cx4_ZERO __lsx_vldi(0) -#define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) -#define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x) +#define GGML_F32Cx4_ZERO (__m128)__lsx_vldi(0) +#define GGML_F32Cx4_SET1(x) (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) +#define GGML_F32Cx4_LOAD(x) (__m128)__lsx_f16x4_load(x) #define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y) #define GGML_F32Cx4_FMA GGML_F32x4_FMA #define GGML_F32Cx4_ADD __lsx_vfadd_s @@ -1120,11 +1136,6 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { #define GGML_F16_EPR GGML_F32_EPR static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) { -#if defined(__NNPA__) - uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)x); - uint16x8_t v_xd = vec_convert_from_fp16(v_x, 0); - return vec_extend_to_fp32_hi(v_xd, 0); -#else float tmp[4]; for (int i = 0; i < 4; i++) { @@ -1134,20 +1145,9 @@ static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) { // note: keep type-cast here to prevent compiler bugs // see: https://github.com/ggml-org/llama.cpp/issues/12846 return vec_xl(0, (const float *)(tmp)); -#endif } static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { -#if defined(__NNPA__) - float32x4_t v_zero = vec_splats(0.0f); - uint16x8_t v_xd = vec_round_from_fp32(v_y, v_zero, 0); - uint16x8_t v_x = vec_convert_to_fp16(v_xd, 0); - - x[0] = vec_extract(v_x, 0); - x[1] = vec_extract(v_x, 1); - x[2] = vec_extract(v_x, 2); - x[3] = vec_extract(v_x, 3); -#else float arr[4]; // note: keep type-cast here to prevent compiler bugs @@ -1157,7 +1157,6 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { for (int i = 0; i < 4; i++) { x[i] = GGML_CPU_FP32_TO_FP16(arr[i]); } -#endif } #define GGML_F16_VEC GGML_F32x4 @@ -1170,6 +1169,36 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { #define GGML_F16_VEC_MUL GGML_F32x4_MUL #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE +#elif defined(__riscv_v_intrinsic) + +// compatible with vlen >= 128 + +#define GGML_SIMD + +// F32 + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 vfloat32m1_t +#define GGML_F32x4_ZERO __riscv_vfmv_v_f_f32m1(0.0f, GGML_F32_EPR) +#define GGML_F32x4_SET1(x) __riscv_vfmv_v_f_f32m1(x, GGML_F32_EPR) +#define GGML_F32x4_LOAD(x) __riscv_vle32_v_f32m1(x, GGML_F32_EPR) +#define GGML_F32x4_STORE(b, v) __riscv_vse32_v_f32m1(b, v, GGML_F32_EPR) +#define GGML_F32x4_FMA(a, b, c) __riscv_vfmacc_vv_f32m1(a, b, c, GGML_F32_EPR) +#define GGML_F32x4_ADD(a, b) __riscv_vfadd_vv_f32m1(a, b, GGML_F32_EPR) +#define GGML_F32x4_MUL(a, b) __riscv_vfmul_vv_f32m1(a, b, GGML_F32_EPR) + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + #endif // GGML_F32_ARR / GGML_F16_ARR diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp index 4fce569b..cf1a4615 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp @@ -52,6 +52,15 @@ static inline float op_sqrt(float x) { return sqrtf(x); } +static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) { + if (x > 0.0f) { + return alpha_p * x * x + beta * x; + } else { + const float min_x_eps = fminf(x, eps); + return (expm1f(min_x_eps) - x) * alpha_n + beta * x; + } +} + static inline float op_sin(float x) { return sinf(x); } @@ -121,6 +130,86 @@ static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) { } } +template +static void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) { + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + apply_unary_op(params, dst); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type)); + GGML_ABORT("fatal error"); + } +} + +// Extend vec_unary_op to support functors +template +static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) { + constexpr auto src0_to_f32 = type_conversion_table::to_f32; + constexpr auto f32_to_dst = type_conversion_table::from_f32; + + for (int i = 0; i < n; i++) { + y[i] = f32_to_dst(op(src0_to_f32(x[i]))); + } +} + +// Extend apply_unary_op to support functors +template +static void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(dst_t)); + GGML_ASSERT(nb00 == sizeof(src0_t)); + + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op); + } +} + +// Generic dispatcher for functors +template +static void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) { + const ggml_tensor * src0 = dst->src[0]; + + /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32 + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16 + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16 + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) { + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + apply_unary_op_functor(params, dst, op); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type)); + GGML_ABORT("fatal error"); + } +} + void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) { unary_op(params, dst); } @@ -184,3 +273,17 @@ void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) { unary_op(params, dst); } + +void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) { + const float alpha_n = ggml_get_op_params_f32(dst, 1); + const float alpha_p = ggml_get_op_params_f32(dst, 2); + const float beta = ggml_get_op_params_f32(dst, 3); + const float eps = ggml_get_op_params_f32(dst, 4); + + const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) { + return op_xielu(f, alpha_n, alpha_p, beta, eps); + }; + + unary_op_functor(params, dst, xielu_op_params); +} + diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.h b/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.h index b1ade2c8..697c1e0d 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.h @@ -22,6 +22,7 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus } diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp index 07b377bd..43dc7537 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp @@ -84,6 +84,22 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G } // reduce sum1,sum2 to sum1 GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8); + #elif defined(__riscv_v_intrinsic) + int vl = __riscv_vsetvlmax_e32m8(); + vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1); + vfloat32m8_t vsum; + vfloat32m8_t ax; + vfloat32m8_t ay; + vsum = __riscv_vfmv_v_f_f32m8_tu(vsum, 0.0f, vl); + for (int i = 0; i < n; i += vl) { + vl = __riscv_vsetvl_e32m8(n - i); + ax = __riscv_vle32_v_f32m8_tu(ax, &x[i], vl); + ay = __riscv_vle32_v_f32m8_tu(ay, &y[i], vl); + vsum = __riscv_vfmacc_vv_f32m8_tu(vsum, ax, ay, vl); + } + vl = __riscv_vsetvlmax_e32m8(); + vs = __riscv_vfredusum_vs_f32m8_f32m1(vsum, vs, vl); + sumf += __riscv_vfmv_f_s_f32m1_f32(vs); #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -197,38 +213,125 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G ggml_float sumf = 0.0; + #if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); + #if defined(__ARM_FEATURE_SVE) + const int sve_register_length = svcntb() * 8; //get vector length + const int ggml_f16_epr = sve_register_length / 16; // running when 16 + const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers - GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; + const int np= (n & ~(ggml_f16_step - 1)); + svfloat16_t sum1 = svdup_n_f16(0.0f); + svfloat16_t sum2 = svdup_n_f16(0.0f); + svfloat16_t sum3 = svdup_n_f16(0.0f); + svfloat16_t sum4 = svdup_n_f16(0.0f); - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; + svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; + svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + for (int i = 0; i < np; i += ggml_f16_step) { + ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); + sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1); - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); + sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2); - sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); + ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); + sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3); + + ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); + ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); + sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4); + + ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); + ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); + sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5); + + ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); + ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); + sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6); + + ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); + ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); + sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7); + + ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); + ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); + sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8); } - } - // reduce sum0..sum3 to sum0 - GGML_F16_VEC_REDUCE(sumf, sum); + const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8 + for (int k = np; k < np2; k += ggml_f16_epr) { + svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); + svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry); + } - // leftovers - for (int i = np; i < n; ++i) { - sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); - } + if (np2 < n) { + svbool_t pg = svwhilelt_b16(np2, n); + svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); + svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); - // if you hit this, you are likely running outside the FP range - assert(!isnan(sumf) && !isinf(sumf)); + sum1 = svmad_f16_x(pg, hx, hy, sum1); + } + GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4); + #elif defined(__riscv_v_intrinsic) + #if defined(__riscv_zvfh) + int vl = __riscv_vsetvlmax_e32m2(); + vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1); + vfloat32m2_t vsum; + vfloat16m1_t ax; + vfloat16m1_t ay; + vsum = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vmv_v_x_u32m2(0, vl)); + for (int i = 0; i < n; i += vl) { + vl = __riscv_vsetvl_e16m1(n - i); + ax = __riscv_vle16_v_f16m1_tu(ax, (const _Float16 *)&x[i], vl); + ay = __riscv_vle16_v_f16m1_tu(ay, (const _Float16 *)&y[i], vl); + vsum = __riscv_vfwmacc_vv_f32m2_tu(vsum, ax, ay, vl); + } + vl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t ac0 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(vsum, 0), __riscv_vget_v_f32m2_f32m1(vsum, 1), vl); + vs = __riscv_vfredusum_vs_f32m1_f32m1(ac0, vs, vl); + sumf += __riscv_vfmv_f_s_f32m1_f32(vs); + #else + for (int i = 0; i < n; ++i) { + sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); + } + #endif // __riscv_zvfh + #else + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + + sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F16_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); + } + // if you hit this, you are likely running outside the FP range + assert(!isnan(sumf) && !isinf(sumf)); + #endif #else for (int i = 0; i < n; ++i) { sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); } -#endif +#endif // GGML_SIMD *s = sumf; } @@ -247,6 +350,12 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) { for (; i + 3 < n; i += 4) { _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i))); } +#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + const int vlen = svcntw(); + for (; i < n; i += vlen) { + const svbool_t pg = svwhilelt_b32_s32(i, n); + svst1_f32(pg, y + i, ggml_v_silu(pg, svld1_f32(pg, x + i))); + } #elif defined(__ARM_NEON) && defined(__aarch64__) for (; i + 3 < n; i += 4) { vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i))); @@ -271,16 +380,96 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * for (; i + 3 < n; i += 4) { _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i))); } +#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + const int vlen = svcntw(); + for (; i < n; i += vlen) { + const svbool_t pg = svwhilelt_b32_s32(i, n); + svst1_f32(pg, y + i, svmul_f32_x(pg, ggml_v_silu(pg, svld1_f32(pg, x + i)), svld1_f32(pg, g + i))); + } #elif defined(__ARM_NEON) && defined(__aarch64__) for (; i + 3 < n; i += 4) { vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i))); } +#elif defined(__riscv_v_intrinsic) + for (int vl; i < n; i += vl) { + vl = __riscv_vsetvl_e32m2(n - i); + vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl); + vfloat32m2_t vg = __riscv_vle32_v_f32m2(&g[i], vl); + vfloat32m2_t vy = __riscv_vfmul_vv_f32m2(ggml_v_silu_m2(vx, vl), vg, vl); + __riscv_vse32_v_f32m2(&y[i], vy, vl); + } #endif for (; i < n; ++i) { y[i] = ggml_silu_f32(x[i]) * g[i]; } } +ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) { + int i = 0; + ggml_float sum = 0; +// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE +// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344 +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + __m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i), + _mm512_set1_ps(mean)); + _mm512_storeu_ps(y + i, val); + sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val)); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + __m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i), + _mm256_set1_ps(mean)); + _mm256_storeu_ps(y + i, val); + val = _mm256_mul_ps(val,val); + __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), + _mm256_castps256_ps128(val)); + val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); + val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); + sum += (ggml_float)_mm_cvtss_f32(val2); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + __m128 val = _mm_sub_ps(_mm_loadu_ps(x + i), + _mm_set1_ps(mean)); + _mm_storeu_ps(y + i, val); + val = _mm_mul_ps(val, val); +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) + val = _mm_add_ps(val, _mm_movehl_ps(val, val)); + val = _mm_add_ss(val, _mm_movehdup_ps(val)); +#else + __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1)); + val = _mm_add_ps(val, tmp); + tmp = _mm_movehl_ps(tmp, val); + val = _mm_add_ss(val, tmp); +#endif // __AVX__ || __AVX2__ || __AVX512F__ + sum += (ggml_float)_mm_cvtss_f32(val); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + float32x4_t val = vsubq_f32(vld1q_f32(x + i), + vdupq_n_f32(mean)); + vst1q_f32(y + i, val); + val = vmulq_f32(val, val); + sum += (ggml_float)vaddvq_f32(val); + } +#elif defined(__VXE__) || defined(__VXE2__) + for (; i + 3 < n; i += 4) { + float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean)); + vec_xst(val, 0, y + i); + val = vec_mul(val, val); + sum += (ggml_float)vec_hsum_f32x4(val); + } +#endif + for (; i < n; ++i) { + float val = x[i] - mean; + y[i] = val; + val *= val; + sum += (ggml_float)val; + } + return sum/n; +} + ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; ggml_float sum = 0; @@ -318,6 +507,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float #endif sum += (ggml_float)_mm_cvtss_f32(val); } +#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + const int vlen = svcntw(); + for (; i < n; i += vlen) { + const svbool_t pg = svwhilelt_b32_s32(i, n); + svfloat32_t val = ggml_v_expf(pg, svsub_f32_x(pg, svld1_f32(pg, x + i), + svdup_n_f32_x(pg, max))); + svst1_f32(pg, y + i, val); + sum += (ggml_float)svaddv_f32(pg, val); + } #elif defined(__ARM_NEON) && defined(__aarch64__) for (; i + 3 < n; i += 4) { float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i), @@ -325,6 +523,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float vst1q_f32(y + i, val); sum += (ggml_float)vaddvq_f32(val); } +#elif defined(__riscv_v_intrinsic) + vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1); + for (int avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m2(n - i); + vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl); + __riscv_vse32_v_f32m2(&y[i], val, avl); + vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl); + } + return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum); #endif for (; i < n; ++i) { float val = expf(x[i] - max); diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/vec.h b/ml/backend/ggml/ggml/src/ggml-cpu/vec.h index 2250d93c..65c7dfb6 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/vec.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/vec.h @@ -44,6 +44,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc); void ggml_vec_silu_f32(const int n, float * y, const float * x); +ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean); //it will also center y ( y = y - mean ) ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max); ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max); @@ -119,36 +120,149 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG } #if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); + #if defined(__ARM_FEATURE_SVE) - GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; + const int sve_register_length = svcntb() * 8; + const int ggml_f16_epr = sve_register_length / 16; // running when 16 + const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; + const int np = (n & ~(ggml_f16_step - 1)); - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + svfloat16_t sum_00 = svdup_n_f16(0.0f); + svfloat16_t sum_01 = svdup_n_f16(0.0f); + svfloat16_t sum_02 = svdup_n_f16(0.0f); + svfloat16_t sum_03 = svdup_n_f16(0.0f); - for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { - ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); + svfloat16_t sum_10 = svdup_n_f16(0.0f); + svfloat16_t sum_11 = svdup_n_f16(0.0f); + svfloat16_t sum_12 = svdup_n_f16(0.0f); + svfloat16_t sum_13 = svdup_n_f16(0.0f); - sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); + svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; + svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + + for (int i = 0; i < np; i += ggml_f16_step) { + ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements + + ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements + sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1 + ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements + sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1); + + ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements + + ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements + sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2); + ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1); + sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2); + + ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); + + ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2); + sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3); + ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2); + sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3); + + ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); + + ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3); + sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4); + ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3); + sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4); + + ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); + + ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4); + + sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5); + ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4); + sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5); + + ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); + + ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5); + + sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6); + ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5); + sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6); + + ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); + + ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6); + + sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7); + ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6); + sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7); + + ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); + + ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7); + + sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8); + ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7); + sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8); + } + + const int np2 = (n & ~(ggml_f16_epr - 1)); + for (int k = np; k < np2; k += ggml_f16_epr) { + svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + + svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0); + sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry); + rx = GGML_F16x_VEC_LOAD(x[1] + k, 0); + sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry); + } + + if (np2 < n) { + svbool_t pg = svwhilelt_b16(np2, n); + svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2)); + svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2)); + svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + + sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00); + sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10); + } + GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03); + GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13); + #elif defined(__riscv_v_intrinsic) + // todo: RVV impl + for (int i = 0; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); + } + } + #else + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); + + sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); + } } } - } - // reduce sum0..sum3 to sum0 - for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { - GGML_F16_VEC_REDUCE(sumf[k], sum[k]); - } - - // leftovers - for (int i = np; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); + // reduce sum0..sum3 to sum0 + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + GGML_F16_VEC_REDUCE(sumf[k], sum[k]); } - } + + // leftovers + for (int i = np; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); + } + } + #endif #else for (int i = 0; i < n; ++i) { for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { @@ -243,6 +357,14 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const svst1_f32(pg, y + np2, ay1); } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl); + vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); + vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, v, ay, avl); + __riscv_vse32_v_f32m8(&y[i], ny, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -276,27 +398,112 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) { #if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); + #if defined(__ARM_FEATURE_SVE) + const int sve_register_length = svcntb() * 8; + const int ggml_f16_epr = sve_register_length / 16; + const int ggml_f16_step = 8 * ggml_f16_epr; - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; + const int np= (n & ~(ggml_f16_step - 1)); - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; + svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + for (int i = 0; i < np; i += ggml_f16_step) { + ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx); - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0); + + ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx); + + GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1); + + ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); + ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); + ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx); + + GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2); + + ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); + ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); + ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx); + + GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3); + + ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); + ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); + ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx); + + GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4); + + ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); + ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); + ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx); + + GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5); + + ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); + ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); + ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx); + + GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6); + + ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); + ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); + ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx); + + GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7); } - } + const int np2 = (n & ~(ggml_f16_epr - 1)); + for (int k = np; k < np2; k += ggml_f16_epr) { + svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); + svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + ry = GGML_F16x_VEC_FMA(ry, rx, vx); - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); - } + GGML_F16x_VEC_STORE(y + k, ry, 0); + } + + if (np2 < n) { + svbool_t pg = svwhilelt_b16(np2, n); + svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); + svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + hy = svmad_f16_x(pg, hx, vx, hy); + svst1_f16(pg, (__fp16 *)(y + np2), hy); + } + + #elif defined(__riscv_v_intrinsic) + // todo: RVV impl + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); + } + #else + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); + } + #endif #else // scalar for (int i = 0; i < n; ++i) { @@ -324,6 +531,16 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int y[i] += x[k][i]*v[k][0]; } } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); + for (int k = 0; k < GGML_VEC_MAD_UNROLL; k++) { + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[k][i], avl); + ay = __riscv_vfmadd_vf_f32m8(ax, v[k][0], ay, avl); + } + __riscv_vse32_v_f32m8(&y[i], ay, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -375,6 +592,14 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co for (int i = 0; i < n; ++i) { y[i] = x[i]*s + b; } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl); + vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl); + vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl); + __riscv_vse32_v_f32m8(&y[i], ny, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -386,7 +611,7 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co for (int i = 0; i < np; i += GGML_F32_STEP) { for (int j = 0; j < GGML_F32_ARR; j++) { ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); + ay[j] = GGML_F32_VEC_FMA(vb, ay[j], vs); GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); } @@ -430,11 +655,18 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { } // leftovers // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only - if (np < n) { - svbool_t pg = svwhilelt_b32(np, n); - ay1 = svld1_f32(pg, y + np); + for (int i = np; i < n; i += ggml_f32_epr) { + svbool_t pg = svwhilelt_b32(i, n); + ay1 = svld1_f32(pg, y + i); ay1 = svmul_f32_m(pg, ay1, vx); - svst1_f32(pg, y + np, ay1); + svst1_f32(pg, y + i, ay1); + } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); + vfloat32m8_t ny = __riscv_vfmul_vf_f32m8(ay, v, avl); + __riscv_vse32_v_f32m8(&y[i], ny, avl); } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -467,25 +699,59 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { #if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); + #if defined(__ARM_FEATURE_SVE) + const int sve_register_length = svcntb() * 8; + const int ggml_f16_epr = sve_register_length / 16; + const int ggml_f16_step = 2 * ggml_f16_epr; - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); + const int np = (n & ~(ggml_f16_step - 1)); + svfloat16_t ay1, ay2; - GGML_F16_VEC ay[GGML_F16_ARR]; + for (int i = 0; i < np; i += ggml_f16_step) { + ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_MUL(ay1, vx); + GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0); - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_MUL(ay[j], vx); - - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_MUL(ay2, vx); + GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1); } - } + // leftovers + // maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only + if (np < n) { + svbool_t pg = svwhilelt_b16(np, n); + svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np)); + svfloat16_t out = svmul_f16_m(pg, hy, vx); + svst1_f16(pg, (__fp16 *)(y + np), out); + } + #elif defined(__riscv_v_intrinsic) + // todo: RVV impl + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); + } + #else + const int np = (n & ~(GGML_F16_STEP - 1)); - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); - } + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); + } + #endif #else // scalar for (int i = 0; i < n; ++i) { @@ -554,7 +820,8 @@ inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_f inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); } inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(expm1f(GGML_CPU_FP16_TO_FP32(x[i]))); + const float v = GGML_CPU_FP16_TO_FP32(x[i]); + y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : expm1f(v)); } } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } @@ -737,7 +1004,39 @@ https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/sr } #endif -#if defined(__ARM_NEON) && defined(__aarch64__) +#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + +inline static svfloat32_t ggml_v_expf(svbool_t pg, svfloat32_t x) { + const svfloat32_t r = svdup_n_f32_x(pg, 0x1.8p23f); + const svfloat32_t z = svmla_n_f32_x(pg, r, x, 0x1.715476p+0f); + const svfloat32_t n = svsub_f32_x(pg, z, r); + const svfloat32_t b = svmls_n_f32_x(pg, svmls_n_f32_x(pg, x, n, 0x1.62e4p-1f), n, 0x1.7f7d1cp-20f); + const svuint32_t e = svlsl_n_u32_x(pg, svreinterpret_u32_f32(z), 23); + const svfloat32_t k = svreinterpret_f32_u32(svadd_u32_x(pg, e, svreinterpret_u32_f32(svdup_n_f32_x(pg, 1)))); + const svbool_t c = svacgt_n_f32(pg, n, 126); + const svfloat32_t u = svmul_f32_x(pg, b, b); + const svfloat32_t j = svmla_f32_x(pg, + svmul_n_f32_x(pg, b, 0x1.ffffecp-1f), + svmla_f32_x(pg, svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.fffdb6p-2f), svdup_n_f32_x(pg, 0x1.555e66p-3f), b), + svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.573e2ep-5f), svdup_n_f32_x(pg, 0x1.0e4020p-7f), b), u), u); + const svuint32_t d = svdup_n_u32_z(svcmple_n_f32(pg, n, 0.0), 0x82000000); + const svfloat32_t s1 = svreinterpret_f32_u32(svadd_n_u32_x(pg, d, 0x7f000000)); + const svfloat32_t s2 = svreinterpret_f32_u32(svsub_u32_x(pg, e, d)); + return svsel_f32(svacgt_f32(pg, n, svdup_n_f32_x(pg, 192)), svmul_f32_x(pg, s1, s1), + svsel_f32(c, svmul_f32_x(pg, svmla_f32_x(pg, s2, s2, j), s1), svmla_f32_x(pg, k, k, j))); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) { + const svfloat32_t one = svdup_n_f32_x(pg, 1.0f); + const svfloat32_t zero = svdup_n_f32_x(pg, 0.0f); + const svfloat32_t neg_x = svsub_f32_x(pg, zero, x); + const svfloat32_t exp_neg_x = ggml_v_expf(pg, neg_x); + const svfloat32_t one_plus_exp_neg_x = svadd_f32_x(pg, one, exp_neg_x); + return svdiv_f32_x(pg, x, one_plus_exp_neg_x); +} + +#elif defined(__ARM_NEON) && defined(__aarch64__) // adapted from arm limited optimized routine // the maximum error is 1.45358 plus 0.5 ulps @@ -928,7 +1227,59 @@ inline static __m128 ggml_v_silu(__m128 x) { return _mm_div_ps(x, one_plus_exp_neg_x); } -#endif // __ARM_NEON / __AVX2__ / __SSE2__ +#elif defined(__riscv_v_intrinsic) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) { + const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl); +#ifdef __riscv_xtheadvector + // workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4') + vfloat32m2_t z = __riscv_vfadd_vf_f32m2(r, 0.0f, vl); + z = __riscv_vfmacc_vf_f32m2(z, 0x1.715476p+0f, x, vl); +#else + const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl); +#endif + const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl); + const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl), + 0x1.7f7d1cp-20f, n, vl); + const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl); + const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); // 1.0f + const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl); + const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl); + const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2( + __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl), + __riscv_vfmacc_vv_f32m2( + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl), + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl), + u, vl), u, vl); + if (!__riscv_vcpop_m_b16(c, vl)) + return __riscv_vfmacc_vv_f32m2(k, j, k, vl); + const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl); + const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl); + const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl)); + const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl)); + const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2( + __riscv_vfmacc_vv_f32m2(k, k, j, vl), + __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl), + c, vl); + return __riscv_vmerge_vvm_f32m2( + r1, __riscv_vfmul_vv_f32m2(s1, s1, vl), + __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl), + vl); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) { + const vfloat32m2_t neg_x = __riscv_vfneg_v_f32m2(x, vl); + const vfloat32m2_t exp_neg_x = ggml_v_expf_m2(neg_x, vl); + const vfloat32m2_t one_plus_exp_neg_x = __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl); + return __riscv_vfdiv_vv_f32m2(x, one_plus_exp_neg_x, vl); +} + +#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt index 98ed29bc..30247751 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt @@ -24,17 +24,15 @@ if (CUDAToolkit_FOUND) # for best performance and to also build real architectures for the most commonly used GPUs. if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") set(CMAKE_CUDA_ARCHITECTURES "native") - elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) - if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8") - set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real") - else() - set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real") - endif() else() + if (CUDAToolkit_VERSION VERSION_LESS "13") + list(APPEND CMAKE_CUDA_ARCHITECTURES 50-virtual 61-virtual 70-virtual) + endif () + + list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8") - set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real") - else() - set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real") + list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real) endif() endif() endif() @@ -46,10 +44,14 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_CUDA "*.cu") + file(GLOB SRCS "template-instances/fattn-tile*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/mmf*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) if (GGML_CUDA_FA_ALL_QUANTS) file(GLOB SRCS "template-instances/fattn-vec*.cu") @@ -91,10 +93,6 @@ if (CUDAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_FA) endif() - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) - add_compile_definitions(GGML_CUDA_F16) - endif() - if (GGML_CUDA_NO_PEER_COPY) add_compile_definitions(GGML_CUDA_NO_PEER_COPY) endif() @@ -104,7 +102,11 @@ if (CUDAToolkit_FOUND) # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas) else () - target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1") + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + else() + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static) + endif() endif() else() target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas) @@ -120,6 +122,10 @@ if (CUDAToolkit_FOUND) set(CUDA_FLAGS -use_fast_math -extended-lambda) + if (GGML_CUDA_DEBUG) + list(APPEND CUDA_FLAGS -lineinfo) + endif() + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8") # Options are: # - none (not recommended) diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/add-id.cu b/ml/backend/ggml/ggml/src/ggml-cuda/add-id.cu index 8bed62ac..8d9cf692 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/add-id.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/add-id.cu @@ -11,14 +11,14 @@ static __global__ void add_id_kernel( const int64_t i1 = blockIdx.x; const int64_t i2 = blockIdx.y; - const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21); + const int i11 = *(const int32_t *) ((const char *) src2 + i1*sizeof(int32_t) + i2*nb21); const size_t nb1 = ne0 * sizeof(float); const size_t nb2 = ne1 * nb1; float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2); - const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02); - const float * src1_row = (const float *)((char *)src1 + i11*nb11); + const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02); + const float * src1_row = (const float *)((const char *)src1 + i11*nb11); for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { dst_row[i0] = src0_row[i0] + src1_row[i0]; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu b/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu index e1fbf0e1..60240102 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu @@ -1,5 +1,6 @@ #include "binbcast.cuh" #include +#include static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; @@ -22,73 +23,295 @@ static __device__ __forceinline__ float op_div(const float a, const float b) { return a / b; } -template -static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13) { - const int i0s = blockDim.x*blockIdx.x + threadIdx.x; - const int i1 = (blockDim.y*blockIdx.y + threadIdx.y); - const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3; - const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3; +template +static __global__ void k_bin_bcast(const src0_t * src0, + const src1_t * src1, + dst_t * dst, + const int ne0, + const int ne1, + const int ne2, + const uint3 ne3, + const uint3 ne10, + const uint3 ne11, + const uint3 ne12, + const uint3 ne13, + /*int s0, */ const int s1, + const int s2, + const int s3, + /*int s00,*/ const int s01, + const int s02, + const int s03, + /*int s10,*/ const int s11, + const int s12, + const int s13, + src1_ptrs... src1s) { + const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x; + const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y); + const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3); + const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z); - if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) { return; } - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; + const uint32_t i11 = fastmodulo(i1, ne11); + const uint32_t i12 = fastmodulo(i2, ne12); + const uint32_t i13 = fastmodulo(i3, ne13); const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; const size_t i_dst = i3*s3 + i2*s2 + i1*s1; - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; + const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; - for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) { - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { + const uint32_t i10 = fastmodulo(i0, ne10); + + float result = src0_row ? (float) src0_row[i0] : 0.0f; + if constexpr (sizeof...(src1_ptrs) > 0) { + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + } else { + result = bin_op(result, (float)src1[i_src1 + i10]); + } + + dst_row[i0] = (dst_t) result; } } -template -static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13) { - +template +static __global__ void k_bin_bcast_unravel(const src0_t * src0, + const src1_t * src1, + dst_t * dst, + const uint3 ne0, + const uint3 ne1, + const uint3 ne2, + const uint32_t ne3, + const uint3 prod_012, + const uint3 prod_01, + const uint3 ne10, + const uint3 ne11, + const uint3 ne12, + const uint3 ne13, + /*int s0, */ const int s1, + const int s2, + const int s3, + /*int s00,*/ const int s01, + const int s02, + const int s03, + /*int s10,*/ const int s11, + const int s12, + const int s13, + src1_ptrs... src1s) { const int i = blockDim.x*blockIdx.x + threadIdx.x; - const int i3 = i/(ne2*ne1*ne0); - const int i2 = (i/(ne1*ne0)) % ne2; - const int i1 = (i/ne0) % ne1; - const int i0 = i % ne0; + const uint32_t i3 = fastdiv(i, prod_012); + const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01); + const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0); + const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z; - if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) { return; } - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; + const int i11 = fastmodulo(i1, ne11); + const int i12 = fastmodulo(i2, ne12); + const int i13 = fastmodulo(i3, ne13); const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; const size_t i_dst = i3*s3 + i2*s2 + i1*s1; - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; + const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + const int i10 = fastmodulo(i0, ne10); + + float result = src0_row ? (float) src0_row[i0] : 0.0f; + if constexpr (sizeof...(src1_ptrs) > 0) { + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + } else { + result = bin_op(result, (float)src1[i_src1 + i10]); + } + + dst_row[i0] = (dst_t) result; +} + +template +static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, + cudaStream_t stream, std::index_sequence) { + GGML_TENSOR_BINARY_OP_LOCALS + + int nr0 = ne10 / ne0; + int nr1 = ne11 / ne1; + int nr2 = ne12 / ne2; + int nr3 = ne13 / ne3; + + int nr[4] = { nr0, nr1, nr2, nr3 }; + + int64_t cne[] = { ne0, ne1, ne2, ne3 }; + int64_t cne0[] = { ne00, ne01, ne02, ne03 }; + int64_t cne1[] = { ne10, ne11, ne12, ne13 }; + + size_t cnb[] = { nb0, nb1, nb2, nb3 }; + size_t cnb0[] = { nb00, nb01, nb02, nb03 }; + size_t cnb1[] = { nb10, nb11, nb12, nb13 }; + + auto collapse = [](int64_t cne[]) { + cne[0] *= cne[1]; + cne[1] = cne[2]; + cne[2] = cne[3]; + cne[3] = 1; + }; + + auto collapse_nb = [](size_t cnb[], const int64_t cne[]) { + cnb[1] *= cne[1]; + cnb[2] *= cne[2]; + cnb[3] *= cne[3]; + }; + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb, cne); + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne); + collapse(cne0); + collapse(cne1); + } + } + } + + { + int64_t ne0 = cne[0]; + int64_t ne1 = cne[1]; + int64_t ne2 = cne[2]; + int64_t ne3 = cne[3]; + + //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00); + //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01); + //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); + //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); + + size_t nb0 = cnb[0]; + size_t nb1 = cnb[1]; + size_t nb2 = cnb[2]; + size_t nb3 = cnb[3]; + + size_t nb00 = cnb0[0]; + size_t nb01 = cnb0[1]; + size_t nb02 = cnb0[2]; + size_t nb03 = cnb0[3]; + + size_t nb10 = cnb1[0]; + size_t nb11 = cnb1[1]; + size_t nb12 = cnb1[2]; + size_t nb13 = cnb1[3]; + + size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); + + size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + size_t s00 = nb00 / sizeof(src0_t); + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + + GGML_ASSERT(s0 == 1); + GGML_ASSERT(s00 == 1); + GGML_ASSERT(s10 == 1); + + const int block_size = 128; + + int64_t hne0 = std::max(ne0 / 2LL, 1LL); + + dim3 block_dims; + block_dims.x = std::min(hne0, block_size); + block_dims.y = std::min(ne1, block_size / block_dims.x); + block_dims.z = std::min(std::min(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U); + + dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y, + (ne2 * ne3 + block_dims.z - 1) / block_dims.z); + + const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]); + const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]); + const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]); + const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]); + + if (block_nums.z > 65535) { + int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; + const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2)); + const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1)); + const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0); + const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1); + const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2); + + if constexpr (sizeof...(I) > 0) { + k_bin_bcast_unravel<<>>( + src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, + ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + } else { + k_bin_bcast_unravel + <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, + ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13); + } + } else { + const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); + if constexpr (sizeof...(I) > 0) { + k_bin_bcast<<>>( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + } else { + k_bin_bcast<<>>( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13); + } + } + } } template @@ -120,160 +343,14 @@ static __global__ void k_repeat_back( dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; } -template +template struct bin_bcast_cuda { template void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, cudaStream_t stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - - int nr0 = ne10/ne0; - int nr1 = ne11/ne1; - int nr2 = ne12/ne2; - int nr3 = ne13/ne3; - - int nr[4] = { nr0, nr1, nr2, nr3 }; - - // collapse dimensions until first broadcast dimension - int64_t cne[] = {ne0, ne1, ne2, ne3}; - int64_t cne0[] = {ne00, ne01, ne02, ne03}; - int64_t cne1[] = {ne10, ne11, ne12, ne13}; - - size_t cnb[] = {nb0, nb1, nb2, nb3}; - size_t cnb0[] = {nb00, nb01, nb02, nb03}; - size_t cnb1[] = {nb10, nb11, nb12, nb13}; - - auto collapse = [](int64_t cne[]) { - cne[0] *= cne[1]; - cne[1] = cne[2]; - cne[2] = cne[3]; - cne[3] = 1; - }; - - auto collapse_nb = [](size_t cnb[], const int64_t cne[]) { - cnb[1] *= cne[1]; - cnb[2] *= cne[2]; - cnb[3] *= cne[3]; - }; - - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { - for (int i = 0; i < 4; i++) { - if (nr[i] != 1) { - break; - } - if (i > 0) { - collapse_nb(cnb, cne); - collapse_nb(cnb0, cne0); - collapse_nb(cnb1, cne1); - collapse(cne); - collapse(cne0); - collapse(cne1); - } - } - } - - { - int64_t ne0 = cne[0]; - int64_t ne1 = cne[1]; - int64_t ne2 = cne[2]; - int64_t ne3 = cne[3]; - - //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00); - //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01); - //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); - //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); - - int64_t ne10 = cne1[0]; - int64_t ne11 = cne1[1]; - int64_t ne12 = cne1[2]; - int64_t ne13 = cne1[3]; - - size_t nb0 = cnb[0]; - size_t nb1 = cnb[1]; - size_t nb2 = cnb[2]; - size_t nb3 = cnb[3]; - - size_t nb00 = cnb0[0]; - size_t nb01 = cnb0[1]; - size_t nb02 = cnb0[2]; - size_t nb03 = cnb0[3]; - - size_t nb10 = cnb1[0]; - size_t nb11 = cnb1[1]; - size_t nb12 = cnb1[2]; - size_t nb13 = cnb1[3]; - - size_t s0 = nb0 / sizeof(dst_t); - size_t s1 = nb1 / sizeof(dst_t); - size_t s2 = nb2 / sizeof(dst_t); - size_t s3 = nb3 / sizeof(dst_t); - - size_t s10 = nb10 / sizeof(src1_t); - size_t s11 = nb11 / sizeof(src1_t); - size_t s12 = nb12 / sizeof(src1_t); - size_t s13 = nb13 / sizeof(src1_t); - - size_t s00 = nb00 / sizeof(src0_t); - size_t s01 = nb01 / sizeof(src0_t); - size_t s02 = nb02 / sizeof(src0_t); - size_t s03 = nb03 / sizeof(src0_t); - - GGML_ASSERT(nb0 % sizeof(dst_t) == 0); - GGML_ASSERT(nb1 % sizeof(dst_t) == 0); - GGML_ASSERT(nb2 % sizeof(dst_t) == 0); - GGML_ASSERT(nb3 % sizeof(dst_t) == 0); - - GGML_ASSERT(nb00 % sizeof(src0_t) == 0); - GGML_ASSERT(nb01 % sizeof(src0_t) == 0); - GGML_ASSERT(nb02 % sizeof(src0_t) == 0); - GGML_ASSERT(nb03 % sizeof(src0_t) == 0); - - GGML_ASSERT(nb10 % sizeof(src1_t) == 0); - GGML_ASSERT(nb11 % sizeof(src1_t) == 0); - GGML_ASSERT(nb12 % sizeof(src1_t) == 0); - GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s00 == 1); - GGML_ASSERT(s10 == 1); - - const int block_size = 128; - - int64_t hne0 = std::max(ne0/2LL, 1LL); - - dim3 block_dims; - block_dims.x = std::min(hne0, block_size); - block_dims.y = std::min(ne1, block_size / block_dims.x); - block_dims.z = std::min(std::min(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U); - - dim3 block_nums( - (hne0 + block_dims.x - 1) / block_dims.x, - (ne1 + block_dims.y - 1) / block_dims.y, - (ne2*ne3 + block_dims.z - 1) / block_dims.z - ); - - if (block_nums.z > 65535) { - // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel - int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; - k_bin_bcast_unravel<<>>( - src0_dd, src1_dd, dst_dd, - ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00, */ s01, s02, s03, - /* s10, */ s11, s12, s13); - } else { - k_bin_bcast<<>>( - src0_dd, src1_dd, dst_dd, - ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00, */ s01, s02, s03, - /* s10, */ s11, s12, s13); - } - } + launch_bin_bcast_pack( + src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence{}); } }; @@ -312,7 +389,7 @@ static void ggml_cuda_op_bin_bcast( } void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); + ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); } void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -331,6 +408,68 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } +template +static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + cudaStream_t stream = ctx.stream(); + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + launch_bin_bcast_pack(src0, src1, dst, + (const float *) src0->data, (const float *) src1->data, (float *) dst->data, + stream, std::make_index_sequence{}); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + launch_bin_bcast_pack(src0, src1, dst, + (const half *) src0->data, (const half *) src1->data, (half *) dst->data, + stream, std::make_index_sequence{}); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + launch_bin_bcast_pack(src0, src1, dst, + (const half *) src0->data, (const float *) src1->data, (half *) dst->data, + stream, std::make_index_sequence{}); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + launch_bin_bcast_pack(src0, src1, dst, + (const half *) src0->data, (const float *) src1->data, (float *) dst->data, + stream, std::make_index_sequence{}); + } else { + fprintf(stderr, + "%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n", + __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } +} + + +void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) { + GGML_ASSERT(2 <= n_fuse && n_fuse <= 8); + + switch (n_fuse) { + case 2: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 3: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 4: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 5: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 6: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 7: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 8: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + default: + GGML_ASSERT(false && "Unsupported n_fuse value"); + } +} + void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cuh index 3ac1c9b0..62bc9501 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cuh @@ -7,3 +7,5 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh index 2e5d4879..28d6bcd7 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh @@ -35,6 +35,31 @@ #include "vendors/cuda.h" #endif // defined(GGML_USE_HIP) +extern bool reserving_graph; + +// If we are reserving the graph, pointers might be invalid and will fail if cudaMemcpyAsync tries to validate them. +// However, since we don't actually expect a result, we don't need to actually do the memcpy. +static cudaError_t cudaMemcpyAsyncReserve ( void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream = 0 ) { + if (!reserving_graph) { + return cudaMemcpyAsync(dst, src, count, kind, stream); + } else { + return cudaSuccess; + } +} + +static cudaError_t cudaMemcpy2DAsyncReserve ( void* dst, size_t dpitch, const void* src, size_t spitch, size_t width, size_t height, cudaMemcpyKind kind, cudaStream_t stream = 0 ) { + if (!reserving_graph) { + return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, kind, stream); + } else { + return cudaSuccess; + } +} + +#undef cudaMemcpyAsync +#define cudaMemcpyAsync cudaMemcpyAsyncReserve +#undef cudaMemcpy2DAsync +#define cudaMemcpy2DAsync cudaMemcpy2DAsyncReserve + #define STRINGIZE_IMPL(...) #__VA_ARGS__ #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) @@ -75,9 +100,13 @@ #define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) #define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1) #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2) +#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3) #define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) // Moore Threads +#define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons + #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 #define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 #define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD @@ -87,6 +116,10 @@ #define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG) #define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 +# define GGML_CUDA_USE_CUB +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 + #ifdef __CUDA_ARCH_LIST__ constexpr bool ggml_cuda_has_arch_impl(int) { return false; @@ -101,9 +134,9 @@ constexpr bool ggml_cuda_has_arch(const int arch) { return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__); } -constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) { +constexpr int ggml_cuda_highest_compiled_arch_impl(const int /*arch*/, const int cur) { if (cur == 0) { - GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch); + return -1; } return cur; } @@ -200,14 +233,6 @@ static const char * cu_get_error_str(CUresult err) { #define GGML_CUDA_ASSUME(x) #endif // CUDART_VERSION >= 11010 -#ifdef GGML_CUDA_F16 -typedef half dfloat; // dequantize float -typedef half2 dfloat2; -#else -typedef float dfloat; // dequantize float -typedef float2 dfloat2; -#endif // GGML_CUDA_F16 - #if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) #define GGML_USE_VMM #endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) @@ -220,14 +245,6 @@ typedef float2 dfloat2; #define FAST_FP16_AVAILABLE #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 -#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) -#define FP16_MMA_AVAILABLE -#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) - -#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) -#define FP16_MMA_AVAILABLE -#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) - #if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) #define AMD_MFMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) @@ -253,7 +270,8 @@ static bool fp16_available(const int cc) { } static bool fast_fp16_available(const int cc) { - return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc); + return GGML_CUDA_CC_IS_AMD(cc) || + (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610); } // To be used for feature selection of external libraries, e.g. cuBLAS. @@ -262,27 +280,6 @@ static bool fast_fp16_hardware_available(const int cc) { (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); } -// Any FP16 tensor core instructions are available for ggml code. -static bool fp16_mma_available(const int cc) { -#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) - return false; -#else - if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || - GGML_CUDA_CC_IS_MTHREADS(cc)) { - return true; - } else if (GGML_CUDA_CC_IS_RDNA4(cc)) { -#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12) - return true; -#else - return false; -#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12) - } else { - return false; - } -#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) -} - // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || @@ -312,11 +309,11 @@ static bool turing_mma_available(const int cc) { } static bool ampere_mma_available(const int cc) { - return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; } static bool cp_async_available(const int cc) { - return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; } static constexpr __device__ int ggml_cuda_get_physical_warp_size() { @@ -327,6 +324,20 @@ static constexpr __device__ int ggml_cuda_get_physical_warp_size() { #endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) } +// Maximum number of bytes that can be copied in a single instruction. +static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() { +#ifdef GGML_USE_HIP + return 16; +#else +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + return 16; +#else + return 8; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // GGML_USE_HIP +} + + [[noreturn]] static __device__ void no_device_code( const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { @@ -420,38 +431,30 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // FP16_AVAILABLE } -// Row reduction kernel template - compute sum (norm=false) or mean (norm=true) -template -static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) { - const int row = blockIdx.x; - const int col = threadIdx.x; - - float sum = 0.0f; - for (int i = col; i < ncols; i += blockDim.x) { - sum += x[row * ncols + i]; +template +static __device__ __forceinline__ int warp_reduce_all(int x) { + if (width == ggml_cuda_get_physical_warp_size()) { + return __all_sync(0xffffffff, x); + } else { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = __shfl_xor_sync(0xffffffff, x, offset, width) && x; + } + return x; } - - sum = warp_reduce_sum(sum); - - if (col != 0) { - return; - } - - dst[row] = norm ? sum / ncols : sum; } template -static __device__ __forceinline__ int warp_reduce_all(int x) { -#ifdef GGML_USE_HIP +static __device__ __forceinline__ int warp_reduce_any(int x) { + if (width == ggml_cuda_get_physical_warp_size()) { + return __any_sync(0xffffffff, x); + } else { #pragma unroll - for (int offset = width/2; offset > 0; offset >>= 1) { - x = x && __shfl_xor_sync(0xffffffff, x, offset, width); + for (int offset = width/2; offset > 0; offset >>= 1) { + x = __shfl_xor_sync(0xffffffff, x, offset, width) || x; + } + return x; } - return x; -#else - static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented"); - return __all_sync(0xffffffff, x); -#endif // GGML_USE_HIP } template @@ -480,25 +483,21 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b } static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) { -#if defined(GGML_USE_HIP) && HIP_VERSION >= 50700000 +#if defined(GGML_USE_HIP) return half2(__hmax(a.x, b.x), __hmax(a.y, b.y)); -#elif !defined(GGML_USE_HIP) && CUDART_VERSION >= CUDART_HMAX +#elif CUDART_VERSION >= CUDART_HMAX return __hmax2(a, b); -#elif !defined(GGML_USE_HIP) +#else half2 ret; reinterpret_cast(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b))); reinterpret_cast(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b))); return ret; -#else - GGML_UNUSED(a); - GGML_UNUSED(b); - NO_DEVICE_CODE; #endif } template static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP) #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width)); @@ -507,16 +506,17 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #else GGML_UNUSED(x); NO_DEVICE_CODE; -#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP) } -#if CUDART_VERSION < CUDART_HMASK +#if (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || \ + (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK) static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b))); const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); return mask_low | mask_high; } -#endif // CUDART_VERSION < CUDART_HMASK +#endif // (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK) static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) { #if defined(GGML_USE_HIP) @@ -558,6 +558,74 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #endif // defined(GGML_USE_HIP) } +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) { + acc += v*u; +} + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) { + acc += v.x*u.x; + acc += v.y*u.y; +} + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) { +#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA)) + asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u)); +#else +#ifdef FAST_FP16_AVAILABLE + const float2 tmp = __half22float2(v*u); + acc += tmp.x + tmp.y; +#else + const float2 tmpv = __half22float2(v); + const float2 tmpu = __half22float2(u); + acc += tmpv.x * tmpu.x; + acc += tmpv.y * tmpu.y; +#endif // FAST_FP16_AVAILABLE +#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA)) +} + +static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) { +#ifdef FAST_FP16_AVAILABLE + acc += v*u; +#else + const float2 tmpv = __half22float2(v); + const float2 tmpu = __half22float2(u); + float2 tmpacc = __half22float2(acc); + tmpacc.x += tmpv.x * tmpu.x; + tmpacc.y += tmpv.y * tmpu.y; + acc = make_half2(tmpacc.x, tmpacc.y); +#endif // FAST_FP16_AVAILABLE +} + +// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD. +// Important: do not use this function if dst and src both point at registers. +// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types. +// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions. +// If dst and src point at different address spaces then they are guaranteed to not be aliased. +template +static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) { + if constexpr (alignment != 0) { + static_assert(nbytes % alignment == 0, "bad alignment"); + } + constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment; + +#pragma unroll + for (int i = 0; i < nbytes/nb_per_cpy; ++i) { + if constexpr (nb_per_cpy == 1) { + ((char *) dst)[i] = ((const char *) src)[i]; + } else if constexpr (nb_per_cpy == 2) { + ((short *) dst)[i] = ((const short *) src)[i]; + } else if constexpr (nb_per_cpy == 4) { + ((int *) dst)[i] = ((const int *) src)[i]; + } else if constexpr (nb_per_cpy == 8) { + ((int2 *) dst)[i] = ((const int2 *) src)[i]; + } else if constexpr (nb_per_cpy == 16) { + ((int4 *) dst)[i] = ((const int4 *) src)[i]; + } else { + static_assert(nbytes == 0 && nbytes == -1, "bad nbytes"); + } + } +} + static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #if CUDART_VERSION >= 12080 const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x); @@ -576,7 +644,49 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #endif // CUDART_VERSION >= 12050 } -typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +static const uint3 init_fastdiv_values(uint32_t d) { + GGML_ASSERT(d != 0); + + // compute L = ceil(log2(d)); + uint32_t L = 0; + while (L < 32 && (uint32_t{ 1 } << L) < d) { + L++; + } + + uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); + // pack divisor as well to reduce error surface + return make_uint3(mp, L, d); +} + +static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) { + // expects fastdiv_values to contain in + // fastdiv_values.z is unused and optimized away by the compiler. + // Compute high 32 bits of n * mp + const uint32_t hi = __umulhi(n, fastdiv_values.x); + // add n, apply bit shift + return (hi + n) >> fastdiv_values.y; +} + +static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) { + // expects fastdiv_values to contain in (see init_fastdiv_values) + return n - fastdiv(n, fastdiv_values) * fastdiv_values.z; +} + +// Calculate both division and modulo at once, returns +static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) { + // expects fastdiv_values to contain in (see init_fastdiv_values) + const uint32_t div_val = fastdiv(n, fastdiv_values); + const uint32_t mod_val = n - div_val * fastdiv_values.z; + return make_uint2(div_val, mod_val); +} + +typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v); static __device__ __forceinline__ float get_alibi_slope( const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1 @@ -771,6 +881,9 @@ struct ggml_cuda_pool { virtual void * alloc(size_t size, size_t * actual_size) = 0; virtual void free(void * ptr, size_t size) = 0; + + virtual bool alloc_memory() = 0; + virtual size_t alloc_size() = 0; }; template @@ -914,11 +1027,11 @@ struct ggml_backend_cuda_context { // pool std::unique_ptr pools[GGML_CUDA_MAX_DEVICES]; - static std::unique_ptr new_pool_for_device(int device); + static std::unique_ptr new_pool_for_device(int device, bool alloc); ggml_cuda_pool & pool(int device) { if (pools[device] == nullptr) { - pools[device] = new_pool_for_device(device); + pools[device] = new_pool_for_device(device, true); } return *pools[device]; } @@ -926,4 +1039,20 @@ struct ggml_backend_cuda_context { ggml_cuda_pool & pool() { return pool(device); } + + void pool_set_alloc(bool alloc) { + GGML_ASSERT(pools[device] == nullptr || pools[device]->alloc_memory() == alloc); + + if (pools[device] == nullptr) { + pools[device] = new_pool_for_device(device, alloc); + } + } + + size_t pool_get_alloc_size() { + if (pools[device] == nullptr) { + return 0; + } + + return pools[device]->alloc_size(); + } }; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/conv-transpose-1d.cu b/ml/backend/ggml/ggml/src/ggml-cuda/conv-transpose-1d.cu index fe4caf67..8418ba66 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/conv-transpose-1d.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/conv-transpose-1d.cu @@ -34,10 +34,7 @@ static __global__ void conv_transpose_1d_kernel( } } dst[global_index] = accumulator; - GGML_UNUSED(p0); GGML_UNUSED(d0); GGML_UNUSED(src0_ne3); - GGML_UNUSED(src1_ne3); GGML_UNUSED(dst_ne3); - GGML_UNUSED(src1_ne1); GGML_UNUSED(dst_ne1); - GGML_UNUSED(src1_ne2); GGML_UNUSED(dst_ne2); + GGML_UNUSED_VARS(p0, d0, src0_ne3, src1_ne3, dst_ne3, src1_ne1, dst_ne1, src1_ne2, dst_ne2); } static void conv_transpose_1d_f32_f32_cuda( diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/conv2d.cu b/ml/backend/ggml/ggml/src/ggml-cuda/conv2d.cu new file mode 100644 index 00000000..142dd669 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/conv2d.cu @@ -0,0 +1,166 @@ +#include "conv2d.cuh" +#include "convert.cuh" + +struct conv_params { + const int64_t IW, IH; + const int64_t OW, OH; + const int64_t KW, KH; + const int64_t ST_X, ST_Y; + const int64_t PD_X, PD_Y; + const int64_t DL_X, DL_Y; + const int64_t IC, OC; + const int64_t B; + const int64_t TOTAL; +}; + +struct kernel_bounds { + int64_t y_min, y_max; + int64_t x_min, x_max; +}; + +__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { + return (a > b) ? a : b; +} + +__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) { + return (a < b) ? a : b; +} + +__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) { + kernel_bounds bounds; + bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + return bounds; +} + +__device__ __forceinline__ int calculate_input_coord(int64_t out_coord, + int64_t kern_coord, + int64_t stride, + int64_t dilation, + int64_t padding) { + return out_coord * stride + kern_coord * dilation - padding; +} + +struct whcn_layout { + __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; + } + + __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) { + return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; + } + + __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; + } + + __device__ static void unpack_indices(int64_t global_idx, + const conv_params & P, + int64_t & n, + int64_t & c, + int64_t & out_y, + int64_t & out_x) { + out_x = global_idx % P.OW; + out_y = (global_idx / P.OW) % P.OH; + c = (global_idx / (P.OW * P.OH)) % P.OC; + n = global_idx / (P.OW * P.OH * P.OC); + } +}; + +template +static __global__ void conv2d_kernel(const float * __restrict__ input, + const T * __restrict__ kernel, + float * __restrict__ output, + const conv_params P) { + const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (global_idx >= P.TOTAL) { + return; + } + + int64_t n, c_out, out_y, out_x; + Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x); + + float acc = 0.0f; + + for (int64_t c_in = 0; c_in < P.IC; ++c_in) { + kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P); + + for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) { + const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y); + + for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) { + const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X); + + const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)]; + const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)]; + acc += (input_val * ggml_cuda_cast(kernel_val)); + } + } + } + + // [N, OC, OH, OW] + output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc; +} + +template +static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; + conv2d_kernel<<>>(X_D, K_D, Y_D, P); +} + +static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_cuda(X_D, K_D, Y_D, P, st); +} + +static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_cuda(X_D, K_D, Y_D, P, st); +} + +void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * kernel = dst->src[0]; + const ggml_tensor * input = dst->src[1]; + float * K_D = (float *) kernel->data; + const float * X_D = (const float *) input->data; + float * Y_D = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous(kernel)); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + + // same number of input channels + GGML_ASSERT(input->ne[2] == kernel->ne[2]); + + cudaStream_t st = ctx.stream(); + + const int32_t * p = (const int32_t *) dst->op_params; + const int ST_X = p[0]; // stride_x + const int ST_Y = p[1]; // stride_y + const int PD_X = p[2]; // padding_x + const int PD_Y = p[3]; // padding_y + const int DL_X = p[4]; // dilation_x + const int DL_Y = p[5]; // dilation_y + + // No cwhn + GGML_ASSERT(p[6] == false); + + const int IW = input->ne[0]; // input_w + const int IH = input->ne[1]; // input_h + const int OW = dst->ne[0]; // output_w + const int OH = dst->ne[1]; // output_h + const int KW = kernel->ne[0]; // kernel_w + const int KH = kernel->ne[1]; // kernel_h + const int IC = input->ne[2]; // input_channels + const int OC = kernel->ne[3]; // ouptut_chanles + const int B = input->ne[3]; // n_batches + + const int64_t total = B * OC * OH * OW; + conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + + if (kernel->type == GGML_TYPE_F16) { + conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st); + } else { + conv2d_cuda_f32(X_D, K_D, Y_D, params, st); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/conv2d.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/conv2d.cuh new file mode 100644 index 00000000..ce4802c7 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/conv2d.cuh @@ -0,0 +1,5 @@ +#pragma once +#include "common.cuh" + +#define CUDA_CONV2D_BLOCK_SIZE 256 +void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu index e3beddbc..ba3d4eeb 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu @@ -27,12 +27,12 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ const int64_t y_offset = qr == 1 ? 1 : qk/2; // dequantize - dfloat2 v; + float2 v; dequantize_kernel(vx, ib, iqs, v); const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs; - y[iy0 + 0] = float(v.x); - y[iy0 + y_offset] = float(v.y); + y[iy0 + 0] = ggml_cuda_cast(v.x); + y[iy0 + y_offset] = ggml_cuda_cast(v.y); } template @@ -71,9 +71,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d)); } #else - GGML_UNUSED(vx); - GGML_UNUSED(y); - GGML_UNUSED(k); + GGML_UNUSED_VARS(vx, y, k); NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL } @@ -630,7 +628,7 @@ static __global__ void convert_unary( const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00; - y[iy] = float(x[ix]); + y[iy] = ggml_cuda_cast(x[ix]); } template diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cuh index f04214be..ef9e1299 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cuh @@ -29,3 +29,18 @@ typedef to_t_nc_cuda_t to_bf16_nc_cuda_t; to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type); to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type); to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type); + +template + __host__ __device__ inline dst_t ggml_cuda_cast(src_t x) { + if constexpr (std::is_same_v) { + return x; + } else if constexpr(std::is_same_v) { + return __float2bfloat16(float(x)); + } else if constexpr(std::is_same_v) { + return __bfloat162float(x); + } else if constexpr(std::is_same_v) { + return int32_t(x); + } else { + return float(x); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cpy-utils.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/cpy-utils.cuh index b8e9e107..597c0c8b 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/cpy-utils.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/cpy-utils.cuh @@ -1,15 +1,7 @@ #pragma once #include "ggml-common.h" - -template -static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) { - if constexpr (std::is_same_v) { - *dst = *src; - } else { - *dst = float(*src); - } -} +#include "convert.cuh" static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { if (x <= val[0]) return 0; @@ -221,7 +213,7 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { template static __device__ void cpy_1_flt(const char * cxi, char * cdsti) { - convert_flt((const src_t *)cxi, (dst_t *)cdsti); + *(dst_t *) cdsti = ggml_cuda_cast(*(const src_t *) cxi); } static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu index 9c3774e5..911220e9 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu @@ -42,7 +42,7 @@ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { #pragma unroll for (int j = 0; j < QK8_0; j += 2) { - dfloat2 dq; + float2 dq; dequantize_q8_0(cxi, 0, j, dq); *(cdstf + j) = dq.x; *(cdstf + j + 1) = dq.y; @@ -55,7 +55,7 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { #pragma unroll for (int j = 0; j < qk/2; j++) { - dfloat2 dq; + float2 dq; dequant(cxi, 0, j, dq); *(cdstf + j) = dq.x; *(cdstf + j + qk/2) = dq.y; @@ -134,8 +134,7 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream)); cuda_graph->graph_cpynode_index = 0; // reset index #else - GGML_UNUSED(cuda_graph); GGML_UNUSED(host_dest_ptrs); - GGML_UNUSED(host_dest_ptrs_size); GGML_UNUSED(stream); + GGML_UNUSED_VARS(cuda_graph, host_dest_ptrs, host_dest_ptrs_size, stream); #endif } @@ -371,7 +370,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY { - CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + if (src0->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + } } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); @@ -418,6 +421,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); @@ -440,7 +447,13 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - return nullptr; + // Prioritize CUDA graph compatibility over direct memory copy optimization. + // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. + if (src0->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; + } else { + return nullptr; + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { @@ -481,6 +494,10 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/dequantize.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/dequantize.cuh index bd3c2d9d..e060fb29 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/dequantize.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/dequantize.cuh @@ -1,48 +1,37 @@ #include "common.cuh" -static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; - const dfloat d = x[ib].d; + const float d = x[ib].d; const int vui = x[ib].qs[iqs]; v.x = vui & 0xF; v.y = vui >> 4; -#ifdef GGML_CUDA_F16 - v = __hsub2(v, {8.0f, 8.0f}); - v = __hmul2(v, {d, d}); -#else v.x = (v.x - 8.0f) * d; v.y = (v.y - 8.0f) * d; -#endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q4_1 * x = (const block_q4_1 *) vx; - const dfloat d = __low2half(x[ib].dm); - const dfloat m = __high2half(x[ib].dm); + const float2 dm = __half22float2(x[ib].dm); const int vui = x[ib].qs[iqs]; v.x = vui & 0xF; v.y = vui >> 4; -#ifdef GGML_CUDA_F16 - v = __hmul2(v, {d, d}); - v = __hadd2(v, {m, m}); -#else - v.x = (v.x * d) + m; - v.y = (v.y * d) + m; -#endif // GGML_CUDA_F16 + v.x = (v.x * dm.x) + dm.y; + v.y = (v.y * dm.x) + dm.y; } -static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q5_0 * x = (const block_q5_0 *) vx; - const dfloat d = x[ib].d; + const float d = x[ib].d; uint32_t qh; memcpy(&qh, x[ib].qh, sizeof(qh)); @@ -53,20 +42,14 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); v.y = ((x[ib].qs[iqs] >> 4) | xh_1); -#ifdef GGML_CUDA_F16 - v = __hsub2(v, {16.0f, 16.0f}); - v = __hmul2(v, {d, d}); -#else v.x = (v.x - 16.0f) * d; v.y = (v.y - 16.0f) * d; -#endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q5_1 * x = (const block_q5_1 *) vx; - const dfloat d = __low2half(x[ib].dm); - const dfloat m = __high2half(x[ib].dm); + const float2 dm = __half22float2(x[ib].dm); uint32_t qh; memcpy(&qh, x[ib].qh, sizeof(qh)); @@ -77,27 +60,18 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); v.y = ((x[ib].qs[iqs] >> 4) | xh_1); -#ifdef GGML_CUDA_F16 - v = __hmul2(v, {d, d}); - v = __hadd2(v, {m, m}); -#else - v.x = (v.x * d) + m; - v.y = (v.y * d) + m; -#endif // GGML_CUDA_F16 + v.x = (v.x * dm.x) + dm.y; + v.y = (v.y * dm.x) + dm.y; } -static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q8_0 * x = (const block_q8_0 *) vx; - const dfloat d = x[ib].d; + const float d = x[ib].d; v.x = x[ib].qs[iqs + 0]; v.y = x[ib].qs[iqs + 1]; -#ifdef GGML_CUDA_F16 - v = __hmul2(v, {d, d}); -#else v.x *= d; v.y *= d; -#endif // GGML_CUDA_F16 } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh index e46f0e20..bc0c2523 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh @@ -33,276 +33,230 @@ typedef void (* fattn_kernel_t)( const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33); -typedef half (*vec_dot_KQ_f16_t)( - const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); -typedef float (*vec_dot_KQ_f32_t)( +typedef float (*vec_dot_KQ_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); -template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( - const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { - - const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; - GGML_UNUSED(Q_v); - - T sum = 0.0f; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const int ib = k_KQ / QI8_1; - const int iqs4 = k_KQ % QI4_0; - const int shift = k_KQ & (QI8_1/2); - - const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int u = Q_q8[k_KQ_0/warp_size]; - - const int sumi = ggml_cuda_dp4a(v, u, 0); - -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - const half2 * Q_ds = (const half2 *) Q_ds_v; - - const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size]; - sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */); - } else -#endif // FP16_AVAILABLE - { - const float2 * Q_ds = (const float2 *) Q_ds_v; - - sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y)); - } - } - - return sum; -} - -template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( - const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { - - const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; - GGML_UNUSED(Q_v); - - T sum = 0.0f; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const int ib = k_KQ / QI8_1; - const int iqs4 = k_KQ % QI4_1; - const int shift = k_KQ & (QI8_1/2); - - const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int u = Q_q8[k_KQ_0/warp_size]; - - const int sumi = ggml_cuda_dp4a(v, u, 0); - -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - const half2 * Q_ds = (const half2 *) Q_ds_v; - - const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size]; - const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1); - sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled)); - } else -#endif // FP16_AVAILABLE - { - const float2 * Q_ds = (const float2 *) Q_ds_v; - - const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi; - const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1; - - sum += (T) (sumid4d8 + m4s8scaled); - } - } - - return sum; -} - -template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( - const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { - - const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; - GGML_UNUSED(Q_v); - - T sum = 0.0f; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const int ib = k_KQ / QI8_1; - const int iqs4 = k_KQ % QI5_0; - const int iqs8 = k_KQ % QI8_1; - const int shift = k_KQ & (QI8_1/2); - - int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0); - v |= (vh << 4) & 0x00000010; // 0 -> 4 - v |= (vh << 11) & 0x00001000; // 1 -> 12 - v |= (vh << 18) & 0x00100000; // 2 -> 20 - v |= (vh << 25) & 0x10000000; // 3 -> 28 - - const int u = Q_q8[k_KQ_0/warp_size]; - - const int sumi = ggml_cuda_dp4a(v, u, 0); - -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - const half2 * Q_ds = (const half2 *) Q_ds_v; - - const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size]; - sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */; - } else -#endif // FP16_AVAILABLE - { - const float2 * Q_ds = (const float2 *) Q_ds_v; - - sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y)); - } - } - - return sum; -} - -template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( - const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { - - const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; - GGML_UNUSED(Q_v); - - T sum = 0.0f; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const int ib = k_KQ / QI8_1; - const int iqs4 = k_KQ % QI5_1; - const int iqs8 = k_KQ % QI8_1; - const int shift = k_KQ & (QI8_1/2); - - int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1); - v |= (vh << 4) & 0x00000010; // 0 -> 4 - v |= (vh << 11) & 0x00001000; // 1 -> 12 - v |= (vh << 18) & 0x00100000; // 2 -> 20 - v |= (vh << 25) & 0x10000000; // 3 -> 28 - - const int u = Q_q8[k_KQ_0/warp_size]; - - const int sumi = ggml_cuda_dp4a(v, u, 0); - -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - const half2 * Q_ds = (const half2 *) Q_ds_v; - - const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size]; - const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1); - sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled)); - } else -#endif // FP16_AVAILABLE - { - const float2 * Q_ds = (const float2 *) Q_ds_v; - - const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi; - const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1; - - sum += (T) (sumid5d8 + m5s8scaled); - } - } - - return sum; -} - -template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( - const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { - - const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; - GGML_UNUSED(Q_v); - - T sum = 0.0f; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const int ib = k_KQ / QI8_0; - const int iqs = k_KQ % QI8_0; - - const int v = get_int_b2(K_q8_0[ib].qs, iqs); - - T Q_d; - if (std::is_same::value) { - const half2 * Q_ds = (const half2 *) Q_ds_v; - Q_d = __low2half(Q_ds[k_KQ_0/warp_size]); - } else { - const float2 * Q_ds = (const float2 *) Q_ds_v; - Q_d = Q_ds[k_KQ_0/warp_size].x; - } - - sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d); - } - - return sum; -} - -template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { const half2 * K_h2 = (const half2 *) K_c; GGML_UNUSED(Q_q8); GGML_UNUSED(Q_ds_v); -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - const half2 * Q_h2 = (const half2 *) Q_v; - - half2 sum2 = make_half2(0.0f, 0.0f); - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const half2 K_ik = K_h2[k_KQ]; - sum2 += K_ik * Q_h2[k_KQ_0/warp_size]; - } - - return __low2half(sum2) + __high2half(sum2); - } -#endif // FP16_AVAILABLE - - const float2 * Q_f2 = (const float2 *) Q_v; + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; float sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const half2 K_ik = K_h2[k_KQ]; - sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x; - sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y; + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { + half2 tmp[cpy_ne]; + ggml_cuda_memcpy_1(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#else + ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // FP16_AVAILABLE + } } return sum; } -template +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_0; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_cuda_memcpy_1(&v, K_q4_0[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads]; + sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x - (8/QI8_1)*Q_ds.y); + } + + return sum; +} + +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_1( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_cuda_memcpy_1(&v, K_q4_1[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + + const float2 K_dm = __half22float2(K_q4_1[ib].dm); + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads]; + + sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1; + } + + return sum; +} + +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_0; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_cuda_memcpy_1(&v, K_q5_0[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + + { + int vh; + ggml_cuda_memcpy_1(&vh, K_q5_0[ib].qh); + vh >>= iqs8 * QI5_0; + + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + } + + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads]; + + sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x - (16/QI8_1)*Q_ds.y); + } + + return sum; +} + +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_1( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_1; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_cuda_memcpy_1(&v, K_q5_1[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + + { + int vh; + ggml_cuda_memcpy_1(&vh, K_q5_1[ib].qh); + vh >>= iqs8 * QI5_0; + + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + } + + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + + const float2 K_dm = __half22float2(K_q5_1[ib].dm); + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads]; + + sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1; + } + + return sum; +} + +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_0; + const int iqs = k_KQ % QI8_0; + + int v; + ggml_cuda_memcpy_1(&v, K_q8_0[ib].qs + 4*iqs); + + const float2 * Q_ds = (const float2 *) Q_ds_v; + const float Q_d = Q_ds[k_KQ_0/nthreads].x; + + sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d); + } + + return sum; +} + +template static __device__ __forceinline__ void quantize_q8_1_to_shared( const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) { float vals[sizeof(int)] = {0.0f}; #pragma unroll for (int l = 0; l < int(sizeof(int)); ++l) { - vals[l] = scale * x[4*threadIdx.x + l]; + vals[l] = (ni == WARP_SIZE || threadIdx.x < ni) ? scale * x[4*threadIdx.x + l] : 0.0f; } float amax = fabsf(vals[0]); @@ -330,7 +284,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared( } yq32[threadIdx.x] = q32; - if (threadIdx.x % QI8_1 == 0) { + if (threadIdx.x % QI8_1 == 0 && (ni == WARP_SIZE || threadIdx.x < ni)) { if (std::is_same::value) { ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum); } else { @@ -339,167 +293,276 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared( } } -typedef half (*dequantize_1_f16_t)(const void *, const int64_t); -typedef float (*dequantize_1_f32_t)(const void *, const int64_t); +typedef void (*dequantize_V_t)(const void *, void *, const int64_t); -template -static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) { +template +static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + if constexpr (std::is_same_v) { + ggml_cuda_memcpy_1(dst, (const half *) vx + i0); + } else if constexpr (std::is_same_v) { + static_assert(ne % 2 == 0, "bad ne"); + half2 tmp[ne/2]; + ggml_cuda_memcpy_1(tmp, (const half *) vx + i0); + float2 * dst_f2 = (float2 *) dst; +#pragma unroll + for (int l = 0; l < ne/2; ++l) { + dst_f2[l] = __half22float2(tmp[l]); + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } +} + +template +static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q4_0 * x = (const block_q4_0 *) vx; - const int64_t ib = i / QK4_0; - const int iqs = i % (QK4_0/2); - const int shift = (i % QK4_0) / (QK4_0/2); + const int64_t ib = i0 / QK4_0; + const int iqs = i0 % (QK4_0/2); + const int shift = (i0 % QK4_0) / (QK4_0/2); - const T d = x[ib].d; - const int q0 = x[ib].qs[iqs]; - const int q = ((q0 >> (4*shift)) & 0x0F) - 8; + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_cuda_memcpy_1(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + q = __vsubss4(q, 0x08080808); + + const int8_t * q8 = (const int8_t *) &q; #ifdef FP16_AVAILABLE - if (std::is_same::value) { - return ((half) d)*((half) q); - } -#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const half2 d = __half2half2(x[ib].d); - return ((float) d)*((float) q); +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]); + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const float d = x[ib].d; + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * q8[l]; + } + } else { + static_assert(std::is_same_v, "bad type"); + } } -template -static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) { +template +static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q4_1 * x = (const block_q4_1 *) vx; - const int64_t ib = i / QK4_1; - const int iqs = i % (QK4_1/2); - const int shift = (i % QK4_1) / (QK4_1/2); + const int64_t ib = i0 / QK4_1; + const int iqs = i0 % (QK4_1/2); + const int shift = (i0 % QK4_1) / (QK4_1/2); - const half2 dm = x[ib].dm; - const int q0 = x[ib].qs[iqs]; - const int q = ((q0 >> (4*shift)) & 0x0F); + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_cuda_memcpy_1(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + const int8_t * q8 = (const int8_t *) &q; #ifdef FP16_AVAILABLE - if (std::is_same::value) { - return __low2half(dm)*((half) q) + __high2half(dm); - } -#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const half2 dm = x[ib].dm; + const half2 d = __half2half2( __low2half(dm)); + const half2 m = __half2half2(__high2half(dm)); - return __low2float(dm)*((float) q) + __high2float(dm); +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m; + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const float2 dm = __half22float2(x[ib].dm); + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = dm.x * q8[l] + dm.y; + } + } else { + static_assert(std::is_same_v, "bad type"); + } } -template -static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) { +template +static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q5_0 * x = (const block_q5_0 *) vx; - const int64_t ib = i / QK5_0; - const int idq = i % QK5_0; - const int iqs = i % (QK5_0/2); - const int shift = (i % QK5_0) / (QK5_0/2); + const int64_t ib = i0 / QK5_0; + const int idq = i0 % QK5_0; + const int iqs = i0 % (QK5_0/2); + const int shift = (i0 % QK5_0) / (QK5_0/2); - const T d = x[ib].d; - const int ql0 = x[ib].qs[iqs]; - const int qh0 = get_int_b2(x[ib].qh, 0); - const int ql = ((ql0 >> (4*shift)) & 0x0F); - const int qh = ((qh0 >> idq) << 4) & 0x10; - const int q = (ql | qh) - 16; + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_cuda_memcpy_1(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + { + int qh; + ggml_cuda_memcpy_1(&qh, x[ib].qh); +#pragma unroll + for (int l = 0; l < ne; ++l) { + q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4); + } + } + + q = __vsubss4(q, 0x10101010); + + const int8_t * q8 = (const int8_t *) &q; #ifdef FP16_AVAILABLE - if (std::is_same::value) { - return ((half) d)*((half) q); - } -#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const half2 d = __half2half2(x[ib].d); - return ((float) d)*((float) q); +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]); + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const float d = x[ib].d; + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * q8[l]; + } + } else { + static_assert(std::is_same_v, "bad type"); + } } -template -static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) { +template +static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q5_1 * x = (const block_q5_1 *) vx; - const int64_t ib = i / QK5_1; - const int idq = i % QK5_1; - const int iqs = i % (QK5_1/2); - const int shift = (i % QK5_1) / (QK5_1/2); + const int64_t ib = i0 / QK5_1; + const int idq = i0 % QK5_1; + const int iqs = i0 % (QK5_1/2); + const int shift = (i0 % QK5_1) / (QK5_1/2); - const half2 dm = x[ib].dm; - const int ql0 = x[ib].qs[iqs]; - const int qh0 = get_int_b4(x[ib].qh, 0); - const int ql = ((ql0 >> (4*shift)) & 0x0F); - const int qh = ((qh0 >> idq) << 4) & 0x10; - const int q = (ql | qh); + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_cuda_memcpy_1(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + { + int qh; + ggml_cuda_memcpy_1(&qh, x[ib].qh); +#pragma unroll + for (int l = 0; l < ne; ++l) { + q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4); + } + } + + const int8_t * q8 = (const int8_t *) &q; #ifdef FP16_AVAILABLE - if (std::is_same::value) { - return __low2half(dm)*((half) q) + __high2half(dm); - } -#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const half2 dm = x[ib].dm; + const half2 d = __half2half2( __low2half(dm)); + const half2 m = __half2half2(__high2half(dm)); - return __low2float(dm)*((float) q) + __high2float(dm); +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m; + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const float2 dm = __half22float2(x[ib].dm); + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = dm.x * q8[l] + dm.y; + } + } else { + static_assert(std::is_same_v, "bad type"); + } } -template -static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) { +template +static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q8_0 * x = (const block_q8_0 *) vx; - const int64_t ib = i / QK8_0; - const int iqs = i % QK8_0; + const int64_t ib = i0 / QK8_0; + const int iqs = i0 % QK8_0; - const T d = x[ib].d; - const int q = x[ib].qs[iqs]; + static_assert(ne % 2 == 0, "bad ne"); + int8_t qs[ne]; + ggml_cuda_memcpy_1(qs, x[ib].qs + iqs); #ifdef FP16_AVAILABLE - if (std::is_same::value) { - return ((half) d)*((half) q); - } + if constexpr (std::is_same::value) { + const half2 d = __half2half2(x[ib].d); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]); + } + } else #endif // FP16_AVAILABLE + if constexpr (std::is_same::value) { + const float d = x[ib].d; - return ((float) d)*((float) q); +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * qs[l]; + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } } -template -static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) { - const half * x = (const half *) vx; - - return x[i]; +template +constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { + if constexpr (type_K == GGML_TYPE_F16) { + return vec_dot_fattn_vec_KQ_f16; + } else if constexpr (type_K == GGML_TYPE_Q4_0) { + return vec_dot_fattn_vec_KQ_q4_0; + } else if constexpr (type_K == GGML_TYPE_Q4_1) { + return vec_dot_fattn_vec_KQ_q4_1; + } else if constexpr (type_K == GGML_TYPE_Q5_0) { + return vec_dot_fattn_vec_KQ_q5_0; + } else if constexpr (type_K == GGML_TYPE_Q5_1) { + return vec_dot_fattn_vec_KQ_q5_1; + } else if constexpr (type_K == GGML_TYPE_Q8_0) { + return vec_dot_fattn_vec_KQ_q8_0; + } else { + static_assert(type_K == -1, "bad type"); + return nullptr; + } } -template -constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { - return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : - type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : - type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : - type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : - type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : - type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : - nullptr; -} - -template -constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { - return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : - type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : - type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : - type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : - type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : - type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : - nullptr; -} - -constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) { - return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0 : - type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 : - type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 : - type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 : - type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 : - type_V == GGML_TYPE_F16 ? dequantize_1_f16 : - nullptr; -} - -constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { - return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0 : - type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 : - type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 : - type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 : - type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 : - type_V == GGML_TYPE_F16 ? dequantize_1_f16 : - nullptr; +template +constexpr __device__ dequantize_V_t get_dequantize_V() { + if constexpr (type_V == GGML_TYPE_F16) { + return dequantize_V_f16; + } else if constexpr (type_V == GGML_TYPE_Q4_0) { + return dequantize_V_q4_0; + } else if constexpr (type_V == GGML_TYPE_Q4_1) { + return dequantize_V_q4_1; + } else if constexpr (type_V == GGML_TYPE_Q5_0) { + return dequantize_V_q5_0; + } else if constexpr (type_V == GGML_TYPE_Q5_1) { + return dequantize_V_q5_1; + } else if constexpr (type_V == GGML_TYPE_Q8_0) { + return dequantize_V_q8_0; + } else { + static_assert(type_V == -1, "bad type"); + return nullptr; + } } template @@ -539,11 +602,15 @@ static __global__ void flash_attn_mask_to_KV_max( all_inf = warp_reduce_all(all_inf); if (!all_inf) { - KV_max_sj += FATTN_KQ_STRIDE; break; } } + // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE. + // If the break was triggered it's the lower edge of the tile with the first non-masked values. + // In either case, walk back the decrementation by FATTN_KQ_STRIDE. + KV_max_sj += FATTN_KQ_STRIDE; + if (threadIdx.x != 0) { return; } @@ -643,9 +710,7 @@ static __global__ void flash_attn_stream_k_fixup( } template // D == head size -#if !defined(GGML_USE_HIP) __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIP) static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const float2 * __restrict__ VKQ_meta, @@ -688,10 +753,7 @@ static __global__ void flash_attn_combine_results( float VKQ_numerator = 0.0f; float VKQ_denominator = 0.0f; for (int l = 0; l < parallel_blocks; ++l) { - const float diff = meta[l].x - kqmax; - float KQ_max_scale = expf(diff); - const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); - *((uint32_t *) &KQ_max_scale) &= ftz_mask; + const float KQ_max_scale = expf(meta[l].x - kqmax); VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; @@ -700,28 +762,6 @@ static __global__ void flash_attn_combine_results( dst[tid] = VKQ_numerator / VKQ_denominator; } -[[noreturn]] -static void on_no_fattn_vec_case(const int D) { - if (D == 64) { - fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); - fprintf(stderr, "By default only f16 KV cache is supported.\n"); - fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); - GGML_ABORT("fatal error"); - } else if (D == 128) { - fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); - fprintf(stderr, "Supported combinations:\n"); - fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n"); - fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n"); - fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n"); - fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); - GGML_ABORT("fatal error"); - } else { - fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D); - fprintf(stderr, "Only f16 is supported.\n"); - GGML_ABORT("fatal error"); - } -} - template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, @@ -753,8 +793,6 @@ void launch_fattn( GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); - GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); - ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); @@ -838,7 +876,7 @@ void launch_fattn( // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or // multiple sequences of possibly different lengths. - if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { + if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { const int s31 = mask->nb[1] / sizeof(half2); const int s33 = mask->nb[3] / sizeof(half2); @@ -854,11 +892,10 @@ void launch_fattn( CUDA_CHECK(cudaGetLastError()); } - int parallel_blocks = 1; - const dim3 block_dim(warp_size, nwarps, 1); int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + int parallel_blocks = max_blocks_per_sm; dim3 blocks_num; if (stream_k) { @@ -877,11 +914,7 @@ void launch_fattn( dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { - GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); - const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. - - // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: - parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); + const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. // parallel_blocks must not be larger than what the tensor size allows: parallel_blocks = std::min(parallel_blocks, ntiles_KQ); @@ -897,7 +930,7 @@ void launch_fattn( const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. - if (efficiency_percent_best >= 90 && nwaves > nwaves_best) { + if (efficiency_percent_best >= 95 && nwaves > nwaves_best) { break; } @@ -910,7 +943,7 @@ void launch_fattn( blocks_num.x = ntiles_x; blocks_num.y = parallel_blocks; - blocks_num.z = Q->ne[2]*Q->ne[3]; + blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3]; if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 39731baa..57defb0c 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -767,14 +767,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } #else - GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); - GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); - GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); - GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K); - GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); - GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); - GGML_UNUSED(kb0); GGML_UNUSED(tile_Q); + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, + scale, slope, logit_softcap, ne01, ne02, + stride_K, stride_V, stride_mask, + tile_Q, tile_K, tile_V, tile_mask, + Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -1236,12 +1233,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } #else - GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); - GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); - GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); - GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); - GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dstk_fixup, + scale, slope, logit_softcap, ne01, ne02, + stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, + jt, kb0_start, kb0_stop); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -1395,17 +1390,15 @@ static __global__ void flash_attn_ext_f16( (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); - GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); - GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu deleted file mode 100644 index 0fcfaa32..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ /dev/null @@ -1,347 +0,0 @@ -#include "common.cuh" -#include "fattn-common.cuh" -#include "fattn-tile-f16.cuh" - -#define FATTN_KQ_STRIDE_TILE_F16 64 - -template // D == head size -#if !defined(GGML_USE_HIP) -__launch_bounds__(nwarps*WARP_SIZE, 2) -#endif // !defined(GGML_USE_HIP) -static __global__ void flash_attn_tile_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) - - // Skip unused kernel variants for faster compilation: -#ifdef FP16_MMA_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FP16_MMA_AVAILABLE - if (use_logit_softcap && !(D == 128 || D == 256)) { - NO_DEVICE_CODE; - return; - } - - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - - const int stride_KV2 = nb11 / sizeof(half2); - - const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - const half slopeh = __float2half(slopef); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - - __shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16]; - half2 * KQ2 = (half2 *) KQ; - - __shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts. - - half kqmax[ncols/nwarps]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - kqmax[j0/nwarps] = -HALF_MAX_HALF; - } - half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}}; - - half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; - - // Convert Q to half2 and store in registers: - __shared__ half2 Q_h2[ncols][D/2]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); - Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); - } - } - - __syncthreads(); - - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) { - // Calculate KQ tile and keep track of new maximum KQ values: - - half kqmax_new[ncols/nwarps]; -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - kqmax_new[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; - } - } - - __syncthreads(); - - half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}}; - -#pragma unroll - for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) { - half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE]; - half2 Q_k[ncols/nwarps]; - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - - K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ]; - } -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps]; - } - } - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - half sum; - if (use_logit_softcap) { - const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - sum = logit_softcap * tanhf(tmp.x + tmp.y); - } else { - sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - } - sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); - - kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum); - - KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum; - } - } - - __syncthreads(); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]); - const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps])); - kqmax[j0/nwarps] = kqmax_new[j0/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]); - const half2 val = h2exp(diff); - kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val; - KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) { - const int k = k0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) { - half2 V_k[(D/2)/WARP_SIZE][2]; - half2 KQ_k[ncols/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i]; - V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i]; - } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2]; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]); - VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]); - } - } - } - - __syncthreads(); - } - - float2 * dst2 = (float2 *) dst; - -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { - const int j_VKQ = j_VKQ_0 + threadIdx.y; - - if (ic0 + j_VKQ >= ne01) { - return; - } - - half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); - kqsum_j = warp_reduce_sum((float)kqsum_j); - - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; - -#pragma unroll - for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { - const int i0 = i00 + threadIdx.x; - - half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; - if (gridDim.y == 1) { - dst_val /= __half2half2(kqsum_j); - } - dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val); - } - - if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); - } - } -#else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); - GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); - GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); - GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); - NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) -} - -template -void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { - case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); - } break; - case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); - } break; - default: { - GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); - } break; - } -} - -void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - const int32_t precision = KQV->op_params[3]; - GGML_ASSERT(precision == GGML_PREC_DEFAULT); - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (Q->ne[1] <= 16) { - constexpr int cols_per_block = 16; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f16_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f16_64_128(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 32; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f16_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f16_64_128(ctx, dst); - } -} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cuh deleted file mode 100644 index ffc58784..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cuh +++ /dev/null @@ -1,3 +0,0 @@ -#include "common.cuh" - -void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu deleted file mode 100644 index 23550cbb..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ /dev/null @@ -1,355 +0,0 @@ -#include "common.cuh" -#include "fattn-common.cuh" -#include "fattn-tile-f32.cuh" - -#define FATTN_KQ_STRIDE_TILE_F32 32 - -template // D == head size -#if !defined(GGML_USE_HIP) -__launch_bounds__(nwarps*WARP_SIZE, 2) -#endif // !defined(GGML_USE_HIP) -static __global__ void flash_attn_tile_ext_f32( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#ifdef FLASH_ATTN_AVAILABLE - - // Skip unused kernel variants for faster compilation: -#ifdef FP16_MMA_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FP16_MMA_AVAILABLE - if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); - GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); - GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); - NO_DEVICE_CODE; - return; - } - - // In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - - const int stride_KV2 = nb11 / sizeof(half2); - - const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - - __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32]; - - __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts. - float2 * KV_tmp2 = (float2 *) KV_tmp; - - float kqmax[ncols/nwarps]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - kqmax[j0/nwarps] = -FLT_MAX/2.0f; - } - float kqsum[ncols/nwarps] = {0.0f}; - - float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; - - // Convert Q to half2 and store in registers: - __shared__ float Q_f[ncols][D]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { - float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f); - Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale; - Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale; - } - } - - __syncthreads(); - - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) { - // Calculate KQ tile and keep track of new maximum KQ values: - - float kqmax_new[ncols/nwarps]; -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - kqmax_new[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) { - const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; - KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp); - KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp); - } - } - - __syncthreads(); - - float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}}; - -#pragma unroll - for (int k_KQ = 0; k_KQ < D; ++k_KQ) { - float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE]; - float Q_k[ncols/nwarps]; - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - - K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ]; - } -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps]; - } - } - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - if (use_logit_softcap) { - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - } - - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; - - kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - - KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]; - } - } - - __syncthreads(); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]); - const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]); - kqmax[j0/nwarps] = kqmax_new[j0/nwarps]; - - float kqsum_add = 0.0f; -#pragma unroll - for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps]; - const float val = expf(diff); - kqsum_add += val; - KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val; - } - kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; - VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) { - const int k = k0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; - KV_tmp2[k*(D/2) + i].x = __low2float(tmp); - KV_tmp2[k*(D/2) + i].y = __high2float(tmp); - } - } - - __syncthreads(); - -#pragma unroll - for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) { - float2 V_k[(D/2)/WARP_SIZE]; - float KQ_k[ncols/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i]; - } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k]; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps]; - VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps]; - } - } - } - - __syncthreads(); - } - - float2 * dst2 = (float2 *) dst; - -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { - const int j_VKQ = j_VKQ_0 + threadIdx.y; - - if (ic0 + j_VKQ >= ne01) { - return; - } - - float kqsum_j = kqsum[j_VKQ_0/nwarps]; - kqsum_j = warp_reduce_sum(kqsum_j); - - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; - -#pragma unroll - for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { - const int i0 = i00 + threadIdx.x; - - float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; - if (gridDim.y == 1) { - dst_val.x /= kqsum_j; - dst_val.y /= kqsum_j; - } - dst2[j_dst_unrolled*(D/2) + i0] = dst_val; - } - - if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); - } - } -#else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); - GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); - GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); - NO_DEVICE_CODE; -#endif // FLASH_ATTN_AVAILABLE -} - -template -void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { - case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); - } break; - case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); - } break; - default: { - GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); - } break; - } -} - -void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (Q->ne[1] <= 16) { - constexpr int cols_per_block = 16; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f32_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f32_64_128(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 32; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f32_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f32_64_128(ctx, dst); - } -} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cuh deleted file mode 100644 index b1c546c8..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cuh +++ /dev/null @@ -1,3 +0,0 @@ -#include "common.cuh" - -void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cu new file mode 100644 index 00000000..3a5806d9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cu @@ -0,0 +1,45 @@ +#include "common.cuh" +#include "fattn-tile.cuh" +#include "fattn-wmma-f16.cuh" + +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + switch (K->ne[0]) { + case 40: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst); + } break; + case 64: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst); + } break; + case 80: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst); + } break; + case 96: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst); + } break; + case 112: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst); + } break; + case 128: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst); + } break; + case 256: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); + } break; + case 576: { + GGML_ASSERT(V->ne[0] == 512); + ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); + } break; + default: { + GGML_ABORT("Unsupported head size"); + } break; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh new file mode 100644 index 00000000..2b60b3bb --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh @@ -0,0 +1,1206 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-wmma-f16.cuh" + +// nbatch_fa == number of KQ rows to process per iteration +// nbatch_K == number of K columns to load in parallel for KQ calculation + +// TODO optimize kernel parameters for FP16 NVIDIA (P100) +// TODO optimize kernel parameters for head sizes 40, 80, 96, 112 + +// The ROCm compiler cannot handle templating in __launch_bounds__. +// As a workaround, define a macro to package the kernel parameters as uint32_t: +#define GGML_CUDA_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \ + if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \ + static_assert((nthreads) <= 512, "bad nthreads"); \ + static_assert((occupancy) <= 8, "bad occupancy"); \ + static_assert((nbatch_fa) <= 256, "bad nbatch_fa"); \ + static_assert((nbatch_K) <= 256, "bad nbatch_K"); \ + return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23); \ + } \ + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp16(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 64, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp32(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 3, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) + + return 0; +} + +static __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) { + if (GGML_CUDA_CC_IS_AMD(cc)) { + if (GGML_CUDA_CC_IS_RDNA(cc)) { + return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols); + } + return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols); + } + if (fast_fp16_available(cc)) { + return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols); + } + return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols); +} + +static constexpr __device__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) { +#ifdef GGML_USE_HIP +#ifdef RDNA + return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols); +#else + return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols); +#endif // RDNA +#else +#ifdef FAST_FP16_AVAILABLE + return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols); +#else + return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols); +#endif // FAST_FP16_AVAILABLE +#endif // GGML_USE_HIP +} + +static __host__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1); +} + +// TODO: deduplicate with mma-f16 +template +static __device__ __forceinline__ void flash_attn_tile_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] __device__ (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j); + const int j0_stop = ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne; + + const half2 zero[cpy_ne] = {{0.0f, 0.0f}}; + ggml_cuda_memcpy_1( + tile_KV + i*(J/2 + J_padding) + j, + !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + } + } + } + }; + // 1: max 64*16=512 bytes, 512 half + // 2: max 32*16=512 bytes, 256 half + // 3: max 16*16=256 bytes, 128 half + // 4: max 8*16=128 bytes, 64 half + // 5: max 4*16= 64 bytes, 32 half + // 6: max 2*16= 32 bytes, 16 half + // 7: max 1*16= 16 bytes, 8 half + static_assert(J % 8 == 0, "bad J"); + static_assert((J/2) % cpy_ne == 0, "bad J"); + ggml_cuda_unroll<7>{}(load); +} + +template +static __device__ __forceinline__ void flash_attn_tile_load_tile( + const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] __device__ (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j); + const int j0_stop = (J/cpy_ne) - (J/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2); + + const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}}; + half2 tmp_h2[cpy_ne/2]; + ggml_cuda_memcpy_1( + tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + + float2 tmp_f2[cpy_ne/2]; +#pragma unroll + for (int l = 0; l < cpy_ne/2; ++l) { + tmp_f2[l] = __half22float2(tmp_h2[l]); + } + ggml_cuda_memcpy_1(tile_KV + i*(J + J_padding) + 2*j, tmp_f2); + } + } + } + }; + // 1: max 32*16=512 bytes, 128 float + // 2: max 16*16=256 bytes, 64 float + // 3: max 8*16=128 bytes, 32 float + // 4: max 4*16= 64 bytes, 16 float + // 5: max 2*16= 32 bytes, 8 float + static_assert(J % 8 == 0, "bad J"); + static_assert(J % cpy_ne == 0, "bad J"); + ggml_cuda_unroll<5>{}(load); +} + +// Function that performs a single iteration in for the KQ matrix multiplication: +template +static __device__ __forceinline__ void flash_attn_tile_iter_KQ( + T_vec_dot * const Q_tmp, + const half2 * const __restrict__ K_h2, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int k_VKQ_0, + const int k_VKQ_sup, + const int k_KQ_0, + float * KQ_acc) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + flash_attn_tile_load_tile + (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE + static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) { + half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + half2 Q_k[cpw][cpy_ne]; +#else + static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) { + float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + float Q_k[cpw][cpy_ne]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]); +#else + ggml_cuda_memcpy_1(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K + cpy_ne) + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y / np)*cpw; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]); +#else + ggml_cuda_memcpy_1(&Q_k[jc0], &Q_tmp[jc* DKQ + k_KQ_0 + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { +#pragma unroll + for (int k = 0; k < cpy_ne; ++k) { + ggml_cuda_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]); + } + } + } + } + + if (k_KQ_0 + nbatch_K < DKQ) { + __syncthreads(); // Sync not needed on last iteration. + } +} + +// Function that performs a single iteration of the main loop over up to nbatch_fa tokens. +template +static __device__ __forceinline__ void flash_attn_tile_iter( + T_vec_dot * const Q_tmp, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ mask, + const float logit_softcap, + const float slope, + T_KQ * const KQ, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int stride_V2, + const int stride_mask, + float * const KQ_max, + float * const KQ_sum, + T_acc * const VKQ, + const int k_VKQ_0, + const int k_VKQ_max) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + + // KQ_cs == KQ chunk size, number of KQ values in j direction to store as one contiguous chunk in memory. + // KQ is originally 2D but uses a Z-shaped 3D memory pattern like KQ[ncols/KQ_cs][DVp][KQ_cs]. +#ifdef FAST_FP16_AVAILABLE + constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; +#else + constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; +#endif // FAST_FP16_AVAILABLE + static_assert(cpw % KQ_cs == 0, "bad KQ_cs"); + const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data + + float KQ_max_new[cpw]; +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_max_new[jc0] = KQ_max[jc0]; + } + + float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication. + + // KQ = K @ Q matrix multiplication: + constexpr int nbatch_K_last = DKQ % nbatch_K; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) { + flash_attn_tile_iter_KQ( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + if (nbatch_K_last > 0) { + constexpr int k_KQ_0 = DKQ - nbatch_K_last; + flash_attn_tile_iter_KQ( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + + // Apply logit softcap + mask, update KQ_max: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2; + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x; + + if (use_logit_softcap) { + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + } + + if (!oob_check || i_KQ < k_VKQ_sup) { + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ? + slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; + + KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + } + } + + KQ_max_new[jc0] = warp_reduce_max(KQ_max_new[jc0]); + } + + if constexpr (np == 1) { + __syncthreads(); + } else { + static_assert(cpw == 1, "bad cpw"); + __shared__ float KQ_max_new_shared[nwarps]; + if (threadIdx.x == 0) { + KQ_max_new_shared[threadIdx.y] = KQ_max_new[0]; + } + __syncthreads(); + KQ_max_new[0] = KQ_max_new_shared[(threadIdx.y & ~(np-1)) + threadIdx.x % np]; + KQ_max_new[0] = warp_reduce_max(KQ_max_new[0]); + } + + // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) { +#ifdef FAST_FP16_AVAILABLE + half tmp[nbatch_fa/(np*warp_size)][KQ_cs]; +#else + float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int jc1 = 0; jc1 < KQ_cs; ++jc1) { + const int jc = jc0 + jc1; + + const float KQ_max_scale = expf(KQ_max[jc] - KQ_max_new[jc]); + KQ_max[jc] = KQ_max_new[jc]; + + float KQ_sum_add = 0.0f; +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ? + expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f; + KQ_sum_add += val; + tmp[i0/(np*warp_size)][jc1] = val; + } + KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale; + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const int i = i0 + (threadIdx.y % np)*warp_size + threadIdx.x; + + ggml_cuda_memcpy_1( + KQ + (jc0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs))*(nbatch_fa*KQ_cs) + i*KQ_cs, + tmp[i0/(np*warp_size)]); + } + } + + // VKQ = V @ KQ matrix multiplication: + static_assert(DV <= DKQ, "bad DV"); + static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K"); + constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K. + static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V"); + static_assert(nbatch_V % np == 0, "bad nbatch_V"); +#pragma unroll + for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { + flash_attn_tile_load_tile + (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + half2 V_k[(DVp/2)/warp_size]; + half2 KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&V_k[i0/warp_size], &KV_tmp[(k1 + threadIdx.y % np)*(DV/2) + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); + + half tmp[KQ_cs]; + ggml_cuda_memcpy_1( + &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); +#pragma unroll + for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) { + KQ_k[jc_VKQ_0+jc_VKQ_1] = __half2half2(tmp[jc_VKQ_1]); + } + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size] += V_k[i0/warp_size]*KQ_k[jc_VKQ_0]; + } + } + } +#else +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + float2 V_k[(DVp/2)/warp_size]; + float KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + threadIdx.y % np)*DV + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); + + ggml_cuda_memcpy_1( + &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[jc_VKQ_0]; + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[jc_VKQ_0]; + } + } + } +#endif // FAST_FP16_AVAILABLE + + __syncthreads(); + } +} + +template // D == head size +__launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2)) +static __global__ void flash_attn_tile( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#ifdef FLASH_ATTN_AVAILABLE + + // Skip unused kernel variants for faster compilation: + + if ( +#ifdef GGML_USE_WMMA_FATTN + (ncols2 != 1 && DV != 40 && DV != 512) || +#endif // GGML_USE_WMMA_FATTN + (use_logit_softcap && !(DV == 128 || DV == 256)) + ) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; + return; + } + + static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined"); + + constexpr int ncols = ncols1*ncols2; + constexpr int warp_size = 32; + constexpr int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size; + constexpr int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2); + constexpr int nbatch_K = ggml_cuda_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2); + + // In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int col_Q_0 = blockIdx.x * ncols1; // Index of the first Q column for this CUDA block to work on. + + const int sequence = blockIdx.z / (ne02/ncols2); + const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2) + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape + + const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr; + + const int stride_K2 = nb11 / sizeof(half2); + const int stride_V2 = nb21 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp. + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column. + static_assert(cpw == 1 || np == 1, "bad cpw / np"); + static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0"); + + constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size. + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + + // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel. + // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11. + // KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV). + // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications. + // VKQ == Accumulators in registers for the final VKQ result. +#ifdef FAST_FP16_AVAILABLE + __shared__ half2 Q_tmp[ncols * DKQ/2]; + __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV]; + __shared__ half KQ[ncols * nbatch_fa]; + half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; +#else + __shared__ float Q_tmp[ncols * DKQ]; + __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV]; + __shared__ float KQ[ncols * nbatch_fa]; + float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; +#endif // FAST_FP16_AVAILABLE + + float KQ_max[cpw]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + KQ_max[j0/nwarps] = -FLT_MAX/2.0f; + } + float KQ_sum[cpw] = {0.0f}; + + // Load Q data, convert to FP16 if fast: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y / np)*cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size; + +#pragma unroll + for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { + if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { + float tmp_f[cpy_ne_D] = {0.0f}; + if (ncols1 == 1 || col_Q_0 + j < ne01) { + ggml_cuda_memcpy_1 + (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float)) + + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); + } + +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp_f[i1] *= scale; + } + +#ifdef FAST_FP16_AVAILABLE + half2 tmp_h2[cpy_ne_D/2]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { + tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); + } + ggml_cuda_memcpy_1( + &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)], + tmp_h2); +#else + ggml_cuda_memcpy_1( + &Q_tmp[jc* DKQ + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x* cpy_ne_D], + tmp_f); +#endif // FAST_FP16_AVAILABLE + } + } + } + + __syncthreads(); + + // Main loop over KV cache: + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + if (ncols2 == 1) { + // Branch with out-of-bounds checks. + int k_VKQ_0 = blockIdx.y*nbatch_fa; + while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { + constexpr bool oob_check = false; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + k_VKQ_0 += gridDim.y*nbatch_fa; + } + if (k_VKQ_0 < k_VKQ_max) { + constexpr bool oob_check = true; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + } + } else { + // Branch without out-of-bounds checks. + for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) { + constexpr bool oob_check = false; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + } + } + +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_sum[jc0] = warp_reduce_sum(KQ_sum[jc0]); + } + + if constexpr (np > 1) { + static_assert(cpw == 1, "bad cpw"); + static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small"); + +#ifdef FAST_FP16_AVAILABLE + half2 * VKQ_combine = (half2 *) KV_tmp; +#else + float * VKQ_combine = (float *) KV_tmp; +#endif // FAST_FP16_AVAILABLE + float * KQ_sum_combine = (float *) Q_tmp; + + if (threadIdx.y % np != 0) { +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&VKQ_combine[threadIdx.y*(DVp/2) + i0 + threadIdx.x*cpy_ne_D], &VKQ[i0/warp_size]); + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1( + &VKQ_combine[threadIdx.y*DVp + i0 + threadIdx.x*cpy_ne_D], ((const float *) VKQ) + i0/warp_size); + } +#endif // FAST_FP16_AVAILABLE + + if (threadIdx.x == 0) { + KQ_sum_combine[threadIdx.y] = KQ_sum[0]; + } + + return; + } + + __syncthreads(); + +#pragma unroll + for (int ip = 1; ip < np; ++ip) { +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + half2 tmp[cpy_ne_D]; + ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + VKQ[i0/warp_size + i1] += tmp[i1]; + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + float tmp[cpy_ne_D]; + ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + ((float *)VKQ)[i0/warp_size + i1] += tmp[i1]; + } + } +#endif // FAST_FP16_AVAILABLE + + KQ_sum[0] += KQ_sum_combine[threadIdx.y + ip]; + } + } + + // Attention sink: adjust KQ max and sum only for the first of all parallel blocks: + if (sinks && blockIdx.y == 0) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y/np)*cpw; + const float sink = ((const float *) sinks)[head0 + jc % ncols2]; + + float KQ_max_new_j = fmaxf(KQ_max[jc0], sink); + const float KQ_max_scale = expf(KQ_max[jc0] - KQ_max_new_j); + KQ_max[jc0] = KQ_max_new_j; + + const float val = expf(sink - KQ_max[jc0]); + KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + } + + // Write back results: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y/np)*cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if (ncols1 > 1 && col_Q_0 + j >= ne01) { + return; + } + + const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f; + + const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; + +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + float2 tmp[cpy_ne_D]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]); + tmp[i1].x *= scale; + tmp[i1].y *= scale; + } + if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) { + ggml_cuda_memcpy_1(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) { +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale; + } + ggml_cuda_memcpy_1( + &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D], + &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]); + } + } +#endif // FAST_FP16_AVAILABLE + + if (gridDim.y != 1 && threadIdx.x == 0) { + dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]); + } + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE +} + +template +static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int warp_size = 32; + + constexpr size_t nbytes_shared = 0; + +#ifdef GGML_USE_HIP + if constexpr (DV <= 128) { + if (Q->ne[1] > 32/ncols2) { + constexpr int cols_per_block = 64; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } +#endif // GGML_USE_HIP + +#ifndef GGML_USE_HIP + if constexpr (DV <= 256) +#endif // GGML_USE_HIP + { + if (Q->ne[1] > 16/ncols2) { + constexpr int cols_per_block = 32; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + + if constexpr (ncols2 <= 8) { + if (Q->ne[1] > 4/ncols2) { + constexpr int cols_per_block = 8; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if constexpr (ncols2 <= 4) { + if (Q->ne[1] > 2/ncols2) { + constexpr int cols_per_block = 4; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if constexpr (ncols2 <= 2) { + constexpr int cols_per_block = 2; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + + GGML_ABORT("fatal error"); +} + +template +static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc); + const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX; + const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; + + if constexpr (DV == 512) { + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + } + + if constexpr (DV <= 256) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + GGML_ABORT("fatal error"); +} + +template +void ggml_cuda_flash_attn_ext_tile_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_switch_ncols2(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_switch_ncols2(ctx, dst); + } +} + +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +#define DECL_FATTN_TILE_CASE(DKQ, DV) \ + template void ggml_cuda_flash_attn_ext_tile_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +extern DECL_FATTN_TILE_CASE( 40, 40); +extern DECL_FATTN_TILE_CASE( 64, 64); +extern DECL_FATTN_TILE_CASE( 80, 80); +extern DECL_FATTN_TILE_CASE( 96, 96); +extern DECL_FATTN_TILE_CASE(112, 112); +extern DECL_FATTN_TILE_CASE(128, 128); +extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh deleted file mode 100644 index b05f682c..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ /dev/null @@ -1,497 +0,0 @@ -#include "common.cuh" -#include "fattn-common.cuh" - -// Currenlty llvm with the amdgcn target dose not support unrolling loops -// that contain a break that can not be resolved at compile time. -#ifdef __clang__ -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wpass-failed" -#endif // __clang__ -template // D == head size -#ifndef GGML_USE_HIP -__launch_bounds__(D, 1) -#endif // GGML_USE_HIP -static __global__ void flash_attn_vec_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) - - // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(D == 128 || D == 256)) { - NO_DEVICE_CODE; - return; - } -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - if (ncols > 1) { - NO_DEVICE_CODE; - return; - } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); - constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; - constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V); - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb03*sequence + nb02* head + nb01*ic0; - K += nb13*sequence + nb12*(head / gqa_ratio); - V += nb23*sequence + nb22*(head / gqa_ratio); - - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - const half slopeh = __float2half(slopef); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - constexpr int nwarps = D / WARP_SIZE; - const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); - - __shared__ half KQ[ncols*D]; - half2 * KQ2 = (half2 *) KQ; - - half kqmax[ncols]; - half kqsum[ncols]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqmax[j] = -HALF_MAX_HALF; - kqsum[j] = 0.0f; - } - - __shared__ half kqmax_shared[ncols][WARP_SIZE]; - __shared__ half kqsum_shared[ncols][WARP_SIZE]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (threadIdx.y == 0) { - kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF; - kqsum_shared[j][threadIdx.x] = 0.0f; - } - } - - __shared__ half maskh_shared[ncols*D]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - maskh_shared[j*D + tid] = 0.0f; - } - - __syncthreads(); - - // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: - half2 Q_h2[ncols][D/(2*WARP_SIZE)]; - int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)]; - half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; - if (Q_q8_1) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (j0 + nwarps > ncols && j >= ncols) { - break; - } - - // Reuse KQ as temporary storage for converting Q to q8_1: - int * tmp_q_i32 = (int *) &KQ[j*D]; - half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); - - // Set memory to zero if out of bounds: - if (ncols > 2 && ic0 + j >= ne01) { -#pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - tmp_q_i32[i] = 0; - } - if (threadIdx.x < D/QK8_1) { - tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f); - } - continue; - } - - const float * Q_f = (const float *) (Q + j*nb01); -#pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { - quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - int * tmp_q_i32 = (int *) &KQ[j*D]; - half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); - -#pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; - Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; - } - } - - __syncthreads(); - } else { -#pragma unroll - for (int j = 0; j < ncols; ++j) { - const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); - Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); - } - } - } - - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -HALF_MAX_HALF; - } - __syncthreads(); - - half2 VKQ[ncols] = {{0.0f, 0.0f}}; - - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - K += blockIdx.y*D * nb11; - V += blockIdx.y*D * nb21; - maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D, - // Increment pointers after each loop: - K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { - - // Calculate KQ tile and keep track of new maximum KQ values: - - if (mask) { -#pragma unroll - for (int j = 0; j < ncols; ++j) { - maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid]; - } - __syncthreads(); - } - - // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, - // see https://github.com/ggerganov/llama.cpp/pull/7061 . - // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). - half kqmax_new = kqmax[0]; - half kqmax_new_arr[ncols]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqmax_new_arr[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - - if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { - break; - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); - sum = warp_reduce_sum((float)sum); - - if (use_logit_softcap) { - sum = logit_softcap*tanhf(sum); - } - - sum += maskh_shared[j*D + i_KQ]; - - if (ncols == 1) { - kqmax_new = ggml_cuda_hmax(kqmax_new, sum); - } else { - kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum); - } - - if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum; - } - } - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; - - if (threadIdx.x == 0) { - kqmax_shared[j][threadIdx.y] = kqmax_new_j; - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half kqmax_new_j = kqmax_shared[j][threadIdx.x]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); - kqmax[j] = kqmax_new_j; - - const half val = hexp(KQ[j*D + tid] - kqmax[j]); - kqsum[j] = kqsum[j]*KQ_max_scale + val; - KQ[j*D + tid] = val; - - VKQ[j] *= __half2half2(KQ_max_scale); - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < D; k0 += 2) { - if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { - break; - } - - half2 V_k; - reinterpret_cast(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid); - reinterpret_cast(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid); -#pragma unroll - for (int j = 0; j < ncols; ++j) { - VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; - } - } - - __syncthreads(); - } - - if (sinksf && blockIdx.y == 0) { - const half sink = __float2half(sinksf[head]); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (threadIdx.x == 0) { - kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink); - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half kqmax_new_j = kqmax_shared[j][threadIdx.x]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); - kqmax[j] = kqmax_new_j; - - const half val = hexp(sink - kqmax[j]); - kqsum[j] = kqsum[j]*KQ_max_scale; - - if (tid == 0) { - kqsum[j] += val; - } - - VKQ[j] *= __half2half2(KQ_max_scale); - } - - __syncthreads(); - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqsum[j] = warp_reduce_sum((float)kqsum[j]); - if (threadIdx.x == 0) { - kqsum_shared[j][threadIdx.y] = kqsum[j]; - } - } - - __syncthreads(); - -#pragma unroll - for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ncols > 2 && ic0 + j_VKQ >= ne01) { - break; - } - - kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; - kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]); - - half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); - if (gridDim.y == 1) { - dst_val /= kqsum[j_VKQ]; - } - dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val; - } - - if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); - } -#else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); - GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); - GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); - NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) -} -#ifdef __clang__ -#pragma clang diagnostic pop -#endif // __clang__ - -template -void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; - constexpr bool need_f16_K = D != 128; - constexpr bool need_f16_V = D != 128 && D != 64; - constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); -} - -template -void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - const int32_t precision = KQV->op_params[3]; - GGML_ASSERT(precision == GGML_PREC_DEFAULT); - - GGML_ASSERT(K->type == type_K); - GGML_ASSERT(V->type == type_V); - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - - if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { - constexpr int cols_per_block = 1; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } - return; - } - - if (Q->ne[1] == 2) { - constexpr int cols_per_block = 2; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } - return; - } - - if (Q->ne[1] <= 4) { - constexpr int cols_per_block = 4; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 8; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } -} - -#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \ - template void ggml_cuda_flash_attn_ext_vec_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ - -extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); - -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); - -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); - -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); - -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); - -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); - -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); - -extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh deleted file mode 100644 index d6d0bfb7..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ /dev/null @@ -1,490 +0,0 @@ -#include "common.cuh" -#include "fattn-common.cuh" - -// Currenlty llvm with the amdgcn target dose not support unrolling loops -// that contain a break that can not be resolved at compile time. -#ifdef __clang__ -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wpass-failed" -#endif // __clang__ -template // D == head size -#ifndef GGML_USE_HIP -__launch_bounds__(D, 1) -#endif // GGML_USE_HIP -static __global__ void flash_attn_vec_ext_f32( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#ifdef FLASH_ATTN_AVAILABLE - - // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); - GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); - GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); - GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); - NO_DEVICE_CODE; - return; - } -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - if (ncols > 1) { - NO_DEVICE_CODE; - return; - } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32(type_K); - constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; - constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V); - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb03*sequence + nb02* head + nb01*ic0; - K += nb13*sequence + nb12*(head / gqa_ratio); - V += nb23*sequence + nb22*(head / gqa_ratio); - - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - constexpr int nwarps = D / WARP_SIZE; - const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); - - __shared__ float KQ[ncols*D]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -FLT_MAX/2.0f; - } - - float kqmax[ncols]; - float kqsum[ncols]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqmax[j] = -FLT_MAX/2.0f; - kqsum[j] = 0.0f; - } - - __shared__ float kqmax_shared[ncols][WARP_SIZE]; - __shared__ float kqsum_shared[ncols][WARP_SIZE]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (threadIdx.y == 0) { - kqmax_shared[j][threadIdx.x] = -FLT_MAX/2.0f; - kqsum_shared[j][threadIdx.x] = 0.0f; - } - } - - __shared__ float maskf_shared[ncols*D]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - maskf_shared[j*D + tid] = 0.0f; - } - - __syncthreads(); - - // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: - float2 Q_f2[ncols][D/(2*WARP_SIZE)]; - int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)]; - float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; - if (Q_q8_1) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (j0 + nwarps > ncols && j >= ncols) { - break; - } - - // Reuse KQ as temporary storage for converting Q to q8_1: - int * tmp_q_i32 = (int *) &KQ[j*D]; - float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); - - // Set memory to zero if out of bounds: - if (ncols > 2 && ic0 + j >= ne01) { -#pragma unroll - for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - tmp_q_i32[i] = 0; - } - if (threadIdx.x < D/QK8_1) { - tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f); - } - continue; - } - - const float * Q_f = (const float *) (Q + j*nb01); -#pragma unroll - for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { - quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - int * tmp_q_i32 = (int *) &KQ[j*D]; - float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); - -#pragma unroll - for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; - Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; - } - } - - __syncthreads(); - } else { -#pragma unroll - for (int j = 0; j < ncols; ++j) { - const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); - Q_f2[j][i0/WARP_SIZE].x *= scale; - Q_f2[j][i0/WARP_SIZE].y *= scale; - } - } - } - - float VKQ[ncols] = {0.0f}; - - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - K += blockIdx.y*D * nb11; - V += blockIdx.y*D * nb21; - maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D, - // Increment pointers after each loop: - K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { - - // Calculate KQ tile and keep track of new maximum KQ values: - - if (mask) { -#pragma unroll - for (int j = 0; j < ncols; ++j) { - maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]); - } - __syncthreads(); - } - - float kqmax_new_arr[ncols]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqmax_new_arr[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - - if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { - break; - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); - sum = warp_reduce_sum(sum); - - if (use_logit_softcap) { - sum = logit_softcap*tanhf(sum); - } - - sum += maskf_shared[j*D + i_KQ]; - - kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); - - if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum; - } - } - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - float kqmax_new_j = kqmax_new_arr[j]; - - if (threadIdx.x == 0) { - kqmax_shared[j][threadIdx.y] = kqmax_new_j; - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - float kqmax_new_j = kqmax_shared[j][threadIdx.x]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); - kqmax[j] = kqmax_new_j; - - const float val = expf(KQ[j*D + tid] - kqmax[j]); - kqsum[j] = kqsum[j]*KQ_max_scale + val; - KQ[j*D + tid] = val; - - VKQ[j] *= KQ_max_scale; - } - - __syncthreads(); - -#pragma unroll - for (int k = 0; k < D; ++k) { - if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) { - break; - } - - const float V_ki = dequantize_1_v(V + k*nb21, tid); -#pragma unroll - for (int j = 0; j < ncols; ++j) { - VKQ[j] += V_ki*KQ[j*D + k]; - } - } - - __syncthreads(); - } - - if (sinksf && blockIdx.y == 0) { - const float sink = sinksf[head]; - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (threadIdx.x == 0) { - kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink); - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - float kqmax_new_j = kqmax_shared[j][threadIdx.x]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); - kqmax[j] = kqmax_new_j; - - const float val = expf(sink - kqmax[j]); - kqsum[j] = kqsum[j]*KQ_max_scale; - - if (tid == 0) { - kqsum[j] += val; - } - - VKQ[j] *= KQ_max_scale; - } - - __syncthreads(); - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqsum[j] = warp_reduce_sum(kqsum[j]); - if (threadIdx.x == 0) { - kqsum_shared[j][threadIdx.y] = kqsum[j]; - } - } - - __syncthreads(); - -#pragma unroll - for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ncols > 2 && ic0 + j_VKQ >= ne01) { - break; - } - - kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; - kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); - - float dst_val = VKQ[j_VKQ]; - if (gridDim.y == 1) { - dst_val /= kqsum[j_VKQ]; - } - dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val; - } - - if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); - } -#else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); - GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); - GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - NO_DEVICE_CODE; -#endif // FLASH_ATTN_AVAILABLE -} -#ifdef __clang__ -#pragma clang diagnostic pop -#endif // __clang__ - -template -void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; - constexpr bool need_f16_K = D != 128; - constexpr bool need_f16_V = D != 128 && D != 64; - constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); -} - -template -void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - GGML_ASSERT(K->type == type_K); - GGML_ASSERT(V->type == type_V); - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - - if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { - constexpr int cols_per_block = 1; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } - return; - } - - if (Q->ne[1] == 2) { - constexpr int cols_per_block = 2; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } - return; - } - - if (Q->ne[1] <= 4) { - constexpr int cols_per_block = 4; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 8; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } -} - -#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \ - template void ggml_cuda_flash_attn_ext_vec_f32_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ - -extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); - -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); - -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); - -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); - -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); - -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); - -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); - -extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh new file mode 100644 index 00000000..89ab0f16 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh @@ -0,0 +1,591 @@ +#include "common.cuh" +#include "fattn-common.cuh" + +static int ggml_cuda_fattn_vec_get_nthreads_host(const int cc) { + return 128; + GGML_UNUSED(cc); +} + +static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { + return 128; +} + +// Currenlty llvm with the amdgcn target dose not support unrolling loops +// that contain a break that can not be resolved at compile time. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template // D == head size +__launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1) +static __global__ void flash_attn_ext_vec( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#ifdef FLASH_ATTN_AVAILABLE + + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; + return; + } + + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + +#ifdef GGML_USE_HIP +#ifdef RDNA + constexpr int nthreads_KQ_q = 2; +#else + constexpr int nthreads_KQ_q = 4; +#endif // RDNA + constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32); +#else + constexpr int nthreads_KQ_q = (D/4 < 32 ? D/4 : 32); + constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32); +#endif // GGML_USE_HIP + + constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); + constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; + constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; + + static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); + static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); + + constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; + constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; + + constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ(); + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; +#ifdef FAST_FP16_AVAILABLE + constexpr dequantize_V_t dequantize_V = get_dequantize_V(); +#else + constexpr dequantize_V_t dequantize_V = get_dequantize_V(); +#endif // FAST_FP16_AVAILABLE + + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); + + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); + + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = nthreads / WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < nthreads); + + constexpr int ne_KQ = ncols*D; + constexpr int ne_combine = nwarps*V_cols_per_iter*D; +#ifdef FAST_FP16_AVAILABLE + half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; + __shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; +#else + float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; + __shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; +#endif // FAST_FP16_AVAILABLE + + float KQ_max[ncols]; + float KQ_sum[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_max[j] = -FLT_MAX/2.0f; + KQ_sum[j] = 0.0f; + } + + // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: +#ifdef FAST_FP16_AVAILABLE + half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely. +#else + float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. +#endif // FAST_FP16_AVAILABLE + int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; + float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; + if constexpr (Q_q8_1) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + + // Reuse KQ as temporary storage for converting Q to q8_1: + int * tmp_q_i32 = (int *) &KQ[j*D]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + + // Set memory to zero if out of bounds: + if (ncols > 1 && ic0 + j >= ne01) { +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) { + tmp_q_i32[i] = 0; + } + } + if (threadIdx.x < D/QK8_1) { + tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f); + } + } else { + const float * Q_f = (const float *) (Q + j*nb01); + constexpr int nthreads_quantize = D/sizeof(int) < WARP_SIZE ? D/sizeof(int) : WARP_SIZE; +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) { + quantize_q8_1_to_shared + (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1); + } + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + int * tmp_q_i32 = (int *) &KQ[j*D]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) { + const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ); + + Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i]; + Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1]; + } + } + + __syncthreads(); + } else { +#ifdef FAST_FP16_AVAILABLE + const half2 scale_h2 = make_half2(scale, scale); +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_j = (const float2 *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { + const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; + + float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; + if (ncols == 1 || ic0 + j < ne01) { + ggml_cuda_memcpy_1(tmp, &Q_j[i]); + ggml_cuda_memcpy_1(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); + } +#pragma unroll + for (int i1 = 0; i1 < cpy_ne; ++i1) { + Q_reg[j][i0/nthreads_KQ + i1] = make_half2(tmp[i1].x, tmp[i1].y); + } + } +#pragma unroll + for (int k = 0; k < (D/2)/nthreads_KQ; ++k) { + Q_reg[j][k] *= scale_h2; + } + } +#else +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_j = (const float2 *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { + const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; + if (ncols == 1 || ic0 + j < ne01) { + ggml_cuda_memcpy_1(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]); + ggml_cuda_memcpy_1(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]); + } + } +#pragma unroll + for (int k = 0; k < (D/2)/nthreads_KQ; ++k) { + Q_reg[j][k].x *= scale; + Q_reg[j][k].y *= scale; + } + } +#endif // FAST_FP16_AVAILABLE + } + + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + K += blockIdx.y*nthreads * nb11; + V += blockIdx.y*nthreads * nb21; + maskh += blockIdx.y*nthreads; + for (int k_VKQ_0 = blockIdx.y*nthreads; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nthreads, + // Increment pointers after each loop: + K += gridDim.y*nthreads*nb11, V += gridDim.y*nthreads*nb21, maskh += gridDim.y*nthreads) { + + // Calculate KQ tile and keep track of new maximum KQ values: + float KQ_reg[ncols]; // KQ in registers. + + float KQ_max_new[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_max_new[j] = KQ_max[j]; + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) { + const int i_KQ = threadIdx.y*WARP_SIZE + (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1))) + i_KQ_0; + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum(sum); + + if (use_logit_softcap) { + sum = logit_softcap*tanhf(sum); + } + + if (mask) { + sum += slope*__half2float(maskh[j*ne11 + i_KQ]); + } + + KQ_max_new[j] = fmaxf(KQ_max_new[j], sum); + + if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) { + KQ_reg[j] = sum; + } + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { +#pragma unroll + for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) { + KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE)); + } + const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]); + KQ_max[j] = KQ_max_new[j]; + + KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]); + KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j]; + KQ[j*nthreads + tid] = KQ_reg[j]; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; + VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + +#ifndef GGML_USE_HIP + __syncwarp(); +#endif // GGML_USE_HIP + +#pragma unroll + for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) { + const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V); + +#ifdef FAST_FP16_AVAILABLE + half2 KQ_k[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_k[j] = __half2half2(KQ[j*nthreads + k]); + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + half2 tmp[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j]; + } + } + } +#else + float KQ_k[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_k[j] = KQ[j*nthreads + k]; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + float2 tmp[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j]; + } + } + } +#endif // FAST_FP16_AVAILABLE + } + } + + if (sinks && blockIdx.y == 0) { + const float sink = ((const float *) sinks)[head]; + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + + const float kqmax_new_j = fmaxf(sink, KQ_max[j]); + const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j); + KQ_max[j] = kqmax_new_j; + + KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f); + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; + VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + } + + __shared__ float KQ_max_shared[ncols][WARP_SIZE]; + __shared__ float KQ_sum_shared[ncols][WARP_SIZE]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.y == 0) { + KQ_max_shared[j][threadIdx.x] = -FLT_MAX/2.0f; + KQ_sum_shared[j][threadIdx.x] = 0.0f; + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.x == 0) { + KQ_max_shared[j][threadIdx.y] = KQ_max[j]; + } + } + __syncthreads(); + +#pragma unroll + for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + if (ncols > 1 && ic0 + j_VKQ >= ne01) { + break; + } + + float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x]; + kqmax_new = warp_reduce_max(kqmax_new); + const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new); + KQ_max[j_VKQ] = kqmax_new; + +#ifdef FAST_FP16_AVAILABLE + half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2) + + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2); + + const half2 kqmax_scale_h2 = make_half2(kqmax_scale, kqmax_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2); + + ggml_cuda_memcpy_1(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); + } +#else + float2 * VKQ_tmp = (float2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2) + + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2); + +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j_VKQ][i_VKQ_0/nthreads_V].x *= kqmax_scale; + VKQ[j_VKQ][i_VKQ_0/nthreads_V].y *= kqmax_scale; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2); + + ggml_cuda_memcpy_1(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); + ggml_cuda_memcpy_1(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]); + } +#endif // FAST_FP16_AVAILABLE + + KQ_sum[j_VKQ] *= kqmax_scale; + KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); + if (threadIdx.x == 0) { + KQ_sum_shared[j_VKQ][threadIdx.y] = KQ_sum[j_VKQ]; + } + + __syncthreads(); + + if (nthreads <= D || tid < D) { + KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ][threadIdx.x]; + KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += nthreads) { + float dst_val = 0; +#pragma unroll + for (int w = 0; w < nwarps; ++w) { +#pragma unroll + for (int v = 0; v < V_cols_per_iter; ++v) { + dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]); + } + } + if (gridDim.y == 1) { + dst_val /= KQ_sum[j_VKQ]; + } + dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; + } + } + + if (j_VKQ < ncols-1) { + __syncthreads(); + } + + } + + if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) { + dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + +template +void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc); + const int nwarps = nthreads / WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_ext_vec; + constexpr bool need_f16_K = false; + constexpr bool need_f16_V = false; + constexpr size_t nbytes_shared = 0; + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); +} + +template +void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(K->type == type_K); + GGML_ASSERT(V->type == type_V); + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (Q->ne[1] == 1) { + constexpr int cols_per_block = 1; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_case_impl(ctx, dst); + } + return; + } + + constexpr int cols_per_block = 2; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_case_impl(ctx, dst); + } +} + +#define DECL_FATTN_VEC_CASE(D, type_K, type_V) \ + template void ggml_cuda_flash_attn_ext_vec_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ + +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) + +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) + +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 93d4d810..6c90d6d5 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -6,20 +6,19 @@ #include "fattn-common.cuh" #include "fattn-wmma-f16.cuh" -#ifdef FP16_MMA_AVAILABLE +#ifdef GGML_USE_WMMA_FATTN #if !defined(GGML_USE_HIP) #include -#ifdef GGML_USE_MUSA +#if defined(GGML_USE_MUSA) namespace wmma = mtmusa::wmma; #else // GGML_USE_MUSA namespace wmma = nvcuda::wmma; #endif // GGML_USE_MUSA -#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) -#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers +#elif defined(GGML_USE_HIP) #include namespace wmma = rocwmma; #endif // !defined(GGML_USE_HIP) -#endif // FP16_MMA_AVAILABLE +#endif // GGML_USE_WMMA_FATTN // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template @@ -46,7 +45,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) +#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -82,11 +81,12 @@ static __global__ void flash_attn_ext_f16( const int sequence = blockIdx.z / ne02; const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const half2 * mask2 = (const half2 *) maskh; + const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const half2 * mask2 = (const half2 *) maskh; + const float * sinksf = (const float *) sinks; const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -381,6 +381,53 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); } + // Apply attention sinks + if (sinksf && blockIdx.y == 0) { + const float sinkf = sinksf[head]; + const half sinkh = __float2half(sinkf); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (std::is_same::value) { + float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf); + + const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new); + KQ_max_f[j0/nwarps] = kqmax_new; + + KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]); + + const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const int i = i0 + threadIdx.x; + if (i0 + warp_size > D/2 && i >= D/2) break; + VKQ2[j*(D_padded/2) + i] *= scale_h2; + } + } else { + half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]); + half kqmax_new = fmaxf(kqmax_old, sinkh); + KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new); + + const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new); + const half2 KQ_max_scale = __half2half2(KQ_max_scale_h); + + KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale; + const half val = hexp(sinkh - kqmax_new); + KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val); + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const int i = i0 + threadIdx.x; + if (i0 + warp_size > D/2 && i >= D/2) break; + VKQ2[j*(D_padded/2) + i] *= KQ_max_scale; + } + } + } + + __syncthreads(); + } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j_VKQ = j0 + threadIdx.y; @@ -424,18 +471,17 @@ static __global__ void flash_attn_ext_f16( dst_meta[j_dst_unrolled] = dst_meta_val; } #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31); - GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); - GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); - GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) +#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) } constexpr int get_max_power_of_2(int x) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index beeea95e..7235f1b7 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -1,3 +1,51 @@ +#pragma once + #include "common.cuh" +#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#define GGML_USE_WMMA_FATTN +#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) + +#if defined(GGML_HIP_ROCWMMA_FATTN) +#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) +#define GGML_USE_WMMA_FATTN +#elif defined(CDNA) +#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance" +#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) +#if defined(RDNA3) +#define GGML_USE_WMMA_FATTN +#endif // defined(RDNA3) +#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 +#define GGML_USE_WMMA_FATTN +#elif defined(RDNA4) +#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance" +#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 +#endif // defined(GGML_HIP_ROCWMMA_FATTN) + +// WMMA flash attention requires FP16 matrix instructions to be available for ggml code. +static bool ggml_cuda_should_use_wmma_fattn(const int cc) { +#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) + return false; +#else + if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) || + GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) { + return true; + } else if (GGML_CUDA_CC_IS_CDNA(cc)){ +#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) + return true; +#else + return false; +#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) + } else if (GGML_CUDA_CC_IS_RDNA4(cc)) { +#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1 + return true; +#else + return false; +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1 + } else { + return false; + } +#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) +} + void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu index 6c1185de..fe970ada 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu @@ -1,10 +1,8 @@ #include "common.cuh" #include "fattn-common.cuh" #include "fattn-mma-f16.cuh" -#include "fattn-tile-f16.cuh" -#include "fattn-tile-f32.cuh" -#include "fattn-vec-f16.cuh" -#include "fattn-vec-f32.cuh" +#include "fattn-tile.cuh" +#include "fattn-vec.cuh" #include "fattn-wmma-f16.cuh" #include "fattn.cuh" @@ -118,232 +116,230 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg } } -#define FATTN_VEC_F16_CASE(D, type_K, type_V) \ - if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ - ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \ - return; \ - } \ +#define FATTN_VEC_CASE(D, type_K, type_V) \ + if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ + ggml_cuda_flash_attn_ext_vec_case(ctx, dst); \ + return; \ + } \ -static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \ + FATTN_VEC_CASE( 64, type_K, type_V) \ + FATTN_VEC_CASE(128, type_K, type_V) \ + FATTN_VEC_CASE(256, type_K, type_V) \ + +static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_tensor * Q = dst->src[0]; ggml_tensor * K = dst->src[1]; ggml_tensor * V = dst->src[2]; #ifdef GGML_CUDA_FA_ALL_QUANTS - FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0) - FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1) - FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) - FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) - FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) - FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 ) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0) - - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) - - FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) #else - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) - - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) - - FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) #endif // GGML_CUDA_FA_ALL_QUANTS - on_no_fattn_vec_case(Q->ne[0]); + GGML_ABORT("fatal error"); } -#define FATTN_VEC_F32_CASE(D, type_K, type_V) \ - if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ - ggml_cuda_flash_attn_ext_vec_f32_case(ctx, dst); \ - return; \ - } \ +// Best FlashAttention kernel for a specific GPU: +enum best_fattn_kernel { + BEST_FATTN_KERNEL_NONE = 0, + BEST_FATTN_KERNEL_TILE = 200, + BEST_FATTN_KERNEL_VEC = 100, + BEST_FATTN_KERNEL_WMMA_F16 = 300, + BEST_FATTN_KERNEL_MMA_F16 = 400, +}; -static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * Q = dst->src[0]; - ggml_tensor * K = dst->src[1]; - ggml_tensor * V = dst->src[2]; +static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) { +#ifndef FLASH_ATTN_AVAILABLE + GGML_UNUSED(device); GGML_UNUSED(dst); + return BEST_FATTN_KERNEL_NONE; +#endif// FLASH_ATTN_AVAILABLE -#ifdef GGML_CUDA_FA_ALL_QUANTS - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0) - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1) - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) - - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0) - - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1) - - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0) - - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1) - - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0) - - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16) - FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) - - FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) -#else - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) - - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) - - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) - FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) - FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) -#endif // GGML_CUDA_FA_ALL_QUANTS - - on_no_fattn_vec_case(Q->ne[0]); -} - -void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; const ggml_tensor * mask = dst->src[3]; - const ggml_tensor * sinks = dst->src[4]; - ggml_cuda_set_device(ctx.device); - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; - const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); - // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS] - if (sinks && !fp16_mma_available(cc)) { - if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - } - return; - } + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); -#if defined(GGML_HIP_ROCWMMA_FATTN) - if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { - ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); - return; - } -#endif // defined(GGML_HIP_ROCWMMA_FATTN) + // The effective batch size for the kernel can be increased by gqa_ratio. + // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, + const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; - if (!fast_fp16_available(cc)) { - if (Q->ne[1] <= 8 || Q->ne[0] == 256) { - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); - } - return; - } + const int cc = ggml_cuda_info().devices[device].cc; - if (!fp16_mma_available(cc)) { - if (prec == GGML_PREC_DEFAULT) { - if (Q->ne[1] <= 8 || Q->ne[0] == 256) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + switch (K->ne[0]) { + case 40: + case 64: + case 80: + case 96: + case 128: + case 112: + case 256: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; } - } else { - if (Q->ne[1] <= 8 || Q->ne[0] == 256) { - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + break; + case 576: + if (V->ne[0] != 512) { + return BEST_FATTN_KERNEL_NONE; + } + if (!gqa_opt_applies || gqa_ratio % 16 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; + default: + return BEST_FATTN_KERNEL_NONE; + } + +#ifndef GGML_CUDA_FA_ALL_QUANTS + if (K->type != V->type) { + return BEST_FATTN_KERNEL_NONE; + } +#endif // GGML_CUDA_FA_ALL_QUANTS + + switch (K->type) { + case GGML_TYPE_F16: + break; + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: +#ifndef GGML_CUDA_FA_ALL_QUANTS + return BEST_FATTN_KERNEL_NONE; +#endif // GGML_CUDA_FA_ALL_QUANTS + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + break; + default: + return BEST_FATTN_KERNEL_NONE; + } + + if (mask && mask->ne[2] != 1) { + return BEST_FATTN_KERNEL_NONE; + } + + // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + + // If Turing tensor cores available, use them: + if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) { + if (can_use_vector_kernel) { + if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { + if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { + return BEST_FATTN_KERNEL_VEC; + } } else { - ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + if (Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + } else { + if (Q->ne[1] == 1) { + return BEST_FATTN_KERNEL_VEC; + } + } + } + if (!gqa_opt_applies && Q->ne[1] == 1) { + return BEST_FATTN_KERNEL_VEC; } } - return; + + return BEST_FATTN_KERNEL_MMA_F16; } - const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations - const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; - const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192); - const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion && - (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000); - const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; - if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { - if (prec == GGML_PREC_DEFAULT) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + // Use the WMMA kernel if possible: + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) { + if (can_use_vector_kernel && Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; } - return; + return BEST_FATTN_KERNEL_WMMA_F16; } - // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: - if (fp16_mma_available(cc) && !turing_mma_available(cc)) { - ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); - return; + // If there are no tensor cores available, use the generic tile kernel: + if (can_use_vector_kernel) { + if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { + if (Q->ne[1] == 1) { + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_VEC; + } + } + } else { + if (Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + } } - - ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); + return BEST_FATTN_KERNEL_TILE; +} + +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_set_device(ctx.device); + switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { + case BEST_FATTN_KERNEL_NONE: + GGML_ABORT("fatal error"); + case BEST_FATTN_KERNEL_TILE: + ggml_cuda_flash_attn_ext_tile(ctx, dst); + break; + case BEST_FATTN_KERNEL_VEC: + ggml_cuda_flash_attn_ext_vec(ctx, dst); + break; + case BEST_FATTN_KERNEL_WMMA_F16: + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + break; + case BEST_FATTN_KERNEL_MMA_F16: + ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); + break; + } +} + +bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) { + return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE; } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cuh index ad3ca7a8..78705d59 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cuh @@ -1,3 +1,5 @@ #include "common.cuh" void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu b/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu index f77b2629..2fab3324 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu @@ -1,68 +1,71 @@ #include "getrows.cuh" #include "dequantize.cuh" +#include "convert.cuh" template static __global__ void k_get_rows( const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ - /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. - const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2; - const int i10 = blockIdx.x; - const int i11 = blockIdx.z / ne12; - const int i12 = blockIdx.z % ne12; + for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i10 = blockIdx.x; + const int i11 = z / ne12; // TODO fastdiv + const int i12 = z % ne12; - if (i00 >= ne00) { - return; + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; + + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index + const int iybs = i00 - i00%qk; // dst block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + float2 v; + dequantize_kernel(src0_row, ib, iqs, v); + + dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x); + dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y); + } } - - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; - - const int ib = i00/qk; // block index - const int iqs = (i00%qk)/qr; // quant index - const int iybs = i00 - i00%qk; // dst block start index - const int y_offset = qr == 1 ? 1 : qk/2; - - // dequantize - dfloat2 v; - dequantize_kernel(src0_row, ib, iqs, v); - - dst_row[iybs + iqs + 0] = float(v.x); - dst_row[iybs + iqs + y_offset] = float(v.y); } template static __global__ void k_get_rows_float( const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ - /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. - const int i00 = blockIdx.y * blockDim.x + threadIdx.x; - const int i10 = blockIdx.x; - const int i11 = blockIdx.z / ne12; - const int i12 = blockIdx.z % ne12; + for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i10 = blockIdx.x; + const int i11 = z / ne12; // TODO fastdiv + const int i12 = z % ne12; - if (i00 >= ne00) { - return; + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); + + dst_row[i00] = ggml_cuda_cast(src0_row[i00]); + } } - - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); - - dst_row[i00] = float(src0_row[i00]); } template @@ -97,7 +100,7 @@ static void get_rows_cuda_q( cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); - const dim3 block_nums(ne10, block_num_y, ne11*ne12); + const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX)); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -115,7 +118,7 @@ static void get_rows_cuda_q( k_get_rows<<>>( src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ + /*ne10,*/ ne11, ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); @@ -130,7 +133,7 @@ static void get_rows_cuda_float( cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; - const dim3 block_nums(ne10, block_num_y, ne11*ne12); + const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX)); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -146,7 +149,7 @@ static void get_rows_cuda_float( k_get_rows_float<<>>( src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ + /*ne10,*/ ne11, ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu index e43fde52..87941f87 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu @@ -12,6 +12,7 @@ #include "ggml-cuda/clamp.cuh" #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" +#include "ggml-cuda/conv2d.cuh" #include "ggml-cuda/conv2d-dw.cuh" #include "ggml-cuda/conv2d-transpose.cuh" #include "ggml-cuda/convert.cuh" @@ -28,6 +29,7 @@ #include "ggml-cuda/mmvq.cuh" #include "ggml-cuda/norm.cuh" #include "ggml-cuda/opt-step-adamw.cuh" +#include "ggml-cuda/opt-step-sgd.cuh" #include "ggml-cuda/out-prod.cuh" #include "ggml-cuda/pad.cuh" #include "ggml-cuda/pool2d.cuh" @@ -43,11 +45,13 @@ #include "ggml-cuda/sumrows.cuh" #include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" +#include "ggml-cuda/topk-moe.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" #include "ggml-cuda/set-rows.cuh" +#include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml.h" #include @@ -230,30 +234,6 @@ static std::string ggml_cuda_parse_uuid(cudaDeviceProp prop, int device_num) { } static ggml_cuda_device_info ggml_cuda_init() { -#if defined(GGML_USE_HIP) - // Workaround for a rocBLAS bug when using multiple graphics cards: - // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346 - { - int major_version = 0; - size_t version_length = 0; - if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) { - std::vector version(version_length+1, '\0'); - if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) { - version.resize(::strlen(version.data())); - int parsed_value = 0; - if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) { - major_version = parsed_value; - } - } - } - if (major_version < 4) { - GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n"); - rocblas_initialize(); - CUDA_CHECK(cudaDeviceSynchronize()); - } - } -#endif - ggml_cuda_device_info info = {}; cudaError_t err = cudaGetDeviceCount(&info.device_count); @@ -276,9 +256,21 @@ static ggml_cuda_device_info ggml_cuda_init() { GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__); #endif // GGML_CUDA_FORCE_CUBLAS GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); + + std::vector> turing_devices_without_mma; for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; +#if defined(GGML_USE_HIP) + if (std::getenv("GGML_CUDA_INIT") != NULL) { + GGML_LOG_INFO("%s: initializing rocBLAS on device %d\n", __func__, id); + CUDA_CHECK(cudaSetDevice(id)); + // rocblas_initialize will SIGABRT if the GPU isn't supported + rocblas_initialize(); + GGML_LOG_INFO("%s: rocBLAS initialized on device %d\n", __func__, id); + } +#endif + #if defined(GGML_USE_VMM) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); @@ -299,7 +291,7 @@ static ggml_cuda_device_info ggml_cuda_init() { info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; - info.devices[id].integrated = prop.integrated; + info.devices[id].integrated = false; // Temporarily disabled due to issues with corrupted output (e.g. #15034) info.devices[id].nsm = prop.multiProcessorCount; info.devices[id].smpb = prop.sharedMemPerBlock; info.devices[id].warp_size = prop.warpSize; @@ -332,10 +324,42 @@ static ggml_cuda_device_info ggml_cuda_init() { #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; +#ifdef __CUDA_ARCH_LIST__ + if (std::getenv("GGML_CUDA_INIT") != NULL) { + GGML_ASSERT(ggml_cuda_has_arch(info.devices[id].cc) && "ggml was not compiled with support for this arch"); + } +#endif // defined(__CUDA_ARCH_LIST__) GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", ggml_cuda_parse_uuid(prop, id).c_str()); -#endif // defined(GGML_USE_HIP) + std::string device_name(prop.name); + if (device_name == "NVIDIA GeForce MX450") { + turing_devices_without_mma.push_back({ id, device_name }); + } else if (device_name == "NVIDIA GeForce MX550") { + turing_devices_without_mma.push_back({ id, device_name }); + } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") { + turing_devices_without_mma.push_back({ id, device_name }); + } + + // Temporary performance fix: + // Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls. + // TODO: Check for future drivers the default scheduling strategy and + // remove this call again when cudaDeviceScheduleSpin is default. + if (prop.major == 12 && prop.minor == 1) { + CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin)); + } + +#endif // defined(GGML_USE_HIP) + } + + if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) { + GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n"); + for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) { + GGML_LOG_INFO( + " Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str()); + } + GGML_LOG_INFO( + "Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n"); } for (int id = 0; id < info.device_count; ++id) { @@ -355,6 +379,8 @@ const ggml_cuda_device_info & ggml_cuda_info() { // #define DEBUG_CUDA_MALLOC +#define CUDA_ALIGNMENT 128 + // buffer pool for cuda (legacy) struct ggml_cuda_pool_leg : public ggml_cuda_pool { static const int MAX_BUFFERS = 256; @@ -367,9 +393,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {}; size_t pool_size = 0; + bool allocate = true; + size_t last_alloc = 0; - explicit ggml_cuda_pool_leg(int device) : - device(device) { + explicit ggml_cuda_pool_leg(int device, bool alloc) : + device(device), + allocate(alloc) { } ~ggml_cuda_pool_leg() { @@ -377,7 +406,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cuda_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { - CUDA_CHECK(cudaFree(b.ptr)); + if (allocate) { + CUDA_CHECK(cudaFree(b.ptr)); + } pool_size -= b.size; } } @@ -425,8 +456,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { void * ptr; size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); - ggml_cuda_set_device(device); - CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); + if (allocate) { + ggml_cuda_set_device(device); + if (ggml_cuda_device_malloc(&ptr, look_ahead_size, device) != cudaSuccess) { + last_alloc = look_ahead_size; + throw std::bad_alloc(); + } + } else { + ptr = (void *)CUDA_ALIGNMENT; + } *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC @@ -446,10 +484,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { } } GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n"); - ggml_cuda_set_device(device); - CUDA_CHECK(cudaFree(ptr)); + if (allocate) { + ggml_cuda_set_device(device); + CUDA_CHECK(cudaFree(ptr)); + } pool_size -= size; } + + bool alloc_memory() override { + return allocate; + } + + size_t alloc_size() override { + return pool_size + last_alloc; + } }; // pool with virtual memory @@ -461,18 +509,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { CUdeviceptr pool_addr = 0; size_t pool_used = 0; size_t pool_size = 0; + bool allocate = true; + size_t last_alloc = 0; size_t granularity; #if defined(GGML_USE_HIP) std::vector> mappings; #endif - explicit ggml_cuda_pool_vmm(int device) : + explicit ggml_cuda_pool_vmm(int device, bool alloc) : device(device), - granularity(ggml_cuda_info().devices[device].vmm_granularity) { + granularity(ggml_cuda_info().devices[device].vmm_granularity), + allocate(alloc) { + if (!allocate) { + pool_addr = (CUdeviceptr)CUDA_ALIGNMENT; + } } ~ggml_cuda_pool_vmm() { - if (pool_addr != 0) { + if (pool_addr != 0 && allocate) { #if defined(GGML_USE_HIP) // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285 for (std::pair & mapping : mappings) { @@ -499,36 +553,50 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE); - // allocate more physical memory - CUmemAllocationProp prop = {}; - prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = device; - CUmemGenericAllocationHandle handle; - CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0)); + if (allocate) { + // allocate more physical memory + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + CUmemGenericAllocationHandle handle; + if (cuMemCreate(&handle, reserve_size, &prop, 0) != CUDA_SUCCESS) { + last_alloc = reserve_size; + throw std::bad_alloc(); + } - // reserve virtual address space (if not already reserved) - if (pool_addr == 0) { - CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0)); + // reserve virtual address space (if not already reserved) + if (pool_addr == 0) { + CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0)); + } + + // map at the end of the pool + CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size); + if (cuMemMap(start_ptr, reserve_size, 0, handle, 0) != CUDA_SUCCESS) { + last_alloc = reserve_size; + CU_CHECK(cuMemRelease(handle)); + throw std::bad_alloc(); + } + + // the memory allocation handle is no longer needed after mapping + CU_CHECK(cuMemRelease(handle)); + + // set access + CUmemAccessDesc access = {}; + access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access.location.id = device; + access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + if (cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1) != CUDA_SUCCESS) { + CU_CHECK(cuMemUnmap(start_ptr, reserve_size)); + last_alloc = reserve_size; + throw std::bad_alloc(); + } + + #if defined(GGML_USE_HIP) + mappings.push_back({start_ptr, reserve_size}); + #endif } - // map at the end of the pool - CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size); - CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0)); -#if defined(GGML_USE_HIP) - mappings.push_back({start_ptr, reserve_size}); -#endif - - // the memory allocation handle is no longer needed after mapping - CU_CHECK(cuMemRelease(handle)); - - // set access - CUmemAccessDesc access = {}; - access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - access.location.id = device; - access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1)); - // add to the pool pool_size += reserve_size; @@ -560,16 +628,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { // all deallocations must be in reverse order of the allocations GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); } + + bool alloc_memory() override { + return allocate; + } + + size_t alloc_size() override { + return pool_size + last_alloc; + } }; #endif // defined(GGML_USE_VMM) -std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { +std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device, bool alloc) { #if defined(GGML_USE_VMM) if (ggml_cuda_info().devices[device].vmm) { - return std::unique_ptr(new ggml_cuda_pool_vmm(device)); + return std::unique_ptr(new ggml_cuda_pool_vmm(device, alloc)); } #endif // defined(GGML_USE_VMM) - return std::unique_ptr(new ggml_cuda_pool_leg(device)); + return std::unique_ptr(new ggml_cuda_pool_leg(device, alloc)); } // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error @@ -753,11 +829,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac } static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 128; + return CUDA_ALIGNMENT; GGML_UNUSED(buft); } +static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_noalloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; + + void * dev_ptr = (void *)ggml_backend_cuda_buffer_type_get_alignment(buft); + ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr); + + return ggml_backend_buffer_init(buft, {}, ctx, size); +} + static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { size_t size = ggml_nbytes(tensor); int64_t ne0 = tensor->ne[0]; @@ -781,6 +866,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, /* .is_host = */ NULL, + /* .noalloc_buffer = */ ggml_backend_cuda_buffer_type_noalloc_buffer, }; ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { @@ -1406,9 +1492,7 @@ static void ggml_cuda_op_mul_mat_cublas( &beta, dst_dd_i, ldc)); } - GGML_UNUSED(dst); - GGML_UNUSED(src1_ddq_i); - GGML_UNUSED(src1_padded_row_size); + GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size); } static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { @@ -2088,7 +2172,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor const int cc = ggml_cuda_info().devices[id].cc; const int warp_size = ggml_cuda_info().devices[id].warp_size; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]); + use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false); use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } @@ -2096,7 +2180,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor const int cc = ggml_cuda_info().devices[ctx.device].cc; const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]); + use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false); use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } @@ -2167,6 +2251,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst); return; } + + if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], /*mul_mat_id=*/true)) { + ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst); + return; + } } cudaStream_t stream = ctx.stream(); @@ -2386,6 +2475,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_ELU: ggml_cuda_op_elu(ctx, dst); break; + case GGML_UNARY_OP_XIELU: + ggml_cuda_op_xielu(ctx, dst); + break; default: return false; } @@ -2432,6 +2524,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_PAD: ggml_cuda_op_pad(ctx, dst); break; + case GGML_OP_PAD_REFLECT_1D: + ggml_cuda_op_pad_reflect_1d(ctx, dst); + break; case GGML_OP_ARANGE: ggml_cuda_op_arange(ctx, dst); break; @@ -2507,6 +2602,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_IM2COL: ggml_cuda_op_im2col(ctx, dst); break; + case GGML_OP_IM2COL_3D: + ggml_cuda_op_im2col_3d(ctx, dst); + break; + case GGML_OP_CONV_2D: + ggml_cuda_op_conv2d(ctx, dst); + break; case GGML_OP_CONV_2D_DW: ggml_cuda_op_conv2d_dw(ctx, dst); break; @@ -2558,6 +2659,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_OPT_STEP_ADAMW: ggml_cuda_opt_step_adamw(ctx, dst); break; + case GGML_OP_OPT_STEP_SGD: + ggml_cuda_opt_step_sgd(ctx, dst); + break; default: return false; } @@ -2692,6 +2796,8 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased"; const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased"; const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; + const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out"; + const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d"; for (int i = 0; i < cgraph->n_nodes; i++) { @@ -2727,7 +2833,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && - strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) { + strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && + strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && + strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) { // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation // by means of matching node names. See // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and @@ -2884,13 +2992,56 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, GGML_ASSERT(unary_ops.size() == num_unary); #endif + //TODO: remove special case once ggml_can_fuse can handle empty nodes + std::initializer_list topk_moe_ops = ggml_cuda_topk_moe_ops(false); + std::initializer_list topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true); + + if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) { + + if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) { + return false; + } + + for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) { + if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false; + } + ggml_tensor * softmax = cgraph->nodes[node_idx]; + ggml_tensor * weights = cgraph->nodes[node_idx+8]; + + if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + return true; + } + } + + if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) { + + if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) { + return false; + } + + for (size_t i = 0; i < topk_moe_ops.size(); i++) { + if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false; + } + + ggml_tensor * softmax = cgraph->nodes[node_idx]; + ggml_tensor * weights = cgraph->nodes[node_idx+4]; + if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + return true; + } + } + if (!ggml_can_fuse(cgraph, node_idx, ops)) { return false; } - if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + const ggml_tensor *add = nullptr; + + if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) { + add = cgraph->nodes[node_idx+2]; + } GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); @@ -2902,6 +3053,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } + if (add && (add->src[0]->type != GGML_TYPE_F32 || + add->src[1]->type != GGML_TYPE_F32 || + add->type != GGML_TYPE_F32) ) { + return false; + } + //if rms norm is the B operand, then we don't handle broadcast if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { return false; @@ -2912,6 +3069,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } + if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) { + return false; + } + return true; } @@ -2941,6 +3102,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { + // flag used to determine whether it is an integrated_gpu const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; @@ -2956,9 +3118,69 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + // When reserving, we are forcing CUDA graphs but this operation is not graph-safe so we need to skip it + if (reserving_graph && node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) { + continue; + } + static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) { + + if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) { + ggml_tensor * weights = cgraph->nodes[i+8]; + ggml_tensor * selected_experts = cgraph->nodes[i+3]; + ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true); + i += 8; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { + ggml_tensor * weights = cgraph->nodes[i+4]; + ggml_tensor * selected_experts = cgraph->nodes[i+3]; + ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false); + i += 4; + continue; + } + + if (node->op == GGML_OP_ADD) { + int n_fuse = 0; + ggml_op ops[8]; + std::fill(ops, ops + 8, GGML_OP_ADD); + + for (; n_fuse <= 6; ++n_fuse){ + if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { + break; + } + if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) { + break; + } + if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) { + break; + } + } + + n_fuse++; + + if (n_fuse > 1) { + for (int j = 0; j < n_fuse - 1; ++j) { + node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + } + cgraph->nodes[i + n_fuse - 1]->data = node->data; + ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse); + i += n_fuse - 1; + + continue; + } + } + + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { + ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) { ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); i++; continue; @@ -3027,6 +3249,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + cuda_ctx->pool_set_alloc(true); ggml_cuda_set_device(cuda_ctx->device); @@ -3106,6 +3329,71 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, return GGML_STATUS_SUCCESS; } +// This is used to skip operations that are not graph safe during the reservation process. +bool reserving_graph = false; + +static enum ggml_status ggml_backend_cuda_graph_reserve(ggml_backend_t backend, ggml_cgraph * cgraph, bool alloc) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + cuda_ctx->pool_set_alloc(alloc); + + #ifdef USE_CUDA_GRAPH + if (cuda_ctx->cuda_graph == nullptr) { + cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); + } + #endif + + ggml_cuda_set_device(cuda_ctx->device); + + { + std::lock_guard lock(ggml_cuda_lock); + ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed); + } + + reserving_graph = true; + + // Create CuBLAS handles early to avoid synchronous allocations during graph capture. + cuda_ctx->cublas_handle(); + + CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); + + enum ggml_status result = GGML_STATUS_SUCCESS; + + try { + bool use_cuda_graph = false; + bool cuda_graph_update_required = false; + bool graph_evaluated_or_captured = false; + + evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + } catch (const std::exception &e) { + result = GGML_STATUS_FAILED; + } + + cudaGraph_t graph; + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph)); + CUDA_CHECK(cudaGraphDestroy(graph)); + + reserving_graph = false; + + { + std::lock_guard lock(ggml_cuda_lock); + if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) { + ggml_cuda_lock_cv.notify_all(); + } + } + + return result; +} + +static size_t ggml_backend_cuda_buffer_size(ggml_backend_t backend) { + ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context; + return ctx->pool_get_alloc_size(); +} + +static void ggml_backend_cuda_reset(ggml_backend_t backend) { + ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context; + ctx->pools[ctx->device] = NULL; +} + static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; @@ -3145,6 +3433,10 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, + /* .graph_optimize = */ NULL, + /* .graph_reserve = */ ggml_backend_cuda_graph_reserve, + /* .buffer_size = */ ggml_backend_cuda_buffer_size, + /* .reset = */ ggml_backend_cuda_reset, }; static ggml_guid_t ggml_backend_cuda_guid() { @@ -3177,7 +3469,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { return false; } -#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) +#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) || defined(GGML_USE_HIP) cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); if (err != cudaSuccess) { // clear the error @@ -3214,7 +3506,16 @@ struct ggml_backend_cuda_device_context { int device; std::string name; std::string description; + std::string pci_bus_id; std::string id; + int major; + int minor; + int driver_major; + int driver_minor; + int integrated; + int pciBusID; + int pciDeviceID; + int pciDomainID; }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { @@ -3235,6 +3536,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_cuda_set_device(ctx->device); + +#if defined(GGML_USE_HIP) + if (ggml_hip_mgmt_init() == 0) { + int status = ggml_hip_get_device_memory(ctx->pciBusID, ctx->pciDeviceID, free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_hip_mgmt_release(); + return; + } + ggml_hip_mgmt_release(); + } +#else + if (ggml_nvml_init() == 0) { + int status = ggml_nvml_get_device_memory(ctx->id.c_str(), free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_nvml_release(); + return; + } + ggml_nvml_release(); + } +#endif CUDA_CHECK(cudaMemGetInfo(free, total)); } @@ -3243,16 +3566,36 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend return GGML_BACKEND_DEVICE_TYPE_GPU; } +#define GGML_HIP_NAME "HIP" static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + props->name = ggml_backend_cuda_device_get_name(dev); props->description = ggml_backend_cuda_device_get_description(dev); props->id = ggml_backend_cuda_device_get_id(dev); props->type = ggml_backend_cuda_device_get_type(dev); + props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device). // If you need the memory data, call ggml_backend_dev_memory() explicitly. props->memory_total = props->memory_free = 0; +#if defined(GGML_USE_HIP) + int cc = ggml_cuda_info().devices[ctx->device].cc - GGML_CUDA_CC_OFFSET_AMD; + props->compute_major = cc / 0x100; + props->compute_minor = cc - (props->compute_major * 0x100); +#else + props->compute_major = ctx->major; + props->compute_minor = ctx->minor; +#endif + props->driver_major = ctx->driver_major; + props->driver_minor = ctx->driver_minor; + props->integrated = ctx->integrated; + props->pci_bus_id = ctx->pciBusID; + props->pci_device_id = ctx->pciDeviceID; + props->pci_domain_id = ctx->pciDomainID; + props->library = GGML_CUDA_NAME; + bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; #ifdef GGML_CUDA_NO_PEER_COPY bool events = false; @@ -3437,7 +3780,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) && op->src[0]->type == GGML_TYPE_F32 && - op->src[1]->type == GGML_TYPE_I64; + (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); } break; case GGML_OP_CPY: { @@ -3481,6 +3824,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { return true; } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) { + return true; + } + if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) { return true; } @@ -3582,19 +3931,24 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]); } case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - case GGML_OP_ARGSORT: case GGML_OP_ACC: return true; + case GGML_OP_ARGSORT: + // TODO: Support arbitrary column width + return op->src[0]->ne[0] <= 1024; + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_GROUP_NORM: + case GGML_OP_PAD: return ggml_is_contiguous(op->src[0]); case GGML_OP_UPSCALE: - case GGML_OP_PAD: + case GGML_OP_PAD_REFLECT_1D: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: @@ -3602,47 +3956,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: return true; - case GGML_OP_FLASH_ATTN_EXT: { -#ifndef FLASH_ATTN_AVAILABLE - return false; -#endif // FLASH_ATTN_AVAILABLE - if (op->src[1]->ne[0] != op->src[2]->ne[0]) { - const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; - if (!turing_mma_available(cc)) { - return false; - } - const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; - return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0; - } - // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS] - if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) - && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { - return false; - } - if (op->src[0]->ne[0] == 192) { - return false; - } - if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { - return false; - } - if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { - return true; - } - if (op->src[0]->ne[0] == 128) { - return true; - } - if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { - return true; - } - if (op->src[3] && op->src[3]->ne[2] != 1) { - return false; - } - return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) && - op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; - } + case GGML_OP_FLASH_ATTN_EXT: + return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: return true; default: return false; @@ -3780,10 +4099,6 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t features.push_back({ "NO_PEER_COPY", "1" }); #endif - #ifdef GGML_CUDA_F16 - features.push_back({ "F16", "1" }); - #endif - #ifdef GGML_CUDA_USE_GRAPHS features.push_back({ "USE_GRAPHS", "1" }); #endif @@ -3843,6 +4158,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { std::lock_guard lock(mutex); if (!initialized) { ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; + int driverVersion = 0; + CUDA_CHECK(cudaDriverGetVersion(&driverVersion)); for (int i = 0; i < ggml_cuda_info().device_count; i++) { ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; @@ -3854,6 +4171,18 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { dev_ctx->description = prop.name; dev_ctx->id = ggml_cuda_parse_uuid(prop, i); + char pci_bus_id[16] = {}; + snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); + dev_ctx->pci_bus_id = pci_bus_id; + + dev_ctx->major = prop.major; + dev_ctx->minor = prop.minor; + dev_ctx->driver_major = driverVersion / 1000; + dev_ctx->driver_minor = (driverVersion - (dev_ctx->driver_major * 1000)) / 10; + dev_ctx->integrated = prop.integrated; + dev_ctx->pciBusID = prop.pciBusID; + dev_ctx->pciDeviceID = prop.pciDeviceID; + dev_ctx->pciDomainID = prop.pciDomainID; ggml_backend_dev_t dev = new ggml_backend_device { /* .iface = */ ggml_backend_cuda_device_interface, /* .reg = */ ®, diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/im2col.cu b/ml/backend/ggml/ggml/src/ggml-cuda/im2col.cu index 16bb9bec..56dc0545 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/im2col.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/im2col.cu @@ -112,3 +112,153 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } } + +// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] +template +static __global__ void im2col_3d_kernel( + const float * src, T * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW, + int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW, + int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) { + const int64_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= IC_KD_KH_KW) { + return; + } + GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH); + GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW); + + const int64_t iic = i / KD_KH_KW; + const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW; + const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; + const int64_t ikw = i % KW; + + const int64_t iow = blockIdx.y; + for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) { + const int64_t in = iz / OD_OH; + const int64_t iod = (iz - in*OD_OH) / OH; + const int64_t ioh = iz % OH; + + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iid = iod * s2 + ikd * d2 - p2; + + const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); + dst[offset_dst] = src[offset_src]; + } + } +} + +// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] +template +static void im2col_3d_cuda(const float * src, T* dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { + const int64_t OH_OW = OH*OW; + const int64_t KD_KH_KW = KD*KH*KW; + const int64_t ID_IH_IW = ID*IH*IW; + const int64_t KH_KW = KH*KW; + const int64_t IH_IW = IH*IW; + const int64_t IC_KD_KH_KW = IC*KD*KH*KW; + const int64_t OW_KD_KH_KW = OW*KD*KH*KW; + const int64_t N_OD_OH = N*OD*OH; + const int64_t OD_OH = OD*OH; + const int64_t IC_ID_IH_IW = IC*ID*IH*IW; + const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; + const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; + dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z)); + im2col_3d_kernel<<>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, + IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW, + OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2); +} + +static void im2col_3d_cuda_f16(const float * src, half * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { + + im2col_3d_cuda(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); +} + +static void im2col_3d_cuda_f32(const float * src, float * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { + + im2col_3d_cuda(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); +} + +void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t OC = ne03 / IC; + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + + const size_t es = ggml_element_size(src1); + const int64_t stride_x = src1->nb[0] / es; + const int64_t stride_y = src1->nb[1] / es; + const int64_t stride_z = src1->nb[2] / es; + const int64_t stride_q = src1->nb[3] / es; + + if(dst->type == GGML_TYPE_F16) { + im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); + } else { + im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/im2col.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/im2col.cuh index 1ce8fae4..2da1223d 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/im2col.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/im2col.cuh @@ -3,3 +3,4 @@ #define CUDA_IM2COL_BLOCK_SIZE 256 void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mean.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mean.cu index 4b238a39..347abc18 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mean.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mean.cu @@ -1,4 +1,14 @@ #include "mean.cuh" +#include "reduce_rows.cuh" + +#ifdef GGML_CUDA_USE_CUB +#include +using namespace cub; +#endif // GGML_CUDA_USE_CUB + +template __global__ void divide_by_count(T * result, size_t count) { + *result /= static_cast(count); +} void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; @@ -13,7 +23,51 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ncols = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); - const dim3 block_dims(WARP_SIZE, 1, 1); +// Special case for reducing vectors +#ifdef GGML_CUDA_USE_CUB +#ifdef USE_CUDA_GRAPH + cudaStreamCaptureStatus iscapturing; + CUDA_CHECK(cudaStreamIsCapturing(stream, &iscapturing)); +#endif // USE_CUDA_GRAPH + if ((nrows == 1) && +#ifdef USE_CUDA_GRAPH + // CUDA_GRAPHS_DISABLED + ((ncols > 65536) && + ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || + ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates || + ctx.cuda_graph->disable_due_to_failed_graph_capture)) || + // CUDA_GRAPHS ENABLED + ((ncols > 32768) && + !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || + ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates || + ctx.cuda_graph->disable_due_to_failed_graph_capture))) { +#else + (ncols > 65536)) { +#endif // USE_CUDA_GRAPH + // Single row - use device-wide reduction + size_t tmp_size = 0; + ggml_cuda_pool & pool = ctx.pool(); + + DeviceReduce::Sum(nullptr, tmp_size, src0_d, dst_d, ncols, stream); + + ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size); + DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, src0_d, dst_d, ncols, stream); + + // Divide by ncols + divide_by_count<<<1, 1, 0, stream>>>(dst_d, ncols); + return; + } +#endif // GGML_CUDA_USE_CUB + const dim3 block_nums(nrows, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + + const int id = ggml_cuda_get_device(); + const int nsm = ggml_cuda_info().devices[id].nsm; + if ((nrows / nsm) < 2) { + const dim3 block_dims(512, 1, 1); + reduce_rows_f32<<>>(src0_d, dst_d, ncols); + } else { + const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); + reduce_rows_f32<<>>(src0_d, dst_d, ncols); + } } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh index 83ee16b2..c1f24243 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh @@ -1,3 +1,4 @@ +#pragma once // This file contains primitives that expose the tensor core PTX instructions for CUDA code. // The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. // The documentation for the PTX instructions can be found under: @@ -291,9 +292,7 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); #else - GGML_UNUSED(t); - GGML_UNUSED(xs0); - GGML_UNUSED(stride); + GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -315,9 +314,7 @@ namespace ggml_cuda_mma { : "r"(A.x[1]), "r"(B.x[0])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -345,9 +342,7 @@ namespace ggml_cuda_mma { : "r"(A.x[3]), "r"(B.x[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -372,9 +367,7 @@ namespace ggml_cuda_mma { : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -408,9 +401,7 @@ namespace ggml_cuda_mma { : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -425,9 +416,7 @@ namespace ggml_cuda_mma { : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // AMPERE_MMA_AVAILABLE } @@ -452,9 +441,7 @@ namespace ggml_cuda_mma { : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -469,9 +456,7 @@ namespace ggml_cuda_mma { : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // AMPERE_MMA_AVAILABLE } @@ -505,9 +490,7 @@ namespace ggml_cuda_mma { : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -533,9 +516,7 @@ namespace ggml_cuda_mma { 0, 0, 0); #endif // defined(CDNA3) #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // AMD_MFMA_AVAILABLE } @@ -561,9 +542,7 @@ namespace ggml_cuda_mma { 0, 0, 0); #endif // defined(CDNA3) #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // AMD_MFMA_AVAILABLE } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu index 1437367e..599e085e 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu @@ -1,344 +1,12 @@ #include "ggml.h" -#include "common.cuh" -#include "mma.cuh" #include "mmf.cuh" -using namespace ggml_cuda_mma; - -#define MMF_ROWS_PER_BLOCK 32 - -template -__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) -static __global__ void mul_mat_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - typedef tile<16, 8, T> tile_A; - typedef tile< 8, 8, T> tile_B; - typedef tile<16, 8, float> tile_C; - - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - constexpr int tile_k_padded = warp_size + 4; - constexpr int ntA = rows_per_block / tile_A::I; - constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; - - const int row0 = blockIdx.x * rows_per_block; - const int channel_dst = blockIdx.y; - const int channel_x = channel_dst / channel_ratio; - const int channel_y = channel_dst; - const int sample_dst = blockIdx.z; - const int sample_x = sample_dst / sample_ratio; - const int sample_y = sample_dst; - - x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ; - y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; - dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; - - const float2 * y2 = (const float2 *) y; - - extern __shared__ char data_mmv[]; - - tile_C C[ntA][ntB]; - - T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded); - - for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { - tile_A A[ntA][warp_size / tile_A::J]; -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { -#pragma unroll - for (int i = 0; i < tile_A::I; ++i) { - tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; - } -#pragma unroll - for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { - load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); - } - } - -#pragma unroll - for (int itB = 0; itB < ntB; ++itB) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int j0 = 0; j0 < tile_B::I; ++j0) { - const int j = j0 + itB*tile_B::I; - - tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; - } - } else if constexpr (std::is_same_v || std::is_same_v) { -#pragma unroll - for (int j0 = 0; j0 < tile_B::I; ++j0) { - const int j = j0 + itB*tile_B::I; - - const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); - tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; - } - } else { - static_assert(std::is_same_v, "unsupported type"); - } -#pragma unroll - for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { - tile_B B; - load_ldmatrix(B, tile_xy + k0, tile_k_padded); -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { - mma(C[itA][itB], A[itA][k0/tile_B::J], B); - } - } - } - } - - float * buf_iw = (float *) data_mmv; - constexpr int kiw = nwarps*rows_per_block + 4; - - if (nwarps > 1) { - __syncthreads(); - } -#pragma unroll - for (int itB = 0; itB < ntB; ++itB) { -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); - const int j = itB*tile_C::J + tile_C::get_j(l); - buf_iw[j*kiw + i] = C[itA][itB].x[l]; - } - } - } - - if (nwarps > 1) { - __syncthreads(); - } - -#pragma unroll - for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (j0 + nwarps > cols_per_block && j >= cols_per_block) { - return; - } - - float sum = 0.0f; - static_assert(rows_per_block == warp_size, "need loop/check"); -#pragma unroll - for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { - const int i = i0 + threadIdx.x; - - sum += buf_iw[j*kiw + i]; - } - dst[j*stride_col_dst + row0 + threadIdx.x] = sum; - } -#else - NO_DEVICE_CODE; - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(ids); GGML_UNUSED(dst); - GGML_UNUSED(ncols); GGML_UNUSED(nchannels_y); GGML_UNUSED(stride_row); GGML_UNUSED(stride_col_y); GGML_UNUSED(stride_col_dst); - GGML_UNUSED(channel_ratio); GGML_UNUSED(stride_channel_x); GGML_UNUSED(stride_channel_y); GGML_UNUSED(stride_channel_dst); - GGML_UNUSED(sample_ratio); GGML_UNUSED(stride_sample_x); GGML_UNUSED(stride_sample_y); GGML_UNUSED(stride_sample_dst); -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) -} - -template -static void mul_mat_f_cuda( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols_x, const int64_t nrows_x, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - typedef tile<16, 8, T> tile_A; - typedef tile< 8, 8, T> tile_B; - typedef tile<16, 8, float> tile_C; - - GGML_ASSERT(!ids && "mul_mat_id not implemented"); - - GGML_ASSERT(ncols_x % 2 == 0); - GGML_ASSERT(stride_row % 2 == 0); - GGML_ASSERT(stride_col_y % 2 == 0); - GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); - GGML_ASSERT( nsamples_dst % nsamples_x == 0); - const int64_t channel_ratio = nchannels_dst / nchannels_x; - const int64_t sample_ratio = nsamples_dst / nsamples_x; - - const int device = ggml_cuda_get_device(); - const int warp_size = ggml_cuda_info().devices[device].warp_size; - - int64_t nwarps_best = 1; - int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); - int64_t max_block_size = 256; - for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { - const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); - if (niter < niter_best) { - niter_best = niter; - nwarps_best = nwarps; - } - } - - constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; - const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; - const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; - const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); - const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst); - const dim3 block_dims(warp_size, nwarps_best, 1); - switch (nwarps_best) { - case 1: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 2: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 3: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 4: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 5: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 6: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 7: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 8: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - default: { - GGML_ABORT("fatal error"); - } break; - } -} - -template -static void mul_mat_f_switch_cols_per_block( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - switch (ncols_dst) { - case 1: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 2: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 3: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 4: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 5: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 6: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 7: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 8: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 9: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 10: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 11: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 12: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 13: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 14: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 15: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 16: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - default: { - GGML_ABORT("fatal error"); - } break; - } -} - void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_TENSOR_BINARY_OP_LOCALS; const size_t ts_src0 = ggml_type_size(src0->type); @@ -352,9 +20,6 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); GGML_ASSERT( nb0 == ts_dst); - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; - const float * src1_d = (const float *) src1->data; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; float * dst_d = (float *) dst->data; @@ -369,55 +34,82 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr const int64_t s13 = src1->nb[3] / ts_src1; const int64_t s3 = dst->nb[3] / ts_dst; + const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0; + const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: const int64_t ncols_dst = ids ? ne2 : ne1; - const int64_t nchannels_y = ids ? ne11 : ne12; - const int64_t nchannels_dst = ids ? ne1 : ne2; - const int64_t stride_channel_dst = ids ? s1 : s2; - const int64_t stride_channel_y = ids ? s11 : s12; + const int64_t nchannels_dst = ids ? ne1 : ne2; - GGML_ASSERT(!ids || ncols_dst == 1); + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; + const int64_t stride_channel_dst = ids ? s1 : s2; + + int64_t stride_channel_y = ids ? s11 : s12; + int64_t nchannels_y = ids ? ne11 : ne12; + + //mul_mat_id: handle broadcast + if (ids && nchannels_y == 1) { + stride_channel_y = 0; + nchannels_y = ids->ne[0]; + } switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; constexpr int vals_per_T = 1; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; constexpr int vals_per_T = 2; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; constexpr int vals_per_T = 2; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); } } -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) { +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols, bool mul_mat_id) { + + if (ggml_is_quantized(type)) { + return false; + } + if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) { return false; } if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { return false; } - if (ne11 > 16) { - return false; + + if (mul_mat_id) { + if (type == GGML_TYPE_F32 && src1_ncols > 32) { + return false; + } + if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) { + return false; + } + } else { + if (src1_ncols > 16) { + return false; + } } + switch (type) { case GGML_TYPE_F32: return ampere_mma_available(cc); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh index 785f9f21..a6c3adfc 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh @@ -1,5 +1,496 @@ +#pragma once + +#include "mma.cuh" #include "common.cuh" +using namespace ggml_cuda_mma; + +#define MMF_ROWS_PER_BLOCK 32 + void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11); +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id); + +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f( + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int stride_col_id, const int stride_row_id, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int tile_k_padded = warp_size + 4; + constexpr int ntA = rows_per_block / tile_A::I; + constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; + + const int row0 = blockIdx.x * rows_per_block; + + int expert_idx = 0; + int col_base = 0; + + const int channel_dst = has_ids ? 0 : blockIdx.y; + + if constexpr (has_ids) { + // experts + tiles of ncols_dst are packed in the y dimension + int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block; + const int nchannels_x = gridDim.y / col_tiles; + const int tile_idx = blockIdx.y / nchannels_x; + expert_idx = blockIdx.y - tile_idx * nchannels_x; + col_base = tile_idx * cols_per_block; + } + + const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio); + const int channel_y = channel_dst; + const int sample_dst = blockIdx.z; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ; + y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y); + dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst); + + if constexpr (has_ids) { + constexpr int y_stride_scale = std::is_same_v ? 1 : 2; + const int64_t col_offset = col_base; + y += col_offset * stride_col_y * y_stride_scale; + dst += col_offset * stride_col_dst; + ids += col_offset * stride_row_id; + } + + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + + char * shmem_base = data_mmv; + int * slot_map = (int *) shmem_base; + char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base; + + tile_C C[ntA][ntB]; + + T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); + + if constexpr (has_ids) { + int found = 0; + + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (threadIdx.x == 0) { + slot_map[j] = -1; + } + + if (col_base + j >= ncols_dst_total) { + continue; + } + + const int32_t * __restrict__ id_row = ids + j*stride_row_id; + + for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) { + int match = id_row[k*stride_col_id] == expert_idx; + + if (match) { + slot_map[j] = k; + found = 1; + break; + } + } + } + + if (!__syncthreads_or(found)) { + return; + } + } + + + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { + tile_A A[ntA][warp_size / tile_A::J]; +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int i = 0; i < tile_A::I; ++i) { + tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { + load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); + } + } + +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + if constexpr (!has_ids) { + tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; + } else { + const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; + tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f; + } + } + } else if constexpr (std::is_same_v || std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + if constexpr (!has_ids) { + const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } else { + const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; + float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f); + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + } + } + + float * buf_iw = (float *) compute_base; + constexpr int kiw = nwarps*rows_per_block + 4; + + if (nwarps > 1) { + __syncthreads(); + } +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); + const int j = itB*tile_C::J + tile_C::get_j(l); + buf_iw[j*kiw + i] = C[itA][itB].x[l]; + } + } + } + + if (nwarps > 1) { + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > cols_per_block && j >= cols_per_block) { + return; + } + + float sum = 0.0f; + static_assert(rows_per_block == warp_size, "need loop/check"); +#pragma unroll + for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { + const int i = i0 + threadIdx.x; + + sum += buf_iw[j*kiw + i]; + } + + if constexpr (!has_ids) { + dst[j*stride_col_dst + row0 + threadIdx.x] = sum; + } else { + const int slot = (j < cols_per_block) ? slot_map[j] : -1; + if (slot >= 0 && (col_base + j) < ncols_dst_total) { + dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum; + } + } + } +#else + GGML_UNUSED_VARS(x, y, ids, dst, + ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + NO_DEVICE_CODE; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +} + +template +static inline void mul_mat_f_switch_ids( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int64_t stride_row_id, + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { + if (ids) { + const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block; + dim3 block_nums_ids = block_nums; + block_nums_ids.y *= col_tiles; + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } else { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } +} + +template +void mul_mat_f_cuda( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int64_t stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + + GGML_ASSERT(ncols_x % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(stride_col_y % 2 == 0); + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const int64_t channel_ratio = nchannels_dst / nchannels_x; + const int64_t sample_ratio = nsamples_dst / nsamples_x; + + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t nwarps_best = 1; + int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); + int64_t max_block_size = 256; + for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { + const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); + if (niter < niter_best) { + niter_best = niter; + nwarps_best = nwarps; + } + } + + constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; + const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; + const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); + const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; + const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; + const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present + + const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst); + const dim3 block_dims(warp_size, nwarps_best, 1); + + switch (nwarps_best) { + case 1: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 2: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 3: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 4: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 5: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 6: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 7: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 8: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } + + GGML_UNUSED_VARS(nchannels_y); +} + +template +static void mul_mat_f_switch_cols_per_block( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + + const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst; + + GGML_ASSERT(ids || ncols_dst <= 16); + + switch (ncols_case) { + case 1: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 2: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 3: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 4: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 5: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 6: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 7: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 8: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 9: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 10: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 11: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 12: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 13: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 14: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 15: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 16: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +#define DECL_MMF_CASE_HELPER(T, ncols_dst) \ + template void mul_mat_f_cuda( \ + const T * x, const float * y, const int32_t * ids, float * dst, \ + const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ + const int64_t stride_col_id, const int64_t stride_row_id, \ + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\ + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ + cudaStream_t stream); + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#define DECL_MMF_CASE_EXTERN(ncols_dst) \ + extern DECL_MMF_CASE_HELPER(float, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + +#define DECL_MMF_CASE(ncols_dst) \ + DECL_MMF_CASE_HELPER(float, ncols_dst) \ + DECL_MMF_CASE_HELPER(half2, ncols_dst) \ + DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + +DECL_MMF_CASE_EXTERN(1); +DECL_MMF_CASE_EXTERN(2); +DECL_MMF_CASE_EXTERN(3); +DECL_MMF_CASE_EXTERN(4); +DECL_MMF_CASE_EXTERN(5); +DECL_MMF_CASE_EXTERN(6); +DECL_MMF_CASE_EXTERN(7); +DECL_MMF_CASE_EXTERN(8); +DECL_MMF_CASE_EXTERN(9); +DECL_MMF_CASE_EXTERN(10); +DECL_MMF_CASE_EXTERN(11); +DECL_MMF_CASE_EXTERN(12); +DECL_MMF_CASE_EXTERN(13); +DECL_MMF_CASE_EXTERN(14); +DECL_MMF_CASE_EXTERN(15); +DECL_MMF_CASE_EXTERN(16); +#else +#define DECL_MMF_CASE(ncols_dst) +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu index 384ee761..12bdc629 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu @@ -3,6 +3,140 @@ #include +// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. +struct mmq_ids_helper_store { + uint32_t data; + + __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) { + data = (it & 0x003FFFFF) | (iex_used << 22); + } + + __device__ uint32_t it() const { + return data & 0x003FFFFF; + } + + __device__ uint32_t iex_used() const { + return data >> 22; + } +}; +static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store"); + +// Helper function for mul_mat_id, converts ids to a more convenient format. +// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. +// ids_dst describes the same mapping but for the dst tensor. +// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. +template +__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mmq_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template; + const int expert = blockIdx.x; + + extern __shared__ char data_mmq_ids_helper[]; + mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper; + + int nex_prev = 0; // Number of columns for experts with a lower index. + int it_compact = 0; // Running index for the compact slice of this expert. + + if constexpr (n_expert_used_template == 0) { + // Generic implementation: + for (int it = 0; it < n_tokens; ++it) { + int iex_used = -1; // The index at which the expert is used, if any. + for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) { + const int expert_used = ids[it*si1 + iex]; + nex_prev += expert_used < expert; + if (expert_used == expert) { + iex_used = iex; + } + } + + if (iex_used != -1) { + store[it_compact] = mmq_ids_helper_store(it, iex_used); + } + + if (warp_reduce_any(iex_used != -1)) { + it_compact++; + } + } + } else { + // Implementation optimized for specific numbers of experts used: + static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used"); + const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2. + for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) { + const int it = it0 + threadIdx.x / neu_padded; + + const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any. + const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ? + ids[it*si1 + iex] : INT_MAX; + const int iex_used = expert_used == expert ? iex : -1; + nex_prev += expert_used < expert; + + // Whether the threads at this token position have used the expert: + const int it_compact_add_self = warp_reduce_any(iex_used != -1); + + // Do a scan over threads at lower token positions in warp to get the correct index for writing data: + int it_compact_add_lower = 0; +#pragma unroll + for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { + const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); + if (threadIdx.x >= static_cast(offset)) { + it_compact_add_lower += tmp; + } + } + + if (iex_used != -1) { + store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used); + } + + // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads: + it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size); + } + } + nex_prev = warp_reduce_sum(nex_prev); + + for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) { + const mmq_ids_helper_store store_it = store[itc]; + const int it = store_it.it(); + const int iex_used = store_it.iex_used(); + ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y; + ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; + } + + if (threadIdx.x != 0) { + return; + } + + expert_bounds[expert] = nex_prev; + + if (expert < static_cast(gridDim.x) - 1) { + return; + } + + expert_bounds[gridDim.x] = nex_prev + it_compact; +} + +template +static void launch_mmq_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store"); + GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store"); + + const int id = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper, smpbo); + + const dim3 num_blocks(n_experts, 1, 1); + const dim3 block_size(warp_size, 1, 1); + const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store); + GGML_ASSERT(nbytes_shared <= smpbo); + mmq_ids_helper<<>> + (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); +} + static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { case GGML_TYPE_Q4_0: @@ -137,7 +271,7 @@ void ggml_cuda_mul_mat_q( ne00, ne01, ne1, s01, ne11, s1, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, - use_stream_k}; + use_stream_k, ne1}; ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); return; } @@ -148,54 +282,50 @@ void ggml_cuda_mul_mat_q( const int64_t n_expert_used = ids->ne[0]; const int64_t ne_get_rows = ne12 * n_expert_used; + GGML_ASSERT(ne1 == n_expert_used); - std::vector ids_host(ggml_nbytes(ids)); - std::vector ids_src1_host; - ids_src1_host.reserve(ne_get_rows); - std::vector ids_dst_host; - ids_dst_host.reserve(ne_get_rows); - std::vector tokens_per_expert_host(ne02); - std::vector expert_bounds_host(ne02 + 1); - ggml_cuda_pool_alloc ids_buf_dev(ctx.pool()); + ggml_cuda_pool_alloc ids_src1(ctx.pool(), ne_get_rows); + ggml_cuda_pool_alloc ids_dst(ctx.pool(), ne_get_rows); + ggml_cuda_pool_alloc expert_bounds(ctx.pool(), ne02 + 1); - CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); + { + GGML_ASSERT(ids->nb[0] == ggml_element_size(ids)); + const int si1 = ids->nb[1] / ggml_element_size(ids); + const int sis1 = nb12 / nb11; - for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices - for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens - for (int64_t iex = 0; iex < n_expert_used; ++iex) { - const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]); - assert(expert_to_use >= 0 && expert_to_use < ne02); - if (expert_to_use == i02) { - ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11); - ids_dst_host.push_back(i12*ne1 + iex); - tokens_per_expert_host[i02]++; - break; - } - } + switch (n_expert_used) { + case 2: + launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 4: + launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 6: + launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 8: + launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 16: + launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 32: + launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + default: + launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; } + CUDA_CHECK(cudaGetLastError()); } - int32_t cumsum = 0; - for (int64_t i = 0; i < ne02; ++i) { - expert_bounds_host[i] = cumsum; - cumsum += tokens_per_expert_host[i]; - } - expert_bounds_host[ne02] = cumsum; - - std::vector ids_buf_host; - ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size()); - ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end()); - ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end()); - ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end()); - ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device. - CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); - - const int32_t * ids_src1_dev = ids_buf_dev.ptr; - const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size(); - const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size(); - const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); ggml_cuda_pool_alloc src1_q8_1(ctx.pool(), nbytes_src1_q8_1); @@ -208,7 +338,7 @@ void ggml_cuda_mul_mat_q( const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[2] / ts_src1; - quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type, + quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); CUDA_CHECK(cudaGetLastError()); } @@ -218,11 +348,11 @@ void ggml_cuda_mul_mat_q( // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. const mmq_args args = { - src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d, + src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d, ne00, ne01, ne_get_rows, s01, ne_get_rows, s1, ne02, ne02, s02, s12, s2, ne03, ne13, s03, s13, s3, - use_stream_k}; + use_stream_k, ne12}; ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); } @@ -262,14 +392,11 @@ void ggml_cuda_op_mul_mat_q( ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, - use_stream_k}; + use_stream_k, src1_ncols}; ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_ddf_i); - GGML_UNUSED(src1_padded_row_size); + GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_padded_row_size); } bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh index 96129bd8..c9a07e82 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh @@ -1255,7 +1255,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } } #else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); + GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; #endif // AMD_MFMA_AVAILABLE } @@ -1572,7 +1572,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } } #else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); + GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; #endif // AMD_MFMA_AVAILABLE } @@ -2301,7 +2301,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } } #else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); + GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; #endif // AMD_MFMA_AVAILABLE } @@ -2855,12 +2855,14 @@ static __device__ __forceinline__ void mmq_write_back_mma( #else typedef tile<16, 8, int> tile_C; constexpr int rows_per_warp = 2 * granularity; -#endif +#endif // defined(AMD_MFMA_AVAILABLE) constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); #if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); +#else + GGML_UNUSED(nwarps); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) #pragma unroll @@ -3136,7 +3138,8 @@ static __global__ void mul_mat_q( const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ncols_max) { // Skip unused template specializations for faster compilation: if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { @@ -3150,7 +3153,7 @@ static __global__ void mul_mat_q( constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); - const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x + const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y // Initialize the ids for writing back data with just the index. @@ -3374,7 +3377,8 @@ template static __global__ void mul_mat_q_stream_k_fixup( const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, - const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) { + const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst, + const int ncols_max) { constexpr int mmq_y = get_mmq_y_device(); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int blocks_per_iter = MMQ_ITER_K / qk; @@ -3385,7 +3389,7 @@ static __global__ void mul_mat_q_stream_k_fixup( float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; - const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; + const int ntx = (ncols_max + mmq_x - 1) / mmq_x; const int nty = (nrows_x + mmq_y - 1) / mmq_y; const int bidx0 = blockIdx.x; @@ -3526,7 +3530,7 @@ struct mmq_args { int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst; int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst; int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst; - bool use_stream_k; + bool use_stream_k; int64_t ncols_max; }; template @@ -3556,7 +3560,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; - const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x; + const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x; const int ntzw = args.nchannels_y * args.nsamples_y; const dim3 block_nums_xy_tiling(nty, ntx, ntzw); @@ -3572,14 +3576,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); } else { constexpr bool need_check = true; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); } return; } @@ -3599,7 +3605,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); if (!fixup_needed) { return; @@ -3607,14 +3614,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a mul_mat_q_stream_k_fixup<<>> (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); } else { constexpr bool need_check = true; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); if (!fixup_needed) { return; @@ -3622,7 +3631,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a mul_mat_q_stream_k_fixup<<>> (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); } } @@ -3647,7 +3657,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda continue; } - const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x; + const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x; if (ntiles_x < ntiles_x_best) { mmq_x_best = mmq_x; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu index 1ad4bc75..5b21ef05 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu @@ -1,5 +1,6 @@ #include "ggml.h" #include "common.cuh" +#include "convert.cuh" #include "mmvf.cuh" template @@ -93,8 +94,8 @@ static __global__ void mul_mat_vec_f( #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += float(reinterpret_cast(&tmpx)[0]) * tmpy.x; - sumf[j] += float(reinterpret_cast(&tmpx)[1]) * tmpy.y; + sumf[j] += ggml_cuda_cast(reinterpret_cast(&tmpx)[0]) * tmpy.x; + sumf[j] += ggml_cuda_cast(reinterpret_cast(&tmpx)[1]) * tmpy.y; } } } else { @@ -432,12 +433,7 @@ void ggml_cuda_op_mul_mat_vec_f( GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); } - GGML_UNUSED(ctx); - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_ddq_i); - GGML_UNUSED(src1_ncols); - GGML_UNUSED(src1_padded_row_size); + GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size); } bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu index 5c8e5c4a..3bf0c9ed 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu @@ -141,9 +141,10 @@ template __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, + const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, + const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qi = ggml_cuda_type_traits::qi; @@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q( constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. - const int channel_dst = blockIdx.y; - const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio; - const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst; - const int sample_dst = blockIdx.z; - const int sample_x = sample_dst / sample_ratio; - const int sample_y = sample_dst; + const uint32_t channel_dst = blockIdx.y; + const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + const uint32_t sample_dst = blockIdx.z; + const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); + const uint32_t sample_y = sample_dst; // partial sum for each thread float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; @@ -219,7 +220,7 @@ static __global__ void mul_mat_vec_q( tmp[j][i] = warp_reduce_sum(tmp[j][i]); } - if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + int(threadIdx.x) < stride_col_dst)) { + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x]; } } @@ -247,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst( GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); - const int channel_ratio = nchannels_dst / nchannels_x; - const int sample_ratio = nsamples_dst / nsamples_x; + const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const int device = ggml_cuda_get_device(); const int warp_size = ggml_cuda_info().devices[device].warp_size; @@ -256,86 +258,70 @@ static void mul_mat_vec_q_switch_ncols_dst( GGML_ASSERT(!ids || ncols_dst == 1); switch (ncols_dst) { - case 1: - { + case 1: { constexpr int c_ncols_dst = 1; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 2: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 2: { constexpr int c_ncols_dst = 2; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 3: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 3: { constexpr int c_ncols_dst = 3; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 4: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 4: { constexpr int c_ncols_dst = 4; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 5: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 5: { constexpr int c_ncols_dst = 5; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 6: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 6: { constexpr int c_ncols_dst = 6; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 7: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 7: { constexpr int c_ncols_dst = 7; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 8: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 8: { constexpr int c_ncols_dst = 8; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; default: GGML_ABORT("fatal error"); break; @@ -596,9 +582,5 @@ void ggml_cuda_op_mul_mat_vec_q( src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream); - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_ddf_i); - GGML_UNUSED(src1_ncols); - GGML_UNUSED(src1_padded_row_size); + GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size); } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu index bddcca51..4f153c57 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu @@ -104,12 +104,30 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } } -template -static __global__ void rms_norm_f32( - const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, - const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0, - const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0, - const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) { +template +static __global__ void rms_norm_f32(const float * x, + float * dst, + const int ncols, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const float eps, + const float * mul = nullptr, + const int64_t mul_stride_row = 0, + const int64_t mul_stride_channel = 0, + const int64_t mul_stride_sample = 0, + const uint3 mul_ncols_packed = make_uint3(0, 0, 0), + const uint3 mul_nrows_packed = make_uint3(0, 0, 0), + const uint3 mul_nchannels_packed = make_uint3(0, 0, 0), + const uint3 mul_nsamples_packed = make_uint3(0, 0, 0), + const float * add = nullptr, + const int64_t add_stride_row = 0, + const int64_t add_stride_channel = 0, + const int64_t add_stride_sample = 0, + const uint3 add_ncols_packed = make_uint3(0, 0, 0), + const uint3 add_nrows_packed = make_uint3(0, 0, 0), + const uint3 add_nchannels_packed = make_uint3(0, 0, 0), + const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) { const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -118,14 +136,23 @@ static __global__ void rms_norm_f32( const int sample = blockIdx.z; const int tid = threadIdx.x; + static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying"); + x += sample*stride_sample + channel*stride_channel + row*stride_row; dst += ((sample*nchannels + channel)*nrows + row)*ncols; if constexpr (do_multiply) { - const int mul_row = row % mul_nrows; - const int mul_channel = channel % mul_nchannels; - const int mul_sample = sample % mul_nsamples; - mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row; + const uint32_t mul_row = fastmodulo(row, mul_nrows_packed); + const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed); + const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed); + mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row; + } + + if constexpr (do_add) { + const int add_row = fastmodulo(row, add_nrows_packed); + const int add_channel = fastmodulo(channel, add_nchannels_packed); + const int add_sample = fastmodulo(sample, add_nsamples_packed); + add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row; } float tmp = 0.0f; // partial sum for thread in warp @@ -138,15 +165,18 @@ static __global__ void rms_norm_f32( // sum up partial sums tmp = warp_reduce_sum(tmp); if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); + static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size"); __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } __syncthreads(); - tmp = s_sum[lane_id]; + tmp = 0.0f; + if (lane_id < (block_size / WARP_SIZE)) { + tmp = s_sum[lane_id]; + } tmp = warp_reduce_sum(tmp); } @@ -154,9 +184,13 @@ static __global__ void rms_norm_f32( const float scale = rsqrtf(mean + eps); for (int col = tid; col < ncols; col += block_size) { - if constexpr (do_multiply) { - const int mul_col = col % mul_ncols; - dst[col] = scale * x[col] * mul[mul_col]; + if constexpr (do_multiply && do_add) { + const int mul_col = fastmodulo(col, mul_ncols_packed); + const int add_col = fastmodulo(col, add_ncols_packed); + dst[col] = scale * x[col] * mul[mul_col] + add[add_col]; + } else if constexpr (do_multiply) { + const int mul_col = fastmodulo(col, mul_ncols_packed); + dst[col] = scale * x[col] * mul[mul_col]; } else { dst[col] = scale * x[col]; } @@ -323,31 +357,87 @@ static void rms_norm_f32_cuda( const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); rms_norm_f32<1024, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } -static void rms_norm_mul_f32_cuda( - const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, - const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, - const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample, - const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples, - const float eps, cudaStream_t stream) { +static void rms_norm_mul_f32_cuda(const float * x, + const float * mul, + const float * add, + float * dst, + const int ncols, + const int nrows, + const int nchannels, + const int nsamples, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const int64_t mul_stride_row, + const int64_t mul_stride_channel, + const int64_t mul_stride_sample, + const uint32_t mul_ncols, + const uint32_t mul_nrows, + const uint32_t mul_nchannels, + const uint32_t mul_nsamples, + const int64_t add_stride_row, + const int64_t add_stride_channel, + const int64_t add_stride_sample, + const uint32_t add_ncols, + const uint32_t add_nrows, + const uint32_t add_nchannels, + const uint32_t add_nsamples, + const float eps, + cudaStream_t stream) { const dim3 blocks_num(nrows, nchannels, nsamples); if (mul == nullptr) { rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream); return; } - if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); + if (add == nullptr) { + const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); + const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); + const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels); + const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); + if (ncols < 1024) { + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32<1024, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + } } else { - const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); + const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); + const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); + const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels); + const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); + + const uint3 add_ncols_packed = init_fastdiv_values(add_ncols); + const uint3 add_nrows_packed = init_fastdiv_values(add_nrows); + const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels); + const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); + if (ncols < 1024) { + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, true, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, + add_nchannels_packed, add_nsamples_packed); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32<1024, true, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, + add_nchannels_packed, add_nsamples_packed); + } } } @@ -491,7 +581,102 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * const int mul_nchannels = mul_src->ne[2]; const int mul_nsamples = mul_src->ne[3]; - rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream); + rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d, + ne00, ne01, ne02, ne03, + /*s00*/ s01, s02, s03, + /*mul_s00*/ mul_s01, mul_s02, mul_s03, + mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, + /*add_s00*/ 0, 0, 0, + 0, 0, 0, 0, + eps, stream); +} + +void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx, + ggml_tensor * dst, + ggml_tensor * mul_tensor, + ggml_tensor * add_tensor) { + const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0]; + float eps = 0.0f; + + memcpy(&eps, dst->op_params, sizeof(float)); + + const float * src0_d = (const float *) rms_norm_src->data; + const float * mul_d = nullptr; + const ggml_tensor * mul_src = nullptr; + + if (mul_tensor->src[0] == dst) { + mul_d = (float *) mul_tensor->src[1]->data; + mul_src = mul_tensor->src[1]; + } else if (mul_tensor->src[1] == dst) { + mul_d = (float *) mul_tensor->src[0]->data; + mul_src = mul_tensor->src[0]; + } else { + GGML_ASSERT(false); + } + + const float * add_d = nullptr; + const ggml_tensor * add_src = nullptr; + + if (add_tensor->src[0] == mul_tensor) { + add_d = (float *) add_tensor->src[1]->data; + add_src = add_tensor->src[1]; + } else if (add_tensor->src[1] == mul_tensor) { + add_d = (float *) add_tensor->src[0]->data; + add_src = add_tensor->src[0]; + } else { + GGML_ASSERT(false); + } + + float * dst_d = (float *) add_tensor->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32); + GGML_ASSERT(add_tensor->type == GGML_TYPE_F32); + GGML_ASSERT(eps >= 0.0f); + + const int64_t ne00 = rms_norm_src->ne[0]; + const int64_t ne01 = rms_norm_src->ne[1]; + const int64_t ne02 = rms_norm_src->ne[2]; + const int64_t ne03 = rms_norm_src->ne[3]; + + const size_t ts0 = ggml_type_size(rms_norm_src->type); + GGML_ASSERT(rms_norm_src->nb[0] == ts0); + const int64_t s01 = rms_norm_src->nb[1] / ts0; + const int64_t s02 = rms_norm_src->nb[2] / ts0; + const int64_t s03 = rms_norm_src->nb[3] / ts0; + + const size_t ts_mul = ggml_type_size(mul_src->type); + GGML_ASSERT(mul_src->nb[0] == ts_mul); + const int64_t mul_s01 = mul_src->nb[1] / ts_mul; + const int64_t mul_s02 = mul_src->nb[2] / ts_mul; + const int64_t mul_s03 = mul_src->nb[3] / ts_mul; + + const int mul_ncols = mul_src->ne[0]; + const int mul_nrows = mul_src->ne[1]; + const int mul_nchannels = mul_src->ne[2]; + const int mul_nsamples = mul_src->ne[3]; + + const size_t ts_add = ggml_type_size(add_src->type); + GGML_ASSERT(add_src->nb[0] == ts_add); + const int64_t add_s01 = add_src->nb[1] / ts_add; + const int64_t add_s02 = add_src->nb[2] / ts_add; + const int64_t add_s03 = add_src->nb[3] / ts_add; + + const int add_ncols = add_src->ne[0]; + const int add_nrows = add_src->ne[1]; + const int add_nchannels = add_src->ne[2]; + const int add_nsamples = add_src->ne[3]; + + rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d, + ne00,ne01, ne02, ne03, + /*s00*/ s01, s02, s03, + /*mul_s00*/ mul_s01, mul_s02, mul_s03, + mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, + /*add_s00*/ add_s01, add_s02, add_s03, + add_ncols, add_nrows, add_nchannels, add_nsamples, + eps, stream); } void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh index 7ea7bd4d..a74f6376 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh @@ -8,6 +8,11 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor); +void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx, + ggml_tensor * dst, + ggml_tensor * mul_tensor, + ggml_tensor * add_tensor); + void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/opt-step-sgd.cu b/ml/backend/ggml/ggml/src/ggml-cuda/opt-step-sgd.cu new file mode 100644 index 00000000..460b16de --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/opt-step-sgd.cu @@ -0,0 +1,49 @@ +#include "ggml-impl.h" +#include "opt-step-sgd.cuh" + +#include + +static __global__ void opt_step_sgd_f32( + float * __restrict__ x, const float * __restrict__ g, + const float * __restrict__ pars, const int64_t k) { + + const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x; + + if (i >= k) { + return; + } + x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i]; +} + +static void opt_step_sgd_f32_cuda( + float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) { + + const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1); + const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1); + opt_step_sgd_f32<<>>(x, g, pars, k); +} + +void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src0_grad = dst->src[1]; + const ggml_tensor * params = dst->src[2]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad->type == GGML_TYPE_F32); + GGML_ASSERT(params->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src0_grad)); + GGML_ASSERT(ggml_is_contiguous(params)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad)); + GGML_ASSERT(ggml_nelements(params) == 2); + + float * src0_d = (float *) src0->data; + const float * src0_grad_d = (const float *) src0_grad->data; + const float * params_d = (const float *) params->data; + + cudaStream_t stream = ctx.stream(); + + const int64_t ne = ggml_nelements(src0); + + opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream); +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/opt-step-sgd.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/opt-step-sgd.cuh new file mode 100644 index 00000000..f97ab7d9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/opt-step-sgd.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256 + +void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu b/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu index 77432b04..29aef33c 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu @@ -1,36 +1,50 @@ #include "pad.cuh" -static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) { - // blockIdx.z: idx of ne2*ne3, aka ne02*ne03 - // blockIdx.y: idx of ne1 - // blockIDx.x: idx of ne0 / BLOCK_SIZE - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { +static __global__ void pad_f32(const float * src, float * dst, + const int lp0, const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, const int rp3, + const int ne0, const int ne1, const int ne2, const int ne3) { + // blockIdx.z: i3*ne2+i2 + // blockIdx.y: i1 + // blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE + // gridDim.y: ne1 + int i0 = threadIdx.x + blockIdx.x * blockDim.x; + int i1 = blockIdx.y; + int i2 = blockIdx.z % ne2; + int i3 = blockIdx.z / ne2; + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { return; } // operation - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) { - int offset_src = - nidx + - blockIdx.y * ne00 + - blockIdx.z * ne00 * ne01; - dst[offset_dst] = x[offset_src]; + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + if ((i0 >= lp0 && i0 < ne0 - rp0) && + (i1 >= lp1 && i1 < ne1 - rp1) && + (i2 >= lp2 && i2 < ne2 - rp2) && + (i3 >= lp3 && i3 < ne3 - rp3)) { + const int64_t i00 = i0 - lp0; + const int64_t i01 = i1 - lp1; + const int64_t i02 = i2 - lp2; + const int64_t i03 = i3 - lp3; + const int64_t ne02 = ne2 - lp2 - rp2; + const int64_t ne01 = ne1 - lp1 - rp1; + const int64_t ne00 = ne0 - lp0 - rp0; + + const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00; + + dst[dst_idx] = src[src_idx]; } else { - dst[offset_dst] = 0.0f; + dst[dst_idx] = 0.0f; } } -static void pad_f32_cuda(const float * x, float * dst, - const int ne00, const int ne01, const int ne02, const int ne03, +static void pad_f32_cuda(const float * src, float * dst, + const int lp0, const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, const int rp3, const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; dim3 gridDim(num_blocks, ne1, ne2*ne3); - pad_f32<<>>(x, dst, ne0, ne00, ne01, ne02, ne03); + pad_f32<<>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3); } void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -41,9 +55,18 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int32_t lp0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t rp0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t lp1 = ((const int32_t*)(dst->op_params))[2]; + const int32_t rp1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t lp2 = ((const int32_t*)(dst->op_params))[4]; + const int32_t rp2 = ((const int32_t*)(dst->op_params))[5]; + const int32_t lp3 = ((const int32_t*)(dst->op_params))[6]; + const int32_t rp3 = ((const int32_t*)(dst->op_params))[7]; pad_f32_cuda(src0_d, dst_d, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/pad_reflect_1d.cu b/ml/backend/ggml/ggml/src/ggml-cuda/pad_reflect_1d.cu new file mode 100644 index 00000000..32993eb5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/pad_reflect_1d.cu @@ -0,0 +1,91 @@ +#include "pad_reflect_1d.cuh" + +static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void + pad_reflect_1d_kernel_f32( + const void * __restrict__ src0, + void * __restrict__ dst, + const int64_t ne0, + const int64_t ne00, + const uint3 ne01, + const int64_t ne02, + const int64_t ne03, + const int64_t nb00, + const int64_t nb01, + const int64_t nb02, + const int64_t nb03, + const int64_t nb0, + const int64_t nb1, + const int64_t nb2, + const int64_t nb3, + const int p0, + const int p1) { + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + + const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01); + const int64_t tile1 = div_mod_packed.y; // i1 + const int64_t tile0 = div_mod_packed.x; // nth i0 tile + const int64_t i1 = tile1; + const int64_t i0 = threadIdx.x + tile0 * blockDim.x; + + // ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh) + if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) { + return; + } + + const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01; + char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1; + + const int64_t rel_i0 = i0 - p0; // relative i0 in src0 + int64_t src_idx; + + if (rel_i0 < 0) { + // Left padding - reflect + src_idx = -rel_i0; + } else if (rel_i0 < ne00) { + // Middle - copy + src_idx = rel_i0; + } else { + // Right padding - reflect + src_idx = 2 * ne00 - 2 - rel_i0; + } + const float value = *(const float *) (src0_ptr + src_idx * nb00); + *(float *) (dst_ptr + i0 * nb0) = value; + + GGML_UNUSED(p1); +} + +void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int32_t * opts = (const int32_t *) dst->op_params; + const int p0 = opts[0]; + const int p1 = opts[1]; + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const uint3 ne01_packed = init_fastdiv_values(ne01); + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + + // sanity: padded length matches + GGML_ASSERT(ne0 == ne00 + p0 + p1); + + constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x) + const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0 + // grid.x covers i1 and all tiles of i0: [ne01 * tiles0] + // grid.y covers i2: [ne02] + // grid.z covers i3: [ne03] + const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03); + const dim3 block_dims((unsigned) bx, 1, 1); + + pad_reflect_1d_kernel_f32<<>>( + src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1); +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/pad_reflect_1d.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/pad_reflect_1d.cuh new file mode 100644 index 00000000..15f2ed17 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/pad_reflect_1d.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256 + +void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cu b/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cu index a0b03a74..5117f9ff 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cu @@ -1,26 +1,27 @@ #include "quantize.cuh" #include +__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1) static __global__ void quantize_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, - const int64_t ne0, const int ne1, const int ne2) { + const int64_t ne0, const uint32_t ne1, const uint3 ne2) { const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= ne0) { return; } + const int64_t i3 = fastdiv(blockIdx.z, ne2); + const int64_t i2 = blockIdx.z - i3*ne2.z; const int64_t i1 = blockIdx.y; - const int64_t i2 = blockIdx.z % ne2; - const int64_t i3 = blockIdx.z / ne2; const int64_t & i00 = i0; const int64_t & i01 = i1; const int64_t & i02 = i2; const int64_t & i03 = i3; - const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0; + const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0; block_q8_1 * y = (block_q8_1 *) vy; @@ -31,10 +32,10 @@ static __global__ void quantize_q8_1( float amax = fabsf(xi); float sum = xi; - amax = warp_reduce_max(amax); - sum = warp_reduce_sum(sum); + amax = warp_reduce_max(amax); + sum = warp_reduce_sum(sum); - const float d = amax / 127; + const float d = amax / 127.0f; const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); y[ib].qs[iqs] = q; @@ -43,8 +44,7 @@ static __global__ void quantize_q8_1( return; } - reinterpret_cast(y[ib].ds.x) = d; - reinterpret_cast(y[ib].ds.y) = sum; + y[ib].ds = make_half2(d, sum); } template @@ -152,10 +152,12 @@ void quantize_row_q8_1_cuda( GGML_ASSERT(!ids); GGML_ASSERT(ne0 % QK8_1 == 0); + const uint3 ne2_fastdiv = init_fastdiv_values(ne2); + const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); GGML_UNUSED(type_src0); } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/reduce_rows.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/reduce_rows.cuh new file mode 100644 index 00000000..6bcae9e5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/reduce_rows.cuh @@ -0,0 +1,53 @@ +#include "common.cuh" + +// Row reduction kernel template - compute sum (norm=false) or mean (norm=true) +template +static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) { + const int row = blockIdx.x; + const int col = threadIdx.x; + + float sum = 0.0f; + const int num_unroll = 8; + float temp[num_unroll]; + float sum_temp[num_unroll] = { 0.0f }; + for (int i = col; i < ncols;) { + for (int j = 0; j < num_unroll; ++j) { + if (i < ncols) { + temp[j] = x[row * ncols + i]; + } else { + temp[j] = 0; + } + i += blockDim.x; + } + for (int j = 0; j < num_unroll; ++j) { + sum_temp[j] += temp[j]; + } + } + for (int j = 0; j < num_unroll; ++j) { + sum += sum_temp[j]; + } + + // sum up partial sums + sum = warp_reduce_sum(sum); + if (blockDim.x > WARP_SIZE) { + assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); + __shared__ float s_sum[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = sum; + } + __syncthreads(); + sum = 0.0f; + if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { + sum = s_sum[lane_id]; + } + sum = warp_reduce_sum(sum); + } + + if (col != 0) { + return; + } + + dst[row] = norm ? sum / ncols : sum; +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/scale.cu b/ml/backend/ggml/ggml/src/ggml-cuda/scale.cu index 2ee9e588..0ddeff6a 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/scale.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/scale.cu @@ -1,18 +1,19 @@ #include "scale.cuh" -static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; +#define MAX_GRIDDIM_X 0x7FFFFFFF - if (i >= k) { - return; +static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) { + int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x; + int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x; + + for (int64_t i = tid; i < nelements; i += stride) { + dst[i] = scale * x[i] + bias; } - - dst[i] = scale * x[i] + bias; } -static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, bias, k); +static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) { + const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; + scale_f32<<>>(x, dst, scale, bias, nelements); } void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/set-rows.cu b/ml/backend/ggml/ggml/src/ggml-cuda/set-rows.cu index 07983436..1525a159 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/set-rows.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/set-rows.cu @@ -3,15 +3,10 @@ typedef void (*set_rows_kernel_t)(const char * src, char * dst); -template -__device__ __forceinline__ void set_rows_1(const src_t * src_f, dst_t * dst_f) { - convert_flt(src_f, dst_f); -} - // Generic quantized set_rows kernel template -template +template static __global__ void k_set_rows_quant( - const float * __restrict__ src0, const int64_t * __restrict__ src1, block_type * __restrict__ dst, + const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, const int64_t s01, const int64_t s02, const int64_t s03, @@ -50,9 +45,9 @@ static __global__ void k_set_rows_quant( } // Template dispatch function for quantized set_rows -template +template static void set_rows_cuda_quant( - const float * src0_d, const int64_t * src1_d, block_type * dst_d, + const float * src0_d, const idx_t * src1_d, block_type * dst_d, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, const size_t nb01, const size_t nb02, const size_t nb03, @@ -69,15 +64,15 @@ static void set_rows_cuda_quant( const int64_t s01 = nb01/sizeof(float); const int64_t s02 = nb02/sizeof(float); const int64_t s03 = nb03/sizeof(float); - const int64_t s10 = nb10/sizeof(int64_t); - const int64_t s11 = nb11/sizeof(int64_t); - const int64_t s12 = nb12/sizeof(int64_t); + const int64_t s10 = nb10/sizeof(idx_t); + const int64_t s11 = nb11/sizeof(idx_t); + const int64_t s12 = nb12/sizeof(idx_t); const int64_t s1 = nb1; const int64_t s2 = nb2; const int64_t s3 = nb3; if (ne_total > 0) { - k_set_rows_quant<<>>( + k_set_rows_quant<<>>( src0_d, src1_d, dst_d, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, @@ -87,9 +82,9 @@ static void set_rows_cuda_quant( } } -template +template static __global__ void k_set_rows( - const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst, + const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, const int64_t s01, const int64_t s02, const int64_t s03, @@ -117,17 +112,15 @@ static __global__ void k_set_rows( const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3; - const src_t* src_elem = src0_row + i00; - dst_t* dst_elem = dst_row_ptr + i00; - set_rows_1(src_elem, dst_elem); + dst_row_ptr[i00] = ggml_cuda_cast(src0_row[i00]); GGML_UNUSED(ne10); GGML_UNUSED(ne13); } -template +template static void set_rows_cuda( - const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d, + const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, const size_t nb01, const size_t nb02, const size_t nb03, @@ -144,9 +137,9 @@ static void set_rows_cuda( const int64_t s01 = nb01/sizeof(src_t); const int64_t s02 = nb02/sizeof(src_t); const int64_t s03 = nb03/sizeof(src_t); - const int64_t s10 = nb10/sizeof(int64_t); - const int64_t s11 = nb11/sizeof(int64_t); - const int64_t s12 = nb12/sizeof(int64_t); + const int64_t s10 = nb10/sizeof(idx_t); + const int64_t s11 = nb11/sizeof(idx_t); + const int64_t s12 = nb12/sizeof(idx_t); const int64_t s1 = nb1/sizeof(dst_t); const int64_t s2 = nb2/sizeof(dst_t); const int64_t s3 = nb3/sizeof(dst_t); @@ -162,23 +155,16 @@ static void set_rows_cuda( } } - -void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_I64); +template +static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const src_t * src0_d = (const src_t *)src0->data; + const idx_t * src1_d = (const idx_t *)src1->data; GGML_TENSOR_BINARY_OP_LOCALS - const float * src0_d = (const float *)src0->data; - const int64_t * src1_d = (const int64_t *)src1->data; - cudaStream_t stream = ctx.stream(); - if (dst->type == GGML_TYPE_F32) { set_rows_cuda( src0_d, src1_d, (float*)dst->data, @@ -210,7 +196,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { stream ); } else if (dst->type == GGML_TYPE_Q4_0) { - set_rows_cuda_quant( + set_rows_cuda_quant( src0_d, src1_d, (block_q4_0*)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, @@ -220,7 +206,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { stream ); } else if (dst->type == GGML_TYPE_Q4_1) { - set_rows_cuda_quant( + set_rows_cuda_quant( src0_d, src1_d, (block_q4_1*)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, @@ -230,7 +216,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { stream ); } else if (dst->type == GGML_TYPE_Q5_0) { - set_rows_cuda_quant( + set_rows_cuda_quant( src0_d, src1_d, (block_q5_0*)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, @@ -240,7 +226,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { stream ); } else if (dst->type == GGML_TYPE_Q5_1) { - set_rows_cuda_quant( + set_rows_cuda_quant( src0_d, src1_d, (block_q5_1*)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, @@ -250,7 +236,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { stream ); } else if (dst->type == GGML_TYPE_Q8_0) { - set_rows_cuda_quant( + set_rows_cuda_quant( src0_d, src1_d, (block_q8_0*)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, @@ -260,7 +246,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { stream ); } else if (dst->type == GGML_TYPE_IQ4_NL) { - set_rows_cuda_quant( + set_rows_cuda_quant( src0_d, src1_d, (block_iq4_nl*)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, @@ -273,3 +259,18 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ABORT("unsupported type %s", ggml_type_name(dst->type)); } } + + +void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32); + + if (src1->type == GGML_TYPE_I64) { + set_rows_cuda(ctx, src0, src1, dst); + } else { + set_rows_cuda(ctx, src0, src1, dst); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu index c9184398..6b424381 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu @@ -1,87 +1,117 @@ +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 +#define USE_CUB +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 + +#ifdef USE_CUB +#include +using namespace cub; +#endif // USE_CUB + #include "ssm-scan.cuh" -template -__global__ void __launch_bounds__(splitD, 2) - ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, - const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, +// We would like to keep pragma unroll for cases where L_template is not 0, +// so we suppress the clang transformation warning. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template +__global__ void __launch_bounds__(splitD, 1) + ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2, + const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5, const int32_t * __restrict__ src6, float * __restrict__ dst, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, - const int64_t s_off, const int64_t d_inner, const int64_t L) { + const int64_t s_off, const int64_t d_inner, const int64_t L_param) +{ + const size_t L = L_template == 0 ? L_param : L_template; + const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2); + const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float)); + const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float)); + const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1); + const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3)); + const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3)); + float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float)); + float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - const int bidx = blockIdx.x; // split along B (sequences) - const int bidy = blockIdx.y; // split along D (d_inner) - const int tid = threadIdx.x; - const int wid = tid / 32; - const int wtid = tid % 32; - - extern __shared__ float smem[]; - const int stride_sA = N + 1; - const int stride_ss0 = N + 1; - float * smem_A = smem; - float * smem_s0 = smem_A + splitD * stride_sA; - - const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2); - const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float)); - const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); - const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1); - const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3)); - const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3)); - float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float)); - float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2); - - const int stride_s0 = src0_nb2 / sizeof(float); - const int stride_x = src1_nb2 / sizeof(float); + const int stride_x = src1_nb2 / sizeof(float); const int stride_dt = src2_nb1 / sizeof(float); - const int stride_A = src3_nb1 / sizeof(float); - const int stride_B = src4_nb2 / sizeof(float); - const int stride_C = src5_nb2 / sizeof(float); - const int stride_s = stride_s0; - const int stride_y = d_inner; + const int stride_B = src4_nb2 / sizeof(float); + const int stride_C = src5_nb2 / sizeof(float); + const int stride_y = d_inner; - // can N not be 16? for example 32? - if (N == 16) { + float regA[N]; + float regs0[N]; + + __shared__ float smemB[N]; + __shared__ float smemC[N]; + +#ifdef USE_CUB + using BlockLoad = cub::BlockLoad; + using BlockStore = cub::BlockStore; + + union CubTempStorage { + typename BlockLoad::TempStorage load_temp; + typename BlockStore::TempStorage store_temp; + }; + __shared__ CubTempStorage cub_temp_storage; + + BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA); + BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0); +#else + const int stride_s0 = src0_nb2 / sizeof(float); + const int stride_A = src3_nb1 / sizeof(float); #pragma unroll - for (size_t i = 0; i < splitD / 4; i += 2) { - float value = A_block[(wid * warp_size + i) * stride_A + wtid]; - // todo: bank conflict - // I am always confused with how to use the swizzling method to solve - // bank conflit. Hoping somebody can tell me. - smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; - } -#pragma unroll - for (size_t i = 0; i < splitD / 4; i += 2) { - float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid]; - smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; - } + for (size_t n = 0; n < N; ++n) + { + regA[n] = A_block[threadIdx.x * stride_A + n]; + regs0[n] = s0_block[threadIdx.x * stride_s0 + n]; } +#endif - __syncthreads(); - - for (int64_t i = 0; i < L; i++) { - float dt_soft_plus = dt_block[i * stride_dt + tid]; - if (dt_soft_plus <= 20.0f) { - dt_soft_plus = log1pf(exp(dt_soft_plus)); - } - float x_dt = x_block[i * stride_x + tid] * dt_soft_plus; - float sumf = 0.0f; #pragma unroll - for (size_t j = 0; j < N; j++) { - float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) + - (B_block[i * stride_B + j] * x_dt); - sumf += state * C_block[i * stride_C + j]; - if (i == L - 1) { - s_block[tid * stride_s + j] = state; - } else { - smem_s0[tid * stride_ss0 + j] = state; - } + for (size_t i = 0; i < L; i++) + { + if (threadIdx.x < N) + { + smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x]; + smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x]; } __syncthreads(); - y_block[i * stride_y + tid] = sumf; + + float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x]; + if (dt_soft_plus <= 20.0f) + { + dt_soft_plus = log1pf(expf(dt_soft_plus)); + } + float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus; + + float sumf = 0.0f; +#pragma unroll + for (size_t n = 0; n < N; n++) + { + float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt; + sumf += state * smemC[n]; + regs0[n] = state; + } + y_block[i * stride_y + threadIdx.x] = sumf; } + +#ifdef USE_CUB + BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0); +#else + const int stride_s = stride_s0; +#pragma unroll + for (size_t n = 0; n < N; ++n) + { + s_block[threadIdx.x * stride_s + n] = regs0[n]; + } +#endif } +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ // assumes as many threads as d_state template @@ -99,7 +129,7 @@ __global__ void __launch_bounds__(d_state, 1) const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float); const int seq_idx = blockIdx.y; - const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float); + const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float)); @@ -201,11 +231,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, cudaStream_t stream) { + const int threads = 128; // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! if (src3_nb1 == sizeof(float)) { // Mamba-2 if (d_state == 128) { - const int threads = 128; GGML_ASSERT(d_state % threads == 0); // NOTE: can be any power of two between 4 and 64 const int splitH = 16; @@ -229,7 +259,6 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa GGML_ABORT("doesn't support d_state!=(128 or 256)."); } } else { - const int threads = 128; // Mamba-1 GGML_ASSERT(n_head % threads == 0); GGML_ASSERT(head_dim == 1); @@ -237,10 +266,63 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); if (d_state == 16) { - ssm_scan_f32<128, 16><<>>( - src0, src1, src2, src3, src4, src5, src6, dst, + switch (n_tok) + { + case 1: + ssm_scan_f32<<>>( + src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + case 2: + ssm_scan_f32<<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + case 3: + ssm_scan_f32<<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + case 4: + ssm_scan_f32<<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + case 5: + ssm_scan_f32<<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + case 6: + ssm_scan_f32<<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + case 7: + ssm_scan_f32<<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + case 8: + ssm_scan_f32<<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + default: + ssm_scan_f32<<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + } } else { GGML_ABORT("doesn't support d_state!=16."); } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/sum.cu b/ml/backend/ggml/ggml/src/ggml-cuda/sum.cu index eb3d7cdb..c56257b4 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/sum.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/sum.cu @@ -1,19 +1,15 @@ -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 -#define USE_CUB -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 +#include "sum.cuh" +#include "sumrows.cuh" -#ifdef USE_CUB +#ifdef GGML_CUDA_USE_CUB #include using namespace cub; -#endif // USE_CUB - -#include "sumrows.cuh" -#include "sum.cuh" +#endif // GGML_CUDA_USE_CUB #include void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) { -#ifdef USE_CUB +#ifdef GGML_CUDA_USE_CUB size_t tmp_size = 0; DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream); ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size); @@ -23,7 +19,7 @@ void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14. sum_rows_f32_cuda(x, dst, ne, 1, stream); GGML_UNUSED(pool); -#endif // USE_CUB +#endif // GGML_CUDA_USE_CUB } void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/sumrows.cu b/ml/backend/ggml/ggml/src/ggml-cuda/sumrows.cu index 2eee08fa..4025771a 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/sumrows.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/sumrows.cu @@ -1,9 +1,17 @@ +#include "reduce_rows.cuh" #include "sumrows.cuh" void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - const dim3 block_dims(WARP_SIZE, 1, 1); + const int id = ggml_cuda_get_device(); + const int nsm = ggml_cuda_info().devices[id].nsm; const dim3 block_nums(nrows, 1, 1); - reduce_rows_f32<<>>(x, dst, ncols); + if ((nrows / nsm) < 2) { + const dim3 block_dims(512, 1, 1); + reduce_rows_f32<<>>(x, dst, ncols); + } else { + const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); + reduce_rows_f32<<>>(x, dst, ncols); + } } void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -19,8 +27,17 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ncols = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); - const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_nums(nrows, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const int id = ggml_cuda_get_device(); + const int nsm = ggml_cuda_info().devices[id].nsm; + if ((nrows / nsm) < 2) { + // Increase num threads to 512 for small nrows to better hide the latency + const dim3 block_dims(512, 1, 1); + reduce_rows_f32<<>>(src0_d, dst_d, ncols); + } else { + // Enough active SMs to hide latency, use smaller blocks to allow better scheduling + const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); + reduce_rows_f32<<>>(src0_d, dst_d, ncols); + } } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu new file mode 100644 index 00000000..a8b15ad7 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(112, 112); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu new file mode 100644 index 00000000..1da18105 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(128, 128); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu new file mode 100644 index 00000000..bc65c723 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(256, 256); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu new file mode 100644 index 00000000..10b330fa --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(40, 40); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu new file mode 100644 index 00000000..254b7d2e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(576, 512); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu new file mode 100644 index 00000000..5caffac0 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(64, 64); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu new file mode 100644 index 00000000..90abb3b1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(80, 80); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu new file mode 100644 index 00000000..7292c0aa --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(96, 96); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu deleted file mode 100644 index 6696a238..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu deleted file mode 100644 index dd070db2..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu deleted file mode 100644 index 54dcde6f..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu deleted file mode 100644 index 4ec22f79..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu deleted file mode 100644 index 3c15bf7f..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu deleted file mode 100644 index 7e61b5fd..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu deleted file mode 100644 index fdb15b58..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu deleted file mode 100644 index 0f7c417d..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu deleted file mode 100644 index 851f33c4..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu deleted file mode 100644 index 763809cb..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu deleted file mode 100644 index f2a276e5..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu deleted file mode 100644 index cb227f6f..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu deleted file mode 100644 index 97ac0520..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu deleted file mode 100644 index c772b426..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu deleted file mode 100644 index 5cb74308..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu deleted file mode 100644 index 98a709d1..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu deleted file mode 100644 index 4f2f947a..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu deleted file mode 100644 index 11f96b6f..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu deleted file mode 100644 index b39bdc06..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu deleted file mode 100644 index bbd6a2c7..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu deleted file mode 100644 index 9d84ff2b..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu deleted file mode 100644 index bc8a5bff..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu deleted file mode 100644 index a679100c..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu deleted file mode 100644 index 8f21bccf..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu deleted file mode 100644 index 858b00fd..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu deleted file mode 100644 index 0fc8011f..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu deleted file mode 100644 index 261fdf62..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu deleted file mode 100644 index 0fb82473..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu deleted file mode 100644 index a9d9d089..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu deleted file mode 100644 index 7d7b2792..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu deleted file mode 100644 index a092ee2d..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu deleted file mode 100644 index db55927a..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu deleted file mode 100644 index c3c21cef..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu deleted file mode 100644 index 35dd9f52..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu deleted file mode 100644 index 050c22ac..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu deleted file mode 100644 index de4866c5..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu deleted file mode 100644 index 57a10bc4..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu deleted file mode 100644 index e0f08b46..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu deleted file mode 100644 index 1c8e8a46..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu deleted file mode 100644 index cefed83f..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu deleted file mode 100644 index aede6e35..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu deleted file mode 100644 index 1a1a92c7..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu deleted file mode 100644 index ad667473..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu deleted file mode 100644 index c499f455..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu deleted file mode 100644 index 8286ebf3..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu deleted file mode 100644 index 45878688..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu deleted file mode 100644 index d89103ce..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu deleted file mode 100644 index bb75fd42..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu deleted file mode 100644 index b1629817..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu deleted file mode 100644 index d8657604..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu deleted file mode 100644 index 2e5bd2f1..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu deleted file mode 100644 index be5f302d..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu deleted file mode 100644 index 8dd91cd7..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu deleted file mode 100644 index 4cb79150..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu deleted file mode 100644 index 09dea426..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu deleted file mode 100644 index 0fbb6076..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu deleted file mode 100644 index 2aeab83b..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu deleted file mode 100644 index 599415b4..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu deleted file mode 100644 index e4f8e308..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu deleted file mode 100644 index 34d16652..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu deleted file mode 100644 index 4bebef45..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu deleted file mode 100644 index 326468da..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu deleted file mode 100644 index 511b58f4..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu deleted file mode 100644 index d9906d14..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu deleted file mode 100644 index f61c183a..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu deleted file mode 100644 index c10450fd..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu deleted file mode 100644 index 2d5cb195..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu deleted file mode 100644 index b384f34d..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu deleted file mode 100644 index 446e293b..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu deleted file mode 100644 index 6f430298..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu deleted file mode 100644 index 1cd8ba88..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu deleted file mode 100644 index 1ee2eab6..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu deleted file mode 100644 index 2bc77816..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu deleted file mode 100644 index d55ced08..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu deleted file mode 100644 index 8361e99c..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu deleted file mode 100644 index 7507a67c..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu deleted file mode 100644 index 61f050b2..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu deleted file mode 100644 index d4a49d9c..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu deleted file mode 100644 index d1462789..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu deleted file mode 100644 index e73f917a..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu deleted file mode 100644 index d40825df..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu deleted file mode 100644 index b5c6869f..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu deleted file mode 100644 index 4e21b0cc..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu deleted file mode 100644 index 2eac321b..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu deleted file mode 100644 index f7d2c3b4..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu deleted file mode 100644 index a013f400..00000000 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu new file mode 100644 index 00000000..c357abd8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu new file mode 100644 index 00000000..4b148656 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu new file mode 100644 index 00000000..ef771575 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu new file mode 100644 index 00000000..9ae11cc5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu new file mode 100644 index 00000000..10ed48af --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu new file mode 100644 index 00000000..4fcc3f33 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu new file mode 100644 index 00000000..7ca50531 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu new file mode 100644 index 00000000..6ef1a48f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu new file mode 100644 index 00000000..4c0532ca --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu new file mode 100644 index 00000000..ed3d7bad --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu new file mode 100644 index 00000000..687f2540 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu new file mode 100644 index 00000000..41107c45 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu new file mode 100644 index 00000000..d523ce01 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu new file mode 100644 index 00000000..8b9ed358 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu new file mode 100644 index 00000000..0553e464 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu new file mode 100644 index 00000000..8390eaf1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu new file mode 100644 index 00000000..f61e19d6 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu new file mode 100644 index 00000000..86a18826 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu new file mode 100644 index 00000000..1d7af474 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu new file mode 100644 index 00000000..837224d3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu new file mode 100644 index 00000000..0dd7dd69 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu new file mode 100644 index 00000000..41b859f4 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu new file mode 100644 index 00000000..d2e5ffd0 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu new file mode 100644 index 00000000..81ff740b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu new file mode 100644 index 00000000..a38dae19 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu new file mode 100644 index 00000000..2304571e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu new file mode 100644 index 00000000..84b83e55 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu new file mode 100644 index 00000000..39f80e21 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu new file mode 100644 index 00000000..cf4e6611 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu new file mode 100644 index 00000000..65654182 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu new file mode 100644 index 00000000..a1bc3f5a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu new file mode 100644 index 00000000..4b76a9be --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu new file mode 100644 index 00000000..77d04125 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu new file mode 100644 index 00000000..6e170fe3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu new file mode 100644 index 00000000..b617cd73 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu new file mode 100644 index 00000000..a5b768b1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu new file mode 100644 index 00000000..f594d5d5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu new file mode 100644 index 00000000..9cc67725 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(10); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu new file mode 100644 index 00000000..317f487d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(11); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu new file mode 100644 index 00000000..dc003322 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(12); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu new file mode 100644 index 00000000..07821017 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(13); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu new file mode 100644 index 00000000..a23ad6ae --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(14); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu new file mode 100644 index 00000000..0fe3f782 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(15); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu new file mode 100644 index 00000000..54408637 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu new file mode 100644 index 00000000..3b901797 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(2); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu new file mode 100644 index 00000000..56e940bb --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(3); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu new file mode 100644 index 00000000..a7665d49 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(4); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu new file mode 100644 index 00000000..3a1dff25 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(5); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu new file mode 100644 index 00000000..400fb7c6 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(6); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu new file mode 100644 index 00000000..954a1c7e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(7); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu new file mode 100644 index 00000000..f1bd09c9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(8); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu new file mode 100644 index 00000000..1255ac2a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(9); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu b/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu new file mode 100644 index 00000000..afe4aee2 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu @@ -0,0 +1,257 @@ +#include "ggml-cuda/common.cuh" +#include "ggml.h" +#include "topk-moe.cuh" + +#include + +/* + This kernel does the following: + 1. softmax over the logits per token [n_experts, n_tokens] + 2. argmax reduce over the top-k (n_experts_used) logits + 3. write weights + ids to global memory + 4. optionally normalize the weights + + It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models +*/ +template +__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, + float * weights, + int32_t * ids, + const int n_rows, + const int n_expert_used) { + const int row = blockIdx.x * blockDim.y + threadIdx.y; + if (row >= n_rows) { + return; + } + + logits += n_experts * row; + weights += n_expert_used * row; + ids += n_experts * row; + + constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; + + float logits_r[experts_per_thread]; + +#pragma unroll + for (int i = 0; i < n_experts; i += WARP_SIZE) { + const int expert = i + threadIdx.x; + logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY; + } + + float max_val = logits_r[0]; + +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const float val = logits_r[i]; + max_val = max(val, max_val); + } + + max_val = warp_reduce_max(max_val); + + float wt[experts_per_thread]; + float tmp = 0.f; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const float val = logits_r[i]; + wt[i] = expf(val - max_val); + tmp += wt[i]; + } + + tmp = warp_reduce_sum(tmp); + + const float inv_sum = 1.0f / tmp; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + wt[i] = wt[i] * inv_sum; + } + + //at this point, each thread holds a portion of softmax, + //we do the argmax reduce over n_expert_used, each time marking + //the expert weight as -inf to exclude from the next iteration + + float wt_sum = 0.f; + + extern __shared__ float data_topk_shared[]; + float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used; + + for (int k = 0; k < n_expert_used; k++) { + float max_val = wt[0]; + int max_expert = threadIdx.x; + +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { + max_val = wt[i]; + max_expert = expert; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val > max_val || (val == max_val && expert < max_expert)) { + max_val = val; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + wt[max_expert / WARP_SIZE] = -INFINITY; + + wt_shared_ptr[k] = max_val; + ids[k] = max_expert; + if constexpr (with_norm) { + wt_sum += max_val; + } + } + } + + if constexpr (with_norm) { + wt_sum = warp_reduce_sum(wt_sum); + const float inv_sum = 1.0f / wt_sum; + + for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { + wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum; + } + } + + for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { + weights[i] = wt_shared_ptr[i]; + } +} + +template +static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, + const float * logits, + float * weights, + int32_t * ids, + const int n_rows, + const int n_expert, + const int n_expert_used) { + const int rows_per_block = 4; + dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); + dim3 block_dims(WARP_SIZE, rows_per_block, 1); + cudaStream_t stream = ctx.stream(); + + const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float); + + switch (n_expert) { + case 1: + topk_moe_cuda<1, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 2: + topk_moe_cuda<2, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 4: + topk_moe_cuda<4, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 8: + topk_moe_cuda<8, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 16: + topk_moe_cuda<16, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 32: + topk_moe_cuda<32, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 64: + topk_moe_cuda<64, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 128: + topk_moe_cuda<128, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 256: + topk_moe_cuda<256, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + case 512: + topk_moe_cuda<512, with_norm> + <<>>(logits, weights, ids, n_rows, n_expert_used); + break; + default: + GGML_ASSERT(false && "fatal error"); + break; + } +} + +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids, + const bool with_norm) { + GGML_ASSERT(logits->type == GGML_TYPE_F32); + GGML_ASSERT(weights->type == GGML_TYPE_F32); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const int n_experts = logits->ne[0]; + const int n_rows = logits->ne[1]; + + const float * logits_d = (const float *) logits->src[0]->data; + float * weights_d = (float *) weights->data; + int32_t * ids_d = (int32_t *) ids->data; + + GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); + + const int n_expert_used = weights->ne[1]; + + if (with_norm) { + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + } else { + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + } +} + +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) { + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); + + if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { + return false; + } + + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } + + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { + return false; + } + + const int n_expert = softmax->ne[0]; + // n_expert must be a power of 2 + if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { + return false; + } + + return true; +} + +std::initializer_list ggml_cuda_topk_moe_ops(bool norm) { + static std::initializer_list norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; + + static std::initializer_list no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS }; + + if (norm) { + return norm_ops; + } + return no_norm_ops; +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cuh new file mode 100644 index 00000000..6613fb56 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cuh @@ -0,0 +1,14 @@ +#include "common.cuh" +#include "ggml.h" + +#include + +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * top_k, + const bool with_norm); + +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights); + +std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/tsembd.cu b/ml/backend/ggml/ggml/src/ggml-cuda/tsembd.cu index 153ddbcd..b91a26fc 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/tsembd.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/tsembd.cu @@ -7,11 +7,11 @@ static __global__ void timestep_embedding_f32(const float * timesteps, float * d int j = threadIdx.x + blockIdx.x * blockDim.x; float * embed_data = (float *)((char *)dst + i*nb1); - if (dim % 2 != 0 && j == ((dim + 1) / 2)) { - embed_data[dim] = 0.f; + int half = dim / 2; + if (dim % 2 != 0 && j == half) { + embed_data[2 * half] = 0.f; } - int half = dim / 2; if (j >= half) { return; } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu index 5aff8a87..3c564566 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu @@ -1,4 +1,5 @@ #include "unary.cuh" +#include "convert.cuh" static __device__ __forceinline__ float op_abs(float x) { return fabsf(x); @@ -375,6 +376,59 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream); } +/* CUDA kernel + launcher for xIELU */ + +template +static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float xi = ggml_cuda_cast(x[i]); + + const float gate_pos = (xi > 0.0f); + const float y_pos = alpha_p * xi * xi + beta * xi; + const float min_v_eps = fminf(xi, eps); + const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi; + const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg; + + dst[i] = ggml_cuda_cast(out); +} + +template +static void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) { + const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE; + xielu_kernel<<>>(x, dst, k, alpha_n, alpha_p, beta, eps); +} + +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const void * src0_d = src0->data; + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + const float alpha_n = ggml_get_op_params_f32(dst, 1); + const float alpha_p = ggml_get_op_params_f32(dst, 2); + const float beta = ggml_get_op_params_f32(dst, 3); + const float eps = ggml_get_op_params_f32(dst, 4); + + if (src0->type == GGML_TYPE_F16) { + xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream); + } else { + xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream); + } +} + + + /* silu_back */ static __device__ __forceinline__ float op_silu_back(float grad, float x) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh index da3caf1d..8e7644fc 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh @@ -16,6 +16,7 @@ #define CUDA_SIN_BLOCK_SIZE 256 #define CUDA_COS_BLOCK_SIZE 256 #define CUDA_GLU_BLOCK_SIZE 256 +#define CUDA_XIELU_BLOCK_SIZE 256 void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -72,3 +73,5 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vecdotq.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/vecdotq.cuh index d8f9aa5b..6baab117 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/vecdotq.cuh @@ -28,7 +28,58 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32 return ((const int *) x)[i32]; // assume at least 4 byte alignment } +// q4 contains 8 indices with 4 bit each. +// This function selects those bytes from table that are at those indices and returns them as int2. +// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4. static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) { +#if defined(GGML_USE_HIP) + // Load the 16-byte table into four 32-bit unsigned integers. + const uint32_t *values = (const uint32_t *)table; + + const uint32_t q_even = q4; + const uint32_t q_odd = (q4 >> 4); + + // Perform lookups in the lower half of the table (indices 0-7). + uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707); + uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707); + + // Perform lookups in the upper half of the table (indices 8-15). + uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707); + uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707); + + // Select between the low and high results based on the MSB of each index nibble. + uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1); + uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even); + uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1); + uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd); + + return make_int2(res_x, res_y); +#elif !defined(GGML_USE_MUSA) + // CUDA does not have an instruction for selecting bytes with 4 bit indices. + // However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead. + const uint32_t * table32 = (const uint32_t *) table; + + // __byte_perm selects bytes based on the lower 16 bits in its third argument. + // Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift. + // To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits. + // Then, call __byte_perm again to select from the low and high bytes based on the fourth bit. + uint32_t tmp[2]; + const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1)); +#pragma unroll + for (uint32_t i = 0; i < 2; ++i) { + const uint32_t shift = 16 * i; + + const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift); + const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift); + tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift); + } + + // tmp contains the bytes from tyble in the same order as the 4 bit indices in q4. + // However, for the result we need ints with all even/odd 4 bit indices in q4. + // Therefore, 2 more calls to __byte_perm to put the bytes in the correct order. + return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531)); +#else + // Generic implementation. const int q0_32 = (q4 >> 0) & 0x0F0F0F0F; const int8_t * q0_8 = (const int8_t *) &q0_32; const char4 val0_8 = make_char4( @@ -40,6 +91,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]); return make_int2(*((const int *) &val0_8), *((const int *) &val1_8)); +#endif } // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called @@ -87,7 +139,7 @@ template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); } -#ifdef GGML_CUDA_F16 +#ifdef FAST_FP16_AVAILABLE const float2 tmp = __half22float2(__hmul2(dm4, ds8)); const float d4d8 = tmp.x; const float m4s8 = tmp.y; @@ -96,7 +148,7 @@ template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp const float2 ds8f = __half22float2(ds8); const float d4d8 = dm4f.x * ds8f.x; const float m4s8 = dm4f.y * ds8f.y; -#endif // GGML_CUDA_F16 +#endif // FAST_FP16_AVAILABLE // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1)); @@ -158,7 +210,7 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values } -#ifdef GGML_CUDA_F16 +#ifdef FAST_FP16_AVAILABLE const float2 tmp = __half22float2(__hmul2(dm5, ds8)); const float d5d8 = tmp.x; const float m5s8 = tmp.y; @@ -167,7 +219,7 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp const float2 ds8f = __half22float2(ds8); const float d5d8 = dm5f.x * ds8f.x; const float m5s8 = dm5f.y * ds8f.y; -#endif // GGML_CUDA_F16 +#endif // FAST_FP16_AVAILABLE // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it return sumi*d5d8 + m5s8 / (QI5_1 / vdr); @@ -201,7 +253,7 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp sumi = ggml_cuda_dp4a(v[i], u[i], sumi); } -#ifdef GGML_CUDA_F16 +#ifdef FAST_FP16_AVAILABLE const float2 tmp = __half22float2(__hmul2(dm8, ds8)); const float d8d8 = tmp.x; const float m8s8 = tmp.y; @@ -210,7 +262,7 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp const float2 ds8f = __half22float2(ds8); const float d8d8 = dm8f.x * ds8f.x; const float m8s8 = dm8f.y * ds8f.y; -#endif // GGML_CUDA_F16 +#endif // FAST_FP16_AVAILABLE // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it return sumi*d8d8 + m8s8 / (QI8_1 / vdr); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h index cf22e60d..2f9ef2dc 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h +++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h @@ -1,13 +1,17 @@ #pragma once -#define HIP_ENABLE_WARP_SYNC_BUILTINS 1 +#define HIP_DISABLE_WARP_SYNC_BUILTINS 1 #include #include #include -#include +#include // for rocblas_initialize() #include "rocblas/rocblas.h" +#if defined(GGML_HIP_ROCWMMA_FATTN) +#include +#endif // defined(GGML_HIP_ROCWMMA_FATTN) + #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N @@ -24,7 +28,10 @@ #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) +#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +#define __all_sync(mask, var) __all(var) +#define __any_sync(mask, var) __any(var) #define cublasCreate hipblasCreate #define cublasDestroy hipblasDestroy #define cublasGemmEx hipblasGemmEx @@ -42,6 +49,7 @@ #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceReset hipDeviceReset #define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaDriverGetVersion hipDriverGetVersion #define cudaError_t hipError_t #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled @@ -138,7 +146,7 @@ #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED -#if HIP_VERSION >= 70000000 +#if HIP_VERSION >= 60500000 #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F @@ -150,7 +158,7 @@ #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F #define cublasComputeType_t hipblasDatatype_t #define cudaDataType_t hipblasDatatype_t -#endif // HIP_VERSION >= 7000000 +#endif // HIP_VERSION >= 6050000 #if !defined(__HIP_PLATFORM_AMD__) #error "The HIP backend supports only AMD targets" @@ -158,34 +166,41 @@ #define __CUDA_ARCH__ 1300 -#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) -#define GCN -#endif +#if defined(__gfx900__) || defined(__gfx906__) +#define GCN5 +#endif // defined(__gfx900__) || defined(__gfx906__) -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) -#define CDNA // For the entire family -#endif +#if defined(__gfx803__) +#define GCN4 +#endif // defined(__gfx803__) + +#if defined(GCN5) || defined(GCN4) +#define GCN +#endif // defined(GCN5) || defined(GCN4) #if defined(__gfx942__) #define CDNA3 -#endif +#endif // defined(__gfx942__) #if defined(__gfx90a__) #define CDNA2 -#endif +#endif // defined(__gfx90a__) #if defined(__gfx908__) #define CDNA1 -#endif +#endif // defined(__gfx908__) + +#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#define CDNA // For the entire family +#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1) #if defined(__GFX12__) #define RDNA4 -#endif +#endif // defined(__GFX12__) -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ - defined(__gfx1150__) || defined(__gfx1151__) +#if defined(__GFX11__) #define RDNA3 -#endif +#endif // defined(__GFX11__) #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) @@ -194,14 +209,18 @@ #if defined(__gfx1010__) || defined(__gfx1012__) #define RDNA1 -#endif +#endif // defined(__gfx1010__) || defined(__gfx1012__) + +#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1) +#define RDNA // For the entire family +#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1) #ifndef __has_builtin #define __has_builtin(x) 0 #endif -typedef hip_bfloat16 nv_bfloat16; -typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); @@ -252,17 +271,3 @@ static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigne } return c; } - -#if HIP_VERSION < 50600000 -// __shfl_xor() for half2 was added in ROCm 5.6 -static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) { - typedef union half2_b32 { - half2 val; - int b32; - } half2_b32_t; - half2_b32_t tmp; - tmp.val = var; - tmp.b32 = __shfl_xor(tmp.b32, laneMask, width); - return tmp.val; -} -#endif // HIP_VERSION < 50600000 diff --git a/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt index 852de973..934aefdc 100644 --- a/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt @@ -39,15 +39,9 @@ endif() find_package(hip REQUIRED) find_package(hipblas REQUIRED) find_package(rocblas REQUIRED) -if (GGML_HIP_ROCWMMA_FATTN) - CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA) - if (NOT ${FOUND_ROCWMMA}) - message(FATAL_ERROR "rocwmma has not been found") - endif() -endif() -if (${hip_VERSION} VERSION_LESS 5.5) - message(FATAL_ERROR "At least ROCM/HIP V5.5 is required") +if (${hip_VERSION} VERSION_LESS 6.1) + message(FATAL_ERROR "At least ROCM/HIP V6.1 is required") endif() message(STATUS "HIP and hipBLAS found") @@ -59,6 +53,8 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu") +file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu") +list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") @@ -117,10 +113,6 @@ if (NOT GGML_HIP_MMQ_MFMA) add_compile_definitions(GGML_HIP_NO_MMQ_MFMA) endif() -if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0) - add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12) -endif() - if (GGML_HIP_EXPORT_METRICS) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps") endif() diff --git a/ml/backend/ggml/ggml/src/ggml-impl.h b/ml/backend/ggml/ggml/src/ggml-impl.h index 19a7adb2..80597b6e 100644 --- a/ml/backend/ggml/ggml/src/ggml-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-impl.h @@ -73,7 +73,7 @@ static inline int ggml_up(int n, int m) { return (n + m - 1) & ~(m - 1); } -// TODO: move to ggml.h? +// TODO: move to ggml.h? (won't be able to inline) static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { if (a->type != b->type) { return false; @@ -89,6 +89,22 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml return true; } +static bool ggml_op_is_empty(enum ggml_op op) { + switch (op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_TRANSPOSE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + return true; + default: + return false; + } +} + +static inline float ggml_softplus(float input) { + return (input > 20.0f) ? input : logf(1 + expf(input)); +} // // logging // @@ -329,6 +345,10 @@ struct ggml_cgraph { // if you need the gradients, get them from the original graph struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1); +// ggml-alloc.c: true if the operation can reuse memory from its sources +GGML_API bool ggml_op_can_inplace(enum ggml_op op); + + // Memory allocation GGML_API void * ggml_aligned_malloc(size_t size); @@ -570,27 +590,27 @@ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int n return true; } -// Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[] +// Returns true if nodes with indices { node_idxs } are the sequence of ggml_ops in ops[] // and are fusable. Nodes are considered fusable according to this function if: // - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses). // - all nodes except the last are a src of the following node. // - all nodes are the same shape. // TODO: Consider allowing GGML_OP_NONE nodes in between -static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) { - if (node_idx + num_ops > cgraph->n_nodes) { - return false; - } - +static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const int * node_idxs, const enum ggml_op * ops, int num_ops) { for (int i = 0; i < num_ops; ++i) { - struct ggml_tensor * node = cgraph->nodes[node_idx + i]; + if (node_idxs[i] >= cgraph->n_nodes) { + return false; + } + + struct ggml_tensor * node = cgraph->nodes[node_idxs[i]]; if (node->op != ops[i]) { return false; } - if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idx + i, 1)) { + if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) { return false; } if (i > 0) { - struct ggml_tensor * prev = cgraph->nodes[node_idx + i - 1]; + struct ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]]; if (node->src[0] != prev && node->src[1] != prev) { return false; } @@ -602,6 +622,30 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx return true; } +// same as above, for sequential indices starting at node_idx +static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) { + assert(num_ops < 32); + + if (node_idx + num_ops > cgraph->n_nodes) { + return false; + } + + int idxs[32]; + for (int i = 0; i < num_ops; ++i) { + idxs[i] = node_idx + i; + } + + return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops); +} + +// Management libraries for fetching more accurate free VRAM data +GGML_API int ggml_nvml_init(); +GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total); +GGML_API void ggml_nvml_release(); +GGML_API int ggml_hip_mgmt_init(); +GGML_API int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total); +GGML_API void ggml_hip_mgmt_release(); + #ifdef __cplusplus } #endif diff --git a/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt index 0ca8a3c5..63418fe1 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt @@ -5,7 +5,12 @@ find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) message(STATUS "Metal framework found") ggml_add_backend_library(ggml-metal - ggml-metal.m + ggml-metal.cpp + ggml-metal-device.m + ggml-metal-device.cpp + ggml-metal-common.cpp + ggml-metal-context.m + ggml-metal-ops.cpp ) target_link_libraries(ggml-metal PRIVATE @@ -18,10 +23,6 @@ if (GGML_METAL_NDEBUG) add_compile_definitions(GGML_METAL_NDEBUG) endif() -if (GGML_METAL_USE_BF16) - add_compile_definitions(GGML_METAL_USE_BF16) -endif() - # copy metal files to bin directory configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-common.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-common.cpp new file mode 100644 index 00000000..95627d38 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -0,0 +1,446 @@ +#include "ggml-metal-common.h" + +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include + +// represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb) +// the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it) +struct ggml_mem_range { + uint64_t pb; // buffer id + + uint64_t p0; // begin + uint64_t p1; // end + + ggml_mem_range_type pt; +}; + +struct ggml_mem_ranges { + std::vector ranges; + + int debug = 0; +}; + +ggml_mem_ranges_t ggml_mem_ranges_init(int debug) { + auto * res = new ggml_mem_ranges; + + res->ranges.reserve(256); + res->debug = debug; + + return res; +} + +void ggml_mem_ranges_free(ggml_mem_ranges_t mrs) { + delete mrs; +} + +void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs) { + mrs->ranges.clear(); +} + +static bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, ggml_mem_range mr) { + mrs->ranges.push_back(mr); + + return true; +} + +static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggml_mem_range_type pt) { + // always use the base tensor + tensor = tensor->view_src ? tensor->view_src : tensor; + + GGML_ASSERT(!tensor->view_src); + + ggml_mem_range mr; + + if (tensor->buffer) { + // when the tensor is allocated, use the actual memory address range in the buffer + // + // take the actual allocated size with ggml_backend_buft_get_alloc_size() + // this can be larger than the tensor size if the buffer type allocates extra memory + // ref: https://github.com/ggml-org/llama.cpp/pull/15966 + mr = { + /*.pb =*/ (uint64_t) tensor->buffer, + /*.p0 =*/ (uint64_t) tensor->data, + /*.p1 =*/ (uint64_t) tensor->data + ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor), + /*.pt =*/ pt, + }; + } else { + // otherwise, the pointer address is used as an unique id of the memory ranges + // that the tensor will be using when it is allocated + mr = { + /*.pb =*/ (uint64_t) tensor, + /*.p0 =*/ 0, // + /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used + /*.pt =*/ pt, + }; + }; + + return mr; +} + +static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) { + return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC); +} + +static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) { + return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST); +} + +static bool ggml_mem_ranges_add_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { + GGML_ASSERT(tensor); + + ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor); + + if (mrs->debug > 2) { + GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1); + } + + return ggml_mem_ranges_add(mrs, mr); +} + +static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { + GGML_ASSERT(tensor); + + ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor); + + if (mrs->debug > 2) { + GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1); + } + + return ggml_mem_ranges_add(mrs, mr); +} + +bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i]) { + ggml_mem_ranges_add_src(mrs, tensor->src[i]); + } + } + + return ggml_mem_ranges_add_dst(mrs, tensor); +} + +static bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, ggml_mem_range mr) { + for (size_t i = 0; i < mrs->ranges.size(); i++) { + const auto & cmp = mrs->ranges[i]; + + // two memory ranges cannot intersect if they are in different buffers + if (mr.pb != cmp.pb) { + continue; + } + + // intersecting source ranges are allowed + if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) { + continue; + } + + if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) { + if (mrs->debug > 2) { + GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n", + __func__, + mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst", + mr.pb, mr.p0, mr.p1, + cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst", + cmp.pb, cmp.p0, cmp.p1); + } + + return false; + } + } + + return true; +} + +static bool ggml_mem_ranges_check_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { + GGML_ASSERT(tensor); + + ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor); + + const bool res = ggml_mem_ranges_check(mrs, mr); + + return res; +} + +static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { + GGML_ASSERT(tensor); + + ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor); + + const bool res = ggml_mem_ranges_check(mrs, mr); + + return res; +} + +bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i]) { + if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) { + return false; + } + } + } + + return ggml_mem_ranges_check_dst(mrs, tensor); +} + +struct node_info { + ggml_tensor * node; + + std::vector fused; + + ggml_op op() const { + return node->op; + } + + const ggml_tensor * dst() const { + return fused.empty() ? node : fused.back(); + } + + bool is_empty() const { + return ggml_op_is_empty(node->op); + } + + void add_fused(ggml_tensor * t) { + fused.push_back(t); + } +}; + +static std::vector ggml_metal_graph_optimize_reorder(const std::vector & nodes) { + // helper to add node src and dst ranges + const auto & h_add = [](ggml_mem_ranges_t mrs, const node_info & node) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node.node->src[i]) { + if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) { + return false; + } + } + } + + // keep track of the sources of the fused nodes as well + for (const auto * fused : node.fused) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (fused->src[i]) { + if (!ggml_mem_ranges_add_src(mrs, fused->src[i])) { + return false; + } + } + } + } + + return ggml_mem_ranges_add_dst(mrs, node.dst()); + }; + + // helper to check if a node can run concurrently with the existing set of nodes + const auto & h_check = [](ggml_mem_ranges_t mrs, const node_info & node) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node.node->src[i]) { + if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) { + return false; + } + } + } + + for (const auto * fused : node.fused) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (fused->src[i]) { + if (!ggml_mem_ranges_check_src(mrs, fused->src[i])) { + return false; + } + } + } + } + + return ggml_mem_ranges_check_dst(mrs, node.dst()); + }; + + // perform reorders only across these types of ops + // can be expanded when needed + const auto & h_safe = [](ggml_op op) { + switch (op) { + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ROPE: + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MUL: + case GGML_OP_ADD: + case GGML_OP_DIV: + case GGML_OP_GLU: + case GGML_OP_SCALE: + case GGML_OP_GET_ROWS: + case GGML_OP_CPY: + case GGML_OP_SET_ROWS: + return true; + default: + return ggml_op_is_empty(op); + } + }; + + const int n = nodes.size(); + + std::vector res; + res.reserve(n); + + std::vector used(n, false); + + // the memory ranges for the set of currently concurrent nodes + ggml_mem_ranges_t mrs0 = ggml_mem_ranges_init(0); + + // the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder + ggml_mem_ranges_t mrs1 = ggml_mem_ranges_init(0); + + for (int i0 = 0; i0 < n; i0++) { + if (used[i0]) { + continue; + } + + const auto & node0 = nodes[i0]; + + // the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0) + // but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0 + // + // note: we can always add empty nodes to the concurrent set as they don't read nor write anything + if (!node0.is_empty() && !h_check(mrs0, node0)) { + // this will hold the set of memory ranges from the nodes that haven't been processed yet + // if a node is not concurrent with this set, we cannot reorder it + ggml_mem_ranges_reset(mrs1); + + // initialize it with the current node + h_add(mrs1, node0); + + // that many nodes forward to search for a concurrent node + constexpr int N_FORWARD = 8; + + for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) { + if (used[i1]) { + continue; + } + + const auto & node1 = nodes[i1]; + + // disallow reordering of certain ops + if (!h_safe(node1.op())) { + break; + } + + const bool is_empty = node1.is_empty(); + + // to reorder a node and add it to the concurrent set, it has to be: + // + empty or concurrent with all nodes in the existing concurrent set (mrs0) + // + concurrent with all nodes prior to it that haven't been processed yet (mrs1) + if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) { + // add the node to the existing concurrent set (i.e. reorder it for early execution) + h_add(mrs0, node1); + res.push_back(i1); + + // mark as used, so we skip re-processing it later + used[i1] = true; + } else { + // expand the set of nodes that haven't been processed yet + h_add(mrs1, node1); + } + } + + // finalize the concurrent set and begin a new one + ggml_mem_ranges_reset(mrs0); + } + + // expand the concurrent set with the current node + { + h_add(mrs0, node0); + res.push_back(i0); + } + } + + ggml_mem_ranges_free(mrs0); + ggml_mem_ranges_free(mrs1); + + return res; +} + +void ggml_graph_optimize(ggml_cgraph * gf) { + constexpr int MAX_FUSE = 16; + + const int n = gf->n_nodes; + + enum ggml_op ops[MAX_FUSE]; + + std::vector nodes; + nodes.reserve(gf->n_nodes); + + // fuse nodes: + // we don't want to make reorders that break fusing, so we first pack all fusable tensors + // and perform the reorder over the fused nodes. after the reorder is done, we unfuse + for (int i = 0; i < n; i++) { + node_info node = { + /*.node =*/ gf->nodes[i], + /*.fused =*/ {}, + }; + + // fuse only ops that start with these operations + // can be expanded when needed + if (node.op() == GGML_OP_ADD || + node.op() == GGML_OP_NORM || + node.op() == GGML_OP_RMS_NORM) { + ops[0] = node.op(); + + int f = i + 1; + while (f < n && f < i + MAX_FUSE) { + // conservatively allow fusing only these ops + // can be expanded when needed + if (gf->nodes[f]->op != GGML_OP_ADD && + gf->nodes[f]->op != GGML_OP_MUL && + gf->nodes[f]->op != GGML_OP_NORM && + gf->nodes[f]->op != GGML_OP_RMS_NORM) { + break; + } + ops[f - i] = gf->nodes[f]->op; + f++; + } + + f -= i; + for (; f > 1; f--) { + if (ggml_can_fuse(gf, i, ops, f)) { + break; + } + } + + // add the fused tensors into the node info so we can unfuse them later + for (int k = 1; k < f; k++) { + ++i; + + // the .dst() becomes the last fused tensor + node.add_fused(gf->nodes[i]); + } + } + + nodes.push_back(std::move(node)); + } + +#if 1 + // reorder to improve concurrency + const auto order = ggml_metal_graph_optimize_reorder(nodes); +#else + std::vector order(nodes.size()); + for (size_t i = 0; i < nodes.size(); i++) { + order[i] = i; + } +#endif + + // unfuse + { + int j = 0; + for (const auto i : order) { + const auto & node = nodes[i]; + + gf->nodes[j++] = node.node; + + for (auto * fused : node.fused) { + gf->nodes[j++] = fused; + } + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-common.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-common.h new file mode 100644 index 00000000..3acbc6ae --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-common.h @@ -0,0 +1,52 @@ +// helper functions for ggml-metal that are too difficult to implement in Objective-C + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_tensor; +struct ggml_cgraph; + +enum ggml_mem_range_type { + MEM_RANGE_TYPE_SRC = 0, + MEM_RANGE_TYPE_DST = 1, +}; + +// a helper object that can be used for reordering operations to improve concurrency +// +// the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they +// don't write to a memory that is being read by another task or written to by another task in the set +// +// with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task +// can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the +// tasks already in the set) +// +typedef struct ggml_mem_ranges * ggml_mem_ranges_t; + +ggml_mem_ranges_t ggml_mem_ranges_init(int debug); +void ggml_mem_ranges_free(ggml_mem_ranges_t mrs); + +// remove all ranges from the set +void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs); + +// add src or dst ranges to track +bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor); + +// return false if: +// - new src range overlaps with any existing dst range +// - new dst range overlaps with any existing range (src or dst) +bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor); + +// reorder the nodes in the graph to improve concurrency, while respecting fusion +// +// note: this implementation is generic and not specific to metal +// if it proves to work well, we can start using it for other backends in the future +void ggml_graph_optimize(struct ggml_cgraph * gf); + +#ifdef __cplusplus +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-context.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-context.h new file mode 100644 index 00000000..ec2b686b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-context.h @@ -0,0 +1,33 @@ +#pragma once + +#include "ggml-metal-device.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// +// backend context +// + +typedef struct ggml_metal * ggml_metal_t; + +ggml_metal_t ggml_metal_init(ggml_metal_device_t dev); +void ggml_metal_free(ggml_metal_t ctx); + +void ggml_metal_synchronize(ggml_metal_t ctx); + +void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); +void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + +enum ggml_status ggml_metal_graph_compute (ggml_metal_t ctx, struct ggml_cgraph * gf); +void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf); + +void ggml_metal_set_n_cb (ggml_metal_t ctx, int n_cb); +void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data); +bool ggml_metal_supports_family (ggml_metal_t ctx, int family); +void ggml_metal_capture_next_compute(ggml_metal_t ctx); + +#ifdef __cplusplus +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-context.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-context.m new file mode 100644 index 00000000..b47dc787 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-context.m @@ -0,0 +1,605 @@ +#import "ggml-metal-context.h" + +#import "ggml-impl.h" +#import "ggml-backend-impl.h" + +#import "ggml-metal-impl.h" +#import "ggml-metal-common.h" +#import "ggml-metal-ops.h" + +#import + +#import + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// max number of MTLCommandBuffer used to submit a graph for processing +#define GGML_METAL_MAX_COMMAND_BUFFERS 8 + +struct ggml_metal_command_buffer { + id obj; +}; + +struct ggml_metal { + id device; + id queue; // currently a pointer to the device queue, but might become separate queue [TAG_QUEUE_PER_BACKEND] + + ggml_metal_device_t dev; + ggml_metal_library_t lib; + + dispatch_queue_t d_queue; + + // additional, inference-time compiled pipelines + ggml_metal_pipelines_t pipelines_ext; + + bool use_bfloat; + bool use_fusion; + bool use_concurrency; + bool use_graph_optimize; + + int debug_graph; + int debug_fusion; + + // how many times a given op was fused + uint64_t fuse_cnt[GGML_OP_COUNT]; + + // capture state + bool capture_next_compute; + bool capture_started; + + id capture_scope; + + // command buffer state + int n_cb; // number of extra threads used to submit the command buffers + int n_nodes_0; // number of nodes submitted by the main thread + int n_nodes_1; // remaining number of nodes submitted by the n_cb threads + int n_nodes_per_cb; + + struct ggml_cgraph * gf; + + // the callback given to the thread pool + void (^encode_async)(size_t ith); + + // n_cb command buffers + 1 used by the main thread + struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + + // extra command buffers for things like getting, setting and copying tensors + NSMutableArray * cmd_bufs_ext; + + // the last command buffer queued into the Metal queue with operations relevant to the current Metal backend + id cmd_buf_last; + + // abort ggml_metal_graph_compute if callback returns true + ggml_abort_callback abort_callback; + void * abort_callback_data; +}; + +ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { + GGML_LOG_INFO("%s: allocating\n", __func__); + +#if TARGET_OS_OSX && !GGML_METAL_NDEBUG + // Show all the Metal device instances in the system + NSArray * devices = MTLCopyAllDevices(); + for (id device in devices) { + GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]); + } + [devices release]; // since it was created by a *Copy* C method +#endif + + // init context + ggml_metal_t res = calloc(1, sizeof(struct ggml_metal)); + + res->device = ggml_metal_device_get_obj(dev); + + GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[res->device name] UTF8String]); + + // TODO: would it be better to have one queue for the backend and one queue for the device? + // the graph encoders and async ops would use the backend queue while the sync ops would use the device queue? + //res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND] + res->queue = ggml_metal_device_get_queue(dev); + if (res->queue == nil) { + GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); + return NULL; + } + + res->dev = dev; + res->lib = ggml_metal_device_get_library(dev); + if (res->lib == NULL) { + GGML_LOG_WARN("%s: the device does not have a precompiled Metal library - this is unexpected\n", __func__); + GGML_LOG_WARN("%s: will try to compile it on the fly\n", __func__); + + res->lib = ggml_metal_library_init(dev); + if (res->lib == NULL) { + GGML_LOG_ERROR("%s: error: failed to initialize the Metal library\n", __func__); + + free(res); + + return NULL; + } + } + + const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + + res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); + + if (@available(macOS 14.0, *)) { + res->use_bfloat = props_dev->has_bfloat; + } else { + res->use_bfloat = false; + } + + res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil; + res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil; + + { + const char * val = getenv("GGML_METAL_GRAPH_DEBUG"); + res->debug_graph = val ? atoi(val) : 0; + } + + { + const char * val = getenv("GGML_METAL_FUSION_DEBUG"); + res->debug_fusion = val ? atoi(val) : 0; + } + + res->use_graph_optimize = true; + + if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) { + res->use_graph_optimize = false; + } + + memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt)); + + GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, res->use_bfloat ? "true" : "false"); + GGML_LOG_INFO("%s: use fusion = %s\n", __func__, res->use_fusion ? "true" : "false"); + GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false"); + GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false"); + + res->capture_next_compute = false; + res->capture_started = false; + res->capture_scope = nil; + + res->gf = nil; + res->encode_async = nil; + for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + res->cmd_bufs[i].obj = nil; + } + + res->cmd_bufs_ext = [[NSMutableArray alloc] init]; + + res->cmd_buf_last = nil; + + res->pipelines_ext = ggml_metal_pipelines_init(); + + return res; +} + +void ggml_metal_free(ggml_metal_t ctx) { + GGML_LOG_INFO("%s: deallocating\n", __func__); + + for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + if (ctx->cmd_bufs[i].obj) { + [ctx->cmd_bufs[i].obj release]; + } + } + + for (int i = 0; i < (int) ctx->cmd_bufs_ext.count; ++i) { + if (ctx->cmd_bufs_ext[i]) { + [ctx->cmd_bufs_ext[i] release]; + } + } + + [ctx->cmd_bufs_ext removeAllObjects]; + [ctx->cmd_bufs_ext release]; + + if (ctx->pipelines_ext) { + ggml_metal_pipelines_free(ctx->pipelines_ext); + ctx->pipelines_ext = nil; + } + + if (ctx->debug_fusion > 0) { + GGML_LOG_DEBUG("%s: fusion stats:\n", __func__); + for (int i = 0; i < GGML_OP_COUNT; i++) { + if (ctx->fuse_cnt[i] == 0) { + continue; + } + + // note: cannot use ggml_log here + GGML_LOG_DEBUG("%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]); + } + } + + Block_release(ctx->encode_async); + + //[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND] + + dispatch_release(ctx->d_queue); + + free(ctx); +} + +void ggml_metal_synchronize(ggml_metal_t ctx) { + // wait for any backend operations to finish + if (ctx->cmd_buf_last) { + [ctx->cmd_buf_last waitUntilCompleted]; + ctx->cmd_buf_last = nil; + } + + // check status of all command buffers + { + const int n_cb = ctx->n_cb; + + for (int cb_idx = 0; cb_idx <= n_cb; ++cb_idx) { + id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + if (!cmd_buf) { + continue; + } + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, cb_idx, (int) status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + GGML_ABORT("fatal error"); + } + } + } + + // release any completed extra command buffers + if (ctx->cmd_bufs_ext.count > 0) { + for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) { + id cmd_buf = ctx->cmd_bufs_ext[i]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + GGML_ABORT("fatal error"); + } + + [cmd_buf release]; + } + + [ctx->cmd_bufs_ext removeAllObjects]; + } +} + +static struct ggml_metal_buffer_id ggml_metal_get_buffer_id(const struct ggml_tensor * t) { + if (!t) { + return (struct ggml_metal_buffer_id) { nil, 0 }; + } + + ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; + + return ggml_metal_buffer_get_id(buffer->context, t); +} + +void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + @autoreleasepool { + // wrap the source data into a Metal buffer + id buf_src = [ctx->device newBufferWithBytes:data + length:size + options:MTLResourceStorageModeShared]; + + GGML_ASSERT(buf_src); + + struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(tensor); + if (bid_dst.metal == nil) { + GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); + } + + bid_dst.offs += offset; + + // queue the copy operation into the queue of the Metal context + // this will be queued at the end, after any currently ongoing GPU operations + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:buf_src + sourceOffset:0 + toBuffer:bid_dst.metal + destinationOffset:bid_dst.offs + size:size]; + + [encoder endEncoding]; + [cmd_buf commit]; + + // do not wait here for completion + //[cmd_buf waitUntilCompleted]; + + // instead, remember a reference to the command buffer and wait for it later if needed + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + } +} + +void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + @autoreleasepool { + id buf_dst = [ctx->device newBufferWithBytesNoCopy:data + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; + + GGML_ASSERT(buf_dst); + + struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(tensor); + if (bid_src.metal == nil) { + GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); + } + + bid_src.offs += offset; + + // queue the copy operation into the queue of the Metal context + // this will be queued at the end, after any currently ongoing GPU operations + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:bid_src.metal + sourceOffset:bid_src.offs + toBuffer:buf_dst + destinationOffset:0 + size:size]; + + [encoder endEncoding]; + [cmd_buf commit]; + + // do not wait here for completion + //[cmd_buf waitUntilCompleted]; + + // instead, remember a reference to the command buffer and wait for it later if needed + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + } +} + +enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) { + // number of nodes encoded by the main thread (empirically determined) + const int n_main = 64; + + // number of threads in addition to the main thread + const int n_cb = ctx->n_cb; + + // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them + // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread + // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes + // each thread creates it's own command buffer and enqueues the ops in parallel + // + // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2 + + @autoreleasepool { + ctx->gf = gf; + + ctx->n_nodes_0 = MIN(n_main, gf->n_nodes); + ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0; + + ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; + + const bool use_capture = ctx->capture_next_compute; + if (use_capture) { + ctx->capture_next_compute = false; + + // make sure all previous computations have finished before starting the capture + if (ctx->cmd_buf_last) { + [ctx->cmd_buf_last waitUntilCompleted]; + ctx->cmd_buf_last = nil; + } + + if (!ctx->capture_started) { + // create capture scope + ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device]; + + MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; + descriptor.captureObject = ctx->capture_scope; + descriptor.destination = MTLCaptureDestinationGPUTraceDocument; + descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; + + NSError * error = nil; + if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { + GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); + } else { + [ctx->capture_scope beginScope]; + ctx->capture_started = true; + } + } + } + + // the main thread commits the first few commands immediately + // cmd_buf[n_cb] + { + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + [cmd_buf retain]; + + if (ctx->cmd_bufs[n_cb].obj) { + [ctx->cmd_bufs[n_cb].obj release]; + } + ctx->cmd_bufs[n_cb].obj = cmd_buf; + + [cmd_buf enqueue]; + + ctx->encode_async(n_cb); + } + + // remember the command buffer for the next iteration + ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj; + + // prepare the rest of the command buffers asynchronously (optional) + // cmd_buf[0.. n_cb) + for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + [cmd_buf retain]; + + if (ctx->cmd_bufs[cb_idx].obj) { + [ctx->cmd_bufs[cb_idx].obj release]; + } + ctx->cmd_bufs[cb_idx].obj = cmd_buf; + + // always enqueue the first two command buffers + // enqueue all of the command buffers if we don't need to abort + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [cmd_buf enqueue]; + + // update the pointer to the last queued command buffer + // this is needed to implement synchronize() + ctx->cmd_buf_last = cmd_buf; + } + } + + dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); + + // for debugging: block until graph is computed + //[ctx->cmd_buf_last waitUntilCompleted]; + + // enter here only when capturing in order to wait for all computation to finish + // otherwise, we leave the graph to compute asynchronously + if (!use_capture && ctx->capture_started) { + // wait for completion and check status of each command buffer + // needed to detect if the device ran out-of-memory for example (#1881) + { + id cmd_buf = ctx->cmd_bufs[n_cb].obj; + [cmd_buf waitUntilCompleted]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + } + + for (int i = 0; i < n_cb; ++i) { + id cmd_buf = ctx->cmd_bufs[i].obj; + [cmd_buf waitUntilCompleted]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + + id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); + if (!next_buffer) { + continue; + } + + const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); + if (next_queued) { + continue; + } + + if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { + GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); + return GGML_STATUS_ABORTED; + } + + [next_buffer commit]; + } + + [ctx->capture_scope endScope]; + [[MTLCaptureManager sharedCaptureManager] stopCapture]; + } + } + + return GGML_STATUS_SUCCESS; +} + +void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) { + //const int64_t t_start = ggml_time_us(); + + if (ctx->use_graph_optimize) { + ggml_graph_optimize(gf); + } + + //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0); +} + +void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) { + if (ctx->n_cb != n_cb) { + ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); + + if (ctx->n_cb > 2) { + GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb); + } + } + + if (ctx->encode_async) { + Block_release(ctx->encode_async); + } + + ctx->encode_async = Block_copy(^(size_t iter) { + const int cb_idx = iter; + const int n_cb_l = ctx->n_cb; + + const int n_nodes_0 = ctx->n_nodes_0; + const int n_nodes_1 = ctx->n_nodes_1; + + const int n_nodes_per_cb = ctx->n_nodes_per_cb; + + int idx_start = 0; + int idx_end = n_nodes_0; + + if (cb_idx < n_cb_l) { + idx_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); + idx_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); + } + + id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + + ggml_metal_op_t ctx_op = ggml_metal_op_init( + ctx->dev, + cmd_buf, + ctx->gf, + idx_start, + idx_end, + ctx->use_fusion, + ctx->use_concurrency, + ctx->capture_next_compute, + ctx->debug_graph, + ctx->debug_fusion); + + for (int idx = 0; idx < ggml_metal_op_n_nodes(ctx_op); ++idx) { + const int res = ggml_metal_op_encode(ctx_op, idx); + if (res == 0) { + break; + } + + idx += res - 1; + } + + ggml_metal_op_free(ctx_op); + + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [cmd_buf commit]; + } + }); +} + +void ggml_metal_set_abort_callback(ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data) { + ctx->abort_callback = abort_callback; + ctx->abort_callback_data = user_data; +} + +bool ggml_metal_supports_family(ggml_metal_t ctx, int family) { + GGML_ASSERT(ctx->device != nil); + + return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; +} + +void ggml_metal_capture_next_compute(ggml_metal_t ctx) { + ctx->capture_next_compute = true; +} diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp new file mode 100644 index 00000000..866cd2da --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -0,0 +1,1540 @@ +#include "ggml-metal-device.h" + +#include "ggml-metal-impl.h" + +#include "ggml-impl.h" + +#include +#include +#include +#include + +struct ggml_metal_device_deleter { + void operator()(ggml_metal_device_t ctx) { + ggml_metal_device_free(ctx); + } +}; + +typedef std::unique_ptr ggml_metal_device_ptr; + +ggml_metal_device_t ggml_metal_device_get(void) { + static ggml_metal_device_ptr ctx { ggml_metal_device_init() }; + + return ctx.get(); +} + +struct ggml_metal_pipelines { + std::unordered_map data; +}; + +ggml_metal_pipelines_t ggml_metal_pipelines_init(void) { + ggml_metal_pipelines_t res = new ggml_metal_pipelines(); + + return res; +} + +void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) { + if (!ppls) { + return; + } + + for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) { + ggml_metal_pipeline_free(it->second); + } + + delete ppls; +} + +void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline) { + ppls->data[name] = pipeline; +} + +ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) { + if (ppls->data.find(name) == ppls->data.end()) { + return nullptr; + } + + return ppls->data[name]; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) { + char base[256]; + char name[256]; + + const char * op_str = "undefined"; + switch (op) { + case GGML_OP_ADD_ID: op_str = "add_id"; break; + case GGML_OP_CONCAT: op_str = "concat"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_%s", op_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); + + const char * pool_str = "undefined"; + switch (op_pool) { + case GGML_OP_POOL_AVG: pool_str = "avg"; break; + case GGML_OP_POOL_MAX: pool_str = "max"; break; + default: GGML_ASSERT(false && "not implemented"); + }; + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + + char base[256]; + char name[256]; + + const int64_t n = ggml_nelements(op); + + const char * op_str = "undefined"; + switch (op->op) { + case GGML_OP_SCALE: op_str = "scale"; break; + case GGML_OP_CLAMP: op_str = "clamp"; break; + case GGML_OP_SQR: op_str = "sqr"; break; + case GGML_OP_SQRT: op_str = "sqrt"; break; + case GGML_OP_SIN: op_str = "sin"; break; + case GGML_OP_COS: op_str = "cos"; break; + case GGML_OP_LOG: op_str = "log"; break; + case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_TANH: op_str = "tanh"; break; + case GGML_UNARY_OP_RELU: op_str = "relu"; break; + case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break; + case GGML_UNARY_OP_GELU: op_str = "gelu"; break; + case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break; + case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break; + case GGML_UNARY_OP_SILU: op_str = "silu"; break; + case GGML_UNARY_OP_ELU: op_str = "elu"; break; + case GGML_UNARY_OP_NEG: op_str = "neg"; break; + case GGML_UNARY_OP_ABS: op_str = "abs"; break; + case GGML_UNARY_OP_SGN: op_str = "sgn"; break; + case GGML_UNARY_OP_STEP: op_str = "step"; break; + case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break; + case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break; + case GGML_UNARY_OP_EXP: op_str = "exp"; break; + default: GGML_ABORT("fatal error"); + } break; + default: GGML_ABORT("fatal error"); + }; + + const char * suffix = ""; + if (n % 4 == 0) { + suffix = "_4"; + } + + snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); + + char base[256]; + char name[256]; + + const char * op_str = "undefined"; + switch (op->op) { + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: op_str = "reglu"; break; + case GGML_GLU_OP_GEGLU: op_str = "geglu"; break; + case GGML_GLU_OP_SWIGLU: op_str = "swiglu"; break; + case GGML_GLU_OP_SWIGLU_OAI: op_str = "swiglu_oai"; break; + case GGML_GLU_OP_GEGLU_ERF: op_str = "geglu_erf"; break; + case GGML_GLU_OP_GEGLU_QUICK: op_str = "geglu_quick"; break; + default: GGML_ABORT("fatal error"); + } break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_SUM); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + + char base[256]; + char name[256]; + + const char * op_str = "undefined"; + switch (op->op) { + case GGML_OP_SUM_ROWS: + op_str = "sum_rows"; break; + case GGML_OP_MEAN: + op_str = "mean"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + const char * suffix = ""; + + if (op->src[0]->ne[0] % 4 == 0) { + suffix = "_4"; + } + + const ggml_type tsrc1 = op->src[1] ? op->src[1]->type : GGML_TYPE_F32; + + snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + + char base[256]; + char name[256]; + + const char * suffix = ""; + + if (op->src[1]->ne[0] % 4 == 0) { + suffix = "_4"; + } + + snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + + char base[256]; + char name[256]; + + const int nsg = (ne00 + 31)/32; + + snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_nsg=%d", base, nsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + const int64_t C = op->ne[0]; + const int64_t H = op->src[0]->ne[1]; + + switch (op->op) { + case GGML_OP_RWKV_WKV6: + { + GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64); + + snprintf(base, 256, "kernel_rwkv_wkv6_%s", ggml_type_name(op->src[0]->type)); + } break; + case GGML_OP_RWKV_WKV7: + { + GGML_ASSERT(op->src[6]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64); + + snprintf(base, 256, "kernel_rwkv_wkv7_%s", ggml_type_name(op->src[0]->type)); + } break; + default: + GGML_ABORT("fatal error"); + } + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); + snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + + const bool bc_inp = op->src[0]->ne[0] % 32 != 0; + const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0; + + snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); + snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); + ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes + ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + + char base[256]; + char name[256]; + + int nsg = 0; // number of simdgroups + int nr0 = 0; // number of src0 rows per simdgroup + int nr1 = 1; // number of src1 rows per threadgroup + + size_t smem = 0; // shared memory + + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + + const char * suffix = ""; + + // use custom matrix x vector kernel + switch (tsrc0) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + { + if (ne00 < 32) { + nsg = 1; + nr0 = 32; + nr1 = 1; + suffix = "_short"; + } else { + nsg = std::min(4, (ne00 + 127) / 128); + nr0 = 2; + nr1 = 1; + smem = 32*sizeof(float)*nr0; + suffix = ne00 % 4 == 0 ? "_4" : ""; + } + } break; + case GGML_TYPE_Q4_0: + { + nsg = N_SG_Q4_0; + nr0 = N_R0_Q4_0; + } break; + case GGML_TYPE_Q4_1: + { + nsg = N_SG_Q4_1; + nr0 = N_R0_Q4_1; + } break; + case GGML_TYPE_Q5_0: + { + nsg = N_SG_Q5_0; + nr0 = N_R0_Q5_0; + } break; + case GGML_TYPE_Q5_1: + { + nsg = N_SG_Q5_1; + nr0 = N_R0_Q5_1; + } break; + case GGML_TYPE_Q8_0: + { + nsg = N_SG_Q8_0; + nr0 = N_R0_Q8_0; + smem = 32*sizeof(float)*N_R0_Q8_0; + } break; + case GGML_TYPE_MXFP4: + { + nsg = N_SG_MXFP4; + nr0 = N_R0_MXFP4; + smem = 32*sizeof(float); + } break; + case GGML_TYPE_Q2_K: + { + nsg = N_SG_Q2_K; + nr0 = N_R0_Q2_K; + } break; + case GGML_TYPE_Q3_K: + { + nsg = N_SG_Q3_K; + nr0 = N_R0_Q3_K; + } break; + case GGML_TYPE_Q4_K: + { + nsg = N_SG_Q4_K; + nr0 = N_R0_Q4_K; + } break; + case GGML_TYPE_Q5_K: + { + nsg = N_SG_Q5_K; + nr0 = N_R0_Q5_K; + } break; + case GGML_TYPE_Q6_K: + { + nsg = N_SG_Q6_K; + nr0 = N_R0_Q6_K; + } break; + case GGML_TYPE_IQ2_XXS: + { + nsg = N_SG_IQ2_XXS; + nr0 = N_R0_IQ2_XXS; + smem = 256*8+128; + } break; + case GGML_TYPE_IQ2_XS: + { + nsg = N_SG_IQ2_XS; + nr0 = N_R0_IQ2_XS; + smem = 512*8+128; + } break; + case GGML_TYPE_IQ3_XXS: + { + nsg = N_SG_IQ3_XXS; + nr0 = N_R0_IQ3_XXS; + smem = 256*4+128; + } break; + case GGML_TYPE_IQ3_S: + { + nsg = N_SG_IQ3_S; + nr0 = N_R0_IQ3_S; + smem = 512*4; + } break; + case GGML_TYPE_IQ2_S: + { + nsg = N_SG_IQ2_S; + nr0 = N_R0_IQ2_S; + } break; + case GGML_TYPE_IQ1_S: + { + nsg = N_SG_IQ1_S; + nr0 = N_R0_IQ1_S; + } break; + case GGML_TYPE_IQ1_M: + { + nsg = N_SG_IQ1_M; + nr0 = N_R0_IQ1_M; + } break; + case GGML_TYPE_IQ4_NL: + { + nsg = N_SG_IQ4_NL; + nr0 = N_R0_IQ4_NL; + smem = 32*sizeof(float); + } break; + case GGML_TYPE_IQ4_XS: + { + nsg = N_SG_IQ4_XS; + nr0 = N_R0_IQ4_XS; + smem = 32*sizeof(float); + } break; + default: + { + GGML_LOG_ERROR("Asserting on type %d\n", (int) tsrc0); + GGML_ABORT("not implemented"); + } + }; + + snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); + snprintf(name, 256, "%s_nsg=%d", base, nsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + ggml_metal_pipeline_set_nr0 (res, nr0); + ggml_metal_pipeline_set_nr1 (res, nr1); + ggml_metal_pipeline_set_nsg (res, nsg); + ggml_metal_pipeline_set_smem(res, smem); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + const size_t smem = (size_t) ne02*ne20*sizeof(uint16_t); + + ggml_metal_pipeline_set_smem(res, smem); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + + const bool bc_inp = op->src[0]->ne[0] % 32 != 0; + + snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); + snprintf(name, 256, "%s_bci=%d", base, bc_inp); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + ggml_metal_pipeline_set_smem(res, 8192); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + + char base[256]; + char name[256]; + + int nsg = 0; // number of simdgroups + int nr0 = 0; // number of src0 rows per simdgroup + int nr1 = 1; // number of src1 rows per threadgroup + + size_t smem = 0; // shared memory + + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + + const char * suffix = ""; + + // use custom matrix x vector kernel + switch (tsrc0) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + { + nsg = std::min(4, (ne00 + 127) / 128); + nr0 = 2; + nr1 = 1; + smem = 32*sizeof(float)*nr0; + suffix = ne00 % 4 == 0 ? "_4" : ""; + } break; + case GGML_TYPE_Q4_0: + { + nsg = N_SG_Q4_0; + nr0 = N_R0_Q4_0; + } break; + case GGML_TYPE_Q4_1: + { + nsg = N_SG_Q4_1; + nr0 = N_R0_Q4_1; + } break; + case GGML_TYPE_Q5_0: + { + nsg = N_SG_Q5_0; + nr0 = N_R0_Q5_0; + } break; + case GGML_TYPE_Q5_1: + { + nsg = N_SG_Q5_1; + nr0 = N_R0_Q5_1; + } break; + case GGML_TYPE_Q8_0: + { + nsg = N_SG_Q8_0; + nr0 = N_R0_Q8_0; + smem = 32*sizeof(float)*N_R0_Q8_0; + } break; + case GGML_TYPE_MXFP4: + { + nsg = N_SG_MXFP4; + nr0 = N_R0_MXFP4; + smem = 32*sizeof(float); + } break; + case GGML_TYPE_Q2_K: + { + nsg = N_SG_Q2_K; + nr0 = N_R0_Q2_K; + } break; + case GGML_TYPE_Q3_K: + { + nsg = N_SG_Q3_K; + nr0 = N_R0_Q3_K; + } break; + case GGML_TYPE_Q4_K: + { + nsg = N_SG_Q4_K; + nr0 = N_R0_Q4_K; + } break; + case GGML_TYPE_Q5_K: + { + nsg = N_SG_Q5_K; + nr0 = N_R0_Q5_K; + } break; + case GGML_TYPE_Q6_K: + { + nsg = N_SG_Q6_K; + nr0 = N_R0_Q6_K; + } break; + case GGML_TYPE_IQ2_XXS: + { + nsg = N_SG_IQ2_XXS; + nr0 = N_R0_IQ2_XXS; + smem = 256*8+128; + } break; + case GGML_TYPE_IQ2_XS: + { + nsg = N_SG_IQ2_XS; + nr0 = N_R0_IQ2_XS; + smem = 512*8+128; + } break; + case GGML_TYPE_IQ3_XXS: + { + nsg = N_SG_IQ3_XXS; + nr0 = N_R0_IQ3_XXS; + smem = 256*4+128; + } break; + case GGML_TYPE_IQ3_S: + { + nsg = N_SG_IQ3_S; + nr0 = N_R0_IQ3_S; + smem = 512*4; + } break; + case GGML_TYPE_IQ2_S: + { + nsg = N_SG_IQ2_S; + nr0 = N_R0_IQ2_S; + } break; + case GGML_TYPE_IQ1_S: + { + nsg = N_SG_IQ1_S; + nr0 = N_R0_IQ1_S; + } break; + case GGML_TYPE_IQ1_M: + { + nsg = N_SG_IQ1_M; + nr0 = N_R0_IQ1_M; + } break; + case GGML_TYPE_IQ4_NL: + { + nsg = N_SG_IQ4_NL; + nr0 = N_R0_IQ4_NL; + smem = 32*sizeof(float); + } break; + case GGML_TYPE_IQ4_XS: + { + nsg = N_SG_IQ4_XS; + nr0 = N_R0_IQ4_XS; + smem = 32*sizeof(float); + } break; + default: + { + GGML_LOG_ERROR("Asserting on type %d\n", (int)op->src[2]->type); + GGML_ABORT("not implemented"); + } + }; + + snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); + snprintf(name, 256, "%s_nsg=%d", base, nsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + ggml_metal_pipeline_set_nr0 (res, nr0); + ggml_metal_pipeline_set_nr1 (res, nr1); + ggml_metal_pipeline_set_nsg (res, nsg); + ggml_metal_pipeline_set_smem(res, smem); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); + GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*(sizeof(float) + sizeof(int32_t))); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_ARGSORT); + + char base[256]; + char name[256]; + + ggml_sort_order order = (ggml_sort_order) op->op_params[0]; + + const char * order_str = "undefined"; + switch (order) { + case GGML_SORT_ORDER_ASC: order_str = "asc"; break; + case GGML_SORT_ORDER_DESC: order_str = "desc"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + int32_t ncpsg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + GGML_UNUSED(op); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_%s", + "flash_attn_ext_pad"); + + snprintf(name, 256, "%s_mask=%d_ncpsg=%d", + base, + has_mask, + ncpsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); + //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + int32_t nqptg, + int32_t ncpsg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + GGML_UNUSED(op); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_%s", + "flash_attn_ext_blk"); + + snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d", + base, + nqptg, + ncpsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); + ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( + ggml_metal_library_t lib, + const ggml_tensor * op, + bool has_mask, + bool has_sinks, + bool has_bias, + bool has_scap, + bool has_kvpad, + int32_t nsg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + char base[256]; + char name[256]; + + const int32_t dk = (int32_t) op->src[1]->ne[0]; + const int32_t dv = (int32_t) op->src[2]->ne[0]; + + const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; + const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + + // do bounds checks for the mask? + const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0); + + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", + "flash_attn_ext", + ggml_type_name(op->src[1]->type), + dk, + dv); + + snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d", + base, + has_mask, + has_sinks, + has_bias, + has_scap, + has_kvpad, + bc_mask, + ns10, + ns20, + nsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0); + ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); + ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); + ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); + + ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); + + ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); + ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); + ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( + ggml_metal_library_t lib, + const ggml_tensor * op, + bool has_mask, + bool has_sinks, + bool has_bias, + bool has_scap, + bool has_kvpad, + int32_t nsg, + int32_t nwg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + char base[256]; + char name[256]; + + const int32_t dk = (int32_t) op->src[1]->ne[0]; + const int32_t dv = (int32_t) op->src[2]->ne[0]; + + const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; + const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", + "flash_attn_ext_vec", + ggml_type_name(op->src[1]->type), + dk, + dv); + + snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", + base, + has_mask, + has_sinks, + has_bias, + has_scap, + has_kvpad, + ns10, + ns20, + nsg, nwg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0); + ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); + ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); + ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); + + ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); + ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); + ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22); + ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( + ggml_metal_library_t lib, + const ggml_tensor * op, + int32_t dv, + int32_t nwg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce"); + snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0); + ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; + + GGML_UNUSED(op); +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin( + ggml_metal_library_t lib, + ggml_op op, + int32_t n_fuse, + bool row) { + char base[256]; + char name[256]; + + const char * op_str = "undefined"; + switch (op) { + case GGML_OP_ADD: op_str = "add"; break; + case GGML_OP_SUB: op_str = "sub"; break; + case GGML_OP_MUL: op_str = "mul"; break; + case GGML_OP_DIV: op_str = "div"; break; + default: GGML_ABORT("fatal error"); + }; + + if (row) { + snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse); + } else { + snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse); + } + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_L2_NORM); + + GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_l2_norm_f32"); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_GROUP_NORM); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_group_norm_f32"); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) { + assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM); + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + char base[256]; + char name[256]; + + const char * suffix = ""; + if (op->ne[0] % 4 == 0) { + suffix = "_4"; + } + + switch (op->op) { + case GGML_OP_NORM: + switch (n_fuse) { + case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix); break; + case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix); break; + case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break; + default: GGML_ABORT("fatal error"); + } break; + case GGML_OP_RMS_NORM: + switch (n_fuse) { + case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix); break; + case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix); break; + case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break; + default: GGML_ABORT("fatal error"); + } break; + default: GGML_ABORT("fatal error"); + } + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_ROPE); + + char base[256]; + char name[256]; + + const int mode = ((const int32_t *) op->op_params)[2]; + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_neox) { + snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type)); + } else if (is_mrope && !is_vision) { + GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token + snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type)); + } else if (is_vision) { + GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token + snprintf(base, 256, "kernel_rope_vision_%s", ggml_type_name(op->src[0]->type)); + } else { + snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type)); + } + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_IM2COL); + + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_CONV_TRANSPOSE_1D); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_UPSCALE); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_PAD); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_PAD_REFLECT_1D); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_ARANGE); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_TIMESTEP_EMBEDDING); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_OPT_STEP_ADAMW); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_OPT_STEP_SGD); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h new file mode 100644 index 00000000..28ae2e17 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h @@ -0,0 +1,243 @@ +#pragma once + +#include "ggml.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_metal_buffer_id { + void * metal; // id + size_t offs; +}; + +typedef struct ggml_metal_device * ggml_metal_device_t; + +// +// MTLFunctionConstantValues wrapper +// + +typedef struct ggml_metal_cv * ggml_metal_cv_t; + +ggml_metal_cv_t ggml_metal_cv_init(void); +void ggml_metal_cv_free(ggml_metal_cv_t cv); + +void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx); +void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx); +void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool value, int32_t idx); + +// +// MTLComputePipelineState wrapper +// + +typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t; + +ggml_metal_pipeline_t ggml_metal_pipeline_init(void); +void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline); + +void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg); +int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline); + +void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0); +int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline); + +void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1); +int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline); + +void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem); +size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline); + +int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline); + +// a collection of pipelines +typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t; + +ggml_metal_pipelines_t ggml_metal_pipelines_init(void); +void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls); + +void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline); +ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name); + +// +// MTLCommandBuffer wrapper +// + +typedef void * ggml_metal_cmd_buf_t; + +// +// MTLComputeCommandEncoder wrapper +// + +typedef struct ggml_metal_encoder * ggml_metal_encoder_t; + +ggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_cmd_buf_t cmd_buf_raw, bool concurrent); +void ggml_metal_encoder_free(ggml_metal_encoder_t encoder); + +void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name); +void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder); + +void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline); + +void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx); +void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx); + +void ggml_metal_encoder_set_threadgroup_memory_size(ggml_metal_encoder_t encoder, size_t size, int idx); + +void ggml_metal_encoder_dispatch_threadgroups(ggml_metal_encoder_t encoder, int tg0, int tg1, int tg2, int tptg0, int tptg1, int tptg2); + +void ggml_metal_encoder_memory_barrier(ggml_metal_encoder_t encoder); + +void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder); + +// +// MTLLibrary wrapper +// + +typedef struct ggml_metal_library * ggml_metal_library_t; + +ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev); +void ggml_metal_library_free(ggml_metal_library_t lib); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); +ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + int32_t ncpsg); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + int32_t nqptg, + int32_t ncpsg); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + bool has_sinks, + bool has_bias, + bool has_scap, + bool has_kvpad, + int32_t nsg); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + bool has_sinks, + bool has_bias, + bool has_scap, + bool has_kvpad, + int32_t nsg, + int32_t nwg); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + int32_t dv, + int32_t nwg); + +// +// device +// + +struct ggml_metal_device_props { + char name[128]; + + size_t max_buffer_size; + size_t max_working_set_size; + size_t max_theadgroup_memory_size; + + bool has_simdgroup_reduction; + bool has_simdgroup_mm; + bool has_unified_memory; + bool has_bfloat; + bool use_residency_sets; + bool use_shared_buffers; + + bool supports_gpu_family_apple7; +}; + +ggml_metal_device_t ggml_metal_device_init(void); +void ggml_metal_device_free(ggml_metal_device_t dev); + +// return a singleton that is automatically destroyed when the program exits +ggml_metal_device_t ggml_metal_device_get(void); + +void * ggml_metal_device_get_obj (ggml_metal_device_t dev); // id +void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id + +ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev); + +void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total); +bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op); + +const struct ggml_metal_device_props * ggml_metal_device_get_props(ggml_metal_device_t dev); + +// +// device buffers +// + +typedef struct ggml_metal_buffer * ggml_metal_buffer_t; + +ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared); +ggml_metal_buffer_t ggml_metal_buffer_map (ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size); + +void ggml_metal_buffer_free (ggml_metal_buffer_t buf); +void * ggml_metal_buffer_get_base (ggml_metal_buffer_t buf); +bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf); + +void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); +void ggml_metal_buffer_set_tensor (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); +void ggml_metal_buffer_get_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); +void ggml_metal_buffer_clear (ggml_metal_buffer_t buf, uint8_t value); + +// finds the Metal buffer that contains the tensor data on the GPU device +// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the +// Metal buffer based on the host memory pointer +// +struct ggml_metal_buffer_id ggml_metal_buffer_get_id(ggml_metal_buffer_t buf, const struct ggml_tensor * t); + +#ifdef __cplusplus +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m new file mode 100644 index 00000000..fc508304 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m @@ -0,0 +1,1310 @@ +#import "ggml-metal-device.h" + +#import "ggml-impl.h" +#import "ggml-threading.h" + +#include + +#include + +#ifndef TARGET_OS_VISION +#define TARGET_OS_VISION 0 +#endif + +// create residency sets only on macOS >= 15.0 +#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \ + TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \ + TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \ + TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000 +#define GGML_METAL_HAS_RESIDENCY_SETS 1 +#endif + +// overload of MTLGPUFamilyMetal3 (not available in some environments) +static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; + +#if !GGML_METAL_EMBED_LIBRARY +// Here to assist with NSBundle Path Hack +@interface GGMLMetalClass : NSObject +@end +@implementation GGMLMetalClass +@end +#endif + +// +// MTLFunctionConstantValues wrapper +// + +struct ggml_metal_cv { + MTLFunctionConstantValues * obj; +}; + +ggml_metal_cv_t ggml_metal_cv_init(void) { + ggml_metal_cv_t res = calloc(1, sizeof(struct ggml_metal_cv)); + + res->obj = [[MTLFunctionConstantValues alloc] init]; + + return res; +} + +void ggml_metal_cv_free(ggml_metal_cv_t cv) { + [cv->obj release]; + free(cv); +} + +void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx) { + [cv->obj setConstantValue:&value type:MTLDataTypeShort atIndex:idx]; +} + +void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx) { + [cv->obj setConstantValue:&value type:MTLDataTypeInt atIndex:idx]; +} + +void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) { + [cv->obj setConstantValue:&value type:MTLDataTypeBool atIndex:idx]; +} + +// +// MTLComputePipelineState wrapper +// + +struct ggml_metal_pipeline { + id obj; + + // suggested dispatch sizes + int nsg; + + int nr0; + int nr1; + + size_t smem; +}; + +ggml_metal_pipeline_t ggml_metal_pipeline_init(void) { + ggml_metal_pipeline_t res = calloc(1, sizeof(struct ggml_metal_pipeline)); + + *res = (struct ggml_metal_pipeline) { + /*.obj =*/ nil, + /*.nsg =*/ 0, + /*.nr0 =*/ 0, + /*.nr1 =*/ 0, + /*.smem =*/ 0, + }; + + return res; +} + +void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) { + [pipeline->obj release]; + + free(pipeline); +} + +void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg) { + pipeline->nsg = nsg; +} + +int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline) { + return pipeline->nsg; +} + +void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0) { + pipeline->nr0 = nr0; +} + +int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline) { + return pipeline->nr0; +} + +void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1) { + pipeline->nr1 = nr1; +} + +int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline) { + return pipeline->nr1; +} + +void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem) { + pipeline->smem = smem; +} + +size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline) { + return pipeline->smem; +} + +int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline) { + return pipeline->obj.maxTotalThreadsPerThreadgroup; +} + +struct ggml_metal_library { + id obj; + id device; + + ggml_metal_pipelines_t pipelines; // cache of compiled pipelines +}; + +ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { + id library = nil; + id device = ggml_metal_device_get_obj(dev); + + // load library + // + // - first check if the library is embedded + // - then check if the library is in the bundle + // - if not found, load the source and compile it + // - if that fails, return NULL + // + // TODO: move to a function + { + const int64_t t_start = ggml_time_us(); + + NSError * error = nil; + NSString * src = nil; + +#if GGML_METAL_EMBED_LIBRARY + GGML_LOG_INFO("%s: using embedded metal library\n", __func__); + + extern const char ggml_metallib_start[]; + extern const char ggml_metallib_end[]; + + src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding]; +#else + +#ifdef SWIFT_PACKAGE + NSBundle * bundle = SWIFTPM_MODULE_BUNDLE; +#else + NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; +#endif + + NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"]; + if (path_lib == nil) { + // Try to find the resource in the directory where the current binary located. + NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0]; + NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent]; + + NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]]; + if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) { + GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]); + + NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error]; + if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) { + // Optionally, if this is a symlink, try to resolve it. + path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error]; + if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) { + // It is a relative path, adding the binary directory as directory prefix. + path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]]; + } + if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) { + // Link to the resource could not be resolved. + path_lib_default = nil; + } else { + GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]); + } + } + } else { + // The resource couldn't be found in the binary's directory. + path_lib_default = nil; + } + + path_lib = path_lib_default; + } + + if (path_lib != nil) { + // pre-compiled library found + NSURL * libURL = [NSURL fileURLWithPath:path_lib]; + GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]); + + library = [device newLibraryWithURL:libURL error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return nil; + } + } else { + GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); + + NSString * path_source; + NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"]; + + GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil"); + + if (path_resource) { + path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"]; + } else { + path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + } + + if (path_source == nil) { + GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); + path_source = @"ggml-metal.metal"; + } + + GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]); + + src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return nil; + } + } +#endif + + if (!library) { + @autoreleasepool { + // dictionary of preprocessor macros + NSMutableDictionary * prep = [NSMutableDictionary dictionary]; + + if (ggml_metal_device_get_props(dev)->has_bfloat) { + [prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"]; + } + +#if GGML_METAL_EMBED_LIBRARY + [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"]; +#endif + + MTLCompileOptions * options = [MTLCompileOptions new]; + options.preprocessorMacros = prep; + + //[options setFastMathEnabled:false]; + + library = [device newLibraryWithSource:src options:options error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return nil; + } + +#if !__has_feature(objc_arc) + [options release]; +#endif + } + } + +#if GGML_METAL_EMBED_LIBRARY + [src release]; +#endif // GGML_METAL_EMBED_LIBRARY + + GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6); + } + + ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); + + res->obj = library; + res->device = device; + res->pipelines = ggml_metal_pipelines_init(); + + return res; +} + +void ggml_metal_library_free(ggml_metal_library_t lib) { + if (!lib) { + return; + } + + if (lib->obj) { + [lib->obj release]; + } + + ggml_metal_pipelines_free(lib->pipelines); + + free(lib); +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { + return ggml_metal_pipelines_get(lib->pipelines, name); +} + +ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { + // note: the pipelines are cached in the library per device, so they are shared across all metal contexts + ggml_critical_section_start(); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + ggml_critical_section_end(); + + return res; + } + + res = ggml_metal_pipeline_init(); + + @autoreleasepool { + NSError * error = nil; + + NSString * base_func = [NSString stringWithUTF8String:base]; + + GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name); + + id mtl_function; + if (!cv) { + mtl_function = [lib->obj newFunctionWithName:base_func]; + } else { + mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error]; + } + if (!mtl_function) { + ggml_critical_section_end(); + + GGML_LOG_ERROR("%s: error: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name); + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + } + + return nil; + } + + res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; + + ggml_metal_pipelines_add(lib->pipelines, name, res); + + [mtl_function release]; + + GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj, + (int) res->obj.maxTotalThreadsPerThreadgroup, + (int) res->obj.threadExecutionWidth); + } + + ggml_critical_section_end(); + + return res; +} + +// +// MTLComputeCommandEncoder wrapper +// + +struct ggml_metal_encoder { + id obj; +}; + +ggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_cmd_buf_t cmd_buf_raw, bool concurrent) { + ggml_metal_encoder_t res = calloc(1, sizeof(struct ggml_metal_encoder)); + + id cmd_buf = (id) cmd_buf_raw; + + if (concurrent) { + res->obj = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } else { + res->obj = [cmd_buf computeCommandEncoder]; + } + + [res->obj retain]; + + return res; +} + +void ggml_metal_encoder_free(ggml_metal_encoder_t encoder) { + [encoder->obj release]; + free(encoder); +} + +void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name) { + [encoder->obj pushDebugGroup:[NSString stringWithCString:name encoding:NSUTF8StringEncoding]]; +} + +void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) { + [encoder->obj popDebugGroup]; +} + +void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline) { + [encoder->obj setComputePipelineState:pipeline->obj]; +} + +void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) { + [encoder->obj setBytes:data length:size atIndex:idx]; +} + +void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx) { + [encoder->obj setBuffer:buffer.metal offset:buffer.offs atIndex:idx]; +} + +void ggml_metal_encoder_set_threadgroup_memory_size(ggml_metal_encoder_t encoder, size_t size, int idx) { + [encoder->obj setThreadgroupMemoryLength:size atIndex:idx]; +} + +void ggml_metal_encoder_dispatch_threadgroups(ggml_metal_encoder_t encoder, int tg0, int tg1, int tg2, int tptg0, int tptg1, int tptg2) { + [encoder->obj dispatchThreadgroups:MTLSizeMake(tg0, tg1, tg2) threadsPerThreadgroup:MTLSizeMake(tptg0, tptg1, tptg2)]; +} + +void ggml_metal_encoder_memory_barrier(ggml_metal_encoder_t encoder) { + [encoder->obj memoryBarrierWithScope:MTLBarrierScopeBuffers]; +} + +void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder) { + [encoder->obj endEncoding]; +} + +struct ggml_metal_device { + id mtl_device; + + // a single global queue shared by all Metal backends + // technically not needed for devices with unified memory, but enables discrete GPUs support + // ref: https://github.com/ggml-org/llama.cpp/pull/15906 + id mtl_queue; + + ggml_metal_library_t library; + + struct ggml_metal_device_props props; +}; + +ggml_metal_device_t ggml_metal_device_init(void) { + ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device)); + + assert(dev != NULL); + + if (dev->mtl_device == nil) { + dev->mtl_device = MTLCreateSystemDefaultDevice(); + + if (dev->mtl_device) { + dev->mtl_queue = [dev->mtl_device newCommandQueue]; + if (dev->mtl_queue == nil) { + GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); + } + + dev->props.has_simdgroup_reduction = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; + dev->props.has_simdgroup_reduction |= [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; + + dev->props.has_simdgroup_mm = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; + dev->props.has_unified_memory = dev->mtl_device.hasUnifiedMemory; + + dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; + dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6]; + + dev->props.use_residency_sets = true; +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + dev->props.use_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil; +#endif + + dev->props.use_shared_buffers = dev->props.has_unified_memory; + + if (getenv("GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) { + dev->props.use_shared_buffers = false; + } + + dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; + + dev->props.max_buffer_size = dev->mtl_device.maxBufferLength; + dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize; + dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength; + + strncpy(dev->props.name, [[dev->mtl_device name] UTF8String], sizeof(dev->props.name) - 1); + + dev->library = ggml_metal_library_init(dev); + if (!dev->library) { + GGML_LOG_ERROR("%s: error: failed to create library\n", __func__); + } + + // -------------------------------------------------- + + // print MTL GPU family: + GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name); + + // determine max supported GPU family + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf + // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf + { + for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { + if ([dev->mtl_device supportsFamily:i]) { + GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); + break; + } + } + + for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) { + if ([dev->mtl_device supportsFamily:i]) { + GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i); + break; + } + } + + for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) { + if ([dev->mtl_device supportsFamily:i]) { + GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i); + break; + } + } + } + + GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, dev->props.has_simdgroup_reduction ? "true" : "false"); + GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, dev->props.has_simdgroup_mm ? "true" : "false"); + GGML_LOG_INFO("%s: has unified memory = %s\n", __func__, dev->props.has_unified_memory ? "true" : "false"); + GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, dev->props.has_bfloat ? "true" : "false"); + GGML_LOG_INFO("%s: use residency sets = %s\n", __func__, dev->props.use_residency_sets ? "true" : "false"); + GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, dev->props.use_shared_buffers ? "true" : "false"); + +#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) + if (@available(macOS 10.12, iOS 16.0, *)) { + GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, dev->props.max_working_set_size / 1e6); + } +#endif + } + } + + return dev; +} + +void ggml_metal_device_free(ggml_metal_device_t dev) { + assert(dev != NULL); + + ggml_metal_library_free(dev->library); + dev->library = NULL; + + if (dev->mtl_queue) { + [dev->mtl_queue release]; + dev->mtl_queue = nil; + } + + if (dev->mtl_device) { + [dev->mtl_device release]; + dev->mtl_device = nil; + } + + free(dev); +} + +void * ggml_metal_device_get_obj(ggml_metal_device_t dev) { + return dev->mtl_device; +} + +void * ggml_metal_device_get_queue(ggml_metal_device_t dev) { + return dev->mtl_queue; +} + +ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev) { + return dev->library; +} + +void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) { + if (@available(macOS 10.12, iOS 16.0, *)) { + *total = dev->mtl_device.recommendedMaxWorkingSetSize; + *free = *total - dev->mtl_device.currentAllocatedSize; + } else { + *free = 0; + *total = 0; + } +} + +bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op) { + const bool has_simdgroup_mm = dev->props.has_simdgroup_mm; + const bool has_simdgroup_reduction = dev->props.has_simdgroup_reduction; + const bool has_bfloat = dev->props.has_bfloat; + + if (!has_bfloat) { + if (op->type == GGML_TYPE_BF16) { + return false; + } + + for (size_t i = 0, n = 3; i < n; ++i) { + if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { + return false; + } + } + } + + switch (op->op) { + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_EXP: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + default: + return false; + } + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + default: + return false; + } + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + case GGML_OP_CONCAT: + return true; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_ADD_ID: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ACC: + case GGML_OP_REPEAT: + case GGML_OP_SCALE: + case GGML_OP_CONV_TRANSPOSE_1D: + return true; + case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_LOG: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_SOFT_MAX: + case GGML_OP_GROUP_NORM: + return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_L2_NORM: + return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + case GGML_OP_ARGMAX: + return has_simdgroup_reduction; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0])); + case GGML_OP_ROPE: + return true; + case GGML_OP_IM2COL: + return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); + case GGML_OP_POOL_1D: + return false; + case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + case GGML_OP_POOL_2D: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_PAD: + return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && + (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); + case GGML_OP_PAD_REFLECT_1D: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_LEAKY_RELU: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ARGSORT: + // TODO: Support arbitrary column width + return op->src[0]->ne[0] <= 1024; + case GGML_OP_ARANGE: + return true; + case GGML_OP_FLASH_ATTN_EXT: + // for new head sizes, add checks here + if (op->src[0]->ne[0] != 40 && + op->src[0]->ne[0] != 64 && + op->src[0]->ne[0] != 80 && + op->src[0]->ne[0] != 96 && + op->src[0]->ne[0] != 112 && + op->src[0]->ne[0] != 128 && + op->src[0]->ne[0] != 192 && + op->src[0]->ne[0] != 256) { + return false; + } + if (op->src[0]->ne[0] == 576) { + // DeepSeek sizes + // TODO: disabled for now, until optmized + return false; + } + if (op->src[1]->type != op->src[2]->type) { + return false; + } + return has_simdgroup_mm; // TODO: over-restricted for vec-kernels + case GGML_OP_SSM_CONV: + case GGML_OP_SSM_SCAN: + return has_simdgroup_reduction; + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + return true; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + return has_simdgroup_reduction; + case GGML_OP_CPY: + case GGML_OP_DUP: + case GGML_OP_CONT: + { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_I32: + return true; + default: + return false; + } + case GGML_TYPE_F16: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + default: + return false; + } + case GGML_TYPE_BF16: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_BF16: + return true; + default: + return false; + } + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + default: + return false; + } + case GGML_TYPE_I32: + return op->type == GGML_TYPE_F32; + default: + return false; + }; + } + case GGML_OP_GET_ROWS: + return true; + case GGML_OP_SET_ROWS: + { + if (op->src[0]->type != GGML_TYPE_F32) { + return false; + } + + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_IQ4_NL: + return true; + default: + return false; + }; + } + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + return has_simdgroup_reduction; + default: + return false; + } +} + +const struct ggml_metal_device_props * ggml_metal_device_get_props(ggml_metal_device_t dev) { + return &dev->props; +} + +// +// device buffers +// + +// max memory buffers that can be mapped to the device +#define GGML_METAL_MAX_BUFFERS 64 + +struct ggml_metal_buffer_wrapper { + void * data; + size_t size; + + id metal; +}; + +struct ggml_metal_buffer { + void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985 + size_t all_size; + + // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host + bool is_shared; + bool owned; + + // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap + int n_buffers; + struct ggml_metal_buffer_wrapper buffers[GGML_METAL_MAX_BUFFERS]; + + bool use_residency_sets; + + // optional MTLResidencySet + // note: cannot use explicity "id" here because it is not available on certain OSes + id rset; + + // pointers to global device objects + id device; + id queue; +}; + +static void ggml_metal_log_allocated_size(id device, size_t size_aligned) { +#ifndef GGML_METAL_NDEBUG +#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) + if (@available(macOS 10.12, iOS 16.0, *)) { + GGML_LOG_DEBUG("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0, + device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + + if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) { + GGML_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); + } + } else { + GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0); + } +#endif +#endif + GGML_UNUSED(device); + GGML_UNUSED(size_aligned); +} + +// rset init +static bool ggml_metal_buffer_rset_init(ggml_metal_buffer_t buf) { + buf->rset = nil; + + if (!buf->use_residency_sets) { + return true; + } + +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { + MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init]; + desc.label = @"ggml_metal"; + desc.initialCapacity = buf->n_buffers; + + NSError * error; + buf->rset = [buf->device newResidencySetWithDescriptor:desc error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + [desc release]; + return false; + } + + [desc release]; + + for (int i = 0; i < buf->n_buffers; i++) { + [buf->rset addAllocation:buf->buffers[i].metal]; + } + + [buf->rset commit]; + [buf->rset requestResidency]; + + return true; + } +#endif + + return true; +} + +// rset free +static void ggml_metal_buffer_rset_free(ggml_metal_buffer_t buf) { +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { + if (buf->rset) { + [buf->rset endResidency]; + [buf->rset removeAllAllocations]; + [buf->rset release]; + } + } +#else + GGML_UNUSED(buf); +#endif +} + +static void * ggml_metal_host_malloc(size_t n) { + void * data = NULL; + +#if TARGET_OS_OSX + kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE); + if (err != KERN_SUCCESS) { + GGML_LOG_ERROR("%s: error: vm_allocate failed\n", __func__); + return NULL; + } +#else + const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); + if (result != 0) { + GGML_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); + return NULL; + } +#endif + + return data; +} + +ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared) { + ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer)); + + const size_t size_page = sysconf(_SC_PAGESIZE); + + size_t size_aligned = size; + if ((size_aligned % size_page) != 0) { + size_aligned += (size_page - (size_aligned % size_page)); + } + + const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + + shared = shared && props_dev->use_shared_buffers; + + // allocate shared buffer if the device supports it and it is required by the buffer type + if (shared) { + res->all_data = ggml_metal_host_malloc(size_aligned); + res->is_shared = true; + res->owned = true; + } else { + // dummy, non-NULL value - we'll populate this after creating the Metal buffer below + res->all_data = (void *) 0x000000400ULL; + res->is_shared = false; + } + res->all_size = size_aligned; + + res->device = ggml_metal_device_get_obj(dev); + res->queue = ggml_metal_device_get_queue(dev); + + res->n_buffers = 1; + + if (res->all_data != NULL) { + res->buffers[0].size = size; + res->buffers[0].metal = nil; + + if (size_aligned > 0) { + if (props_dev->use_shared_buffers &&shared) { + res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data + length:size_aligned + options:MTLResourceStorageModeShared + deallocator:nil]; + } else { + res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate]; + + res->all_data = (void *) (res->buffers[0].metal.gpuAddress); + } + } + + res->buffers[0].data = res->all_data; + } + + if (size_aligned > 0 && (res->all_data == NULL || res->buffers[0].metal == nil)) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + free(res); + return NULL; + } + + res->use_residency_sets = props_dev->use_residency_sets; + + if (!ggml_metal_buffer_rset_init(res)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(res); + return NULL; + } + + //ggml_metal_log_allocated_size(device, size_aligned); + + return res; +} + +ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size) { + ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer)); + + res->all_data = ptr; + res->all_size = size; + + res->is_shared = true; + res->owned = false; + + res->n_buffers = 0; + + const size_t size_page = sysconf(_SC_PAGESIZE); + + // page-align the data ptr + { + const uintptr_t offs = (uintptr_t) ptr % size_page; + ptr = (void *) ((char *) ptr - offs); + size += offs; + } + + size_t size_aligned = size; + if ((size_aligned % size_page) != 0) { + size_aligned += (size_page - (size_aligned % size_page)); + } + + res->device = ggml_metal_device_get_obj(dev); + res->queue = ggml_metal_device_get_queue(dev); + + const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + + // the buffer fits into the max buffer size allowed by the device + if (size_aligned <= props_dev->max_buffer_size) { + res->buffers[res->n_buffers].data = ptr; + res->buffers[res->n_buffers].size = size; + res->buffers[res->n_buffers].metal = nil; + + if (size_aligned > 0) { + res->buffers[res->n_buffers].metal = [res->device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (res->buffers[res->n_buffers].metal == nil) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + free(res); + return NULL; + } + } + + ggml_metal_log_allocated_size(res->device, size_aligned); + + ++res->n_buffers; + } else { + // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into + // one of the views + const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case + const size_t size_step = props_dev->max_buffer_size - size_ovlp; + const size_t size_view = props_dev->max_buffer_size; + + for (size_t i = 0; i < size; i += size_step) { + const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); + + res->buffers[res->n_buffers].data = (void *) ((uint8_t *) ptr + i); + res->buffers[res->n_buffers].size = size_step_aligned; + res->buffers[res->n_buffers].metal = nil; + + if (size_step_aligned > 0) { + res->buffers[res->n_buffers].metal = [res->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (res->buffers[res->n_buffers].metal == nil) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); + free(res); + return NULL; + } + } + + ggml_metal_log_allocated_size(res->device, size_step_aligned); + + if (i + size_step < size) { + GGML_LOG_INFO("\n"); + } + + ++res->n_buffers; + } + } + + res->use_residency_sets = props_dev->use_residency_sets; + + if (!ggml_metal_buffer_rset_init(res)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(res); + return NULL; + } + + return res; +} + +void ggml_metal_buffer_free(ggml_metal_buffer_t buf) { + for (int i = 0; i < buf->n_buffers; i++) { + [buf->buffers[i].metal release]; + } + + ggml_metal_buffer_rset_free(buf); + + if (buf->is_shared && buf->owned) { +#if TARGET_OS_OSX + vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)buf->all_data, buf->all_size); +#else + free(buf->all_data); +#endif + } + + free(buf); +} + +void * ggml_metal_buffer_get_base(ggml_metal_buffer_t buf) { + return buf->all_data; +} + +bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) { + return buf->is_shared; +} + +void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + if (buf->is_shared) { + memset((char *)tensor->data + offset, value, size); + return; + } + + @autoreleasepool { + // dst + struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor); + bid_dst.offs += offset; + + id queue = buf->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder fillBuffer:bid_dst.metal + range:NSMakeRange(bid_dst.offs, bid_dst.offs + size) + value:value]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } +} + +void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + if (buf->is_shared) { + memcpy((char *)tensor->data + offset, data, size); + return; + } + + @autoreleasepool { + // src + void * data_ptr = (void *)(uintptr_t) data; // "const cast" the src data + id buf_src = [buf->device newBufferWithBytesNoCopy:data_ptr + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; + + GGML_ASSERT(buf_src); + + // dst + struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor); + bid_dst.offs += offset; + + // note: for experimentation purposes, here we use a semaphore to wait for the copy to complete + // this is alternative to waitUntilCompleted, which should be faster, but don't seem to make much difference + dispatch_semaphore_t completion_semaphore = dispatch_semaphore_create(0); + + id queue = buf->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:buf_src + sourceOffset:0 + toBuffer:bid_dst.metal + destinationOffset:bid_dst.offs + size:size]; + + [encoder endEncoding]; + } + + [cmd_buf addCompletedHandler:^(id cb) { + // TODO: can check for errors here + GGML_UNUSED(cb); + + dispatch_semaphore_signal(completion_semaphore); + }]; + + [cmd_buf commit]; + + dispatch_semaphore_wait(completion_semaphore, DISPATCH_TIME_FOREVER); + dispatch_release(completion_semaphore); + + //[cmd_buf waitUntilCompleted]; + } +} + +void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + if (buf->is_shared) { + memcpy(data, (const char *)tensor->data + offset, size); + return; + } + + @autoreleasepool { + // src + struct ggml_metal_buffer_id bid_src = ggml_metal_buffer_get_id(buf, tensor); + bid_src.offs += offset; + + // dst + id buf_dst = [buf->device newBufferWithBytesNoCopy:data + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; + + GGML_ASSERT(buf_dst); + + id queue = buf->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:bid_src.metal + sourceOffset:bid_src.offs + toBuffer:buf_dst + destinationOffset:0 + size:size]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } +} + +void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) { + if (buf->is_shared) { + memset(buf->all_data, value, buf->all_size); + return; + } + + @autoreleasepool { + id queue = buf->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder fillBuffer:buf->buffers[0].metal + range:NSMakeRange(0, buf->buffers[0].size) + value:value]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } +} + +struct ggml_metal_buffer_id ggml_metal_buffer_get_id(ggml_metal_buffer_t buf, const struct ggml_tensor * t) { + struct ggml_metal_buffer_id res = { nil, 0 }; + + const int64_t tsize = ggml_nbytes(t); + + // find the view that contains the tensor fully + for (int i = 0; i < buf->n_buffers; ++i) { + const int64_t ioffs = (int64_t) t->data - (int64_t) buf->buffers[i].data; + + //GGML_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf->buffers[i].size); + if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf->buffers[i].size) { + res.metal = buf->buffers[i].metal; + res.offs = (size_t) ioffs; + + //GGML_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs); + + return res; + } + } + + GGML_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name); + + return res; +} diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index dbd955ec..9c0e0c56 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -1905,8 +1905,8 @@ GGML_TABLE_END() #define N_R0_Q5_1 4 #define N_SG_Q5_1 2 -#define N_R0_Q8_0 4 -#define N_SG_Q8_0 2 +#define N_R0_Q8_0 2 +#define N_SG_Q8_0 4 #define N_R0_MXFP4 2 #define N_SG_MXFP4 2 @@ -1917,13 +1917,13 @@ GGML_TABLE_END() #define N_R0_Q3_K 2 #define N_SG_Q3_K 2 -#define N_R0_Q4_K 4 +#define N_R0_Q4_K 2 #define N_SG_Q4_K 2 #define N_R0_Q5_K 2 #define N_SG_Q5_K 2 -#define N_R0_Q6_K 1 +#define N_R0_Q6_K 2 #define N_SG_Q6_K 2 #define N_R0_IQ1_S 4 @@ -1953,6 +1953,22 @@ GGML_TABLE_END() #define N_R0_IQ4_XS 2 #define N_SG_IQ4_XS 2 +// function constants offsets +#define FC_FLASH_ATTN_EXT_PAD 100 +#define FC_FLASH_ATTN_EXT_BLK 200 +#define FC_FLASH_ATTN_EXT 300 +#define FC_FLASH_ATTN_EXT_VEC 400 +#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500 +#define FC_MUL_MV 600 +#define FC_MUL_MM 700 + +// op-specific constants +#define OP_FLASH_ATTN_EXT_NQPTG 8 +#define OP_FLASH_ATTN_EXT_NCPSG 64 + +#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1 +#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage @@ -2046,6 +2062,17 @@ typedef struct { } ggml_metal_kargs_repeat; typedef struct { + float scale; + float bias; +} ggml_metal_kargs_scale; + +typedef struct { + float min; + float max; +} ggml_metal_kargs_clamp; + +typedef struct { + int64_t nk0; int64_t ne00; int64_t ne01; int64_t ne02; @@ -2112,12 +2139,6 @@ typedef struct { } ggml_metal_kargs_rope; typedef struct { - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; int32_t ne11; int32_t ne_12_2; // assume K and V are same shape int32_t ne_12_3; @@ -2127,6 +2148,44 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_pad; + +typedef struct { + int32_t ne01; + int32_t ne30; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_blk; + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + int32_t ns10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ns20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ne31; int32_t ne32; int32_t ne33; uint64_t nb31; @@ -2134,6 +2193,7 @@ typedef struct { uint64_t nb33; int32_t ne1; int32_t ne2; + int32_t ne3; float scale; float max_bias; float m0; @@ -2142,6 +2202,45 @@ typedef struct { float logit_softcap; } ggml_metal_kargs_flash_attn_ext; +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + int32_t ns10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ns20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; + int32_t ne1; + int32_t ne2; + int32_t ne3; + float scale; + float max_bias; + float m0; + float m1; + int32_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext_vec; + +typedef struct { + int32_t nrows; +} ggml_metal_kargs_flash_attn_ext_vec_reduce; + typedef struct { int32_t ne00; int32_t ne02; @@ -2176,6 +2275,7 @@ typedef struct { uint64_t nb13; int32_t ne0; int32_t ne1; + int32_t nr0; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mv; @@ -2199,46 +2299,34 @@ typedef struct { int32_t ne1; int16_t r2; int16_t r3; - int16_t nsg; - int16_t nxpsg; - int16_t r1ptg; } ggml_metal_kargs_mul_mv_ext; typedef struct { + int32_t ne02; int32_t ne10; int32_t ne11; // n_expert_used (bcast) uint64_t nb11; uint64_t nb12; - int32_t neh11; // n_tokens - uint64_t nbh11; + int32_t ne21; // n_tokens int32_t ne20; // n_expert_used uint64_t nb21; } ggml_metal_kargs_mul_mm_id_map0; -typedef struct { - int32_t ne20; // n_expert_used - int32_t neh0; - int32_t neh1; - uint64_t nbh1; - uint64_t nbh2; - int32_t ne0; - uint64_t nb1; - uint64_t nb2; -} ggml_metal_kargs_mul_mm_id_map1; - typedef struct { int32_t ne00; int32_t ne02; uint64_t nb01; uint64_t nb02; uint64_t nb03; - int32_t neh12; - uint64_t nbh10; - uint64_t nbh11; - uint64_t nbh12; - uint64_t nbh13; - int32_t neh0; - int32_t neh1; + int32_t ne11; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne20; + int32_t ne21; + int32_t ne0; + int32_t ne1; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mm_id; @@ -2263,18 +2351,14 @@ typedef struct { int32_t ne0; int32_t ne1; uint64_t nb1; + int32_t nr0; } ggml_metal_kargs_mul_mv_id; +// NORM +// RMS_NORM typedef struct { int32_t ne00; - int32_t ne00_4; - uint64_t nb01; - float eps; -} ggml_metal_kargs_norm; - -typedef struct { - int32_t ne00; - int32_t ne00_4; + int32_t ne00_t; uint64_t nb1; uint64_t nb2; uint64_t nb3; @@ -2285,7 +2369,7 @@ typedef struct { uint64_t nbf1[3]; uint64_t nbf2[3]; uint64_t nbf3[3]; -} ggml_metal_kargs_rms_norm; +} ggml_metal_kargs_norm; typedef struct { int32_t ne00; @@ -2301,7 +2385,7 @@ typedef struct { uint64_t nb00; uint64_t nb01; uint64_t nb02; - int32_t n_groups; + int32_t ngrp; float eps; } ggml_metal_kargs_group_norm; @@ -2345,6 +2429,10 @@ typedef struct{ float limit; } ggml_metal_kargs_glu; +typedef struct { + uint64_t np; +} ggml_metal_kargs_sum; + typedef struct { int64_t ne00; int64_t ne01; @@ -2354,14 +2442,6 @@ typedef struct { uint64_t nb01; uint64_t nb02; uint64_t nb03; - int64_t ne10; - int64_t ne11; - int64_t ne12; - int64_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; int64_t ne0; int64_t ne1; int64_t ne2; @@ -2395,12 +2475,6 @@ typedef struct { int32_t n_head_log2; } ggml_metal_kargs_soft_max; -typedef struct { - int64_t ne00; - int64_t ne01; - int n_past; -} ggml_metal_kargs_diag_mask_inf; - typedef struct { int64_t ne00; int64_t ne01; @@ -2427,33 +2501,46 @@ typedef struct { int64_t n_group; int64_t n_seq_tokens; int64_t n_seqs; - int64_t s_off; + uint64_t s_off; + uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; + uint64_t nb10; uint64_t nb11; uint64_t nb12; + uint64_t ns12; uint64_t nb13; + uint64_t nb20; uint64_t nb21; + uint64_t ns21; uint64_t nb22; + int64_t ne30; uint64_t nb31; uint64_t nb41; uint64_t nb42; + uint64_t ns42; uint64_t nb43; uint64_t nb51; uint64_t nb52; + uint64_t ns52; uint64_t nb53; + uint64_t nb0; } ggml_metal_kargs_ssm_scan; typedef struct { - int64_t ne00; + int32_t ne00t; + int32_t ne00; uint64_t nb01; uint64_t nb02; - int64_t ne10; + uint64_t nb03; + int32_t ne10; uint64_t nb10; uint64_t nb11; + uint64_t nb12; uint64_t nb1; uint64_t nb2; + uint64_t nb3; } ggml_metal_kargs_get_rows; typedef struct { @@ -2567,9 +2654,22 @@ typedef struct { int64_t IW; int64_t OH; int64_t OW; - int64_t parallel_elements; + int64_t np; } ggml_metal_kargs_pool_2d; +typedef struct { + int64_t ne00; + uint64_t nb01; +} ggml_metal_kargs_argmax; + +typedef struct { + int64_t np; +} ggml_metal_kargs_opt_step_adamw; + +typedef struct { + int64_t np; +} ggml_metal_kargs_opt_step_sgd; + #endif // GGML_METAL_IMPL #include @@ -2580,6 +2680,10 @@ using namespace metal; #define MIN(x, y) ((x) < (y) ? (x) : (y)) #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } +#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1)) + +#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x) + #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf @@ -2588,12 +2692,13 @@ using namespace metal; // .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/ggml-metal.metal // .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal // -#if __METAL_VERSION__ < 310 && defined(GGML_METAL_USE_BF16) -#undef GGML_METAL_USE_BF16 +#if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16) +#undef GGML_METAL_HAS_BF16 #endif -#if defined(GGML_METAL_USE_BF16) +#if defined(GGML_METAL_HAS_BF16) typedef matrix bfloat4x4; +typedef matrix bfloat2x4; #endif constexpr constant static float kvalues_iq4nl_f[16] = { @@ -2627,12 +2732,21 @@ static inline float e8m0_to_fp32(uint8_t x) { return as_type(bits); } +static inline float dot(float x, float y) { + return x*y; +} + // NOTE: this is not dequantizing - we are simply fitting the template template void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { reg = (type4x4)(*src); } +template +void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) { + reg = (type4)(*src); +} + template void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { reg = (type4x4)(*src); @@ -2643,7 +2757,7 @@ void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { reg = (type4)(*(src)); } -#if defined(GGML_METAL_USE_BF16) +#if defined(GGML_METAL_HAS_BF16) template void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { reg = (type4x4)(*src); @@ -3484,7 +3598,7 @@ kernel void kernel_add_fuse_impl( typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; -template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; +template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; @@ -3493,7 +3607,7 @@ template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_ template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; -kernel void kernel_sub( +kernel void kernel_sub_fuse_1( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -3519,7 +3633,7 @@ kernel void kernel_sub( } } -kernel void kernel_mul( +kernel void kernel_mul_fuse_1( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -3539,13 +3653,20 @@ kernel void kernel_mul( device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + if (args.ne10 == 1) { + const float x = *((device float *)(src1_ptr)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; + } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + } } } -kernel void kernel_div( +kernel void kernel_div_fuse_1( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -3565,9 +3686,16 @@ kernel void kernel_div( device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); + if (args.ne10 == 1) { + const float x = 1.0f / *((device float *)(src1_ptr)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; + } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); + } } } @@ -3638,23 +3766,17 @@ kernel void kernel_add_row_c4_fuse_impl( device const char * src1, device char * dst, uint tpig[[thread_position_in_grid]]) { - const uint nb = args.ne00/4; const uint i = tpig % nb; device const float4 * src0_row = (device const float4 *) (src0); device float4 * dst_row = (device float4 *) (dst); - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - float4 res = src0_row[tpig]; #pragma unroll(F) for (short j = 0; j < F; ++j) { - res += src1_row[j][i]; + res += ((device const float4 *) (src1 + args.o1[j]))[i]; } dst_row[tpig] = res; @@ -3662,7 +3784,7 @@ kernel void kernel_add_row_c4_fuse_impl( typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t; -template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; +template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>; template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>; template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>; @@ -3702,7 +3824,7 @@ kernel void kernel_sub_row_c4_fuse_impl( typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t; -template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; +template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; template kernel void kernel_mul_row_c4_fuse_impl( @@ -3735,7 +3857,7 @@ kernel void kernel_mul_row_c4_fuse_impl( typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t; -template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; +template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; template kernel void kernel_div_row_c4_fuse_impl( @@ -3768,55 +3890,80 @@ kernel void kernel_div_row_c4_fuse_impl( typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t; -template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; +template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; -kernel void kernel_scale( +kernel void kernel_scale_f32( + constant ggml_metal_kargs_scale & args, device const float * src0, device float * dst, - constant float & scale, - constant float & bias, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale + bias; + dst[tpig] = src0[tpig] * args.scale + args.bias; } -kernel void kernel_scale_4( +kernel void kernel_scale_f32_4( + constant ggml_metal_kargs_scale & args, device const float4 * src0, device float4 * dst, - constant float & scale, - constant float & bias, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale + bias; + dst[tpig] = src0[tpig] * args.scale + args.bias; } -kernel void kernel_clamp( +kernel void kernel_clamp_f32( + constant ggml_metal_kargs_clamp & args, device const float * src0, device float * dst, - constant float & min, - constant float & max, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); + dst[tpig] = clamp(src0[tpig], args.min, args.max); } -kernel void kernel_relu( +kernel void kernel_clamp_f32_4( + constant ggml_metal_kargs_clamp & args, + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = clamp(src0[tpig], args.min, args.max); +} + +kernel void kernel_relu_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = max(0.0f, src0[tpig]); } -kernel void kernel_sigmoid( +kernel void kernel_relu_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]); +} + +kernel void kernel_sigmoid_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); } -kernel void kernel_tanh( +kernel void kernel_sigmoid_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); +} + +kernel void kernel_tanh_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = precise::tanh(x); + dst[tpig] = precise::tanh(src0[tpig]); +} + +kernel void kernel_tanh_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = precise::tanh(src0[tpig]); } constant float GELU_COEF_A = 0.044715f; @@ -3824,7 +3971,7 @@ constant float GELU_QUICK_COEF = -1.702f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; constant float SQRT_2_INV = 0.70710678118654752440084436210484f; -kernel void kernel_gelu( +kernel void kernel_gelu_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { @@ -3833,7 +3980,7 @@ kernel void kernel_gelu( dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } -kernel void kernel_gelu_4( +kernel void kernel_gelu_f32_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -3846,7 +3993,7 @@ kernel void kernel_gelu_4( dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } -kernel void kernel_gelu_quick( +kernel void kernel_gelu_quick_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { @@ -3855,7 +4002,7 @@ kernel void kernel_gelu_quick( dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); } -kernel void kernel_gelu_quick_4( +kernel void kernel_gelu_quick_f32_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -3882,7 +4029,7 @@ T erf_approx(T x) { return sign_x * y; } -kernel void kernel_gelu_erf( +kernel void kernel_gelu_erf_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { @@ -3891,7 +4038,7 @@ kernel void kernel_gelu_erf( dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); } -kernel void kernel_gelu_erf_4( +kernel void kernel_gelu_erf_f32_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -3900,7 +4047,7 @@ kernel void kernel_gelu_erf_4( dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); } -kernel void kernel_silu( +kernel void kernel_silu_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { @@ -3908,7 +4055,7 @@ kernel void kernel_silu( dst[tpig] = x / (1.0f + exp(-x)); } -kernel void kernel_silu_4( +kernel void kernel_silu_f32_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -3916,99 +4063,202 @@ kernel void kernel_silu_4( dst[tpig] = x / (1.0f + exp(-x)); } -kernel void kernel_elu( +kernel void kernel_elu_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; + const float x = src0[tpig]; dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f); } -kernel void kernel_sqr( +kernel void kernel_elu_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const float4 x = src0[tpig]; + dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f); + dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f); + dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f); + dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f); +} + +kernel void kernel_sqr_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * src0[tpig]; } -kernel void kernel_sqrt( +kernel void kernel_sqr_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src0[tpig]; +} + +kernel void kernel_sqrt_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = sqrt(src0[tpig]); } -kernel void kernel_sin( +kernel void kernel_sqrt_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sqrt(src0[tpig]); +} + +kernel void kernel_sin_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = sin(src0[tpig]); } -kernel void kernel_cos( +kernel void kernel_sin_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sin(src0[tpig]); +} + +kernel void kernel_cos_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = cos(src0[tpig]); } -kernel void kernel_neg( +kernel void kernel_cos_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = cos(src0[tpig]); +} + +kernel void kernel_log_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = log(src0[tpig]); +} + +kernel void kernel_log_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = log(src0[tpig]); +} + +kernel void kernel_neg_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = -src0[tpig]; } -kernel void kernel_abs( +kernel void kernel_neg_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = -src0[tpig]; +} + +kernel void kernel_abs_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = fabs(src0[tpig]); } -kernel void kernel_sgn( - device const float * src0, - device float * dst, +kernel void kernel_abs_f32_4( + device const float4 * src0, + device float4 * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f); + dst[tpig] = fabs(src0[tpig]); } -kernel void kernel_step( +kernel void kernel_sgn_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f; + dst[tpig] = sign(src0[tpig]); } -kernel void kernel_hardswish( +kernel void kernel_sgn_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sign(src0[tpig]); +} + +kernel void kernel_step_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; + dst[tpig] = step(0.0f, src0[tpig]); +} + +kernel void kernel_step_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = step(0.0f, src0[tpig]); +} + +kernel void kernel_hardswish_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + const float x = src0[tpig]; dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); } -kernel void kernel_hardsigmoid( +kernel void kernel_hardswish_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const float4 x = src0[tpig]; + dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); +} + +kernel void kernel_hardsigmoid_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; + const float x = src0[tpig]; dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); } -kernel void kernel_exp( +kernel void kernel_hardsigmoid_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const float4 x = src0[tpig]; + dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); +} + +kernel void kernel_exp_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = exp(src0[tpig]); } -kernel void kernel_reglu( +kernel void kernel_exp_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]); +} + +kernel void kernel_reglu_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -4024,11 +4274,11 @@ kernel void kernel_reglu( } } -kernel void kernel_geglu( +kernel void kernel_geglu_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -4046,11 +4296,11 @@ kernel void kernel_geglu( } } -kernel void kernel_swiglu( +kernel void kernel_swiglu_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -4068,11 +4318,11 @@ kernel void kernel_swiglu( } } -kernel void kernel_swiglu_oai( +kernel void kernel_swiglu_oai_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -4094,11 +4344,11 @@ kernel void kernel_swiglu_oai( } } -kernel void kernel_geglu_erf( +kernel void kernel_geglu_erf_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -4116,11 +4366,11 @@ kernel void kernel_geglu_erf( } } -kernel void kernel_geglu_quick( +kernel void kernel_geglu_quick_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -4138,6 +4388,24 @@ kernel void kernel_geglu_quick( } } +kernel void kernel_op_sum_f32( + constant ggml_metal_kargs_sum & args, + device const float * src0, + device float * dst, + ushort tiitg[[thread_index_in_threadgroup]]) { + + if (tiitg != 0) { + return; + } + + float acc = 0.0f; + for (ulong i = 0; i < args.np; ++i) { + acc += src0[i]; + } + + dst[0] = acc; +} + template kernel void kernel_sum_rows( constant ggml_metal_kargs_sum_rows & args, @@ -4190,16 +4458,16 @@ kernel void kernel_sum_rows( typedef decltype(kernel_sum_rows) kernel_sum_rows_t; -template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows; -template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows; +template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; +template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; template kernel void kernel_soft_max( + constant ggml_metal_kargs_soft_max & args, device const char * src0, device const char * src1, device const char * src2, device char * dst, - constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -4301,11 +4569,11 @@ kernel void kernel_soft_max( template kernel void kernel_soft_max_4( + constant ggml_metal_kargs_soft_max & args, device const char * src0, device const char * src1, device const char * src2, device char * dst, - constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -4415,53 +4683,12 @@ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kerne template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; -kernel void kernel_diag_mask_inf( - device const float * src0, - device float * dst, - constant ggml_metal_kargs_diag_mask_inf & args, - uint3 tpig[[thread_position_in_grid]]) { - const int64_t i02 = tpig[2]; - const int64_t i01 = tpig[1]; - const int64_t i00 = tpig[0]; - - if (i00 > args.n_past + i01) { - dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY; - } else { - dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00]; - } -} - -kernel void kernel_diag_mask_inf_8( - device const float4 * src0, - device float4 * dst, - constant ggml_metal_kargs_diag_mask_inf & args, - uint3 tpig[[thread_position_in_grid]]) { - - const int64_t i = 2*tpig[0]; - - dst[i+0] = src0[i+0]; - dst[i+1] = src0[i+1]; - int64_t i4 = 4*i; - const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01; - const int64_t i01 = i4/(args.ne00); i4 -= i01*args.ne00; - const int64_t i00 = i4; - for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= args.n_past + i01) { - break; - } - dst[i+1][k] = -INFINITY; - if (i00 + k > args.n_past + i01) { - dst[i][k] = -INFINITY; - } - } -} - // ref: ggml.c:ggml_compute_forward_ssm_conv_f32 -kernel void kernel_ssm_conv_f32( +kernel void kernel_ssm_conv_f32_f32( + constant ggml_metal_kargs_ssm_conv & args, device const void * src0, device const void * src1, device float * dst, - constant ggml_metal_kargs_ssm_conv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -4488,123 +4715,40 @@ kernel void kernel_ssm_conv_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part -kernel void kernel_ssm_scan_f32( - device const void * src0, - device const void * src1, - device const void * src2, - device const void * src3, - device const void * src4, - device const void * src5, - device const void * src6, - device float * dst, - threadgroup float * shared [[threadgroup(0)]], - constant ggml_metal_kargs_ssm_scan & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { +kernel void kernel_ssm_conv_f32_f32_4( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; - const int64_t i0 = tpitg.x; - const int64_t i1 = 0; - const int64_t ir = tgpig.x; // current head - const int64_t i3 = tgpig.y; // current seq + const int64_t nc = args.ne10; + //const int64_t ncs = args.ne00; + //const int64_t nr = args.ne01; + //const int64_t n_t = args.ne1; + //const int64_t n_s = args.ne2; - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); + device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11); + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; + float sumf = 0.0f; - const int64_t s_off = args.s_off; - - device const int32_t * ids = (device const int32_t *) src6; - - device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; - float s0 = s0_buff[i]; - float s = s_buff[i]; - - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); - - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} - - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; - - const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); - s = state; - - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - if (sgptg > 1) { - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. - threadgroup_barrier(mem_flags::mem_threadgroup); - - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); - if (tiisg == 0) { - y[0] = sumf; - } - } - } else if (tiisg == 0) { - y[0] = sumf; - } - - // recurse - s0 = s; + for (int64_t i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); } - // Assign the final state to the output buffer - s_buff[i] = s; + x[0] = sumf; } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -kernel void kernel_ssm_scan_f32_group( +kernel void kernel_ssm_scan_f32( + constant ggml_metal_kargs_ssm_scan & args, device const void * src0, device const void * src1, device const void * src2, @@ -4614,103 +4758,88 @@ kernel void kernel_ssm_scan_f32_group( device const void * src6, device float * dst, threadgroup float * shared [[threadgroup(0)]], - constant ggml_metal_kargs_ssm_scan & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + constexpr short NW = N_SIMDWIDTH; - const int64_t i0 = tpitg.x; - const int64_t i1 = tgpig.x; - const int64_t ir = tgpig.y; // current head - const int64_t i3 = tgpig.z; // current seq + shared[tpitg.x] = 0.0f; - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); + const int32_t i0 = tpitg.x; + const int32_t i1 = tgpig.x; + const int32_t ir = tgpig.y; // current head + const int32_t i3 = tgpig.z; // current seq - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; + const int32_t nc = args.d_state; + const int32_t nr = args.d_inner; + const int32_t nh = args.n_head; + const int32_t ng = args.n_group; + const int32_t n_t = args.n_seq_tokens; - const int64_t s_off = args.s_off; + const int32_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; + + const int32_t i = i0 + i1*nc; + const int32_t g = ir / (nh / ng); // repeat_interleave + float s0 = s0_buff[i]; - float s = s_buff[i]; + float s = 0.0f; - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh} - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} + const float A0 = A[i0%args.ne30]; - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; - const float dA = exp(dt_soft_plus * A[0]); + device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns} + device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns} - const float state = (s0 * dA) + (B[i0] * x_dt); - s = state; + device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns} - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. + for (int i2 = 0; i2 < n_t; i2 += sgptg) { threadgroup_barrier(mem_flags::mem_threadgroup); - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float dt0 = dt[0]; + const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; + const float x_dt = x[0] * dtsp; + const float dA = exp(dtsp * A0); + + s = (s0 * dA) + (B[i0] * x_dt); + + const float sumf = simd_sum(s * C[i0]); + if (tiisg == 0) { - y[0] = sumf; + shared[t*NW + sgitg] = sumf; } + + // recurse + s0 = s; + + x += args.ns12; + dt += args.ns21; + B += args.ns42; + C += args.ns52; } - // recurse - s0 = s; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float sumf = simd_sum(shared[sgitg*NW + tiisg]); + + if (tiisg == 0 && i2 + sgitg < n_t) { + y[sgitg*nh*nr] = sumf; + } + + y += sgptg*nh*nr; } - // Assign the final state to the output buffer s_buff[i] = s; } @@ -4892,24 +5021,22 @@ kernel void kernel_rwkv_wkv7_f32( } } -kernel void kernel_argmax( - device const void * x, - device int32_t * dst, - constant int64_t & ncols, - constant uint64_t & nb01, - threadgroup float * shared_maxval [[threadgroup(0)]], - threadgroup int32_t * shared_argmax [[threadgroup(1)]], +kernel void kernel_argmax_f32( + constant ggml_metal_kargs_argmax & args, + device const char * src0, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01); + device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01); float lmax = -INFINITY; int32_t larg = -1; - for (int i00 = tpitg; i00 < ncols; i00 += ntg) { + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { if (x_row[i00] > lmax) { lmax = x_row[i00]; larg = i00; @@ -4920,6 +5047,11 @@ kernel void kernel_argmax( float max_val = simd_max(lmax); int32_t arg_val = simd_max(select(-1, larg, lmax == max_val)); + device int32_t * dst_i32 = (device int32_t *) dst; + + threadgroup float * shared_maxval = (threadgroup float *) shmem; + threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH; + if (ntg > N_SIMDWIDTH) { if (sgitg == 0) { shared_maxval[tiisg] = -INFINITY; @@ -4941,38 +5073,51 @@ kernel void kernel_argmax( float max_val_reduced = simd_max(max_val); int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced)); - dst[tgpig] = arg_val_reduced; + dst_i32[tgpig] = arg_val_reduced; return; } - dst[tgpig] = arg_val; + dst_i32[tgpig] = arg_val; } -kernel void kernel_norm( +// F == 1 : norm (no fuse) +// F == 2 : norm + mul +// F == 3 : norm + mul + add +template +kernel void kernel_norm_fuse_impl( constant ggml_metal_kargs_norm & args, device const char * src0, + device const char * src1_0, + device const char * src1_1, device char * dst, threadgroup float * shmem_f32 [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - ushort tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { if (sgitg == 0) { shmem_f32[tiisg] = 0.0f; } - device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + const int i01 = tgpig.x; + const int i02 = tgpig.y; + const int i03 = tgpig.z; - float4 sumf4(0.0f); + device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); + + device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); + device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); + + T sumft(0.0f); float sumf = 0.0f; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { - sumf4 += x[i00]; + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { + sumft += x[i00]; } - sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3]; + sumf = dot(sumft, T(1.0f)); sumf = simd_sum(sumf); threadgroup_barrier(mem_flags::mem_threadgroup); @@ -4988,10 +5133,10 @@ kernel void kernel_norm( const float mean = sumf/args.ne00; - device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); sumf = 0.0f; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { y[i00] = x[i00] - mean; sumf += dot(y[i00], y[i00]); } @@ -5011,17 +5156,35 @@ kernel void kernel_norm( const float variance = sumf/args.ne00; const float scale = 1.0f/sqrt(variance + args.eps); - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { - y[i00] = y[i00] * scale; + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { + if (F == 1) { + y[i00] = (y[i00]*scale); + } + if (F == 2) { + y[i00] = (y[i00]*scale)*f0[i00]; + } + if (F == 3) { + y[i00] = (y[i00]*scale)*f0[i00] + f1[i00]; + } } } +typedef decltype(kernel_norm_fuse_impl) kernel_norm_fuse_t; + +template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; +template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; +template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; + +template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; +template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; +template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; + // F == 1 : rms_norm (no fuse) // F == 2 : rms_norm + mul // F == 3 : rms_norm + mul + add -template +template kernel void kernel_rms_norm_fuse_impl( - constant ggml_metal_kargs_rms_norm & args, + constant ggml_metal_kargs_norm & args, device const char * src0, device const char * src1_0, device const char * src1_1, @@ -5040,15 +5203,15 @@ kernel void kernel_rms_norm_fuse_impl( const int i02 = tgpig.y; const int i03 = tgpig.z; - device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); + device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); - device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); - device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); + device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); + device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); float sumf = 0.0f; // parallel sum - for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { sumf += dot(x[i00], x[i00]); } sumf = simd_sum(sumf); @@ -5067,8 +5230,8 @@ kernel void kernel_rms_norm_fuse_impl( const float mean = sumf/args.ne00; const float scale = 1.0f/sqrt(mean + args.eps); - device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); - for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { + device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { if (F == 1) { y[i00] = (x[i00]*scale); } @@ -5081,13 +5244,17 @@ kernel void kernel_rms_norm_fuse_impl( } } -typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t; +typedef decltype(kernel_rms_norm_fuse_impl) kernel_rms_norm_fuse_t; -template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>; -template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>; -template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>; +template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; +template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; +template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; -kernel void kernel_l2_norm( +template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; +template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; +template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; + +kernel void kernel_l2_norm_f32( constant ggml_metal_kargs_l2_norm & args, device const char * src0, device char * dst, @@ -5130,10 +5297,10 @@ kernel void kernel_l2_norm( } } -kernel void kernel_group_norm( +kernel void kernel_group_norm_f32( + constant ggml_metal_kargs_group_norm & args, device const float * src0, device float * dst, - constant ggml_metal_kargs_group_norm & args, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], @@ -5141,7 +5308,7 @@ kernel void kernel_group_norm( uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { const int64_t ne = args.ne00*args.ne01*args.ne02; - const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups); + const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp); int start = tgpig * gs; int end = start + gs; @@ -5299,7 +5466,52 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } -template +template +static inline void helper_mv_reduce_and_write( + device float * dst_f32, + float sumf[NR0], + const int r0, + const int ne01, + ushort tiisg, + ushort sgitg, + threadgroup char * shmem) { + constexpr short NW = N_SIMDWIDTH; + + threadgroup float * shmem_f32[NR0]; + + for (short row = 0; row < NR0; ++row) { + shmem_f32[row] = (threadgroup float *) shmem + NW*row; + + if (sgitg == 0) { + shmem_f32[row][tiisg] = 0.0f; + } + + sumf[row] = simd_sum(sumf[row]); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short row = 0; row < NR0; ++row) { + if (tiisg == 0) { + shmem_f32[row][sgitg] = sumf[row]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short row = 0; row < NR0 && r0 + row < ne01; ++row) { + float tot = simd_sum(shmem_f32[row][tiisg]); + + if (tiisg == 0 && sgitg == 0) { + dst_f32[r0 + row] = tot; + } + } +} + +constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]]; +constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]]; + +template void mul_vec_q_n_f32_impl( args_t args, device const char * src0, @@ -5309,45 +5521,54 @@ void mul_vec_q_n_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + constexpr short NW = N_SIMDWIDTH; + constexpr short NQ = 16; + const int nb = args.ne00/QK4_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * nsg + sgitg) * nr0; + const int r0 = (tgpig.x*NSG + sgitg)*NR0; + //const int r0 = tgpig.x*NR0; + const int r1 = tgpig.y; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q_type * ax[nr0]; - for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + device const block_q_type * ax[NR0]; + FOR_UNROLL (int row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } + float sumf[NR0] = {0.f}; + + const short ix = (tiisg/(NW/NQ)); + const short il = (tiisg%(NW/NQ))*8; + + //const int ib0 = sgitg*NQ + ix; + const int ib0 = ix; + float yl[16]; // src1 vector cache - float sumf[nr0] = {0.f}; - const short ix = (tiisg/2); - const short il = (tiisg%2)*8; - - device const float * yb = y + ix*QK4_0 + il; + //device const float * yb = y + ix*QK4_0 + il; + device const float * yb = y + ib0*QK4_0 + il; // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { + //for (int ib = ib0; ib < nb; ib += NSG*NQ) { + for (int ib = ib0; ib < nb; ib += NQ) { float sumy[2] = { 0.f, 0.f }; -#pragma unroll - for (short i = 0; i < 8; i += 2) { + FOR_UNROLL (short i = 0; i < 8; i += 2) { sumy[0] += yb[i + 0] + yb[i + 1]; yl[i + 0] = yb[i + 0]; yl[i + 1] = yb[i + 1]/256.f; @@ -5357,21 +5578,23 @@ void mul_vec_q_n_f32_impl( yl[i + 9] = yb[i + 17]/4096.f; } -#pragma unroll - for (short row = 0; row < nr0; row++) { + FOR_UNROLL (short row = 0; row < NR0; row++) { sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); } yb += QK4_0 * 16; + //yb += NSG*NQ*QK4_0; } device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; - for (int row = 0; row < nr0; ++row) { + //helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); + + for (int row = 0; row < NR0; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < args.ne01) { - dst_f32[first_row + row] = tot; + if (tiisg == 0 && r0 + row < args.ne01) { + dst_f32[r0 + row] = tot; } } } @@ -5381,10 +5604,11 @@ kernel void kernel_mul_mv_q4_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -5392,10 +5616,11 @@ kernel void kernel_mul_mv_q4_1_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -5403,10 +5628,11 @@ kernel void kernel_mul_mv_q5_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -5414,15 +5640,14 @@ kernel void kernel_mul_mv_q5_1_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -#define NB_Q8_0 8 - -template +template void kernel_mul_mv_q8_0_f32_impl( args_t args, device const char * src0, @@ -5432,66 +5657,68 @@ void kernel_mul_mv_q8_0_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + constexpr short NW = N_SIMDWIDTH; + constexpr short NQ = 8; + const int nb = args.ne00/QK8_0; - const int r0 = tgpig.x; + const int r0 = tgpig.x*NR0; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; - const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q8_0 * ax[nr0]; - for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + device const block_q8_0 * ax[NR0]; + FOR_UNROLL (short row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } - float yl[NB_Q8_0]; - float sumf[nr0] = { 0.f }; + float sumf[NR0] = { 0.f }; - const short ix = tiisg/4; - const short il = tiisg%4; + const short ix = tiisg/(NW/NQ); + const short il = tiisg%(NW/NQ); - device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; + const int ib0 = sgitg*NQ + ix; - // each thread in a SIMD group deals with NB_Q8_0 quants at a time - for (int ib = ix; ib < nb; ib += nw/4) { - for (short i = 0; i < NB_Q8_0; ++i) { + float yl[NQ]; + + device const float * yb = y + ib0*QK8_0 + il*NQ; + + // each thread in a SIMD group deals with NQ quants at a time + for (int ib = ib0; ib < nb; ib += NSG*NQ) { + for (short i = 0; i < NQ; ++i) { yl[i] = yb[i]; } - for (short row = 0; row < nr0; row++) { - device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; + for (short row = 0; row < NR0; row++) { + device const int8_t * qs = ax[row][ib].qs + il*NQ; + float sumq = 0.f; - for (short iq = 0; iq < NB_Q8_0; ++iq) { - sumq += qs[iq] * yl[iq]; + FOR_UNROLL (short i = 0; i < NQ; ++i) { + sumq += qs[i] * yl[i]; } + sumf[row] += sumq*ax[row][ib].d; } - yb += nw*NB_Q8_0; + yb += NSG*NQ*QK8_0; } device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0; ++row) { - const float tot = simd_sum(sumf[row]); - - if (tiisg == 0 && first_row + row < args.ne01) { - dst_f32[first_row + row] = tot; - } - } + helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } [[host_name("kernel_mul_mv_q8_0_f32")]] @@ -5500,15 +5727,16 @@ kernel void kernel_mul_mv_q8_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } // mat-vec kernel processing in chunks of float4 // chpb - chunks per quantization block -template +template void kernel_mul_mv_ext_q4_f32_impl( constant ggml_metal_kargs_mul_mv_ext & args, device const char * src0, @@ -5517,6 +5745,9 @@ void kernel_mul_mv_ext_q4_f32_impl( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short NSG = FC_mul_mv_nsg; + const short nxpsg = FC_mul_mv_nxpsg; + const short chpt = 4; // chunks per thread //const short nxpsg = (32); @@ -5525,7 +5756,7 @@ void kernel_mul_mv_ext_q4_f32_impl( const short tx = tiisg%nxpsg; const short ty = tiisg/nxpsg; - const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty; const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; @@ -5566,7 +5797,6 @@ void kernel_mul_mv_ext_q4_f32_impl( #pragma unroll(r1ptg) for (short ir1 = 0; ir1 < r1ptg; ++ir1) { sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]); - } } @@ -5609,7 +5839,7 @@ void kernel_mul_mv_ext_q4_f32_impl( } // mat-vec kernel processing in chunks of float4x4 -template +template void kernel_mul_mv_ext_q4x4_f32_impl( constant ggml_metal_kargs_mul_mv_ext & args, device const char * src0, @@ -5618,6 +5848,9 @@ void kernel_mul_mv_ext_q4x4_f32_impl( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short NSG = FC_mul_mv_nsg; + const short nxpsg = FC_mul_mv_nxpsg; + const short chpt = 1; //const short nxpsg = (32); @@ -5626,7 +5859,7 @@ void kernel_mul_mv_ext_q4x4_f32_impl( const short tx = tiisg%nxpsg; const short ty = tiisg/nxpsg; - const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty; const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; @@ -5723,12 +5956,7 @@ kernel void kernel_mul_mv_ext_q4_f32_disp( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - switch (args.nxpsg) { - case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - } + kernel_mul_mv_ext_q4_f32_impl(args, src0, src1, dst, tgpig, tiisg, sgitg); } template @@ -5740,17 +5968,17 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - switch (args.nxpsg) { - case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - } + kernel_mul_mv_ext_q4x4_f32_impl(args, src0, src1, dst, tgpig, tiisg, sgitg); } typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t; typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>; + template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>; template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>; template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>; @@ -5806,106 +6034,253 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4 template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>; template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>; -#define N_MV_T_T 4 - -template -void kernel_mul_mv_impl( +template +void kernel_mul_mv_t_t_impl( args_t args, device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem, uint3 tgpig, - ushort tiisg) { - const int r0 = tgpig.x; - const int rb = tgpig.y*N_MV_T_T; + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + constexpr short NW = N_SIMDWIDTH; + constexpr short NB = 32; + constexpr short NF = 8; + + const int nb = args.ne00/NB; + + const int r0 = tgpig.x*NR0; + const int r1 = tgpig.y; const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const T0 * x = (device const T0 *) (src0 + offset0); + //device const T0 * x = (device const T0 *) (src0 + offset0); + device const T1 * y = (device const T1 *) (src1 + offset1); - device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + // pointers to src0 rows + device const T0 * ax [NR0]; + FOR_UNROLL (short row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - if (args.ne00 < 128) { - for (int row = 0; row < N_MV_T_T; ++row) { - int r1 = rb + row; - if (r1 >= args.ne11) { - break; - } + ax[row] = (device const T0 *) ((device char *) src0 + offset0); + } - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + float sumf[NR0] = { 0.f }; - device const T1 * y = (device const T1 *) (src1 + offset1); + const short ix = tiisg/(NW/NF); + const short il = tiisg%(NW/NF); - float sumf = 0; - for (int i = tiisg; i < args.ne00; i += 32) { - sumf += (T0) x[i] * (T1) y[i]; - } + const int ib0 = sgitg*NF + ix; - float sum_all = simd_sum(sumf); - if (tiisg == 0) { - dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; - } + T1 yl[NF]; + + device const T1 * yb = y + (ib0*NB + il*NF); + + for (int ib = ib0; ib < nb; ib += NSG*NF) { + for (short i = 0; i < NF; ++i) { + yl[i] = yb[i]; } - } else { - device const T04 * x4 = (device const T04 *) x; - for (int row = 0; row < N_MV_T_T; ++row) { - int r1 = rb + row; - if (r1 >= args.ne11) { - break; + + for (short row = 0; row < NR0; row++) { + device const T0 * xb = ax[row] + (ib*NB + il*NF); + + float sumq = 0.f; + FOR_UNROLL (short i = 0; i < NF; ++i) { + sumq += xb[i] * yl[i]; } - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - - device const T1 * y = (device const T1 *) (src1 + offset1); - device const T14 * y4 = (device const T14 *) y; - - float sumf = 0; - for (int i = tiisg; i < args.ne00/4; i += 32) { - sumf += dot((float4) x4[i], (float4) y4[i]); - } - - float sum_all = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); - dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; - } + sumf[row] += sumq; } + + yb += NSG*NF*NW; + } + + for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) { + for (short row = 0; row < NR0; row++) { + sumf[row] += ax[row][i] * y[i]; + } + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); +} + +template +void kernel_mul_mv_t_t_disp( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + switch (args.nr0) { + //case 1: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + case 2: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 3: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 4: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; } } -template -kernel void kernel_mul_mv( +template +kernel void kernel_mul_mv_t_t( constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_impl( - args, - src0, - src1, - dst, - tgpig, - tiisg); + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_t_t_disp(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(kernel_mul_mv) mul_mv_t; +typedef decltype(kernel_mul_mv_t_t) mul_mv_t_t; -template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv; -template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv; -template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv; -template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; #endif -template -void kernel_mul_mv_c4_impl( +template +void kernel_mul_mv_t_t_4_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + constexpr short NW = N_SIMDWIDTH; + constexpr short NB = 32; + constexpr short NF = 16; + constexpr short NF4 = NF/4; + + const int nb = args.ne00/NB; + + const int r0 = tgpig.x*NR0; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); + device const T14 * y4 = (device const T14 *) (src1 + offset1); + + // pointers to src0 rows + device const T0 * ax [NR0]; + device const T04 * ax4[NR0]; + FOR_UNROLL (short row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + ax [row] = (device const T0 *) ((device char *) src0 + offset0); + ax4[row] = (device const T04 *) ((device char *) src0 + offset0); + } + + float sumf[NR0] = { 0.f }; + + const short ix = tiisg/(NW/NF); + const short il = tiisg%(NW/NF); + + const int ib0 = sgitg*NF + ix; + + T14 yl4[NF4]; + + device const T14 * yb4 = y4 + (ib0*NB + il*NF)/4; + + for (int ib = ib0; ib < nb; ib += NSG*NF) { + for (short i = 0; i < NF4; ++i) { + yl4[i] = yb4[i]; + } + + for (short row = 0; row < NR0; row++) { + device const T04 * xb4 = ax4[row] + (ib*NB + il*NF)/4; + + float sumq = 0.f; + FOR_UNROLL (short i = 0; i < NF4; ++i) { + sumq += dot(float4(xb4[i]), float4(yl4[i])); + } + + sumf[row] += sumq; + } + + yb4 += NSG*NF*NW/4; + } + + for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) { + for (short row = 0; row < NR0; row++) { + sumf[row] += ax[row][i] * y[i]; + } + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); +} + +template +void kernel_mul_mv_t_t_4_disp( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + switch (args.nr0) { + //case 1: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + case 2: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 3: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 4: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + }; +} + +template +kernel void kernel_mul_mv_t_t_4( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_t_t_4_disp(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +typedef decltype(kernel_mul_mv_t_t_4) mul_mv_t_t_4; + +template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +#endif + +template +void kernel_mul_mv_t_t_short_impl( args_t args, device const char * src0, device const char * src1, @@ -5913,7 +6288,7 @@ void kernel_mul_mv_c4_impl( uint3 tgpig, ushort tiisg) { const int r0 = tgpig.x*32 + tiisg; - const int rb = tgpig.y*N_MV_T_T; + const int r1 = tgpig.y; const int im = tgpig.z; if (r0 >= args.ne01) { @@ -5925,33 +6300,32 @@ void kernel_mul_mv_c4_impl( const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - device const T04 * x = (device const T04 *) (src0 + offset0); + device const T0 * x = (device const T0 *) (src0 + offset0); device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; - for (int row = 0; row < N_MV_T_T; ++row) { - int r1 = rb + row; - if (r1 >= args.ne11) { - break; - } + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + device const T1 * y = (device const T1 *) (src1 + offset1); - device const T14 * y = (device const T14 *) (src1 + offset1); + float res = 0.0f; - dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]); + for (int i = 0; i < args.ne00; ++i) { + res += (float) x[i] * (float) y[i]; } + + dst_f32[(uint64_t)r1*args.ne0 + r0] = res; } -template -kernel void kernel_mul_mv_c4( +template +kernel void kernel_mul_mv_t_t_short( constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_c4_impl( + kernel_mul_mv_t_t_short_impl( args, src0, src1, @@ -5960,116 +6334,14 @@ kernel void kernel_mul_mv_c4( tiisg); } -typedef decltype(kernel_mul_mv_c4) mul_mv_c4_t; +typedef decltype(kernel_mul_mv_t_t_short) mul_mv_t_t_short_t; -template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -#endif - -template -kernel void kernel_mul_mv_1row( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]]) { - - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; - - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - - device const T * x = (device const T *) (src0 + offset0); - device const float * y = (device const float *) (src1 + offset1); - - device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - - float sumf = 0; - if (args.ne00 < 128) { - for (int i = tiisg; i < args.ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - float sum_all = simd_sum(sumf); - if (tiisg == 0) { - dst_f32[r0] = sum_all; - } - } else { - device const T4 * x4 = (device const T4 *) x; - device const float4 * y4 = (device const float4 *) y; - - for (int i = tiisg; i < args.ne00/4; i += 32) { - sumf += dot((float4) x4[i], y4[i]); - } - - float sum_all = simd_sum(sumf); - - if (tiisg == 0) { - for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); - dst_f32[r0] = sum_all; - } - } -} - -typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t; - -template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; -#endif - -// Assumes row size (ne00) is a multiple of 4 -template -kernel void kernel_mul_mv_l4( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]]) { - - const int nrows = args.ne11; - const int r0 = tgpig.x; - const int im = tgpig.z; - - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; - - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - - device const T4 * x4 = (device const T4 *) (src0 + offset0); - - device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; - - for (int r1 = 0; r1 < nrows; ++r1) { - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - - device const float4 * y4 = (device const float4 *) (src1 + offset1); - - float sumf = 0; - for (int i = tiisg; i < args.ne00/4; i += 32) { - sumf += dot((float4) x4[i], y4[i]); - } - - float sum_all = simd_sum(sumf); - if (tiisg == 0) { - dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; - } - } -} - -typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t; - -template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; +template [[host_name("kernel_mul_mv_f32_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_f16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_f16_f16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; #endif static float rope_yarn_ramp(const float low, const float high, const int i0) { @@ -6372,9 +6644,9 @@ template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t ker template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision; typedef void (im2col_t)( + constant ggml_metal_kargs_im2col & args, device const float * x, device char * dst, - constant ggml_metal_kargs_im2col & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -6382,9 +6654,9 @@ typedef void (im2col_t)( template kernel void kernel_im2col( + constant ggml_metal_kargs_im2col & args, device const float * x, device char * dst, - constant ggml_metal_kargs_im2col & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -6393,11 +6665,10 @@ kernel void kernel_im2col( const int64_t OH = tgpg[1]; const int64_t OW = tgpg[2]; -// const int64_t N = ntg[0]; const int64_t KH = ntg[1]; const int64_t KW = ntg[2]; - const int64_t in = tpitg[0]; + int64_t in = tpitg[0]; const int64_t ikh = tpitg[1]; const int64_t ikw = tpitg[2]; @@ -6408,88 +6679,102 @@ kernel void kernel_im2col( const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0; const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1; - const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); + int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); device T * pdst = (device T *) (dst); if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { - pdst[offset_dst] = 0.0f; + while (in < args.N) { + pdst[offset_dst] = 0.0f; + offset_dst += ntg[0]*args.CHW*OH*OW; + + in += ntg[0]; + } } else { - const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; - pdst[offset_dst] = x[offset_src]; + int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; + + while (in < args.N) { + pdst[offset_dst] = x[offset_src]; + + offset_dst += ntg[0]*args.CHW*OH*OW; + offset_src += ntg[0]*args.ofs0; + + in += ntg[0]; + } } } template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; -typedef void (im2col_ext_t)( - device const float * x, - device char * dst, - constant ggml_metal_kargs_im2col & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]); - -template -kernel void kernel_im2col_ext( - device const float * x, - device char * dst, - constant ggml_metal_kargs_im2col & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] - const int64_t KHW = (int64_t)args.KHW; - - const int64_t d = tgpig[0] / args.CHW; - const int64_t chw = tgpig[0] % args.CHW; - const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) - const int64_t HW = tgpig[0] % KHW; - - const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; - if (tpitg_0 >= args.N) { - return; - } - - const int64_t tpitg_1 = HW / args.KW; - const int64_t tpitg_2 = HW % args.KW; - - const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; - const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; - - const int64_t offset_dst = - (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + - (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); - - device T * pdst = (device T *) (dst); - - if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { - pdst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; - pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; - } -} - -template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; -template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; +// TODO: obolete -- remove +//typedef void (im2col_ext_t)( +// constant ggml_metal_kargs_im2col & args, +// device const float * x, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// uint3 tgpg[[threadgroups_per_grid]], +// uint3 tpitg[[thread_position_in_threadgroup]], +// uint3 ntg[[threads_per_threadgroup]]); +// +//template +//kernel void kernel_im2col_ext( +// constant ggml_metal_kargs_im2col & args, +// device const float * x, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW +// uint3 tpitg[[thread_position_in_threadgroup]], +// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] +// const int64_t KHW = (int64_t)args.KHW; +// +// const int64_t d = tgpig[0] / args.CHW; +// const int64_t chw = tgpig[0] % args.CHW; +// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) +// const int64_t HW = tgpig[0] % KHW; +// +// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; +// if (tpitg_0 >= args.N) { +// return; +// } +// +// const int64_t tpitg_1 = HW / args.KW; +// const int64_t tpitg_2 = HW % args.KW; +// +// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; +// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; +// +// const int64_t offset_dst = +// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + +// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); +// +// device T * pdst = (device T *) (dst); +// +// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { +// pdst[offset_dst] = 0.0f; +// } else { +// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; +// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; +// } +//} +// +//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; +//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; typedef void (conv_transpose_1d_t)( + constant ggml_metal_kargs_conv_transpose_1d & args, device const float * src0, device const float * src1, device char * dst, - constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); template kernel void kernel_conv_transpose_1d( + constant ggml_metal_kargs_conv_transpose_1d & args, device const T * src0, device const float * src1, device char * dst, - constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]) { @@ -6513,26 +6798,26 @@ kernel void kernel_conv_transpose_1d( template [[host_name("kernel_conv_transpose_1d_f32_f32")]] kernel void kernel_conv_transpose_1d( + constant ggml_metal_kargs_conv_transpose_1d & args, device const float * src0, device const float * src1, device char * dst, - constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); template [[host_name("kernel_conv_transpose_1d_f16_f32")]] kernel void kernel_conv_transpose_1d( + constant ggml_metal_kargs_conv_transpose_1d & args, device const half * src0, device const float * src1, device char * dst, - constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); kernel void kernel_upscale_f32( + constant ggml_metal_kargs_upscale & args, device const char * src0, device char * dst, - constant ggml_metal_kargs_upscale & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -6556,9 +6841,9 @@ kernel void kernel_upscale_f32( } kernel void kernel_pad_f32( + constant ggml_metal_kargs_pad & args, device const char * src0, device char * dst, - constant ggml_metal_kargs_pad & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -6592,9 +6877,9 @@ kernel void kernel_pad_f32( } kernel void kernel_pad_reflect_1d_f32( + constant ggml_metal_kargs_pad_reflect_1d & args, device const char * src0, device char * dst, - constant ggml_metal_kargs_pad_reflect_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -6625,8 +6910,8 @@ kernel void kernel_pad_reflect_1d_f32( } kernel void kernel_arange_f32( - device char * dst, constant ggml_metal_kargs_arange & args, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -6639,9 +6924,9 @@ kernel void kernel_arange_f32( } kernel void kernel_timestep_embedding_f32( + constant ggml_metal_kargs_timestep_embedding & args, device const char * src0, device char * dst, - constant ggml_metal_kargs_timestep_embedding & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -6659,25 +6944,25 @@ kernel void kernel_timestep_embedding_f32( } if (args.dim % 2 != 0 && tpitg.x == 0) { - embed_data[args.dim] = 0.f; + embed_data[2 * half_] = 0.f; } } // bitonic sort implementation following the CUDA kernels as reference typedef void (argsort_t)( - device const float * x, - device int32_t * dst, constant ggml_metal_kargs_argsort & args, + device const float * x, + device int32_t * dst, threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]); template kernel void kernel_argsort_f32_i32( - device const float * x, - device int32_t * dst, constant ggml_metal_kargs_argsort & args, - threadgroup int32_t * shared_values [[threadgroup(0)]], + device const float * x, + device int32_t * dst, + threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]) { // bitonic sort @@ -6726,17 +7011,236 @@ kernel void kernel_argsort_f32_i32( } } +typedef void (i32_argsort_t)( + constant ggml_metal_kargs_argsort & args, + device const int32_t * x, + device int32_t * dst, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +template +kernel void kernel_argsort_i32_i32( + constant ggml_metal_kargs_argsort & args, + device const int32_t * x, + device int32_t * dst, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + // bitonic sort + int col = tpitg[0]; + int row = tgpig[1]; + + if (col >= args.ncols_pad) return; + + device const int32_t * x_row = x + row * args.ncols; + threadgroup int32_t * dst_row = shared_values; + + // initialize indices + dst_row[col] = col; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 2; k <= args.ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= args.ncols || + (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= args.ncols || + (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + // copy the result to dst without the padding + if (col < args.ncols) { + dst[row * args.ncols + col] = dst_row[col]; + } +} + template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_i32_i32_asc")]] kernel i32_argsort_t kernel_argsort_i32_i32; +template [[host_name("kernel_argsort_i32_i32_desc")]] kernel i32_argsort_t kernel_argsort_i32_i32; kernel void kernel_leaky_relu_f32( + constant ggml_metal_kargs_leaky_relu & args, device const float * src0, device float * dst, - constant ggml_metal_kargs_leaky_relu & args, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope; + const float x = src0[tpig]; + dst[tpig] = x > 0.0f ? x : x * args.slope; } +kernel void kernel_leaky_relu_f32_4( + constant ggml_metal_kargs_leaky_relu & args, + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const float4 x = src0[tpig]; + dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope); +} + +constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; + +constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]]; + +// pad the last chunk of C elements of k and v into a an extra pad buffer +kernel void kernel_flash_attn_ext_pad( + constant ggml_metal_kargs_flash_attn_ext_pad & args, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int32_t C = FC_flash_attn_ext_pad_ncpsg; + + device char * k_pad = dst; + device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3; + device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const int32_t icp = args.ne11 % C; + const int32_t ic0 = args.ne11 - icp; + + const int32_t i1 = tgpig[0]; + const int32_t i2 = tgpig[1]; + const int32_t i3 = tgpig[2]; + + if (i2 < args.ne_12_2 && i3 < args.ne_12_3) { + device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3; + device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3; + + device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3; + device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3; + + if (i1 >= icp) { + // here it is not important the exact value that will be used as we rely on masking out the scores in the attention + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = 0; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = 0; + } + } else { + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = k_src[i]; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = v_src[i]; + } + } + } + + if (FC_flash_attn_ext_pad_has_mask) { + if (i2 < args.ne32 && i3 < args.ne33) { + for (int ib = i1; ib < args.ne31; ib += C) { + device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0; + device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3; + + for (int i = tiitg; i < C; i += ntg.x) { + if (i >= icp) { + mask_dst[i] = -MAXHALF; + } else { + mask_dst[i] = mask_src[i]; + } + } + } + } + } +} + +constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]]; +constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]]; + +// scan the blocks of the mask that are not masked +// 0 - masked (i.e. full of -INF, skip) +// 1 - not masked (i.e. at least one element of the mask is not -INF) +kernel void kernel_flash_attn_ext_blk( + constant ggml_metal_kargs_flash_attn_ext_blk & args, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + // block size C x Q + const int32_t Q = FC_flash_attn_ext_blk_nqptg; + const int32_t C = FC_flash_attn_ext_blk_ncpsg; + + constexpr short NW = N_SIMDWIDTH; + + const int32_t i3 = tgpig[2]/args.ne32; + const int32_t i2 = tgpig[2]%args.ne32; + const int32_t i1 = tgpig[1]; + const int32_t i0 = tgpig[0]; + + char res = i0*C + C > args.ne30 ? 1 : 0; + + device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg; + + // fast route + if (res == 0) { + if (simd_max(*mask_src) > -MAXHALF/2) { + res = 1; + } + } + + // detailed check of the elements of the block + if ((C > NW || Q > 1) && res == 0) { + half m = -MAXHALF; + + FOR_UNROLL (short j = 0; j < Q; ++j) { + FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) { + m = max(m, mask_src[ii*NW]); + } + + mask_src += args.nb31/2; + } + + if (simd_max(m) > -MAXHALF/2) { + res = 1; + } + } + + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne30 + C - 1)/C); + + if (tiisg == 0) { + dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res; + } +} + +constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; +constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; +constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; +constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]]; +constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]]; + +constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; + +//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; +//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]]; +//constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]]; + +constant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]]; +constant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]]; +constant int32_t FC_flash_attn_ext_nsg [[function_constant(FC_FLASH_ATTN_EXT + 22)]]; + // ref: https://arxiv.org/pdf/2307.08691.pdf template< typename q_t, // query types in shared memory @@ -6751,6 +7255,7 @@ template< typename qk_t, // Q*K types typename qk8x8_t, typename s_t, // soft-max types + typename s2_t, typename s8x8_t, typename o_t, // attention accumulation types typename o4_t, @@ -6761,59 +7266,107 @@ template< typename vd4x4_t, // value type in device memory short nl_v, void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), - short DK, // K head size - short DV, // V head size - short Q = 8, // queries per threadgroup - short KV = 8, // key/value processed per each simdgroup - short C = 32> // cache items per threadgroup -kernel void kernel_flash_attn_ext( + short DK, // K head size + short DV, // V head size + short Q, // queries per threadgroup + short C, // cache items per threadgroup + short NSG> // number of simd groups +void kernel_flash_attn_ext_impl( constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, device const char * k, device const char * v, device const char * mask, device const char * sinks, + device const char * pad, + device const char * blk, device char * dst, - threadgroup half * shmem_f16 [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups + threadgroup half * shmem_f16, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const ushort iq3 = tgpig[2]; + const ushort iq2 = tgpig[1]; + const ushort iq1 = tgpig[0]*Q; - const int iq3 = tgpig[2]; - const int iq2 = tgpig[1]; - const int iq1 = tgpig[0]*Q; +#define NS10 (FC_flash_attn_ext_ns10) +#define NS20 (FC_flash_attn_ext_ns20) + + // note: I had some concerns that using this instead of the ugly macros above was affecting performance + // need to re-check carefully and if no regressions are observerd - remove the macros + // the concerns is that maybe using const variables requires extra registers? but not sure if the compiler + // is clever enough to avoid this. unfortunately, using constexpr is not possible with FC + //const short NS10 = FC_flash_attn_ext_ns10; + //const short NS20 = FC_flash_attn_ext_ns20; + + constexpr short KV = 8; constexpr short DK4 = DK/4; constexpr short DK8 = DK/8; constexpr short DK16 = DK/16; constexpr short DV4 = DV/4; - constexpr short DV8 = DV/8; + //constexpr short DV8 = DV/8; constexpr short DV16 = DV/16; + constexpr short PV = PAD2(DV, 64); + constexpr short PV4 = PV/4; + constexpr short PV8 = PV/8; + //constexpr short PV16 = PV/16; + constexpr short NW = N_SIMDWIDTH; - constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) + constexpr short NQ = Q/NSG; + constexpr short SH = 2*C; // shared memory per simdgroup (s_t == float) - const short TS = nsg*SH; // shared memory size per query in (s_t == float) - const short T = 2*DK + 2*TS; // shared memory size per query in (half) + constexpr short TS = 2*SH; + constexpr short T = DK + 2*PV; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*T); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper) + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK); + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix + threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t - threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory - threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t - threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory - threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - o8x8_t lo[DV8]; + // mask storage in shared mem + threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C); + + // per-query mask pointers + device const half2 * pm2[NQ]; + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); + } + + { + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne11 + C - 1)/C); + + blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0; + } + + { + q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += ikv2*args.nb12 + ikv3*args.nb13; + v += ikv2*args.nb22 + ikv3*args.nb23; + } // load heads from Q to shared memory - for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01); for (short i = tiisg; i < DK4; i += NW) { if (iq1 + j < args.ne01) { @@ -6824,43 +7377,30 @@ kernel void kernel_flash_attn_ext( } } - // zero out lo - for (short i = 0; i < DV8; ++i) { - lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f); - } + // zero out + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] = 0; + } - // zero out shared memory SH - for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*TS + i] = 0.0f; + ss[j*SH + i] = 0.0f; } } threadgroup_barrier(mem_flags::mem_threadgroup); + float S[NQ] = { [0 ... NQ-1] = 0.0f }; + { - float S[Q] = { [0 ... Q-1] = 0.0f }; - float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 }; - - // thread indices inside the simdgroup - // TODO: see if we can utilize quad-group functions for better performance - // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3) - const short tx = tiisg%4; - const short ty = tiisg/4; - - // broadcast kv - //const short rk2 = args.ne02/args.ne12; - //const short rk3 = args.ne03/args.ne13; - - const short ikv2 = iq2/(args.ne02/args.ne_12_2); - const short ikv3 = iq3/(args.ne03/args.ne_12_3); - - const bool has_mask = mask != q; + float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 }; float slope = 1.0f; // ALiBi - if (args.max_bias > 0.0f) { + if (FC_flash_attn_ext_has_bias) { const short h = iq2; const float base = h < args.n_head_log2 ? args.m0 : args.m1; @@ -6871,177 +7411,354 @@ kernel void kernel_flash_attn_ext( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { - const int ic = ic0 + C*sgitg; + for (int ic0 = 0; ; ++ic0) { + int ic = ic0*C; if (ic >= args.ne11) { break; } - if (has_mask) { - // used to detect blocks full of -INF - float smax = -INFINITY; + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; - // load the mask in shared memory - #pragma unroll(Q) - for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); - const float m = pm[ic + tiisg]; + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; - ss[j*TS + C + tiisg] = m; - smax = max(smax, m); + if (!FC_flash_attn_ext_has_mask) { + threadgroup half * sm = (threadgroup half *) (sm2); + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + for (short i = tiisg; i < C; i += NW) { + if (ic + i >= args.ne11) { + sm[2*j*SH + i] = -MAXHALF; + } + } + } + } else { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + pm2[jj] = (device const half2 *) ((device const half *) mask + + (iq1 + j)*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32)); + } } - smax = simd_max(smax); + ic = 0; + } + + // read the mask into shared mem + if (FC_flash_attn_ext_has_mask) { + if (blk[ic0] == 0) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + pm2[jj] += NW; + } - if (smax == -INFINITY) { continue; } + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + if (FC_flash_attn_ext_bc_mask) { + sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); + } else { + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + } + + pm2[jj] += NW; + } + +#if 0 + // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // used to detect blocks full of -INF + // skip only when the entire threadgroup is masked + half2 smax2(-MAXHALF/2, -MAXHALF/2); + + FOR_UNROLL (short j = 0; j < Q; ++j) { + smax2 = max(smax2, sm2[j*SH + tiisg]); + } + + smax2 = simd_max(smax2); + + if (max(smax2[0], smax2[1]) <= -MAXHALF/2) { + // this barrier is important + threadgroup_barrier(mem_flags::mem_threadgroup); + + continue; + } +#endif } // Q*K^T - { - for (short cc = 0; cc < C/8; ++cc) { + // this is compile-time check, so it does not have runtime overhead + if (is_same::value) { + // we can read directly from global memory + device const k_t * pk = (device const k_t *) (k + ic*args.nb11); + threadgroup const q_t * pq = sq; + threadgroup s_t * ps = ss; + + pk += sgitg*(8*NS10); + ps += sgitg*(8*1); + + static_assert((C/8) % NSG == 0, ""); + + constexpr short NC = (C/8)/NSG; + + // note: do not unroll for large heads + #pragma unroll (DK <= 64 ? NC : 1) + for (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - // this is compile-time check, so it does not have runtime overhead - if (is_same::value) { - // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + if (DK % 16 != 0) { + k8x8_t mk; + q8x8_t mq; - #pragma unroll(DK8) - for (short i = 0; i < DK8; ++i) { - k8x8_t mk; - simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + FOR_UNROLL (short i = 0; i < DK8; ++i) { + simdgroup_barrier(mem_flags::mem_none); + + simdgroup_load(mk, pk + 8*i, NS10, 0, true); + simdgroup_load(mq, pq + 8*i, DK); + + simdgroup_barrier(mem_flags::mem_none); - q8x8_t mq; - simdgroup_load(mq, sq + i*8, DK); simdgroup_multiply_accumulate(mqk, mq, mk, mqk); } } else { - for (short ii = 0; ii < DK16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + k8x8_t mk[2]; + q8x8_t mq[2]; - if (DK16%4 == 0) { - // the head is evenly divisible by 4*16 = 64, so no need for bound checks - { - k4x4_t tmp; - deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); - sk4x4[4*ty + tx] = tmp; - } + FOR_UNROLL (short i = 0; i < DK8/2; ++i) { + simdgroup_barrier(mem_flags::mem_none); - simdgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); + simdgroup_load(mq[1], pq + 1*8 + 16*i, DK); - #pragma unroll(4) - for (short k = 0; k < 4; ++k) { - k8x8_t mk; - q8x8_t mq; + simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true); + simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true); - simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - } - } else { - if (ii + tx < DK16) { - k4x4_t tmp; - deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); - sk4x4[4*ty + tx] = tmp; - } + simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); + simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); + } + } - simdgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_store(mqk, ps, SH, 0, false); - for (short k = 0; k < 4 && ii + k < DK16; ++k) { - k8x8_t mk; - q8x8_t mq; + pk += 8*(NSG*NS10); + ps += 8*(NSG); + } + } else { + // TODO: this is the quantized K cache branch - not optimized yet + for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) { + const short cc = ccc*NSG + sgitg; - simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + const short tx = tiisg%4; + const short ty = tiisg/4; - simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - } + qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); + + for (short ii = 0; ii < DK16; ii += 4) { + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11)); + + if (DK16%4 == 0) { + // the head is evenly divisible by 4*16 = 64, so no need for bound checks + { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short k = 0; k < 4; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + if (ii + tx < DK16) { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < DK16; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); } } } - // cast qk_t -> s_t - //s8x8_t mqks(1.0f); - //simdgroup_multiply(mqks, mqk, mqks); - //simdgroup_store(mqks, ss + 8*cc, TS, 0, false); - - simdgroup_store(mqk, ss + 8*cc, TS, 0, false); + simdgroup_store(mqk, ss + 8*cc, SH, 0, false); } } + threadgroup_barrier(mem_flags::mem_threadgroup); + // online softmax - { - for (ushort j = 0; j < Q; ++j) { - const float m = M[j]; + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - // scale and apply the logitcap / mask - float s = ss[j*TS + tiisg]*args.scale; + const float m = M[jj]; - if (args.logit_softcap != 0.0f) { - s = args.logit_softcap*precise::tanh(s); + // scale and apply the logitcap / mask + float2 s2 = ss2[j*SH/2 + tiisg]*args.scale; + + if (FC_flash_attn_ext_has_scap) { + s2 = args.logit_softcap*precise::tanh(s2); + } + + // mqk = mqk + slope*mask + if (FC_flash_attn_ext_has_bias) { + s2 += s2_t(sm2[j*SH + tiisg])*slope; + } else { + s2 += s2_t(sm2[j*SH + tiisg]); + } + + M[jj] = simd_max(max(M[jj], max(s2[0], s2[1]))); + + const float ms = exp(m - M[jj]); + const float2 vs2 = exp(s2 - M[jj]); + + S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]); + + // the P matrix from the paper (Q rows, C columns) + ss2[j*SH/2 + tiisg] = vs2; + + if (DV4 % NW == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { + const short i = ii*NW + tiisg; + + so4[j*PV4 + i] *= ms; } - - // mqk = mqk + mask*slope - s += slope*ss[j*TS + C + tiisg]; - - M[j] = simd_max(max(M[j], s)); - - const float ms = exp(m - M[j]); - const float vs = exp(s - M[j]); - - S[j] = S[j]*ms + simd_sum(vs); - - // the P matrix from the paper (Q rows, C columns) - ss[j*TS + tiisg] = vs; - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*TS + 2*C + j] = ms; + } else { + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] *= ms; } } } - // O = diag(ms)*O - { - s8x8_t ms; - simdgroup_load(ms, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_multiply(lo[i], ms, lo[i]); - } - } + threadgroup_barrier(mem_flags::mem_threadgroup); // O = O + (Q*K^T)*V { - for (short cc = 0; cc < C/8; ++cc) { - s8x8_t vs; - simdgroup_load(vs, ss + 8*cc, TS, 0, false); + // we can read directly from global memory + if (is_same::value) { + static_assert(PV8 % NSG == 0, ""); - if (is_same::value) { - // we can read directly from global memory - device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + constexpr short NO = PV8/NSG; - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - v8x8_t mv; - simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 + o8x8_t lo[NO]; - simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]); + { + auto sot = so + 8*sgitg; + + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + simdgroup_load(lo[ii], sot, PV, 0, false); + + sot += 8*NSG; } - } else { - for (short ii = 0; ii < DV16; ii += 4) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + } + + { + device const v_t * pv = (device const v_t *) (v + ic*args.nb21); + + pv += 8*sgitg; + + if (DV <= 64) { + FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, SH, 0, false); + + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[2]; + + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false); + + simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]); + } + + pv += 8*NS20; + } + } else { + FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) { + s8x8_t vs[2]; + + simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); + simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false); + + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[4]; + + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]); + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]); + } + + pv += 2*8*NS20; + } + } + } + + { + auto sot = so + 8*sgitg; + + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + simdgroup_store(lo[ii], sot, PV, 0, false); + + sot += 8*NSG; + } + } + } else { + // TODO: this is the quantized V cache branch - not optimized yet + + const short tx = tiisg%4; + const short ty = tiisg/4; + + for (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, SH, 0, false); + + for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) { + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21)); if (DV16%4 == 0) { // no need for bound checks @@ -7053,15 +7770,20 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); - #pragma unroll(4) - for (short k = 0; k < 4; ++k) { - v8x8_t mv; + FOR_UNROLL (short k = 0; k < 4; ++k) { + v8x8_t mv[2]; + o8x8_t lo[2]; - simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); + simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); - simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); + simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]); + simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]); + + simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); } } else { if (ii + tx < DV16) { @@ -7073,236 +7795,252 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); for (short k = 0; k < 4 && ii + k < DV16; ++k) { - v8x8_t mv; + v8x8_t mv[2]; + o8x8_t lo[2]; - simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); + simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); - simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); + simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]); + simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]); + + simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); } } } } } } + + threadgroup_barrier(mem_flags::mem_threadgroup); } - if (sinks != q && sgitg == 0) { - for (ushort j = 0; j < Q; ++j) { - const float m = M[j]; + if (FC_flash_attn_ext_has_sinks) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + const float m = M[jj]; const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; - M[j] = simd_max(max(M[j], s)); + M[jj] = simd_max(max(M[jj], s)); - const float ms = exp(m - M[j]); - const float vs = exp(s - M[j]); + const float ms = exp(m - M[jj]); + const float vs = exp(s - M[jj]); - S[j] = S[j]*ms + simd_sum(vs); + S[jj] = S[jj]*ms + simd_sum(vs); - if (tiisg == j) { - ss[j*TS + 2*C + j] = ms; - } - } - - // O = diag(ms)*O - { - s8x8_t ms; - simdgroup_load(ms, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_multiply(lo[i], ms, lo[i]); + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] *= ms; } } } - - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (short j = tiisg; j < Q; j += NW) { - ss[j*TS + 0] = S[j]; - ss[j*TS + 1] = M[j]; - } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation - threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK); - - // store result to shared memory in F32 - if (sgitg == 0) { - for (short i = 0; i < DV8; ++i) { - //simdgroup_store(lo[i], so + i*8, DV, 0, false); - simdgroup_float8x8 t(1.0f); - simdgroup_multiply(t, lo[i], t); - simdgroup_store(t, so + i*8, DV, 0, false); + // store to global memory + for (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + if (iq1 + j >= args.ne01) { + break; } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // reduce the warps sequentially - for (ushort sg = 1; sg < nsg; ++sg) { - if (sgitg == sg) { - for (short j = tiisg; j < Q; j += NW) { - const float S0 = ss[j*TS - 1*SH + 0]; - const float S1 = ss[j*TS + 0]; - - const float M0 = ss[j*TS - 1*SH + 1]; - const float M1 = ss[j*TS + 1]; - - const float M = max(M0, M1); - - float ms0 = exp(M0 - M); - float ms1 = exp(M1 - M); - - const float S = S0*ms0 + S1*ms1; - - ss[j*TS + 0] = S; - ss[j*TS + 1] = M; - - ss[j*TS + 2*C + j - 1*SH] = ms0; - ss[j*TS + 2*C + j ] = ms1; - } - - //simdgroup_barrier(mem_flags::mem_threadgroup); - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - { - s8x8_t ms0; - s8x8_t ms1; - - simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false); - simdgroup_load(ms1, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_float8x8 t; - - simdgroup_load (t, so + i*8, DV, 0, false); - simdgroup_multiply(t, ms0, t); - - simdgroup_multiply_accumulate(t, ms1, lo[i], t); - simdgroup_store(t, so + i*8, DV, 0, false); - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK); - - // final rescale with 1/S and store to global memory - for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { - const float S = 1.0f/sf[j*TS + 0]; device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; - for (short i = tiisg; i < DV4; i += NW) { - dst4[i] = (float4) so4[j*DV4 + i]*S; + const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj]; + + if (DV4 % NW == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { + const short i = ii*NW + tiisg; + + dst4[i] = (float4) so4[j*PV4 + i]*scale; + } + } else { + for (short i = tiisg; i < DV4; i += NW) { + dst4[i] = (float4) so4[j*PV4 + i]*scale; + } } } + +#undef NS10 +#undef NS20 +} + +template< + typename q_t, // query types in shared memory + typename q4_t, + typename q8x8_t, + typename k_t, // key types in shared memory + typename k4x4_t, + typename k8x8_t, + typename v_t, // value types in shared memory + typename v4x4_t, + typename v8x8_t, + typename qk_t, // Q*K types + typename qk8x8_t, + typename s_t, // soft-max types + typename s2_t, + typename s8x8_t, + typename o_t, // attention accumulation types + typename o4_t, + typename o8x8_t, + typename kd4x4_t, // key type in device memory + short nl_k, + void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), + typename vd4x4_t, // value type in device memory + short nl_v, + void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), + short DK, // K head size + short DV, // V head size + short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup +kernel void kernel_flash_attn_ext( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device const char * sinks, + device const char * pad, + device const char * blk, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C +#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg + switch (FC_flash_attn_ext_nsg) { + // note: disabled cases to reduce library load time + //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; + //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; + case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; + } +#undef FWD_TMPL +#undef FWD_ARGS } // TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as // template to be able to explore different combinations // #define FA_TYPES \ - float, float4, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ - half, half4, simdgroup_half8x8 - //float, float4, simdgroup_float8x8 + float, float2, simdgroup_float8x8, \ + float, float4, simdgroup_float8x8 + //half, half4, simdgroup_half8x8 #define FA_TYPES_BF \ bfloat, bfloat4, simdgroup_bfloat8x8, \ bfloat, bfloat4x4, simdgroup_bfloat8x8, \ bfloat, bfloat4x4, simdgroup_bfloat8x8, \ float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ + float, float2, simdgroup_float8x8, \ half, half4, simdgroup_half8x8 //float, float4, simdgroup_float8x8 typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif -template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES #undef FA_TYPES_BF +constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]]; +constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; +constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]]; +constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]]; +constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]]; + +//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]]; +//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]]; +//constant float FC_flash_attn_ext_vec_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 12)]]; + +constant int32_t FC_flash_attn_ext_vec_ns10 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 20)]]; +constant int32_t FC_flash_attn_ext_vec_ns20 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 21)]]; +constant int32_t FC_flash_attn_ext_vec_nsg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 22)]]; +constant int32_t FC_flash_attn_ext_vec_nwg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 23)]]; + template< typename q4_t, // query types in shared memory typename k4_t, // key types in shared memory @@ -7319,60 +8057,89 @@ template< void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), short DK, // K head size short DV, // V head size - short NE = 4, // head elements per thread - short Q = 1, // queries per threadgroup - short C = 32> // cache items per threadgroup -kernel void kernel_flash_attn_ext_vec( - constant ggml_metal_kargs_flash_attn_ext & args, + short NE, // head elements per thread + short Q, // queries per threadgroup + short C, // cache items per threadgroup + short NSG> // number of simd groups +void kernel_flash_attn_ext_vec_impl( + constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, device const char * k, device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups + static_assert(DK % 32 == 0, "DK must be divisible by 32"); + static_assert(DV % 32 == 0, "DV must be divisible by 32"); - const int iq3 = tgpig[2]; - const int iq2 = tgpig[1]; - const int iq1 = tgpig[0]; +#define NWG (FC_flash_attn_ext_vec_nwg) + +#define NS10 (FC_flash_attn_ext_vec_ns10) +#define NS20 (FC_flash_attn_ext_vec_ns20) + + const short iwg = tgpig[2]%NWG; + + const ushort iq3 = tgpig[2]/NWG; + const ushort iq2 = tgpig[1]; + const ushort iq1 = tgpig[0]; constexpr short DK4 = DK/4; constexpr short DV4 = DV/4; + + constexpr short PK = PAD2(DK, 128); + constexpr short PK4 = PK/4; + + constexpr short PV = PAD2(DV, 128); + constexpr short PV4 = PV/4; + constexpr short NW = N_SIMDWIDTH; constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads constexpr short SH = 4*C; // shared memory per simdgroup - const short T = DK + nsg*SH; // shared memory size per query in (half) + static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); + static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t - threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask - threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results + const short T = PK + NSG*SH; // shared memory size per query in (half) - // store the result for all queries in local memory (the O matrix from the paper) - o4_t lo[DV4/NL]; + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results + + // store the result for all queries in shared memory (the O matrix from the paper) + so4 += tiisg; + + { + q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += ikv2*args.nb12 + ikv3*args.nb13; + v += ikv2*args.nb22 + ikv3*args.nb23; + } // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q); - for (short i = tiisg; i < DK4; i += NW) { - if (iq1 < args.ne01) { + for (short i = tiisg; i < PK4; i += NW) { + if (iq1 < args.ne01 && i < DK4) { sq4[i] = (q4_t) q4[i]; } else { sq4[i] = (q4_t) 0.0f; } } - // zero out lo + // zero out so for (short i = 0; i < DV4/NL; ++i) { - lo[i] = (o4_t) 0.0f; + so4[i*NL] = (o4_t) 0.0f; } // zero out shared memory SH @@ -7384,28 +8151,19 @@ kernel void kernel_flash_attn_ext_vec( { float S = 0.0f; - float M = -__FLT_MAX__/2; + float M = -FLT_MAX/2; // thread indices inside the simdgroup const short tx = tiisg%NL; const short ty = tiisg/NL; - // broadcast kv - //const short rk2 = args.ne02/args.ne12; - //const short rk3 = args.ne03/args.ne13; - - const short ikv2 = iq2/(args.ne02/args.ne_12_2); - const short ikv3 = iq3/(args.ne03/args.ne_12_3); - - const bool has_mask = mask != q; - // pointer to the mask device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); float slope = 1.0f; // ALiBi - if (args.max_bias > 0.0f) { + if (FC_flash_attn_ext_vec_has_bias) { const short h = iq2; const float base = h < args.n_head_log2 ? args.m0 : args.m1; @@ -7416,13 +8174,39 @@ kernel void kernel_flash_attn_ext_vec( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { - const int ic = ic0 + C*sgitg; + for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) { + int ic = ic0*C; if (ic >= args.ne11) { break; } - if (has_mask) { + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; + + if (!FC_flash_attn_ext_vec_has_mask) { + if (ic + tiisg >= args.ne11) { + sm[tiisg] = -MAXHALF; + } + } else { + pm = (device const half *) (mask) + + iq1*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32); + } + + ic = 0; + } + + if (FC_flash_attn_ext_vec_has_mask) { sm[tiisg] = pm[ic + tiisg]; } @@ -7433,70 +8217,82 @@ kernel void kernel_flash_attn_ext_vec( // Q*K^T { - // each simdgroup processes 1 query and NE (NW/NL) head elements - for (short cc = 0; cc < C/NE; ++cc) { - qk_t mqk = 0.0f; + device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11); + threadgroup const q4_t * pq4 = sq4; - device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + pk4 += ty*NS10/4 + tx; + pq4 += tx; - #pragma unroll(DK4/NL) - for (short ii = 0; ii < DK4; ii += NL) { - const short i = ii + tx; + qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f }; + + // each simdgroup processes 1 query and NE (NW/NL) cache elements + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + if (is_same::value) { + FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { + mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); + } + } else { + device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11)); k4_t mk; - deq_k_t4(pk + i/nl_k, i%nl_k, mk); - // note: this is less precise than the version below - //mqka[0] += dot(mq[0], mk[0]); - //mqka[1] += dot(mq[1], mk[1]); - //mqka[2] += dot(mq[2], mk[2]); - //mqka[3] += dot(mq[3], mk[3]); + FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { + const short i = ii*NL + tx; - //q4x4_t mq = sq4x4[i]; - //mqka[0] += dot((float4) mq[0], (float4) mk[0]); - //mqka[1] += dot((float4) mq[1], (float4) mk[1]); - //mqka[2] += dot((float4) mq[2], (float4) mk[2]); - //mqka[3] += dot((float4) mq[3], (float4) mk[3]); + deq_k_t4(pk + i/nl_k, i%nl_k, mk); - mqk += dot((float4) mk, (float4) sq4[i]); + mqk[cc] += dot((float4) mk, (float4) sq4[i]); + } } - static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails - - // simdgroup reduce (NE = 4) - // [ 0 .. 7] -> [ 0] - // [ 8 .. 15] -> [ 8] - // [16 .. 23] -> [16] - // [24 .. 31] -> [24] - if (NE <= 1) { - mqk += simd_shuffle_down(mqk, 16); - } - if (NE <= 2) { - mqk += simd_shuffle_down(mqk, 8); - } - if (NE <= 4) { - mqk += simd_shuffle_down(mqk, 4); - } - if (NE <= 8) { - mqk += simd_shuffle_down(mqk, 2); - } - if (NE <= 16) { - mqk += simd_shuffle_down(mqk, 1); - } - - // mqk = mqk*scale + mask*slope - if (tx == 0) { - mqk *= args.scale; - - if (args.logit_softcap != 0.0f) { - mqk = args.logit_softcap*precise::tanh(mqk); + if (NE == 1) { + mqk[cc] = simd_sum(mqk[cc]); + } else { + // simdgroup reduce (NE = 4) + // [ 0 .. 7] -> [ 0] + // [ 8 .. 15] -> [ 8] + // [16 .. 23] -> [16] + // [24 .. 31] -> [24] + if (NE <= 1) { + mqk[cc] += simd_shuffle_down(mqk[cc], 16); + } + if (NE <= 2) { + mqk[cc] += simd_shuffle_down(mqk[cc], 8); + } + if (NE <= 4) { + mqk[cc] += simd_shuffle_down(mqk[cc], 4); + } + if (NE <= 8) { + mqk[cc] += simd_shuffle_down(mqk[cc], 2); + } + if (NE <= 16) { + mqk[cc] += simd_shuffle_down(mqk[cc], 1); } - mqk += sm[NE*cc + ty]*slope; - - ss[NE*cc + ty] = mqk; + // broadcast + mqk[cc] = simd_shuffle(mqk[cc], NL*ty); } } + + if (FC_flash_attn_ext_vec_has_mask && + !FC_flash_attn_ext_vec_has_scap && + !FC_flash_attn_ext_vec_has_bias) { + ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]); + } else { + mqk[tx] *= args.scale; + + if (FC_flash_attn_ext_vec_has_scap) { + mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]); + } + + if (FC_flash_attn_ext_vec_has_bias) { + mqk[tx] += (qk_t) sm[NE*tx + ty]*slope; + } else { + mqk[tx] += (qk_t) sm[NE*tx + ty]; + } + + ss[NE*tx + ty] = mqk[tx]; + } } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -7517,9 +8313,10 @@ kernel void kernel_flash_attn_ext_vec( ss[tiisg] = vs; // O = diag(ms)*O - #pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - lo[ii/NL] *= ms; + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] *= ms; + } } } @@ -7527,26 +8324,84 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { - //#pragma unroll(C/NE) - for (short cc = 0; cc < C/NE; ++cc) { - device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + o4_t lo[DV4/NL]; + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + lo[ii] = 0.0f; + } - const s4_t ms(ss[NE*cc + ty]); + if (is_same::value) { + device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21); - #pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - const short i = ii + tx; + pv4 += ty*NS20/4 + tx; - v4_t mv; - deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + const auto sst = ss + ty; - lo[ii/NL] += o4_t(float4(mv)*float4(ms)); + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE])); + } + } + } else { + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21)); + + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + const short i = ii*NL + tx; + + v4_t mv; + deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + + lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty])); + } + } + } + + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + if (NE > 1) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 16); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 16); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 16); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 16); + } + + if (NE > 2) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 8); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 8); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 8); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 8); + } + + if (NE > 4) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 4); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 4); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 4); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 4); + } + + if (NE > 8) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 2); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 2); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 2); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 2); + } + + if (NE > 16) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 1); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 1); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 1); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 1); + } + } + + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] += lo[ii]; } } } } - if (sinks != q && sgitg == 0) { + if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) { const float m = M; const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; @@ -7557,9 +8412,10 @@ kernel void kernel_flash_attn_ext_vec( S = S*ms + simd_sum(vs); -#pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - lo[ii/NL] *= ms; + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] *= ms; + } } } @@ -7570,63 +8426,12 @@ kernel void kernel_flash_attn_ext_vec( } } - // simdgroup reduce (NE = 4) - // [ 0, 8, 16, 24] -> [ 0] - // [ 1, 9, 17, 25] -> [ 1] - // [ 2, 10, 18, 26] -> [ 2] - // [ 3, 11, 19, 27] -> [ 3] - // [ 4, 12, 20, 28] -> [ 4] - // [ 5, 13, 21, 29] -> [ 5] - // [ 6, 14, 22, 30] -> [ 6] - // [ 7, 15, 23, 31] -> [ 7] - for (short ii = 0; ii < DV4; ii += NL) { - if (NE > 1) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); - } - - if (NE > 2) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); - } - - if (NE > 4) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); - } - - if (NE > 8) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); - } - - if (NE > 16) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // store results to shared memory - for (short i = tiisg; i < DV4; i += NL) { - sr4[i] = lo[i/NL]; - } + so4 -= tiisg; threadgroup_barrier(mem_flags::mem_threadgroup); // parallel reduce - for (short r = nsg/2; r > 0; r >>= 1) { + for (short r = NSG/2; r > 0; r >>= 1) { if (sgitg < r) { const float S0 = ss[ 0]; const float S1 = ss[r*(SH/2) + 0]; @@ -7648,23 +8453,87 @@ kernel void kernel_flash_attn_ext_vec( // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 for (short i = tiisg; i < DV4; i += NW) { - sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1; + so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1; } } threadgroup_barrier(mem_flags::mem_threadgroup); } - device float4 * dst4 = (device float4 *) dst; - // final rescale with 1/S and store to global memory if (sgitg == 0) { - const float S = ss[0]; + const int64_t nrows = args.ne3*args.ne2*args.ne1; + const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1; + device float4 * dst4 = (device float4 *) dst; + device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results + + const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f; + + // interleave the workgroup data for (short i = tiisg; i < DV4; i += NW) { - dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S; + dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S; + } + + // store S and M + if (NWG > 1) { + if (tiisg == 0) { + dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0]; + dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1]; + } } } + +#undef NWG +#undef NS10 +#undef NS20 +} + +template< + typename q4_t, // query types in shared memory + typename k4_t, // key types in shared memory + typename v4_t, // value types in shared memory + typename qk_t, // Q*K types + typename s_t, // soft-max types + typename s4_t, + typename o4_t, // attention accumulation types + typename kd4_t, // key type in device memory + short nl_k, + void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), + typename vd4_t, // value type in device memory + short nl_v, + void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), + short DK, // K head size + short DV, // V head size + short NE = 4, // head elements per thread + short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( + constant ggml_metal_kargs_flash_attn_ext_vec & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device const char * sinks, + device const char * pad, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C +#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg + switch (FC_flash_attn_ext_vec_nsg) { + // note: disabled cases to reduce library load time + case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + case 2: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + case 4: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 8: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 16: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 32: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + } +#undef FWD_TMPL +#undef FWD_ARGS } // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem @@ -7680,126 +8549,135 @@ kernel void kernel_flash_attn_ext_vec( typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; -template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #undef FA_TYPES -template -kernel void kernel_set( - constant ggml_metal_kargs_set & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i13 = tgpig[2]; - const int i12 = tgpig[1]; - const int i11 = tgpig[0]; +constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]]; +constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]]; - const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10; +kernel void kernel_flash_attn_ext_vec_reduce( + constant ggml_metal_kargs_flash_attn_ext_vec_reduce & args, + device const char * htmp, + device char * dst, + uint tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define NWG (FC_flash_attn_ext_vec_reduce_NWG) +#define DV (FC_flash_attn_ext_vec_reduce_DV) - const int64_t i3 = n / (args.ne12*args.ne11*args.ne10); - const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10); - const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10; + const uint64_t rid = tgpig; - device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs); + const short iwg = tiisg; - for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) { - device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10); - dst_data[i10] = (T) src[0]; + device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*NWG; + + float S = ss[rid*(2*NWG) + 2*iwg + 0]; + float M = ss[rid*(2*NWG) + 2*iwg + 1]; + + const float m = simd_max(M); + const float ms = exp(M - m); + + S = simd_sum(S*ms); + S = S == 0.0f ? 0.0f : 1.0f/S; + + const short DV4 = DV/4; + + device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG; + device float4 * dst4 = (device float4 *) dst + rid*DV4; + + for (short i = sgitg; i < DV4; i += NWG) { + const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms); + + if (iwg == 0) { + dst4[i] = v*S; + } } + +#undef NWG +#undef DV } -typedef decltype(kernel_set) kernel_set_t; - -template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set; -template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set; - template -kernel void kernel_cpy( +kernel void kernel_cpy_t_t( constant ggml_metal_kargs_cpy & args, device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 tptg[[threads_per_threadgroup]]) { + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x; - - if (i01 >= args.ne01) { - return; - } + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -7810,188 +8688,70 @@ kernel void kernel_cpy( device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; + break; } } -typedef decltype(kernel_cpy) kernel_cpy_t; +typedef decltype(kernel_cpy_t_t) kernel_cpy_t; -template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -// TODO: templetify these kernels -kernel void kernel_cpy_f32_q8_0( +template +kernel void kernel_cpy_f32_q( constant ggml_metal_kargs_cpy & args, device const char * src0, - device char * dst, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; - device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); - quantize_q8_0(src, dst_data[i00/QK8_0]); + quantize_func(src, dst_data[i00]); + + break; } } -kernel void kernel_cpy_f32_q4_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; +typedef decltype(kernel_cpy_f32_q) cpy_f_q_t; - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0; - - device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_0(src, dst_data[i00/QK4_0]); - } -} - -kernel void kernel_cpy_f32_q4_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1; - - device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_1(src, dst_data[i00/QK4_1]); - } -} - -kernel void kernel_cpy_f32_q5_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0; - - device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_0(src, dst_data[i00/QK5_0]); - } -} - -kernel void kernel_cpy_f32_q5_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1; - - device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_1(src, dst_data[i00/QK5_1]); - } -} - -kernel void kernel_cpy_f32_iq4_nl( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL; - - device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_iq4_nl(src, dst_data[i00/QK4_NL]); - } -} +template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q; template kernel void kernel_cpy_q_f32( @@ -7999,11 +8759,12 @@ kernel void kernel_cpy_q_f32( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -8015,10 +8776,12 @@ kernel void kernel_cpy_q_f32( device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { T4x4 temp; dequantize_func(src_data + i00/nl, i00%nl, temp); dst_data[i00] = temp; + + break; } } @@ -8067,7 +8830,7 @@ kernel void kernel_concat( } } -template +template void kernel_mul_mv_q2_K_f32_impl( args_t args, device const char * src0, @@ -8077,13 +8840,15 @@ void kernel_mul_mv_q2_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -8167,10 +8932,10 @@ kernel void kernel_mul_mv_q2_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q3_K_f32_impl( args_t args, device const char * src0, @@ -8180,6 +8945,7 @@ void kernel_mul_mv_q3_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; @@ -8187,7 +8953,7 @@ void kernel_mul_mv_q3_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -8331,10 +9097,10 @@ kernel void kernel_mul_mv_q3_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q4_K_f32_impl( args_t args, device const char * src0, @@ -8344,9 +9110,11 @@ void kernel_mul_mv_q4_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; + const short NSG = FC_mul_mv_nsg; + + constexpr uint16_t kmask1 = 0x3f3f; + constexpr uint16_t kmask2 = 0x0f0f; + constexpr uint16_t kmask3 = 0xc0c0; const short ix = tiisg/8; // 0...3 const short it = tiisg%8; // 0...7 @@ -8359,7 +9127,7 @@ void kernel_mul_mv_q4_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -8405,7 +9173,7 @@ void kernel_mul_mv_q4_K_f32_impl( float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (short i = 0; i < 4; ++i) { + FOR_UNROLL (short i = 0; i < 4; ++i) { acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F); acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00); acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0); @@ -8416,14 +9184,11 @@ void kernel_mul_mv_q4_K_f32_impl( acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000); } - float dall = dh[0]; - float dmin = dh[1]; - - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + - (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + - (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + - (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - + dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); q1 += args.nb01/2; sc += args.nb01/2; @@ -8453,10 +9218,10 @@ kernel void kernel_mul_mv_q4_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q5_K_f32_impl( args_t args, device const char * src0, @@ -8466,6 +9231,7 @@ void kernel_mul_mv_q5_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; @@ -8473,7 +9239,7 @@ void kernel_mul_mv_q5_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -8488,9 +9254,9 @@ void kernel_mul_mv_q5_K_f32_impl( float yl[16], yh[16]; - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; + constexpr uint16_t kmask1 = 0x3f3f; + constexpr uint16_t kmask2 = 0x0f0f; + constexpr uint16_t kmask3 = 0xc0c0; const short tid = tiisg/4; const short ix = tiisg%4; @@ -8536,7 +9302,7 @@ void kernel_mul_mv_q5_K_f32_impl( float4 acc1 = {0.f}; float4 acc2 = {0.f}; - for (short l = 0; l < 8; ++l) { + FOR_UNROLL (short l = 0; l < 8; ++l) { uint8_t h = qh[l]; acc1[0] += yl[l+0] * (q1[l] & 0x0F); acc1[1] += yl[l+8] * (q1[l] & 0xF0); @@ -8547,13 +9313,12 @@ void kernel_mul_mv_q5_K_f32_impl( acc2[2] += h & hm3 ? yh[l+0] : 0.f; acc2[3] += h & hm4 ? yh[l+8] : 0.f; } - const float dall = dh[0]; - const float dmin = dh[1]; - sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + - sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + - sc8[4] * (acc1[2] + 16.f*acc2[2]) + - sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + sumf[row] += dh[0] * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + + sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + + sc8[4] * (acc1[2] + 16.f*acc2[2]) + + sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - + dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); q1 += args.nb01; qh += args.nb01; @@ -8584,10 +9349,10 @@ kernel void kernel_mul_mv_q5_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q6_K_f32_impl( args_t args, device const char * src0, @@ -8597,11 +9362,12 @@ void kernel_mul_mv_q6_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; - const uint8_t kmask1 = 0x03; - const uint8_t kmask2 = 0x0C; - const uint8_t kmask3 = 0x30; - const uint8_t kmask4 = 0xC0; + constexpr uint8_t kmask1 = 0x03; + constexpr uint8_t kmask2 = 0x0C; + constexpr uint8_t kmask3 = 0x30; + constexpr uint8_t kmask4 = 0xC0; const int nb = args.ne00/QK_K; @@ -8609,7 +9375,7 @@ void kernel_mul_mv_q6_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -8652,18 +9418,16 @@ void kernel_mul_mv_q6_K_f32_impl( } for (short row = 0; row < nr0; ++row) { - const float dall = dh[0]; - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (short l = 0; l < 4; ++l) { + FOR_UNROLL (short l = 0; l < 4; ++l) { sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); } - sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); q1 += args.nb01; q2 += args.nb01; @@ -8693,12 +9457,12 @@ kernel void kernel_mul_mv_q6_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit -template +template void kernel_mul_mv_iq2_xxs_f32_impl( args_t args, device const char * src0, @@ -8708,13 +9472,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -8801,10 +9567,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_xs_f32_impl( args_t args, device const char * src0, @@ -8814,13 +9580,15 @@ void kernel_mul_mv_iq2_xs_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -8918,10 +9686,10 @@ kernel void kernel_mul_mv_iq2_xs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_xxs_f32_impl( args_t args, device const char * src0, @@ -8931,13 +9699,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -9028,10 +9798,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_s_f32_impl( args_t args, device const char * src0, @@ -9041,13 +9811,15 @@ void kernel_mul_mv_iq3_s_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -9138,10 +9910,10 @@ kernel void kernel_mul_mv_iq3_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_s_f32_impl( args_t args, device const char * src0, @@ -9151,13 +9923,15 @@ void kernel_mul_mv_iq2_s_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -9249,10 +10023,10 @@ kernel void kernel_mul_mv_iq2_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq1_s_f32_impl( args_t args, device const char * src0, @@ -9262,13 +10036,15 @@ void kernel_mul_mv_iq1_s_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -9346,10 +10122,10 @@ kernel void kernel_mul_mv_iq1_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq1_m_f32_impl( args_t args, device const char * src0, @@ -9359,6 +10135,7 @@ void kernel_mul_mv_iq1_m_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; @@ -9366,7 +10143,7 @@ void kernel_mul_mv_iq1_m_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -9454,10 +10231,10 @@ kernel void kernel_mul_mv_iq1_m_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, device const char * src0, @@ -9467,15 +10244,15 @@ void kernel_mul_mv_iq4_nl_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -9486,6 +10263,9 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK4_NL; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -9493,24 +10273,25 @@ void kernel_mul_mv_iq4_nl_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK4_NL + it * 8; + device const float * yb = y + ix*QK4_NL + it*8; uint32_t aux32[2]; thread const uint8_t * q8 = (thread const uint8_t *)aux32; float4 qf1, qf2; - for (int ib = ix; ib < nb; ib += 16) { + // [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; row++) { - device const block_iq4_nl & xb = x[row*nb + ib]; + for (short row = 0; row < NR0; row++) { + device const block_iq4_nl & xb = x[row*ns01 + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -9541,7 +10322,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -9560,10 +10341,10 @@ kernel void kernel_mul_mv_iq4_nl_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, device const char * src0, @@ -9573,13 +10354,14 @@ void kernel_mul_mv_iq4_xs_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -9590,6 +10372,9 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_K; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/16; // 0 or 1 const short it = tiisg%16; // 0...15 const short ib = it/2; @@ -9599,7 +10384,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -9608,15 +10393,16 @@ void kernel_mul_mv_iq4_xs_f32_impl( float4 qf1, qf2; - for (int ibl = ix; ibl < nb; ibl += 2) { + // [TAG_MUL_MV_WEIRD] + for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; ++row) { - device const block_iq4_xs & xb = x[row*nb + ibl]; + for (short row = 0; row < NR0; ++row) { + device const block_iq4_xs & xb = x[row*ns01 + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -9646,7 +10432,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -9665,10 +10451,10 @@ kernel void kernel_mul_mv_iq4_xs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_mxfp4_f32_impl( args_t args, device const char * src0, @@ -9678,15 +10464,15 @@ void kernel_mul_mv_mxfp4_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_MXFP4; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -9697,6 +10483,9 @@ void kernel_mul_mv_mxfp4_f32_impl( device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_MXFP4; + const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -9704,20 +10493,22 @@ void kernel_mul_mv_mxfp4_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK_MXFP4 + it * 8; + device const float * yb = y + ix*QK_MXFP4 + it*8; + + // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster + // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { + device const float4 * y4 = (device const float4 *) yb; - for (int ib = ix; ib < nb; ib += 16) { - device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; -#pragma unroll(nr0) - for (short row = 0; row < nr0; row++) { - device const block_mxfp4 & xb = x[row*nb + ib]; + FOR_UNROLL (short row = 0; row < NR0; row++) { + device const block_mxfp4 & xb = x[row*ns01 + ib]; device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it); float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]); @@ -9735,7 +10526,7 @@ void kernel_mul_mv_mxfp4_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -9754,76 +10545,70 @@ kernel void kernel_mul_mv_mxfp4_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template kernel void kernel_get_rows_q( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) { + auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); + + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { float4x4 temp; - dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp; + dequantize_func(psrc + ind/nl, ind%nl, temp); + pdst[ind] = temp; + + break; } } -template +template kernel void kernel_get_rows_f( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; + auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); + + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { + pdst[ind] = psrc[ind]; + + break; } } -kernel void kernel_get_rows_i32( - constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device int32_t * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; - } -} - -template +template kernel void kernel_set_rows_q32( constant ggml_metal_kargs_set_rows & args, device const void * src0, @@ -9844,7 +10629,7 @@ kernel void kernel_set_rows_q32( } const int32_t i10 = i01; - const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; + const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); @@ -9854,7 +10639,7 @@ kernel void kernel_set_rows_q32( } } -template +template kernel void kernel_set_rows_f( constant ggml_metal_kargs_set_rows & args, device const void * src0, @@ -9875,9 +10660,9 @@ kernel void kernel_set_rows_f( } const int32_t i10 = i01; - const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; + const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; - device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); + device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) { @@ -9885,6 +10670,9 @@ kernel void kernel_set_rows_f( } } +constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; +constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; + #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B #define BLOCK_SIZE_K 32 @@ -9897,7 +10685,7 @@ kernel void kernel_set_rows_f( #define SG_MAT_ROW 8 // each block_q contains 16*nl weights -template +template kernel void kernel_mul_mm( constant ggml_metal_kargs_mul_mm & args, device const char * src0, @@ -9908,8 +10696,8 @@ kernel void kernel_mul_mm( ushort tiitg[[thread_index_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup T * sa = (threadgroup T *)(shmem); - threadgroup float * sb = (threadgroup float *)(shmem + 4096); + threadgroup S0 * sa = (threadgroup S0 *)(shmem); + threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); const int r0 = tgpig.y; const int r1 = tgpig.x; @@ -9923,8 +10711,9 @@ kernel void kernel_mul_mm( const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_T8x8 ma[4]; - simdgroup_float8x8 mb[2]; + S0_8x8 ma[4]; + S1_8x8 mb[2]; + simdgroup_float8x8 mc[8]; for (short i = 0; i < 8; i++){ @@ -9942,27 +10731,45 @@ kernel void kernel_mul_mm( device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; - device const float * y = (device const float *)(src1 + const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)); + + device const T1 * y = (device const T1 *)(src1 + args.nb13*i13 + args.nb12*i12 + args.nb11*(r1*BLOCK_SIZE_N + thread_col) - + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + args.nb10*iy); for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory - T4x4 temp_a; - dequantize_func(x, il, temp_a); + if (is_same::value && FC_mul_mm_bc_inp) { + threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup_barrier(mem_flags::mem_threadgroup); + // no need for dequantization + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0; + } + } else { + S0_4x4 temp_a; + dequantize_func(x, il, temp_a); - #pragma unroll(16) - for (short i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ - + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ - + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + } } - *(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y); + if (FC_mul_mm_bc_inp) { + for (short i = 0; i < 8; ++i) { + sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0; + } + } else { + *(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y)); + } il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; @@ -9971,23 +10778,25 @@ kernel void kernel_mul_mm( threadgroup_barrier(mem_flags::mem_threadgroup); // load matrices from threadgroup memory and conduct outer products - threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); - threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); + threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); #pragma unroll(4) for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(4) for (short i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } - simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) for (short i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(8) for (short i = 0; i < 8; i++){ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); @@ -9998,7 +10807,8 @@ kernel void kernel_mul_mm( } } - if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) { + if (!FC_mul_mm_bc_out || ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1)) { + // if no bounds checks on the output are needed, we can directly write to device memory device float * C = (device float *) dst + (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; @@ -10039,124 +10849,111 @@ kernel void kernel_mul_mm( } } -template +template // n_expert_used kernel void kernel_mul_mm_id_map0( constant ggml_metal_kargs_mul_mm_id_map0 & args, - device const char * src1, device const char * src2, - device char * hsrc1, device char * htpe, device char * hids, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int ide = tgpig[0]; // expert id + threadgroup char * shmem [[threadgroup(0)]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort ntg[[threads_per_threadgroup]]) { + const short ide = tpitg; // expert id - int n_all = 0; + uint32_t n_all = 0; - device int32_t * ids_i32 = (device int32_t *) (hids); + device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21; - for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens - device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21); + for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens + if (i21 + tpitg < args.ne21) { + device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21); - for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used - if (src2_i32[i20] != ide) { - continue; + threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20; + + #pragma unroll(ne20) + for (short i20 = 0; i20 < ne20; i20++) { + sids[i20] = src2_i32[i20]; } - - device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11); - device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11); - - for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) { - hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]); - } - - if (tpitg.x == 0) { - ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all; - } - - ++n_all; } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short t = 0; t < ntg; t++) { + if (i21 + t >= args.ne21) { + break; + } + + threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20; + + short sel = 0; + #pragma unroll(ne20) + for (short i20 = 0; i20 < ne20; i20++) { + sel += (sids[i20] == ide)*(i20 + 1); + } + + ids_i32[n_all] = (i21 + t)*ne20 + sel - 1; + + n_all += sel > 0; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); } - if (tpitg.x == 0) { - device int32_t * tpe_i32 = (device int32_t *) (htpe); - tpe_i32[ide] = n_all; - } + device uint32_t * tpe_u32 = (device uint32_t *) (htpe); + tpe_u32[ide] = n_all; } -typedef decltype(kernel_mul_mm_id_map0) kernel_mul_mm_id_map0_t; +typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t; -template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0; +template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>; +template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>; +template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>; +template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>; +template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; +template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; +template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>; -template -kernel void kernel_mul_mm_id_map1( - constant ggml_metal_kargs_mul_mm_id_map1 & args, - device const char * hdst, - device const char * hids, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i20 = tgpig[0]; // used expert - const int i21 = tgpig[1]; // token - - device const int32_t * ids_i32 = (device const int32_t *) (hids); - device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2); - - const int id = ids_i32[i21*args.ne20 + i20]; - - const int ide = id / args.neh1; - const int idt = id % args.neh1; - - device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2); - - for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) { - dst_f32x4[i0] = hdst_f32x4[i0]; - } -} - -typedef decltype(kernel_mul_mm_id_map1) kernel_mul_mm_id_map1_t; - -template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1; - -template +template kernel void kernel_mul_mm_id( constant ggml_metal_kargs_mul_mm_id & args, device const char * src0, device const char * src1, - device const char * tpe, + device const char * htpe, + device const char * hids, device char * dst, threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup T * sa = (threadgroup T *)(shmem); - threadgroup half * sb = (threadgroup half *)(shmem + 4096); + threadgroup S0 * sa = (threadgroup S0 *)(shmem); + threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); const int r0 = tgpig.y; const int r1 = tgpig.x; - const int im = tgpig.z; + const int im = tgpig.z; // expert - device const int32_t * tpe_i32 = (device const int32_t *) (tpe); + device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe); + device const int32_t * ids_i32 = (device const int32_t *) (hids); - const int neh1 = tpe_i32[im]; + const int32_t neh1 = tpe_u32[im]; if (r1*BLOCK_SIZE_N >= neh1) { return; } // if this block is of 64x32 shape or smaller - const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; - const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; + const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; // a thread shouldn't load data outside of the matrix const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_T8x8 ma[4]; - simdgroup_half8x8 mb[2]; + S0_8x8 ma[4]; + S1_8x8 mb[2]; + simdgroup_float8x8 mc[8]; for (short i = 0; i < 8; i++){ @@ -10165,36 +10962,57 @@ kernel void kernel_mul_mm_id( short il = (tiitg % THREAD_PER_ROW); - const int i12 = im%args.neh12; - const int i13 = im/args.neh12; + const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col]; - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const short i11 = (id % args.ne20) % args.ne11; + const short i12 = (id / args.ne20); + const short i13 = 0; + + const uint64_t offset0 = im*args.nb02 + i13*args.nb03; const short offset1 = il/nl; device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; - device const half * y = (device const half *)(src1 - + args.nbh13*i13 - + args.nbh12*i12 - + args.nbh11*(r1*BLOCK_SIZE_N + thread_col) - + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)); + + device const T1 * y = (device const T1 *)(src1 + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*i11 + + args.nb10*iy); for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory - T4x4 temp_a; - dequantize_func(x, il, temp_a); + if (is_same::value && FC_mul_mm_bc_inp) { + threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup_barrier(mem_flags::mem_threadgroup); + // no need for dequantization + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0; + } + } else { + S0_4x4 temp_a; + dequantize_func(x, il, temp_a); - #pragma unroll(16) - for (short i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ - + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ - + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + } } - *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y); + if (FC_mul_mm_bc_inp) { + for (short i = 0; i < 8; ++i) { + sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0; + } + } else { + *(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y)); + } il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; @@ -10203,8 +11021,8 @@ kernel void kernel_mul_mm_id( threadgroup_barrier(mem_flags::mem_threadgroup); // load matrices from threadgroup memory and conduct outer products - threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); - threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); + threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); #pragma unroll(4) for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { @@ -10230,43 +11048,38 @@ kernel void kernel_mul_mm_id( } } - if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) { - device float * C = (device float *) dst + - (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ - (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0; + threadgroup_barrier(mem_flags::mem_threadgroup); - for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0); - } - } else { - // block is smaller than 64x32, we should avoid writing data outside of the matrix - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *) shmem) \ - + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; - for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; + + #pragma unroll(8) + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short j = sgitg; j < n_cols; j += 4) { + const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j]; + + const short ide = id % args.ne20; + const short idt = id / args.ne20; + + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = tiisg; + for (; i < n_rows/4; i += 32) { + *(D4 + i) = *(C4 + i); } - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (sgitg == 0) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0; - device float4 * D4 = (device float4 *) D; - - threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); - threadgroup float4 * C4 = (threadgroup float4 *) C; - - int i = 0; - for (; i < n_rows/4; i++) { - *(D4 + i) = *(C4 + i); - } - - i *= 4; - for (; i < n_rows; i++) { - *(D + i) = *(C + i); - } - } + i = (4*(n_rows/4)) + tiisg; + for (; i < n_rows; i += 32) { + *(D + i) = *(C + i); } } } @@ -10277,12 +11090,13 @@ kernel void kernel_mul_mm_id( // get rows // -typedef decltype(kernel_get_rows_f) get_rows_f_t; +typedef decltype(kernel_get_rows_f) get_rows_f_t; -template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; #endif typedef decltype(kernel_get_rows_q) get_rows_q_t; @@ -10312,93 +11126,153 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get // set rows // -typedef decltype(kernel_set_rows_f) set_rows_f_t; +typedef decltype(kernel_set_rows_f) set_rows_f_t; -template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f; -template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f; +template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f; +template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f; +template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f; +template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f; +template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f; #endif -typedef decltype(kernel_set_rows_q32) set_rows_q32_t; +typedef decltype(kernel_set_rows_q32) set_rows_q32_t; -template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32; -template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32; -template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32; -template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32; -template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32; -template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; // // matrix-matrix multiplication // -typedef decltype(kernel_mul_mm) mul_mm_t; +typedef decltype(kernel_mul_mm) mul_mm_t; -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; #endif -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm; + +template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm; +#endif +template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication // -typedef decltype(kernel_mul_mm_id) mul_mm_id; +typedef decltype(kernel_mul_mm_id) mul_mm_id; -template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id; #endif -template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +#endif +template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; // // matrix-vector multiplication // -typedef void (kernel_mul_mv_impl_t)( +typedef void (kernel_mul_mv_disp_t)( ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, @@ -10406,7 +11280,7 @@ typedef void (kernel_mul_mv_impl_t)( uint3 tgpig, ushort tiisg); -typedef void (kernel_mul_mv2_impl_t)( +typedef void (kernel_mul_mv2_disp_t)( ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, @@ -10416,7 +11290,7 @@ typedef void (kernel_mul_mv2_impl_t)( ushort tiisg, ushort sgitg); -template +template void mmv_fn( ggml_metal_kargs_mul_mv args, device const char * src0, @@ -10427,10 +11301,10 @@ void mmv_fn( ushort tiitg, ushort tiisg, ushort sgitg) { - impl_fn(args, src0, src1, dst, tgpig, tiisg); + disp_fn(args, src0, src1, dst, tgpig, tiisg); } -template +template void mmv_fn( ggml_metal_kargs_mul_mv args, device const char * src0, @@ -10441,12 +11315,12 @@ void mmv_fn( ushort tiitg, ushort tiisg, ushort sgitg) { - impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(mmv_fn>) mul_mv_impl_fn_t; +typedef decltype(mmv_fn>) mul_mv_disp_fn_t; -template +template kernel void kernel_mul_mv_id( constant ggml_metal_kargs_mul_mv_id & args, device const char * src0s, @@ -10493,11 +11367,12 @@ kernel void kernel_mul_mv_id( /*.nb13 =*/ args.nb12, // ne12 == 1 /*.ne0 =*/ args.ne0, /*.ne1 =*/ 1, // args.ne1, + /*.nr0 =*/ args.nr0, /*.r2 =*/ 1, /*.r3 =*/ 1, }; - impl_fn( + disp_fn( args0, /* src0 */ src0_cur, /* src1 */ src1_cur, @@ -10509,44 +11384,52 @@ kernel void kernel_mul_mv_id( sgitg); } -typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; -template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_4_t; + +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +#endif +template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; #endif -template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; kernel void kernel_pool_2d_max_f32( + constant ggml_metal_kargs_pool_2d & args, device const float * src0, device float * dst, - constant ggml_metal_kargs_pool_2d & args, uint gid[[thread_position_in_grid]]) { - if (gid >= args.parallel_elements) { + if (gid >= args.np) { return; } @@ -10579,12 +11462,12 @@ kernel void kernel_pool_2d_max_f32( } kernel void kernel_pool_2d_avg_f32( + constant ggml_metal_kargs_pool_2d & args, device const float * src0, device float * dst, - constant ggml_metal_kargs_pool_2d & args, uint gid[[thread_position_in_grid]]) { - if (gid >= args.parallel_elements) { + if (gid >= args.np) { return; } @@ -10618,3 +11501,51 @@ kernel void kernel_pool_2d_avg_f32( o_ptr[cur_oh * args.OW + cur_ow] = res; } + +kernel void kernel_opt_step_adamw_f32( + constant ggml_metal_kargs_opt_step_adamw & args, + device float * x, + device const float * g, + device float * g_m, + device float * g_v, + device const float * pars, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.np) { + return; + } + + const float alpha = pars[0]; + const float beta1 = pars[1]; + const float beta2 = pars[2]; + const float eps = pars[3]; + const float wd = pars[4]; + const float beta1h = pars[5]; + const float beta2h = pars[6]; + + const float gi = g[gid]; + const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1); + const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2); + + g_m[gid] = gmi; + g_v[gid] = gvi; + + const float mh = gmi * beta1h; + const float vh = sqrt(gvi * beta2h) + eps; + + x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh; +} + +kernel void kernel_opt_step_sgd_f32( + constant ggml_metal_kargs_opt_step_sgd & args, + device float * x, + device const float * g, + device const float * pars, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.np) { + return; + } + + x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid]; +} diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h index fc6526d6..a448c14f 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h @@ -20,8 +20,8 @@ #define N_R0_Q5_1 4 #define N_SG_Q5_1 2 -#define N_R0_Q8_0 4 -#define N_SG_Q8_0 2 +#define N_R0_Q8_0 2 +#define N_SG_Q8_0 4 #define N_R0_MXFP4 2 #define N_SG_MXFP4 2 @@ -32,13 +32,13 @@ #define N_R0_Q3_K 2 #define N_SG_Q3_K 2 -#define N_R0_Q4_K 4 +#define N_R0_Q4_K 2 #define N_SG_Q4_K 2 #define N_R0_Q5_K 2 #define N_SG_Q5_K 2 -#define N_R0_Q6_K 1 +#define N_R0_Q6_K 2 #define N_SG_Q6_K 2 #define N_R0_IQ1_S 4 @@ -68,6 +68,22 @@ #define N_R0_IQ4_XS 2 #define N_SG_IQ4_XS 2 +// function constants offsets +#define FC_FLASH_ATTN_EXT_PAD 100 +#define FC_FLASH_ATTN_EXT_BLK 200 +#define FC_FLASH_ATTN_EXT 300 +#define FC_FLASH_ATTN_EXT_VEC 400 +#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500 +#define FC_MUL_MV 600 +#define FC_MUL_MM 700 + +// op-specific constants +#define OP_FLASH_ATTN_EXT_NQPTG 8 +#define OP_FLASH_ATTN_EXT_NCPSG 64 + +#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1 +#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage @@ -161,6 +177,17 @@ typedef struct { } ggml_metal_kargs_repeat; typedef struct { + float scale; + float bias; +} ggml_metal_kargs_scale; + +typedef struct { + float min; + float max; +} ggml_metal_kargs_clamp; + +typedef struct { + int64_t nk0; int64_t ne00; int64_t ne01; int64_t ne02; @@ -227,12 +254,6 @@ typedef struct { } ggml_metal_kargs_rope; typedef struct { - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; int32_t ne11; int32_t ne_12_2; // assume K and V are same shape int32_t ne_12_3; @@ -242,6 +263,44 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_pad; + +typedef struct { + int32_t ne01; + int32_t ne30; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_blk; + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + int32_t ns10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ns20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ne31; int32_t ne32; int32_t ne33; uint64_t nb31; @@ -249,6 +308,7 @@ typedef struct { uint64_t nb33; int32_t ne1; int32_t ne2; + int32_t ne3; float scale; float max_bias; float m0; @@ -257,6 +317,45 @@ typedef struct { float logit_softcap; } ggml_metal_kargs_flash_attn_ext; +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + int32_t ns10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ns20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; + int32_t ne1; + int32_t ne2; + int32_t ne3; + float scale; + float max_bias; + float m0; + float m1; + int32_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext_vec; + +typedef struct { + int32_t nrows; +} ggml_metal_kargs_flash_attn_ext_vec_reduce; + typedef struct { int32_t ne00; int32_t ne02; @@ -291,6 +390,7 @@ typedef struct { uint64_t nb13; int32_t ne0; int32_t ne1; + int32_t nr0; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mv; @@ -314,46 +414,34 @@ typedef struct { int32_t ne1; int16_t r2; int16_t r3; - int16_t nsg; - int16_t nxpsg; - int16_t r1ptg; } ggml_metal_kargs_mul_mv_ext; typedef struct { + int32_t ne02; int32_t ne10; int32_t ne11; // n_expert_used (bcast) uint64_t nb11; uint64_t nb12; - int32_t neh11; // n_tokens - uint64_t nbh11; + int32_t ne21; // n_tokens int32_t ne20; // n_expert_used uint64_t nb21; } ggml_metal_kargs_mul_mm_id_map0; -typedef struct { - int32_t ne20; // n_expert_used - int32_t neh0; - int32_t neh1; - uint64_t nbh1; - uint64_t nbh2; - int32_t ne0; - uint64_t nb1; - uint64_t nb2; -} ggml_metal_kargs_mul_mm_id_map1; - typedef struct { int32_t ne00; int32_t ne02; uint64_t nb01; uint64_t nb02; uint64_t nb03; - int32_t neh12; - uint64_t nbh10; - uint64_t nbh11; - uint64_t nbh12; - uint64_t nbh13; - int32_t neh0; - int32_t neh1; + int32_t ne11; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne20; + int32_t ne21; + int32_t ne0; + int32_t ne1; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mm_id; @@ -378,18 +466,14 @@ typedef struct { int32_t ne0; int32_t ne1; uint64_t nb1; + int32_t nr0; } ggml_metal_kargs_mul_mv_id; +// NORM +// RMS_NORM typedef struct { int32_t ne00; - int32_t ne00_4; - uint64_t nb01; - float eps; -} ggml_metal_kargs_norm; - -typedef struct { - int32_t ne00; - int32_t ne00_4; + int32_t ne00_t; uint64_t nb1; uint64_t nb2; uint64_t nb3; @@ -400,7 +484,7 @@ typedef struct { uint64_t nbf1[3]; uint64_t nbf2[3]; uint64_t nbf3[3]; -} ggml_metal_kargs_rms_norm; +} ggml_metal_kargs_norm; typedef struct { int32_t ne00; @@ -416,7 +500,7 @@ typedef struct { uint64_t nb00; uint64_t nb01; uint64_t nb02; - int32_t n_groups; + int32_t ngrp; float eps; } ggml_metal_kargs_group_norm; @@ -460,6 +544,10 @@ typedef struct{ float limit; } ggml_metal_kargs_glu; +typedef struct { + uint64_t np; +} ggml_metal_kargs_sum; + typedef struct { int64_t ne00; int64_t ne01; @@ -469,14 +557,6 @@ typedef struct { uint64_t nb01; uint64_t nb02; uint64_t nb03; - int64_t ne10; - int64_t ne11; - int64_t ne12; - int64_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; int64_t ne0; int64_t ne1; int64_t ne2; @@ -510,12 +590,6 @@ typedef struct { int32_t n_head_log2; } ggml_metal_kargs_soft_max; -typedef struct { - int64_t ne00; - int64_t ne01; - int n_past; -} ggml_metal_kargs_diag_mask_inf; - typedef struct { int64_t ne00; int64_t ne01; @@ -542,33 +616,46 @@ typedef struct { int64_t n_group; int64_t n_seq_tokens; int64_t n_seqs; - int64_t s_off; + uint64_t s_off; + uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; + uint64_t nb10; uint64_t nb11; uint64_t nb12; + uint64_t ns12; uint64_t nb13; + uint64_t nb20; uint64_t nb21; + uint64_t ns21; uint64_t nb22; + int64_t ne30; uint64_t nb31; uint64_t nb41; uint64_t nb42; + uint64_t ns42; uint64_t nb43; uint64_t nb51; uint64_t nb52; + uint64_t ns52; uint64_t nb53; + uint64_t nb0; } ggml_metal_kargs_ssm_scan; typedef struct { - int64_t ne00; + int32_t ne00t; + int32_t ne00; uint64_t nb01; uint64_t nb02; - int64_t ne10; + uint64_t nb03; + int32_t ne10; uint64_t nb10; uint64_t nb11; + uint64_t nb12; uint64_t nb1; uint64_t nb2; + uint64_t nb3; } ggml_metal_kargs_get_rows; typedef struct { @@ -682,7 +769,20 @@ typedef struct { int64_t IW; int64_t OH; int64_t OW; - int64_t parallel_elements; + int64_t np; } ggml_metal_kargs_pool_2d; +typedef struct { + int64_t ne00; + uint64_t nb01; +} ggml_metal_kargs_argmax; + +typedef struct { + int64_t np; +} ggml_metal_kargs_opt_step_adamw; + +typedef struct { + int64_t np; +} ggml_metal_kargs_opt_step_sgd; + #endif // GGML_METAL_IMPL diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp new file mode 100644 index 00000000..a61ea8fb --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -0,0 +1,3509 @@ +#include "ggml-metal-ops.h" + +#include "ggml.h" +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include "ggml-metal-impl.h" +#include "ggml-metal-common.h" +#include "ggml-metal-device.h" + +#include +#include + +static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) { + if (!t) { + return { nullptr, 0 }; + } + + ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; + + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t) buffer->context; + + return ggml_metal_buffer_get_id(ctx, t); +} + +struct ggml_metal_op { + ggml_metal_op( + ggml_metal_device_t dev, + ggml_metal_cmd_buf_t cmd_buf, + ggml_cgraph * gf, + int idx_start, + int idx_end, + bool use_fusion, + bool use_concurrency, + bool use_capture, + int debug_graph, + int debug_fusion) { + this->dev = dev; + this->lib = ggml_metal_device_get_library(dev); + this->enc = ggml_metal_encoder_init(cmd_buf, use_concurrency); + this->mem_ranges = ggml_mem_ranges_init(debug_graph); + this->idx_start = idx_start; + this->idx_end = idx_end; + this->use_fusion = use_fusion; + this->use_concurrency = use_concurrency; + this->use_capture = use_capture; + this->debug_graph = debug_graph; + this->debug_fusion = debug_fusion; + this->gf = gf; + + idxs.reserve(gf->n_nodes); + + // filter empty nodes + // TODO: this can be removed when the allocator starts filtering them earlier + // https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830 + for (int i = idx_start; i < idx_end; i++) { + if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) { + idxs.push_back(i); + } + } + } + + ~ggml_metal_op() { + ggml_metal_encoder_end_encoding(this->enc); + ggml_metal_encoder_free(this->enc); + ggml_mem_ranges_free(this->mem_ranges); + } + + int n_nodes() const { + return idxs.size(); + } + + ggml_tensor * node(int i) const { + assert(i >= 0 && i < (int) idxs.size()); + return ggml_graph_node(gf, idxs[i]); + } + + bool can_fuse(int i0, const ggml_op * ops, int n_ops) const { + assert(use_fusion); + assert(i0 >= 0 && i0 < n_nodes()); + + if (i0 + n_ops > n_nodes()) { + return false; + } + + return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops); + } + + ggml_metal_device_t dev; + ggml_metal_library_t lib; + ggml_metal_encoder_t enc; + ggml_mem_ranges_t mem_ranges; + + bool use_fusion; + bool use_concurrency; + bool use_capture; + + int debug_graph; + int debug_fusion; + +private: + ggml_cgraph * gf; + + int idx_start; + int idx_end; + + // non-empty node indices + std::vector idxs; +}; + +ggml_metal_op_t ggml_metal_op_init( + ggml_metal_device_t dev, + ggml_metal_cmd_buf_t cmd_buf, + ggml_cgraph * gf, + int idx_start, + int idx_end, + bool use_fusion, + bool use_concurrency, + bool use_capture, + int debug_graph, + int debug_fusion) { + ggml_metal_op_t res = new ggml_metal_op( + dev, + cmd_buf, + gf, + idx_start, + idx_end, + use_fusion, + use_concurrency, + use_capture, + debug_graph, + debug_fusion); + + return res; +} + +void ggml_metal_op_free(ggml_metal_op_t ctx) { + delete ctx; +} + +int ggml_metal_op_n_nodes(ggml_metal_op_t ctx) { + return ctx->n_nodes(); +} + +static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) { + if (!ctx->mem_ranges) { + return true; + } + + ggml_metal_encoder_memory_barrier(ctx->enc); + + ggml_mem_ranges_reset(ctx->mem_ranges); + + return true; +} + +static bool ggml_metal_op_concurrency_check(ggml_metal_op_t ctx, const ggml_tensor * node) { + if (!ctx->mem_ranges) { + return false; + } + + return ggml_mem_ranges_check(ctx->mem_ranges, node); +} + +static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor * node) { + if (!ctx->mem_ranges) { + return true; + } + + return ggml_mem_ranges_add(ctx->mem_ranges, node); +} + +static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { + struct ggml_tensor * node = ctx->node(idx); + + //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op)); + + if (ggml_is_empty(node)) { + return 1; + } + + switch (node->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + { + // noop -> next node + if (ctx->debug_graph > 0) { + GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), "(noop)"); + } + } return 1; + default: + { + } break; + } + + if (!ggml_metal_device_supports_op(ctx->dev, node)) { + GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(node)); + GGML_ABORT("unsupported op"); + } + + int n_fuse = 1; + + // check if the current node can run concurrently with other nodes before it + // the condition is that: + // - the current node cannot write to any previous src or dst ranges + // - the current node cannot read from any previous dst ranges + // + // if the condition is not satisfied, we put a memory barrier and clear all ranges + // otherwise, we add the new ranges to the encoding context and process the node concurrently + // + { + const bool is_concurrent = ggml_metal_op_concurrency_check(ctx, node); + + if (!is_concurrent) { + ggml_metal_op_concurrency_reset(ctx); + } + + if (ctx->debug_graph > 0) { + GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), is_concurrent ? "(concurrent)" : ""); + } + if (ctx->debug_graph > 1) { + GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb); + GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb); + GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb); + GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb); + GGML_TENSOR_LOCALS( int64_t, ne, node, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, node, nb); + + if (node->src[0]) { + GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[0]->type), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, + ggml_is_contiguous(node->src[0]), node->src[0]->name); + } + if (node->src[1]) { + GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, + ggml_is_contiguous(node->src[1]), node->src[1]->name); + } + if (node->src[2]) { + GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23, + ggml_is_contiguous(node->src[2]), node->src[2]->name); + } + if (node->src[3]) { + GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33, + ggml_is_contiguous(node->src[3]), node->src[3]->name); + } + if (node) { + GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, + node->name); + } + } + } + + switch (node->op) { + case GGML_OP_CONCAT: + { + n_fuse = ggml_metal_op_concat(ctx, idx); + } break; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + { + n_fuse = ggml_metal_op_bin(ctx, idx); + } break; + case GGML_OP_ADD_ID: + { + n_fuse = ggml_metal_op_add_id(ctx, idx); + } break; + case GGML_OP_REPEAT: + { + n_fuse = ggml_metal_op_repeat(ctx, idx); + } break; + case GGML_OP_ACC: + { + n_fuse = ggml_metal_op_acc(ctx, idx); + } break; + case GGML_OP_SCALE: + { + n_fuse = ggml_metal_op_scale(ctx, idx); + } break; + case GGML_OP_CLAMP: + { + n_fuse = ggml_metal_op_clamp(ctx, idx); + } break; + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_LOG: + case GGML_OP_UNARY: + { + n_fuse = ggml_metal_op_unary(ctx, idx); + } break; + case GGML_OP_GLU: + { + n_fuse = ggml_metal_op_glu(ctx, idx); + } break; + case GGML_OP_SUM: + { + n_fuse = ggml_metal_op_sum(ctx, idx); + } break; + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + { + n_fuse = ggml_metal_op_sum_rows(ctx, idx); + } break; + case GGML_OP_SOFT_MAX: + { + n_fuse = ggml_metal_op_soft_max(ctx, idx); + } break; + case GGML_OP_SSM_CONV: + { + n_fuse = ggml_metal_op_ssm_conv(ctx, idx); + } break; + case GGML_OP_SSM_SCAN: + { + n_fuse = ggml_metal_op_ssm_scan(ctx, idx); + } break; + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + { + n_fuse = ggml_metal_op_rwkv(ctx, idx); + } break; + case GGML_OP_MUL_MAT: + { + n_fuse = ggml_metal_op_mul_mat(ctx, idx); + } break; + case GGML_OP_MUL_MAT_ID: + { + n_fuse = ggml_metal_op_mul_mat_id(ctx, idx); + } break; + case GGML_OP_GET_ROWS: + { + n_fuse = ggml_metal_op_get_rows(ctx, idx); + } break; + case GGML_OP_SET_ROWS: + { + n_fuse = ggml_metal_op_set_rows(ctx, idx); + } break; + case GGML_OP_L2_NORM: + { + n_fuse = ggml_metal_op_l2_norm(ctx, idx); + } break; + case GGML_OP_GROUP_NORM: + { + n_fuse = ggml_metal_op_group_norm(ctx, idx); + } break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + { + n_fuse = ggml_metal_op_norm(ctx, idx); + } break; + case GGML_OP_ROPE: + { + n_fuse = ggml_metal_op_rope(ctx, idx); + } break; + case GGML_OP_IM2COL: + { + n_fuse = ggml_metal_op_im2col(ctx, idx); + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx); + } break; + case GGML_OP_UPSCALE: + { + n_fuse = ggml_metal_op_upscale(ctx, idx); + } break; + case GGML_OP_PAD: + { + n_fuse = ggml_metal_op_pad(ctx, idx); + } break; + case GGML_OP_PAD_REFLECT_1D: + { + n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx); + } break; + case GGML_OP_ARANGE: + { + n_fuse = ggml_metal_op_arange(ctx, idx); + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + n_fuse = ggml_metal_op_timestep_embedding(ctx, idx); + } break; + case GGML_OP_ARGSORT: + { + n_fuse = ggml_metal_op_argsort(ctx, idx); + } break; + case GGML_OP_LEAKY_RELU: + { + n_fuse = ggml_metal_op_leaky_relu(ctx, idx); + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx); + } break; + case GGML_OP_DUP: + case GGML_OP_CPY: + case GGML_OP_CONT: + { + n_fuse = ggml_metal_op_cpy(ctx, idx); + } break; + case GGML_OP_POOL_2D: + { + n_fuse = ggml_metal_op_pool_2d(ctx, idx); + } break; + case GGML_OP_ARGMAX: + { + n_fuse = ggml_metal_op_argmax(ctx, idx); + } break; + case GGML_OP_OPT_STEP_ADAMW: + { + n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx); + } break; + case GGML_OP_OPT_STEP_SGD: + { + n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx); + } break; + default: + { + GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op)); + GGML_ABORT("fatal error"); + } + } + + if (ctx->debug_graph > 0) { + if (n_fuse > 1) { + GGML_LOG_DEBUG("%s: fuse %d ops\n", __func__, n_fuse); + } + } + + // update the mem ranges in the encoding context + for (int i = 0; i < n_fuse; ++i) { + if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) { + ggml_metal_op_concurrency_reset(ctx); + } + } + + return n_fuse; +} + +int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) { + if (ctx->use_capture) { + ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx))); + } + + int res = ggml_metal_op_encode_impl(ctx, idx); + if (idx + res > ctx->n_nodes()) { + GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s", + "https://github.com/ggml-org/llama.cpp/pull/14849"); + } + + if (ctx->use_capture) { + ggml_metal_encoder_debug_group_pop(ctx->enc); + } + + return res; +} + +int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t dim = ((const int32_t *) op->op_params)[0]; + + ggml_metal_kargs_concat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.dim =*/ dim, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + const int nth = std::min(1024, ne0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type); + + ggml_metal_kargs_repeat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + + const size_t pnb1 = ((const int32_t *) op->op_params)[0]; + const size_t pnb2 = ((const int32_t *) op->op_params)[1]; + const size_t pnb3 = ((const int32_t *) op->op_params)[2]; + const size_t offs = ((const int32_t *) op->op_params)[3]; + + const bool inplace = (bool) ((const int32_t *) op->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + //const id pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj; + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + + ggml_metal_kargs_cpy args = { + /*.nk0 =*/ ne00, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + ggml_metal_op_concurrency_reset(ctx); + } + + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ pnb1, + /*.nb02 =*/ pnb2, + /*.nb03 =*/ pnb3, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + /*.offs =*/ offs, + /*.o1 =*/ { 0 }, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float scale; + float bias; + memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float)); + memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float)); + + ggml_metal_kargs_scale args = { + /*.scale =*/ scale, + /*.bias =*/ bias, + }; + + int64_t n = ggml_nelements(op); + + if (n % 4 == 0) { + n /= 4; + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + + return 1; +} + +int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float min; + float max; + memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float)); + memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float)); + + ggml_metal_kargs_clamp args = { + /*.min =*/ min, + /*.max =*/ max, + }; + + int64_t n = ggml_nelements(op); + + if (n % 4 == 0) { + n /= 4; + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + + return 1; +} + +int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + int64_t n = ggml_nelements(op); + + if (n % 4 == 0) { + n /= 4; + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + + return 1; +} + +int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + if (op->src[1]) { + GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1])); + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_glu(lib, op); + + const int32_t swp = ggml_get_op_params_i32(op, 1); + const float alpha = ggml_get_op_params_f32(op, 2); + const float limit = ggml_get_op_params_f32(op, 3); + + const int32_t i00 = swp ? ne0 : 0; + const int32_t i10 = swp ? 0 : ne0; + + ggml_metal_kargs_glu args = { + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.ne10 =*/ op->src[1] ? ne10 : ne00, + /*.nb11 =*/ op->src[1] ? nb11 : nb01, + /*.ne0 =*/ ne0, + /*.nb1 =*/ nb1, + /*.i00 =*/ op->src[1] ? 0 : i00, + /*.i10 =*/ op->src[1] ? 0 : i10, + /*.alpha=*/ alpha, + /*.limit=*/ limit + }; + + const int64_t nrows = ggml_nrows(op->src[0]); + + const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2); + + //[encoder setComputePipelineState:pipeline]; + //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + //if (src1) { + // [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + //} else { + // [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + //} + //[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + //[encoder setBytes:&args length:sizeof(args) atIndex:3]; + + //[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + if (op->src[1]) { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + } else { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 2); + } + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const uint64_t n = (uint64_t) ggml_nelements(op->src[0]); + + ggml_metal_kargs_sum args = { + /*.np =*/ n, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1); + + return 1; +} + +int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_kargs_sum_rows args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + //[encoder setComputePipelineState:pipeline]; + //[encoder setBytes:&args length:sizeof(args) atIndex:0]; + //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + //[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + //[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + //[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); + + ggml_metal_kargs_get_rows args = { + /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00, + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const int nw0 = (args.ne00t + nth - 1)/nth; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type); + + const int32_t nk0 = ne0/ggml_blck_size(op->type); + + int nth = 32; // SIMD width + + while (nth < nk0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + int nrptg = 1; + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; + + if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nrptg--; + } + } + + nth = std::min(nth, nk0); + + ggml_metal_kargs_set_rows args = { + /*.nk0 =*/ nk0, + /*.ne01 =*/ ne01, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1); + + return 1; +} + +int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float scale; + float max_bias; + + memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias)); + + const uint32_t n_head = op->src[0]->ne[2]; + const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + // softmax + + ggml_metal_kargs_soft_max args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op); + + int nth = 32; // SIMD width + + if (ne00%4 == 0) { + while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + } else { + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + } + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + if (op->src[1]) { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + } else { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2); + } + if (op->src[2]) { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[2]), 3); + } else { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3); + } + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_kargs_ssm_conv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); + + return 1; +} + +int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + GGML_TENSOR_LOCALS( int32_t, ne4, op->src[4], ne); + GGML_TENSOR_LOCALS(uint64_t, nb4, op->src[4], nb); + GGML_TENSOR_LOCALS( int32_t, ne5, op->src[5], ne); + GGML_TENSOR_LOCALS(uint64_t, nb5, op->src[5], nb); + GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne); + GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const ggml_tensor * src3 = op->src[3]; + const ggml_tensor * src4 = op->src[4]; + const ggml_tensor * src5 = op->src[5]; + const ggml_tensor * src6 = op->src[6]; + + GGML_ASSERT(src3); + GGML_ASSERT(src4); + GGML_ASSERT(src5); + GGML_ASSERT(src6); + + const int64_t d_state = ne00; + const int64_t d_inner = ne01; + const int64_t n_head = ne02; + const int64_t n_group = ne41; + const int64_t n_seq_tokens = ne12; + const int64_t n_seqs = ne13; + + ggml_metal_kargs_ssm_scan args = { + /*.d_state =*/ d_state, + /*.d_inner =*/ d_inner, + /*.n_head =*/ n_head, + /*.n_group =*/ n_group, + /*.n_seq_tokens =*/ n_seq_tokens, + /*.n_seqs =*/ n_seqs, + /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float), + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ns12 =*/ nb12/nb10, + /*.nb13 =*/ nb13, + /*.nb20 =*/ nb20, + /*.nb21 =*/ nb21, + /*.ns21 =*/ nb21/nb20, + /*.nb22 =*/ nb22, + /*.ne30 =*/ ne30, + /*.nb31 =*/ nb31, + /*.nb41 =*/ nb41, + /*.nb42 =*/ nb42, + /*.ns42 =*/ nb42/nb40, + /*.nb43 =*/ nb43, + /*.nb51 =*/ nb51, + /*.nb52 =*/ nb52, + /*.ns52 =*/ nb52/nb50, + /*.nb53 =*/ nb53, + /*.nb0 =*/ nb0, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); + + GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const size_t sms = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 6); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); + + return 1; +} + +int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1]; + const int64_t T = op->src[0]->ne[2]; + const int64_t C = op->ne[0]; + const int64_t H = op->src[0]->ne[1]; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); + if (op->op == GGML_OP_RWKV_WKV7) { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++); + } + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++); + + ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1); + + return 1; +} + +int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + + GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); + + int64_t nk0 = ne00; + if (ggml_is_quantized(op->src[0]->type)) { + nk0 = ne00/16; + } else if (ggml_is_quantized(op->type)) { + nk0 = ne00/ggml_blck_size(op->type); + } + + int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + // when rows are small, we can batch them together in a single threadgroup + int nrptg = 1; + + // TODO: relax this constraint in the future + if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) { + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; + + if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nrptg--; + } + } + } + + nth = std::min(nth, nk0); + + ggml_metal_kargs_cpy args = { + /*.nk0 =*/ nk0, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1); + + return 1; +} + +int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int32_t * opts = op->op_params; + ggml_op_pool op_pool = (ggml_op_pool) opts[0]; + + const int32_t k0 = opts[1]; + const int32_t k1 = opts[2]; + const int32_t s0 = opts[3]; + const int32_t s1 = opts[4]; + const int32_t p0 = opts[5]; + const int32_t p1 = opts[6]; + + const int64_t IH = op->src[0]->ne[1]; + const int64_t IW = op->src[0]->ne[0]; + + const int64_t N = op->ne[3]; + const int64_t OC = op->ne[2]; + const int64_t OH = op->ne[1]; + const int64_t OW = op->ne[0]; + + const int64_t np = N * OC * OH * OW; + + ggml_metal_kargs_pool_2d args_pool_2d = { + /* .k0 = */ k0, + /* .k1 = */ k1, + /* .s0 = */ s0, + /* .s1 = */ s1, + /* .p0 = */ p0, + /* .p1 = */ p1, + /* .IH = */ IH, + /* .IW = */ IW, + /* .OH = */ OH, + /* .OW = */ OW, + /* .np = */ np + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); + const int ntg = (np + nth - 1) / nth; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args_pool_2d, sizeof(args_pool_2d), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + GGML_ASSERT(ne00 == ne10); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + const int16_t r2 = ne12/ne02; + const int16_t r3 = ne13/ne03; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + const int ne11_mm_min = 8; + + // first try to use small-batch mat-mv kernels + // these should be efficient for BS [2, ~8] + if (op->src[1]->type == GGML_TYPE_F32 && (ne00%128 == 0) && + ( + ( + ( + op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function + op->src[0]->type == GGML_TYPE_F16 || + op->src[0]->type == GGML_TYPE_Q4_0 || + op->src[0]->type == GGML_TYPE_Q4_1 || + op->src[0]->type == GGML_TYPE_Q5_0 || + op->src[0]->type == GGML_TYPE_Q5_1 || + op->src[0]->type == GGML_TYPE_Q8_0 || + op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_IQ4_NL || + false) && (ne11 >= 2 && ne11 <= 8) + ) || + ( + ( + op->src[0]->type == GGML_TYPE_Q4_K || + op->src[0]->type == GGML_TYPE_Q5_K || + op->src[0]->type == GGML_TYPE_Q6_K || + false) && (ne11 >= 4 && ne11 <= 8) + ) + ) + ) { + // TODO: determine the optimal parameters based on grid utilization + // I still don't know why we should not always use the maximum available threads: + // + // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32 + // + // my current hypothesis is that the work grid is not evenly divisible for different nsg + // values and there can be some tail effects when nsg is high. need to confirm this + // + const int nsg = 2; // num simdgroups per threadgroup + + // num threads along row per simdgroup + int16_t nxpsg = 0; + if (ne00 % 256 == 0 && ne11 < 3) { + nxpsg = 16; + } else if (ne00 % 128 == 0) { + nxpsg = 8; + } else { + nxpsg = 4; + } + + const int16_t nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time) + const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup + int16_t r1ptg = 4; // num src1 rows per threadgroup + + // note: not sure how optimal are those across all different hardware. there might be someting cleverer + switch (ne11) { + case 2: + r1ptg = 2; break; + case 3: + case 6: + r1ptg = 3; break; + case 4: + case 7: + case 8: + r1ptg = 4; break; + case 5: + r1ptg = 5; break; + default: + GGML_ABORT("unsupported ne11"); + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); + + ggml_metal_kargs_mul_mv_ext args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + r0ptg - 1)/r0ptg), ((ne11 + r1ptg - 1)/r1ptg), ne12*ne13, 32, nsg, 1); + } else if ( + !ggml_is_transposed(op->src[0]) && + !ggml_is_transposed(op->src[1]) && + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) { + //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + //switch (op->src[0]->type) { + // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; + // default: break; + //} + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op); + + ggml_metal_kargs_mul_mm args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1); + } else { + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); + + const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); + const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); + const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_kargs_mul_mv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nr0 =*/ nr0, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + if (op->src[0]->type == GGML_TYPE_F32 || + op->src[0]->type == GGML_TYPE_F16 || + op->src[0]->type == GGML_TYPE_BF16 || + op->src[0]->type == GGML_TYPE_Q8_0) { + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1); + } else { + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1); + } + } + + return 1; +} + +size_t ggml_metal_op_mul_mat_id_extra_tpe(const ggml_tensor * op) { + assert(op->op == GGML_OP_MUL_MAT_ID); + + const int64_t ne02 = op->src[0]->ne[2]; // n_expert + + return ggml_type_size(GGML_TYPE_I32)*ne02; +} + +size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) { + assert(op->op == GGML_OP_MUL_MAT_ID); + + const int64_t ne02 = op->src[0]->ne[2]; // n_expert + const int64_t ne21 = op->src[2]->ne[1]; // n_token + + return ggml_type_size(GGML_TYPE_I32)*ne02*ne21; +} + +int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + // src2 = ids + GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32); + + GGML_ASSERT(!ggml_is_transposed(op->src[0])); + GGML_ASSERT(!ggml_is_transposed(op->src[1])); + + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + const uint32_t r2 = 1; + const uint32_t r3 = 1; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + // ne20 = n_used_experts + // ne21 = n_rows (batch size) + const int ne21_mm_id_min = 32; + + if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) { + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + //switch (op->src[0]->type) { + // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; + // default: break; + //} + + // extra buffers for intermediate id mapping + ggml_metal_buffer_id bid_tpe = bid_dst; + bid_tpe.offs += ggml_nbytes(op); + + ggml_metal_buffer_id bid_ids = bid_tpe; + bid_ids.offs += ggml_metal_op_mul_mat_id_extra_tpe(op); + + { + ggml_metal_kargs_mul_mm_id_map0 args = { + ne02, + ne10, + ne11, // n_expert_used (bcast) + nb11, + nb12, + ne21, // n_tokens + ne20, // n_expert_used + nb21, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src2, 1); + ggml_metal_encoder_set_buffer (enc, bid_tpe, 2); + ggml_metal_encoder_set_buffer (enc, bid_ids, 3); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, ne02, 1, 1); + } + + // this barrier is always needed because the next kernel has to wait for the id maps to be computed + ggml_metal_op_concurrency_reset(ctx); + + { + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op); + + ggml_metal_kargs_mul_mm_id args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, // n_expert_used (bcast) + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne20 =*/ ne20, // n_expert_used + /*.ne21 =*/ ne21, // n_tokens + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_tpe, 3); + ggml_metal_encoder_set_buffer (enc, bid_ids, 4); + ggml_metal_encoder_set_buffer (enc, bid_dst, 5); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1); + } + } else { + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); + + const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); + const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); + const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_kargs_mul_mv_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nb1 =*/ nb1, + /*.nr0 =*/ nr0, + }; + + if (ggml_is_quantized(op->src[0]->type)) { + GGML_ASSERT(ne00 >= nsg*nr0); + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, bid_src0, 1); + ggml_metal_encoder_set_buffer(enc, bid_src1, 2); + ggml_metal_encoder_set_buffer(enc, bid_dst, 3); + ggml_metal_encoder_set_buffer(enc, bid_src2, 4); + + const int64_t _ne1 = 1; + const int64_t ne123 = ne20*ne21; + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + if (op->src[0]->type == GGML_TYPE_F32 || + op->src[0]->type == GGML_TYPE_F16 || + op->src[0]->type == GGML_TYPE_BF16 || + op->src[0]->type == GGML_TYPE_Q8_0) { + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1); + } else { + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1); + } + } + + return 1; +} + +int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_kargs_add_id args = { + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb11 =*/ nb11, + /*.nb21 =*/ nb21, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, 1, nth, 1, 1); + + return 1; +} + +bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + const int64_t ne00 = op->src[0]->ne[0]; // head size + const int64_t ne01 = op->src[0]->ne[1]; // batch size + + // use vec kernel if the batch size is small and if the head size is supported + return (ne01 < 20) && (ne00 % 32 == 0); +} + +size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + const bool has_mask = op->src[3] != nullptr; + + if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; + + if (has_kvpad) { + res += OP_FLASH_ATTN_EXT_VEC_NCPSG*( + nb11*ne12*ne13 + + nb21*ne22*ne23 + + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); + } + } else { + const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0; + + if (has_kvpad) { + res += OP_FLASH_ATTN_EXT_NCPSG*( + nb11*ne12*ne13 + + nb21*ne22*ne23 + + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); + } + } + + return res; +} + +size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + const bool has_mask = op->src[3] != nullptr; + + if (!has_mask) { + return res; + } + + const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op); + + // this optimization is not useful for the vector kernels + if (is_vec) { + return res; + } + + const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG; + const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG; + + const int64_t ne1 = (ne01 + nqptg - 1)/nqptg; + const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg; + + res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32); + + return res; +} + +size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + const int64_t nwg = 32; + + // temp buffer for writing the results from each workgroup + // - ne20: the size of the Value head + // - + 2: the S and M values for each intermediate result + res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2)); + } + + return res; +} + +int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS( int32_t, nb, op, nb); + + GGML_ASSERT(ne00 % 4 == 0); + + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == op->src[2]->type); + + //GGML_ASSERT(ggml_are_same_shape (src1, src2)); + GGML_ASSERT(ne11 == ne21); + GGML_ASSERT(ne12 == ne22); + + GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16); + GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] && + "the Flash-Attention Metal kernel requires the mask to be at least n_queries big"); + + float scale; + float max_bias; + float logit_softcap; + + memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias)); + memcpy(&logit_softcap, ((const int32_t *) op->op_params) + 2, sizeof(logit_softcap)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const bool has_mask = op->src[3] != NULL; + const bool has_sinks = op->src[4] != NULL; + const bool has_bias = max_bias != 0.0f; + const bool has_scap = logit_softcap != 0.0f; + + const uint32_t n_head = op->src[0]->ne[2]; + const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + GGML_ASSERT(ne01 < 65536); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]); + ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0; + ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0; + + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_pad = bid_dst; + bid_pad.offs += ggml_nbytes(op); + + ggml_metal_buffer_id bid_blk = bid_pad; + bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op); + + ggml_metal_buffer_id bid_tmp = bid_blk; + bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op); + + if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { + // half8x8 kernel + const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup + const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + bool need_sync = false; + + const bool has_kvpad = ne11 % ncpsg != 0; + + if (has_kvpad) { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); + + ggml_metal_kargs_flash_attn_ext_pad args0 = { + /*.ne11 =*/ne11, + /*.ne_12_2 =*/ne12, + /*.ne_12_3 =*/ne13, + /*.nb11 =*/nb11, + /*.nb12 =*/nb12, + /*.nb13 =*/nb13, + /*.nb21 =*/nb21, + /*.nb22 =*/nb22, + /*.nb23 =*/nb23, + /*.ne31 =*/ne31, + /*.ne32 =*/ne32, + /*.ne33 =*/ne33, + /*.nb31 =*/nb31, + /*.nb32 =*/nb32, + /*.nb33 =*/nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + + assert(ne12 == ne22); + assert(ne13 == ne23); + + ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); + + need_sync = true; + } else { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); + } + + if (has_mask) { + assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0); + + ggml_metal_kargs_flash_attn_ext_blk args0 = { + /*.ne01 =*/ ne01, + /*.ne30 =*/ ne30, + /*.ne31 =*/ ne31, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src3, 1); + ggml_metal_encoder_set_buffer (enc, bid_blk, 2); + + const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg); + const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg); + + ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1); + + need_sync = true; + } else { + assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0); + } + + if (need_sync) { + ggml_metal_op_concurrency_reset(ctx); + } + + const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0; + + // 2*(2*ncpsg) + // ncpsg soft_max values + ncpsg mask values + // + // 16*32*(nsg) + // the shared memory needed for the simdgroups to load the KV cache + // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) + + //int64_t nsgmax = 4; + // + //if (is_q) { + // nsgmax = 2; + // while (true) { + // const size_t smem = FATTN_SMEM(nsgmax); + // if (smem > props_dev->max_theadgroup_memory_size) { + // break; + // } + // nsgmax *= 2; + // } + // nsgmax /= 2; + //} + + // simdgroups per threadgroup (a.k.a. warps) + //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + int32_t nsg = 4; + + const size_t smem = FATTN_SMEM(nsg); + + ggml_metal_kargs_flash_attn_ext args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.ns10 =*/ int32_t(nb11/nb10), + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ns20 =*/ int32_t(nb21/nb20), + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ne31 =*/ ne31, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); + ggml_metal_encoder_set_buffer (enc, bid_pad, 6); + ggml_metal_encoder_set_buffer (enc, bid_blk, 7); + ggml_metal_encoder_set_buffer (enc, bid_dst, 8); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03, 32, nsg, 1); +#undef FATTN_SMEM + } else { + // half4x4 kernel + const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup + const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !! + const int nkpsg = 1*ncpsg; + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + bool need_sync = false; + + const bool has_kvpad = ne11 % ncpsg != 0; + + if (has_kvpad) { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); + + ggml_metal_kargs_flash_attn_ext_pad args0 = { + /*.ne11 =*/ne11, + /*.ne_12_2 =*/ne12, + /*.ne_12_3 =*/ne13, + /*.nb11 =*/nb11, + /*.nb12 =*/nb12, + /*.nb13 =*/nb13, + /*.nb21 =*/nb21, + /*.nb22 =*/nb22, + /*.nb23 =*/nb23, + /*.ne31 =*/ne31, + /*.ne32 =*/ne32, + /*.ne33 =*/ne33, + /*.nb31 =*/nb31, + /*.nb32 =*/nb32, + /*.nb33 =*/nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + + assert(ne12 == ne22); + assert(ne13 == ne23); + + ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); + + need_sync = true; + } else { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); + } + + if (need_sync) { + ggml_metal_op_concurrency_reset(ctx); + } + + // ne00 + 2*ncpsg*(nsg) + // for each query, we load it as f16 in shared memory (ne00) + // and store the soft_max values and the mask + // + // ne20*(nsg) + // each simdgroup has a full f32 head vector in shared mem to accumulate results + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16)) + + int64_t nsgmax = 2; + while (true) { + const size_t smem = FATTN_SMEM(nsgmax); + // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes + if (smem > props_dev->max_theadgroup_memory_size/2) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); + const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32))); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + // workgroups + // each workgroup handles nsg*nkpsg cache values + int32_t nwg = 1; + if (false) { + // for small KV caches, we could launch a single workgroup and write the results directly to dst/ + // however, this does not lead to significant improvement, so disabled + nwg = 1; + nsg = 4; + } else { + nwg = 32; + nsg = 1; + while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) { + nsg *= 2; + } + } + + ggml_metal_kargs_flash_attn_ext_vec args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.ns10 =*/ int32_t(nb11/nb10), + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ns20 =*/ int32_t(nb21/nb20), + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ne31 =*/ ne31, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); + + GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); + + const size_t smem = FATTN_SMEM(nsg); + + //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax); + GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size); + + if (nwg == 1) { + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0); + + // using 1 workgroup -> write the result directly into dst + ggml_metal_encoder_set_buffer(enc, bid_pad, 6); + ggml_metal_encoder_set_buffer(enc, bid_dst, 7); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); + } else { + // sanity checks + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0); + + GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); + GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31)); + + // write the results from each workgroup into a temp buffer + ggml_metal_encoder_set_buffer(enc, bid_pad, 6); + ggml_metal_encoder_set_buffer(enc, bid_tmp, 7); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); + + // sync the 2 kernels + ggml_metal_op_concurrency_reset(ctx); + + // reduce the results from the workgroups + { + const int32_t nrows = ne1*ne2*ne3; + + ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = { + nrows, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, 32*nwg, 1, 1); + } + } +#undef FATTN_SMEM + } + + return 1; +} + +int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const bool use_fusion = ctx->use_fusion; + + const int debug_fusion = ctx->debug_fusion; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); + + bool bcast_row = false; + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.offs =*/ 0, + /*.o1 =*/ { bid_src1.offs }, + }; + + ggml_op fops[8]; + + int n_fuse = 1; + + // c[0] = add(a, b[0]) + // c[1] = add(c[0], b[1]) + // c[2] = add(c[1], b[2]) + // ... + if (use_fusion) { + fops[0] = GGML_OP_ADD; + fops[1] = GGML_OP_ADD; + fops[2] = GGML_OP_ADD; + fops[3] = GGML_OP_ADD; + fops[4] = GGML_OP_ADD; + fops[5] = GGML_OP_ADD; + fops[6] = GGML_OP_ADD; + fops[7] = GGML_OP_ADD; + + // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops + // across splits. idx_end indicates the last node in the current split + for (n_fuse = 0; n_fuse <= 6; ++n_fuse) { + if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) { + break; + } + + ggml_tensor * f0 = ctx->node(idx + n_fuse); + ggml_tensor * f1 = ctx->node(idx + n_fuse + 1); + + if (f0 != f1->src[0]) { + break; + } + + // b[0] === b[1] === ... + if (!ggml_are_same_layout(f0->src[1], f1->src[1])) { + break; + } + + // only fuse ops if src1 is in the same Metal buffer + ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]); + if (bid_fuse.metal != bid_src1.metal) { + break; + } + + //ctx->fuse_cnt[ops[n_fuse + 1]->op]++; + + args.o1[n_fuse + 1] = bid_fuse.offs; + } + + ++n_fuse; + + if (debug_fusion > 1 && n_fuse > 1) { + GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse); + } + } + + // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer + bid_src1.offs = 0; + + ggml_metal_pipeline_t pipeline = nullptr; + + if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true); + + bcast_row = true; + } else { + pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false); + } + + if (n_fuse > 1) { + bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); + + for (int i = 1; i < n_fuse; ++i) { + if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) { + ggml_metal_op_concurrency_reset(ctx); + + break; + } + } + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_dst, 3); + + if (bcast_row) { + const int64_t n = ggml_nelements(op)/4; + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + } else { + int nth = 32; + + while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + } + + return n_fuse; +} + +int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float eps; + memcpy(&eps, op->op_params, sizeof(float)); + + int nth = 32; // SIMD width + + ggml_metal_kargs_l2_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); + + while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00/4); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + const int64_t nrows = ggml_nrows(op->src[0]); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int32_t ngrp = ((const int32_t *) op->op_params)[0]; + + float eps; + memcpy(&eps, op->op_params + 1, sizeof(float)); + + ggml_metal_kargs_group_norm args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ngrp =*/ ngrp, + /*.eps =*/ eps, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op); + + int nth = 32; // SIMD width + //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + // nth *= 2; + //} + + //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + //nth = std::min(nth, ne00/4); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ngrp, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const bool use_fusion = ctx->use_fusion; + + const int debug_fusion = ctx->debug_fusion; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float eps; + memcpy(&eps, op->op_params, sizeof(float)); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_kargs_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.eps =*/ eps, + /*.nef1 =*/ { ne01 }, + /*.nef2 =*/ { ne02 }, + /*.nef3 =*/ { ne03 }, + /*.nbf1 =*/ { nb01 }, + /*.nbf2 =*/ { nb02 }, + /*.nbf3 =*/ { nb03 }, + }; + + ggml_op fops[8]; + + int n_fuse = 1; + + ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 }; + + // d[0] = norm(a) + // d[1] = mul(d[0], b) + // d[2] = add(d[1], c) + if (use_fusion) { + fops[0] = op->op; + fops[1] = GGML_OP_MUL; + fops[2] = GGML_OP_ADD; + + for (n_fuse = 0; n_fuse <= 1; ++n_fuse) { + if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) { + break; + } + + ggml_tensor * f0 = ctx->node(idx + n_fuse); + ggml_tensor * f1 = ctx->node(idx + n_fuse + 1); + + if (f0 != f1->src[0]) { + break; + } + + if (f1->src[1]->ne[0] != op->ne[0]) { + break; + } + + if (!ggml_is_contiguous_rows(f1->src[1])) { + break; + } + + if (f1->type != GGML_TYPE_F32) { + break; + } + + //ctx->fuse_cnt[f1->op]++; + + bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]); + + args.nef1[n_fuse + 1] = f1->src[1]->ne[1]; + args.nef2[n_fuse + 1] = f1->src[1]->ne[2]; + args.nef3[n_fuse + 1] = f1->src[1]->ne[3]; + + args.nbf1[n_fuse + 1] = f1->src[1]->nb[1]; + args.nbf2[n_fuse + 1] = f1->src[1]->nb[2]; + args.nbf3[n_fuse + 1] = f1->src[1]->nb[3]; + } + + ++n_fuse; + + if (debug_fusion > 1 && n_fuse > 1) { + if (n_fuse == 2) { + GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, ggml_op_name(op->op)); + } + if (n_fuse == 3) { + GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, ggml_op_name(op->op)); + } + } + } + + if (n_fuse > 1) { + bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); + + for (int i = 1; i < n_fuse; ++i) { + if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) { + ggml_metal_op_concurrency_reset(ctx); + + break; + } + } + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse); + + int nth = 32; // SIMD width + + while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, args.ne00_t); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2); + ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3); + ggml_metal_encoder_set_buffer (enc, bid_dst, 4); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return n_fuse; +} + +int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + // make sure we have one or more position id(ne10) per token(ne02) + GGML_ASSERT(ne10 % ne02 == 0); + GGML_ASSERT(ne10 >= ne02); + + const int nth = std::min(1024, ne00); + + const int n_past = ((const int32_t *) op->op_params)[0]; + const int n_dims = ((const int32_t *) op->op_params)[1]; + //const int mode = ((const int32_t *) op->op_params)[2]; + // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal + const int n_ctx_orig = ((const int32_t *) op->op_params)[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (const int32_t *) op->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (const int32_t *) op->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (const int32_t *) op->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (const int32_t *) op->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (const int32_t *) op->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (const int32_t *) op->op_params + 10, sizeof(float)); + + // mrope + const int sect_0 = ((const int32_t *) op->op_params)[11]; + const int sect_1 = ((const int32_t *) op->op_params)[12]; + const int sect_2 = ((const int32_t *) op->op_params)[13]; + const int sect_3 = ((const int32_t *) op->op_params)[14]; + + ggml_metal_kargs_rope args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.n_past =*/ n_past, + /*.n_dims =*/ n_dims, + /*.n_ctx_orig =*/ n_ctx_orig, + /*.freq_base =*/ freq_base, + /*.freq_scale =*/ freq_scale, + /*.ext_factor =*/ ext_factor, + /*.attn_factor =*/ attn_factor, + /*.beta_fast =*/ beta_fast, + /*.beta_slow =*/ beta_slow, + /* sect_0 =*/ sect_0, + /* sect_1 =*/ sect_1, + /* sect_2 =*/ sect_2, + /* sect_3 =*/ sect_3, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + if (op->src[2]) { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); + } else { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 3); + } + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + const int32_t s1 = ((const int32_t *)(op->op_params))[1]; + const int32_t p0 = ((const int32_t *)(op->op_params))[2]; + const int32_t p1 = ((const int32_t *)(op->op_params))[3]; + const int32_t d0 = ((const int32_t *)(op->op_params))[4]; + const int32_t d1 = ((const int32_t *)(op->op_params))[5]; + + const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1; + + const int32_t N = op->src[1]->ne[is_2D ? 3 : 2]; + const int32_t IC = op->src[1]->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? op->src[1]->ne[1] : 1; + const int32_t IW = op->src[1]->ne[0]; + + const int32_t KH = is_2D ? op->src[0]->ne[1] : 1; + const int32_t KW = op->src[0]->ne[0]; + + const int32_t OH = is_2D ? op->ne[2] : 1; + const int32_t OW = op->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4; + const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4; + + ggml_metal_kargs_im2col args = { + /*.ofs0 =*/ ofs0, + /*.ofs1 =*/ ofs1, + /*.IW =*/ IW, + /*.IH =*/ IH, + /*.CHW =*/ CHW, + /*.s0 =*/ s0, + /*.s1 =*/ s1, + /*.p0 =*/ p0, + /*.p1 =*/ p1, + /*.d0 =*/ d0, + /*.d1 =*/ d1, + /*.N =*/ N, + /*.KH =*/ KH, + /*.KW =*/ KW, + /*.KHW =*/ KH * KW, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); + + GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); + + return 1; +} + +int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + + const int32_t IC = op->src[1]->ne[1]; + const int32_t IL = op->src[1]->ne[0]; + + const int32_t K = op->src[0]->ne[0]; + + const int32_t OL = op->ne[0]; + const int32_t OC = op->ne[1]; + + ggml_metal_kargs_conv_transpose_1d args = { + /*.IC =*/ IC, + /*.IL =*/ IL, + /*.K =*/ K, + /*.s0 =*/ s0, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, OL, OC, 1, 1, 1, 1); + + return 1; +} + +int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const float sf0 = (float)ne0/op->src[0]->ne[0]; + const float sf1 = (float)ne1/op->src[0]->ne[1]; + const float sf2 = (float)ne2/op->src[0]->ne[2]; + const float sf3 = (float)ne3/op->src[0]->ne[3]; + + ggml_metal_kargs_upscale args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.sf0 =*/ sf0, + /*.sf1 =*/ sf1, + /*.sf2 =*/ sf2, + /*.sf3 =*/ sf3 + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_kargs_pad args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3 + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad(lib, op); + + const int nth = std::min(1024, ne0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_kargs_pad_reflect_1d args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.p0 =*/ ((const int32_t *)(op->op_params))[0], + /*.p1 =*/ ((const int32_t *)(op->op_params))[1] + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op); + + const int nth = std::min(1024, ne0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float start; + float step; + + memcpy(&start, ((const int32_t *) op->op_params) + 0, sizeof(float)); + memcpy(&step, ((const int32_t *) op->op_params) + 2, sizeof(float)); + + ggml_metal_kargs_arange args = { + /*.ne0 =*/ ne0, + /*.start =*/ start, + /*.step =*/ step + }; + + const int nth = std::min(1024, ne0); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op); + + //[encoder setComputePipelineState:pipeline]; + //[encoder setBuffer:id_dst offset:offs_dst atIndex:0]; + //[encoder setBytes:&args length:sizeof(args) atIndex:1]; + + //[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int dim = op->op_params[0]; + const int max_period = op->op_params[1]; + + ggml_metal_kargs_timestep_embedding args = { + /*.nb1 =*/ nb1, + /*.dim =*/ dim, + /*.max_period =*/ max_period, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op); + + const int nth = std::max(1, std::min(1024, dim/2)); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne00, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_kargs_argmax args = { + /*.ne00 = */ ne00, + /*.nb01 = */ nb01, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argmax(lib, op); + + const int64_t nrows = ggml_nrows(op->src[0]); + + int nth = 32; // SIMD width + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + // bitonic sort requires the number of elements to be power of 2 + int64_t ne00_padded = 1; + while (ne00_padded < ne00) { + ne00_padded *= 2; + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); + + const int64_t nrows = ggml_nrows(op->src[0]); + + // Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength + const size_t smem = GGML_PAD(ne00_padded*sizeof(int32_t), 16); + + ggml_metal_kargs_argsort args = { + /*.ncols =*/ ne00, + /*.ncols_pad =*/ ne00_padded + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, nrows, 1, ne00_padded, 1, 1); + + return 1; +} + +int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float slope; + memcpy(&slope, op->op_params, sizeof(float)); + + ggml_metal_kargs_leaky_relu args = { + /*.slope =*/ slope + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + + int64_t n = ggml_nelements(op); + + if (n % 4 == 0) { + n /= 4; + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + + return 1; +} + +int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); + + const int64_t np = ggml_nelements(op->src[0]); + ggml_metal_kargs_opt_step_adamw args = { + /*.np =*/ np, + }; + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + const int64_t n = (np + nth - 1) / nth; + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); + + const int64_t np = ggml_nelements(op->src[0]); + ggml_metal_kargs_opt_step_sgd args = { + /*.np =*/ np, + }; + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + const int64_t n = (np + nth - 1) / nth; + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1); + + return 1; +} diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h new file mode 100644 index 00000000..f3527386 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h @@ -0,0 +1,87 @@ +#pragma once + +#include "ggml-metal-device.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct ggml_metal_op * ggml_metal_op_t; + +ggml_metal_op_t ggml_metal_op_init( + ggml_metal_device_t dev, + ggml_metal_cmd_buf_t cmd_buf, + struct ggml_cgraph * gf, + int idx_start, + int idx_end, + bool use_fusion, + bool use_concurrency, + bool use_capture, + int debug_graph, + int debug_fusion); + +void ggml_metal_op_free(ggml_metal_op_t ctx); + +int ggml_metal_op_n_nodes(ggml_metal_op_t ctx); + +int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx); + +// +// available ops: +// + +// tokens per expert +size_t ggml_metal_op_mul_mat_id_extra_tpe(const struct ggml_tensor * op); + +// id map [n_tokens, n_expert] +size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op); + +// return true if we should use the FA vector kernel for this op +bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op); + +size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op); +size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op); +size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); + +int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); +int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); + +#ifdef __cplusplus +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp new file mode 100644 index 00000000..f356e4a0 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp @@ -0,0 +1,723 @@ +#include "ggml-metal.h" + +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include "ggml-metal-device.h" +#include "ggml-metal-context.h" +#include "ggml-metal-ops.h" + +// globals + +// initialized in ggml_backend_metal_reg +static ggml_backend_reg g_ggml_metal_reg; +static ggml_backend_device g_ggml_metal_device; + +//////////////////////////////////////////////////////////////////////////////// +// backend interface +//////////////////////////////////////////////////////////////////////////////// + +// shared buffer + +static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t buffer) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_free(ctx); + delete buffer; +} + +static void * ggml_backend_metal_buffer_shared_get_base(ggml_backend_buffer_t buffer) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); + + return ggml_metal_buffer_get_base(ctx); +} + +static void ggml_backend_metal_buffer_shared_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size); +} + +static void ggml_backend_metal_buffer_shared_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size); +} + +static void ggml_backend_metal_buffer_shared_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size); +} + +static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); + + GGML_UNUSED(buffer); + GGML_UNUSED(src); + GGML_UNUSED(dst); + + return false; +} + +static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_clear(ctx, value); +} + +static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = { + /* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_shared_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, + /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_shared_clear, + /* .reset = */ NULL, +}; + +// private buffer + +static void ggml_backend_metal_buffer_private_free_buffer(ggml_backend_buffer_t buffer) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_free(ctx); + delete buffer; +} + +static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + return ggml_metal_buffer_get_base(ctx); +} + +static void ggml_backend_metal_buffer_private_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size); +} + +static void ggml_backend_metal_buffer_private_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size); +} + +static void ggml_backend_metal_buffer_private_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size); +} + +static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + GGML_UNUSED(buffer); + GGML_UNUSED(src); + GGML_UNUSED(dst); + + return false; +} + +static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_clear(ctx, value); +} + +static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { + /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_private_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, + /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_private_clear, + /* .reset = */ NULL, +}; + +// +// buffer types +// + +// common method for allocating shread or private Metal buffers +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; + ggml_metal_buffer_t res = ggml_metal_buffer_init(ctx_dev, size, shared); + + ggml_backend_buffer_i buf_i = ggml_metal_buffer_is_shared(res) + ? ggml_backend_metal_buffer_shared_i + : ggml_backend_metal_buffer_private_i; + + return ggml_backend_buffer_init(buft, buf_i, res, size); +} + +static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + size_t res = ggml_nbytes(tensor); + + // some operations require additional memory for fleeting data: + switch (tensor->op) { + case GGML_OP_MUL_MAT_ID: + { + res += ggml_metal_op_mul_mat_id_extra_tpe(tensor); + res += ggml_metal_op_mul_mat_id_extra_ids(tensor); + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + res += ggml_metal_op_flash_attn_ext_extra_pad(tensor); + res += ggml_metal_op_flash_attn_ext_extra_blk(tensor); + res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); + } break; + default: + break; + } + + return res; + + GGML_UNUSED(buft); +} + +// default (shared) buffer type + +static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) { + return "Metal"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true); +} + +static size_t ggml_backend_metal_buffer_type_shared_get_alignment(ggml_backend_buffer_type_t buft) { + return 32; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; + + return ggml_metal_device_get_props(ctx_dev)->max_buffer_size; +} + +static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor); +} + +static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) { + return false; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) { + static ggml_backend_buffer_type ggml_backend_buffer_type_metal = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, + }, + /* .device = */ &g_ggml_metal_device, + /* .context = */ NULL, + }; + + return &ggml_backend_buffer_type_metal; +} + +// default (private) buffer type + +static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) { + return "Metal_Private"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, false); +} + +static size_t ggml_backend_metal_buffer_type_private_get_alignment(ggml_backend_buffer_type_t buft) { + return 32; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; + + return ggml_metal_device_get_props(ctx_dev)->max_buffer_size; +} + +static size_t ggml_backend_metal_buffer_type_private_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor); +} + +static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) { + return false; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) { + static ggml_backend_buffer_type ggml_backend_buffer_type_metal = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, + }, + /* .device = */ &g_ggml_metal_device, + /* .context = */ NULL, + }; + + return &ggml_backend_buffer_type_metal; +} + +// mapped buffer type + +static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) { + return "Metal_Mapped"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + // for mapped buffers, prefer shared memory + return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true); +} + +static size_t ggml_backend_metal_buffer_type_mapped_get_alignment(ggml_backend_buffer_type_t buft) { + return 32; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; + + return ggml_metal_device_get_props(ctx_dev)->max_buffer_size; +} + +static size_t ggml_backend_metal_buffer_type_mapped_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor); +} + +static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) { + return false; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) { + // note: not obvious, but this buffer type still needs to implement .alloc_buffer: + // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099 + static ggml_backend_buffer_type ggml_backend_buffer_type_mapped_metal = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, + }, + /* .device = */ &g_ggml_metal_device, + /* .context = */ NULL, + }; + + return &ggml_backend_buffer_type_mapped_metal; +} + +// backend + +static const char * ggml_backend_metal_name(ggml_backend_t backend) { + return "Metal"; + + GGML_UNUSED(backend); +} + +static void ggml_backend_metal_free(ggml_backend_t backend) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + // wait for any ongoing async operations to finish + ggml_metal_synchronize(ctx); + + ggml_metal_free(ctx); + + free(backend); +} + +static void ggml_backend_metal_synchronize(ggml_backend_t backend) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + ggml_metal_synchronize(ctx); +} + +static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + ggml_metal_set_tensor_async(ctx, tensor, data, offset, size); +} + +static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + ggml_metal_get_tensor_async(ctx, tensor, data, offset, size); +} + +static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { + return false; + + GGML_UNUSED(backend_src); + GGML_UNUSED(backend_dst); + GGML_UNUSED(src); + GGML_UNUSED(dst); +} + +static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + return ggml_metal_graph_compute(ctx, cgraph); +} + +static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + ggml_metal_graph_optimize(ctx, cgraph); +} + +static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + ggml_metal_set_n_cb(ctx, n_cb); + +} + +static ggml_backend_i ggml_backend_metal_i = { + /* .get_name = */ ggml_backend_metal_name, + /* .free = */ ggml_backend_metal_free, + /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, + /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups + /* .synchronize = */ ggml_backend_metal_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_metal_graph_compute, + + // the events API is needed only for multi-GPU setups, so likely no need to implement it for Metal + // in any case, these docs seem relevant if we ever decide to implement it: + // https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ ggml_backend_metal_graph_optimize, +}; + +static ggml_guid_t ggml_backend_metal_guid(void) { + static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 }; + return &guid; +} + +ggml_backend_t ggml_backend_metal_init(void) { + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0); + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_t ctx = ggml_metal_init(ctx_dev); + if (ctx == NULL) { + GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); + return NULL; + } + + ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend)); + + *backend = { + /* .guid = */ ggml_backend_metal_guid(), + /* .interface = */ ggml_backend_metal_i, + /* .device = */ dev, + /* .context = */ ctx, + }; + + ggml_backend_metal_set_n_cb(backend, 1); + + return backend; +} + +bool ggml_backend_is_metal(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid()); +} + +void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + ggml_metal_set_abort_callback(ctx, abort_callback, user_data); +} + +bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + return ggml_metal_supports_family(ctx, family); +} + +void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + ggml_metal_capture_next_compute(ctx); +} + +// backend device + +static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) { + return "Metal"; + + GGML_UNUSED(dev); +} + +static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + return ggml_metal_device_get_props(ctx_dev)->name; +} + +static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_device_get_memory(ctx_dev, free, total); +} + +static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_GPU; + + GGML_UNUSED(dev); +} + +#define GGML_METAL_NAME "Metal" +static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_metal_device_get_name(dev); + props->description = ggml_backend_metal_device_get_description(dev); + props->id = "0"; + props->type = ggml_backend_metal_device_get_type(dev); + + ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); + + props->library = GGML_METAL_NAME; + props->caps = { + /* .async = */ true, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_t ctx = ggml_metal_init(ctx_dev); + if (ctx == NULL) { + GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); + return NULL; + } + + ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend)); + + *backend = { + /* .guid = */ ggml_backend_metal_guid(), + /* .interface = */ ggml_backend_metal_i, + /* .device = */ dev, + /* .context = */ ctx, + }; + + ggml_backend_metal_set_n_cb(backend, 1); + + return backend; + + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); + + return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared() : ggml_backend_metal_buffer_type_private(); +} + +static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_buffer_t res = ggml_metal_buffer_map(ctx_dev, ptr, size, max_tensor_size); + + return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(), ggml_backend_metal_buffer_shared_i, res, size); +} + +static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + return ggml_metal_device_supports_op(ctx_dev, op); +} + +static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return + buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name || + buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name || + buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name; + + GGML_UNUSED(dev); +} + +static int64_t get_op_batch_size(const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_MUL_MAT: + return op->ne[1]; + case GGML_OP_MUL_MAT_ID: + return op->ne[2]; + default: + return ggml_nrows(op); + } +} + +static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + const int min_batch_size = 32; + + return (op->op == GGML_OP_MUL_MAT || + op->op == GGML_OP_MUL_MAT_ID) && + get_op_batch_size(op) >= min_batch_size; + + GGML_UNUSED(dev); + GGML_UNUSED(op); +} + +static ggml_backend_device_i ggml_backend_metal_device_i = { + /* .get_name = */ ggml_backend_metal_device_get_name, + /* .get_description = */ ggml_backend_metal_device_get_description, + /* .get_memory = */ ggml_backend_metal_device_get_memory, + /* .get_type = */ ggml_backend_metal_device_get_type, + /* .get_props = */ ggml_backend_metal_device_get_props, + /* .init_backend = */ ggml_backend_metal_device_init, + /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped, + /* .supports_op = */ ggml_backend_metal_device_supports_op, + /* .supports_buft = */ ggml_backend_metal_device_supports_buft, + /* .offload_op = */ ggml_backend_metal_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// backend registry + +static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) { + return "Metal"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) { + return 1; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index == 0); + + return &g_ggml_metal_device; + + GGML_UNUSED(reg); + GGML_UNUSED(index); +} + +static ggml_backend_feature g_ggml_backend_metal_features[] = { +#if defined(GGML_METAL_EMBED_LIBRARY) + { "EMBED_LIBRARY", "1" }, +#endif + { NULL, NULL }, +}; + +static ggml_backend_feature * ggml_backend_metal_get_features(ggml_backend_reg_t reg) { + return g_ggml_backend_metal_features; + + GGML_UNUSED(reg); +} + +static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (strcmp(name, "ggml_backend_get_features") == 0) { + return (void *)ggml_backend_metal_get_features; + } + + return NULL; + + GGML_UNUSED(reg); +} + +static ggml_backend_reg_i ggml_backend_metal_reg_i = { + /* .get_name = */ ggml_backend_metal_reg_get_name, + /* .device_count = */ ggml_backend_metal_reg_device_count, + /* .device_get = */ ggml_backend_metal_reg_device_get, + /* .get_proc_address = */ ggml_backend_metal_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_metal_reg(void) { + { + g_ggml_metal_reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_metal_reg_i, + /* .context = */ NULL, + }; + + g_ggml_metal_device = { + /* .iface = */ ggml_backend_metal_device_i, + /* .reg = */ &g_ggml_metal_reg, + /* .context = */ ggml_metal_device_get(), + }; + } + + return &g_ggml_metal_reg; +} + +GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg) diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m deleted file mode 100644 index e4c31268..00000000 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m +++ /dev/null @@ -1,6781 +0,0 @@ -#import "ggml-metal.h" - -#import "ggml-impl.h" -#import "ggml-backend-impl.h" -#import "ggml-metal-impl.h" - -#import - -#import - -#undef MIN -#undef MAX -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -// max memory buffers that can be mapped to the device -#define GGML_METAL_MAX_BUFFERS 64 - -// max number of MTLCommandBuffer used to submit a graph for processing -#define GGML_METAL_MAX_COMMAND_BUFFERS 8 - -#ifndef TARGET_OS_VISION -#define TARGET_OS_VISION 0 -#endif - -// create residency sets only on macOS >= 15.0 -#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \ - TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \ - TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \ - TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000 -#define GGML_METAL_HAS_RESIDENCY_SETS 1 -#endif - -// globals - -// overload of MTLGPUFamilyMetal3 (not available in some environments) -static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; - -// initialized in ggml_backend_metal_reg -static struct ggml_backend_reg g_ggml_backend_metal_reg; -static struct ggml_backend_device g_ggml_backend_metal_device; - -// information about a Metal device -// note: assumes single GPU device - the default one -// TODO: support multiple GPU devices -static struct ggml_backend_metal_device_context { - id mtl_device; - int mtl_device_ref_count; - id mtl_library; - - NSLock * mtl_lock; - - bool has_simdgroup_reduction; - bool has_simdgroup_mm; - bool has_residency_sets; - bool has_bfloat; - bool use_bfloat; - bool use_fusion; - - int debug_fusion; - - // how many times a given op was fused - uint64_t fuse_cnt[GGML_OP_COUNT]; - - size_t max_size; - - char name[128]; -} g_ggml_ctx_dev_main = { - /*.mtl_device =*/ nil, - /*.mtl_device_ref_count =*/ 0, - /*.mtl_library =*/ nil, - /*.mtl_lock =*/ nil, - /*.has_simdgroup_reduction =*/ false, - /*.has_simdgroup_mm =*/ false, - /*.has_residency_sets =*/ false, - /*.has_bfloat =*/ false, - /*.use_bfloat =*/ false, - /*.use_fusion =*/ true, - /*.debug_fusion =*/ 0, - /*.fuse_cnt =*/ { 0 }, - /*.max_size =*/ 0, - /*.name =*/ "", -}; - -// acquire -static id ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) { - assert(ctx != NULL); - - if (ctx->mtl_lock == nil) { - ctx->mtl_lock = [[NSLock alloc] init]; - } - - if (ctx->mtl_device == nil) { - ctx->mtl_device = MTLCreateSystemDefaultDevice(); - - ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; - ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; - - ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; - -#if defined(GGML_METAL_HAS_RESIDENCY_SETS) - ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil; -#endif - - ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; - ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6]; - -#if defined(GGML_METAL_USE_BF16) - if (@available(macOS 14.0, *)) { - ctx->use_bfloat = ctx->has_bfloat; - } else { - ctx->use_bfloat = false; - } -#else - ctx->use_bfloat = false; -#endif - ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil; - - { - const char * val = getenv("GGML_METAL_FUSION_DEBUG"); - ctx->debug_fusion = val ? atoi(val) : 0; - } - - memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt)); - - ctx->max_size = ctx->mtl_device.maxBufferLength; - - strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1); - } - - ctx->mtl_device_ref_count++; - - return ctx->mtl_device; -} - -// release -static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_context * ctx) { - assert(ctx != NULL); - assert(ctx->mtl_device_ref_count > 0); - - ctx->mtl_device_ref_count--; - - if (ctx->mtl_device_ref_count == 0) { - if (ctx->debug_fusion > 0) { - fprintf(stderr, "%s: fusion stats:\n", __func__); - for (int i = 0; i < GGML_OP_COUNT; i++) { - if (ctx->fuse_cnt[i] == 0) { - continue; - } - - // note: cannot use ggml_log here - fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]); - } - } - - if (ctx->mtl_lock) { - [ctx->mtl_lock release]; - ctx->mtl_lock = nil; - } - - if (ctx->mtl_library) { - [ctx->mtl_library release]; - ctx->mtl_library = nil; - } - - if (ctx->mtl_device) { - [ctx->mtl_device release]; - ctx->mtl_device = nil; - } - } -} - -// kernels - -struct ggml_metal_kernel { - id pipeline; -}; - -enum ggml_metal_kernel_type { - GGML_METAL_KERNEL_TYPE_ADD, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, - GGML_METAL_KERNEL_TYPE_SUB, - GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, - GGML_METAL_KERNEL_TYPE_MUL, - GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, - GGML_METAL_KERNEL_TYPE_DIV, - GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, - GGML_METAL_KERNEL_TYPE_ADD_ID, - GGML_METAL_KERNEL_TYPE_REPEAT_F32, - GGML_METAL_KERNEL_TYPE_REPEAT_F16, - GGML_METAL_KERNEL_TYPE_REPEAT_I32, - GGML_METAL_KERNEL_TYPE_REPEAT_I16, - GGML_METAL_KERNEL_TYPE_SCALE, - GGML_METAL_KERNEL_TYPE_SCALE_4, - GGML_METAL_KERNEL_TYPE_CLAMP, - GGML_METAL_KERNEL_TYPE_TANH, - GGML_METAL_KERNEL_TYPE_RELU, - GGML_METAL_KERNEL_TYPE_SIGMOID, - GGML_METAL_KERNEL_TYPE_GELU, - GGML_METAL_KERNEL_TYPE_GELU_4, - GGML_METAL_KERNEL_TYPE_GELU_ERF, - GGML_METAL_KERNEL_TYPE_GELU_ERF_4, - GGML_METAL_KERNEL_TYPE_GELU_QUICK, - GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, - GGML_METAL_KERNEL_TYPE_SILU, - GGML_METAL_KERNEL_TYPE_SILU_4, - GGML_METAL_KERNEL_TYPE_ELU, - GGML_METAL_KERNEL_TYPE_ABS, - GGML_METAL_KERNEL_TYPE_SGN, - GGML_METAL_KERNEL_TYPE_STEP, - GGML_METAL_KERNEL_TYPE_HARDSWISH, - GGML_METAL_KERNEL_TYPE_HARDSIGMOID, - GGML_METAL_KERNEL_TYPE_EXP, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, - GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, - GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, - GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, - GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, - GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, - GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, - GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, - GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, - GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, - GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, - GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, - GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, - GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, - GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, - GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, - GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, - GGML_METAL_KERNEL_TYPE_RMS_NORM, - GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, - GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, - GGML_METAL_KERNEL_TYPE_L2_NORM, - GGML_METAL_KERNEL_TYPE_GROUP_NORM, - GGML_METAL_KERNEL_TYPE_NORM, - GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, - GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, - GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, - GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, - GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, - GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, - GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, - GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, - GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, - GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, - GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, - GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, - GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, - //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, - //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, - //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, - GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, - GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, - GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, - GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, - GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, - GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, - GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, - GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, - GGML_METAL_KERNEL_TYPE_IM2COL_F16, - GGML_METAL_KERNEL_TYPE_IM2COL_F32, - GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, - GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, - GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, - GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, - GGML_METAL_KERNEL_TYPE_UPSCALE_F32, - GGML_METAL_KERNEL_TYPE_PAD_F32, - GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, - GGML_METAL_KERNEL_TYPE_ARANGE_F32, - GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, - GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, - GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, - GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_SET_I32, - GGML_METAL_KERNEL_TYPE_SET_F32, - GGML_METAL_KERNEL_TYPE_CPY_F32_F32, - GGML_METAL_KERNEL_TYPE_CPY_F32_F16, - GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, - GGML_METAL_KERNEL_TYPE_CPY_F16_F16, - GGML_METAL_KERNEL_TYPE_CPY_F16_F32, - GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, - GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, - GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, - GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, - GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, - GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, - GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, - GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, - GGML_METAL_KERNEL_TYPE_CONCAT, - GGML_METAL_KERNEL_TYPE_SQR, - GGML_METAL_KERNEL_TYPE_SQRT, - GGML_METAL_KERNEL_TYPE_SIN, - GGML_METAL_KERNEL_TYPE_COS, - GGML_METAL_KERNEL_TYPE_NEG, - GGML_METAL_KERNEL_TYPE_REGLU, - GGML_METAL_KERNEL_TYPE_GEGLU, - GGML_METAL_KERNEL_TYPE_SWIGLU, - GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, - GGML_METAL_KERNEL_TYPE_GEGLU_ERF, - GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, - GGML_METAL_KERNEL_TYPE_SUM_ROWS, - GGML_METAL_KERNEL_TYPE_MEAN, - GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, - GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, - GGML_METAL_KERNEL_TYPE_ARGMAX, - - GGML_METAL_KERNEL_TYPE_COUNT -}; - -// -// ggml_metal_heap -// - -struct ggml_metal_heap { - // number of times the heap was unused - int n_unused; - - // total number of buffer allocations in this heap across all computes - int64_t n_alloc; - - // current offset in the heap - we reset this after each node in order to reuse the memory - size_t offs; - - // the currently allocated MTLBuffer objects in this heap - id obj; - - NSMutableArray * bufs; -}; - -static struct ggml_metal_heap * ggml_metal_heap_init(id device, size_t size) { - struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap)); - - MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init]; - desc.storageMode = MTLStorageModePrivate; - desc.cpuCacheMode = MTLCPUCacheModeDefaultCache; - desc.type = MTLHeapTypePlacement; - desc.size = size; - - heap->n_unused = 0; - heap->n_alloc = 0; - - heap->obj = [device newHeapWithDescriptor:desc]; - if (!heap->obj) { - GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size); - - free(heap); - - return false; - } - - [desc release]; - - heap->bufs = [[NSMutableArray alloc] init]; - - return heap; -} - -static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) { - heap->offs = 0; - - // count how many graph computes the heap ended up being unused - if ([heap->bufs count] > 0) { - heap->n_unused = 0; - } else { - heap->n_unused++; - } - - for (id buf in heap->bufs) { - [buf release]; - } - [heap->bufs removeAllObjects]; - - // tell the OS that it can reuse this memory if needed - // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc - [heap->obj setPurgeableState:MTLPurgeableStateVolatile]; -} - -static void ggml_metal_heap_free(struct ggml_metal_heap * heap) { - if (heap == nil) { - return; - } - - ggml_metal_heap_reset(heap); - - [heap->obj release]; - [heap->bufs release]; - - free(heap); -} - -@interface ggml_metal_heap_ptr : NSObject - -@property (nonatomic, assign) struct ggml_metal_heap * data; - -@end - -@implementation ggml_metal_heap_ptr -@end - -// -// ggml_metal_mem_pool -// - -struct ggml_metal_mem_pool { - id device; - - int n_heaps; // total number of heaps ever created (including those that were removed) - - NSMutableArray * heaps; - NSMutableArray * heaps_to_remove; -}; - -static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) { - struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool)); - - mem_pool->n_heaps = 0; - - mem_pool->heaps = [[NSMutableArray alloc] init]; - mem_pool->heaps_to_remove = [[NSMutableArray alloc] init]; - - return mem_pool; -} - -static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) { - GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps); - - size_t size_all = 0; - size_t size_cur = 0; - - for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { - GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data); - GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc); - GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused); - GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0); - GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]); - - if ([ptr.data->bufs count] > 0) { - size_cur += [ptr.data->obj size]; - } - size_all += [ptr.data->obj size]; - - ggml_metal_heap_free(ptr.data); - [ptr release]; - } - [mem_pool->heaps release]; - [mem_pool->heaps_to_remove release]; - - if (size_all > 0) { - GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0); - GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0); - } - - free(mem_pool); -} - -static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) { - for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) { - ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i]; - - struct ggml_metal_heap * heap = ptr.data; - ggml_metal_heap_reset(heap); - - // if the heap hasn't been used for a while, remove it - if (heap->n_unused >= 128) { - [mem_pool->heaps_to_remove addObject:@(i)]; - } - } - - if (mem_pool->heaps_to_remove.count > 0) { - // remove in reverse order - for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) { - NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue]; - ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index]; - - struct ggml_metal_heap * heap = ptr.data; - ggml_metal_heap_free(heap); - - [mem_pool->heaps removeObjectAtIndex:index]; - [ptr release]; - - if (i == 0) { - break; - } - } - - [mem_pool->heaps_to_remove removeAllObjects]; - } -} - -static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) { - for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { - ptr.data->offs = 0; - } -} - -static id ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) { - const size_t alignment = 256; - - const size_t size_aligned = GGML_PAD(size, alignment); - - // try one of the existing heaps - for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { - struct ggml_metal_heap * heap = ptr.data; - if (heap->offs + size_aligned <= [heap->obj size]) { - // if this is the first buffer in the heap for the current command buffer, tell the OS that - // it cannot free the memory used by the heap - // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc - if ([heap->bufs count] == 0) { - [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile]; - } - - id buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs]; - if (buf == nil) { - GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned); - return nil; - } - - heap->n_alloc++; - heap->offs += size_aligned; - - [heap->bufs addObject:buf]; - - return buf; - } - } - - // create a new heap that can fit this buffer - ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new]; - - struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned); - if (heap == NULL) { - GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned); - return NULL; - } - - //GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]); - - heap_ptr.data = heap; - ggml_metal_heap_reset(heap); - - [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile]; - id buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs]; - if (buf == nil) { - GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned); - return NULL; - } - - heap->n_alloc++; - heap->offs += size_aligned; - - [heap->bufs addObject:buf]; - - [mem_pool->heaps addObject:heap_ptr]; - mem_pool->n_heaps++; - - return buf; -} - -struct ggml_metal_command_buffer { - id obj; - - // each command buffer has a memory pool from which it can allocate temporary buffers during the compute - struct ggml_metal_mem_pool * mem_pool; -}; - -struct ggml_backend_metal_context { - id device; - id queue; - - dispatch_queue_t d_queue; - - struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; - - // capture state - bool capture_next_compute; - bool capture_started; - - id capture_scope; - - // command buffer state - int n_cb; // number of extra threads used to submit the command buffers - int n_nodes_0; // number of nodes submitted by the main thread - int n_nodes_1; // remaining number of nodes submitted by the n_cb threads - int n_nodes_per_cb; - - struct ggml_cgraph * gf; - - // the callback given to the thread pool - void (^encode_async)(size_t ith); - - // n_cb command buffers + 1 used by the main thread - struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; - - // abort ggml_metal_graph_compute if callback returns true - ggml_abort_callback abort_callback; - void * abort_callback_data; -}; - -// MSL code -// TODO: move the contents here when ready -// for now it is easier to work in a separate file -// static NSString * const msl_library_source = @"see metal.metal"; - -#if !GGML_METAL_EMBED_LIBRARY -// Here to assist with NSBundle Path Hack -@interface GGMLMetalClass : NSObject -@end -@implementation GGMLMetalClass -@end -#endif - -static void * ggml_metal_host_malloc(size_t n) { - void * data = NULL; - -#if TARGET_OS_OSX - kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE); - if (err != KERN_SUCCESS) { - GGML_LOG_ERROR("%s: error: vm_allocate failed\n", __func__); - return NULL; - } -#else - const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); - if (result != 0) { - GGML_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); - return NULL; - } -#endif - - return data; -} - -// load library -// -// - first check if the library is embedded -// - then check if the library is in the bundle -// - if not found, load the source and compile it -// - if that fails, return NULL -static id ggml_metal_load_library(id device, bool use_bfloat) { - id metal_library = nil; - NSError * error = nil; - NSString * src = nil; - -#if GGML_METAL_EMBED_LIBRARY - GGML_LOG_INFO("%s: using embedded metal library\n", __func__); - - extern const char ggml_metallib_start[]; - extern const char ggml_metallib_end[]; - - src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding]; - -#else - -#ifdef SWIFT_PACKAGE - NSBundle * bundle = SWIFTPM_MODULE_BUNDLE; -#else - NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; -#endif - - NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"]; - if (path_lib == nil) { - // Try to find the resource in the directory where the current binary located. - NSString * current_binary = [[NSProcessInfo processInfo] arguments][0]; - NSString * bin_dir = [current_binary stringByDeletingLastPathComponent]; - NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]]; - if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) { - GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]); - NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error]; - if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) { - // Optionally, if this is a symlink, try to resolve it. - default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error]; - if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) { - // It is a relative path, adding the binary directory as directory prefix. - default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]]; - } - if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) { - // Link to the resource could not be resolved. - default_metallib_path = nil; - } else { - GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]); - } - } - } else { - // The resource couldn't be found in the binary's directory. - default_metallib_path = nil; - } - path_lib = default_metallib_path; - } - - if (path_lib != nil) { - // pre-compiled library found - NSURL * libURL = [NSURL fileURLWithPath:path_lib]; - GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]); - - metal_library = [device newLibraryWithURL:libURL error:&error]; - if (error) { - GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - } else { - GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); - - NSString * path_source; - NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"]; - - GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil"); - - if (path_resource) { - path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"]; - } else { - path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; - } - - if (path_source == nil) { - GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); - path_source = @"ggml-metal.metal"; - } - - GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]); - - src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error]; - if (error) { - GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - } -#endif - - if (!metal_library) { - @autoreleasepool { - // dictionary of preprocessor macros - NSMutableDictionary * prep = [NSMutableDictionary dictionary]; - - if (use_bfloat) { - [prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"]; - } - -#if GGML_METAL_EMBED_LIBRARY - [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"]; -#endif - - MTLCompileOptions * options = [MTLCompileOptions new]; - options.preprocessorMacros = prep; - - //[options setFastMathEnabled:false]; - - metal_library = [device newLibraryWithSource:src options:options error:&error]; - if (error) { - GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - -#if !__has_feature(objc_arc) - [options release]; -#endif - } - } - -#if GGML_METAL_EMBED_LIBRARY - [src release]; -#endif // GGML_METAL_EMBED_LIBRARY - - return metal_library; -} - -static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) { - GGML_LOG_INFO("%s: allocating\n", __func__); - -#if TARGET_OS_OSX && !GGML_METAL_NDEBUG - // Show all the Metal device instances in the system - NSArray * devices = MTLCopyAllDevices(); - for (id device in devices) { - GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]); - } - [devices release]; // since it was created by a *Copy* C method -#endif - - // init context - struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context)); - struct ggml_backend_metal_device_context * ctx_dev = dev->context; - - id device = ctx_dev->mtl_device; - - GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); - - ctx->device = device; - ctx->queue = [device newCommandQueue]; - if (ctx->queue == nil) { - GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); - return NULL; - } - - ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); - - // load library - { - [ctx_dev->mtl_lock lock]; - - if (ctx_dev->mtl_library == nil) { - ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat); - } - - [ctx_dev->mtl_lock unlock]; - } - - id metal_library = ctx_dev->mtl_library; - if (metal_library == nil) { - GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__); - return NULL; - } - - // print MTL GPU family: - GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[device name] UTF8String]); - - // determine max supported GPU family - // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf - // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf - { - for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { - if ([device supportsFamily:i]) { - GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); - break; - } - } - - for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) { - if ([device supportsFamily:i]) { - GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i); - break; - } - } - - for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) { - if ([device supportsFamily:i]) { - GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i); - break; - } - } - } - - GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false"); - GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false"); - GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false"); - GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false"); - GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false"); - GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); - - ctx->capture_next_compute = false; - ctx->capture_started = false; - ctx->capture_scope = nil; - - ctx->gf = nil; - ctx->encode_async = nil; - for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { - ctx->cmd_bufs[i].obj = nil; - - ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init(); - ctx->cmd_bufs[i].mem_pool->device = device; - } - -#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) - if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6); - } -#endif - - // load kernels - { - NSError * error = nil; - - for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { - ctx->kernels[i].pipeline = nil; - } - -#define GGML_METAL_ADD_KERNEL(e, name, supported) \ - if (supported) { \ - struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ - id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ - kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \ - GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ - (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ - (int) kernel->pipeline.threadExecutionWidth); \ - [metal_function release]; \ - if (error) { \ - GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ - return NULL; \ - } \ - } else { \ - GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ - } - - const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm; - const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction; - const bool use_bfloat = ctx_dev->use_bfloat; - - // simd_sum and simd_max requires MTLGPUFamilyApple7 - - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, mul_mv_ext_mxfp4_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, mul_mv_ext_mxfp4_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, mul_mv_ext_mxfp4_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, mul_mv_ext_mxfp4_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); - } - - return ctx; -} - -static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { - GGML_LOG_INFO("%s: deallocating\n", __func__); - - for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { - [ctx->kernels[i].pipeline release]; - } - - Block_release(ctx->encode_async); - - [ctx->queue release]; - - for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { - // ctx->cmd_bufs[i].obj is auto released - - ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool); - } - - dispatch_release(ctx->d_queue); - - free(ctx); -} - -// temporarily defined here for compatibility between ggml-backend and the old API - -struct ggml_backend_metal_buffer { - void * data; - size_t size; - - id metal; -}; - -struct ggml_backend_metal_buffer_context { - void * all_data; - size_t all_size; - bool owned; - - // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap - int n_buffers; - struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; - - // optional MTLResidencySet - id rset; -}; - -// rset init -static bool ggml_backend_metal_buffer_rset_init( - struct ggml_backend_metal_buffer_context * ctx, - struct ggml_backend_metal_device_context * ctx_dev, - id device) { - ctx->rset = nil; - - if (!ctx_dev->has_residency_sets) { - return true; - } - -#if defined(GGML_METAL_HAS_RESIDENCY_SETS) - if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { - MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init]; - desc.label = @"ggml_backend_metal"; - desc.initialCapacity = ctx->n_buffers; - - NSError * error; - ctx->rset = [device newResidencySetWithDescriptor:desc error:&error]; - if (error) { - GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - [desc release]; - return false; - } - - [desc release]; - - for (int i = 0; i < ctx->n_buffers; i++) { - [ctx->rset addAllocation:ctx->buffers[i].metal]; - } - - [ctx->rset commit]; - [ctx->rset requestResidency]; - - return true; - } -#else - GGML_UNUSED(ctx_dev); - GGML_UNUSED(device); -#endif - - return true; -} - -// rset free -static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) { -#if defined(GGML_METAL_HAS_RESIDENCY_SETS) - if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { - if (ctx->rset) { - [ctx->rset endResidency]; - [ctx->rset removeAllAllocations]; - [ctx->rset release]; - } - } -#else - GGML_UNUSED(ctx); -#endif -} - -// finds the Metal buffer that contains the tensor data on the GPU device -// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the -// Metal buffer based on the host memory pointer -// -static id ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) { - //GGML_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); - - const int64_t tsize = ggml_nbytes(t); - - ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; - - struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context; - - // find the view that contains the tensor fully - for (int i = 0; i < buf_ctx->n_buffers; ++i) { - const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data; - - //GGML_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size); - if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) { - *offs = (size_t) ioffs; - - //GGML_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs); - - return buf_ctx->buffers[i].metal; - } - } - - GGML_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name); - - return nil; -} - -static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) { - const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm; - const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction; - const bool use_bfloat = ctx_dev->use_bfloat; - - if (!use_bfloat) { - if (op->type == GGML_TYPE_BF16) { - return false; - } - - for (size_t i = 0, n = 3; i < n; ++i) { - if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { - return false; - } - } - } - - switch (op->op) { - case GGML_OP_UNARY: - switch (ggml_get_unary_op(op)) { - case GGML_UNARY_OP_TANH: - case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_SIGMOID: - case GGML_UNARY_OP_GELU: - case GGML_UNARY_OP_GELU_ERF: - case GGML_UNARY_OP_GELU_QUICK: - case GGML_UNARY_OP_SILU: - case GGML_UNARY_OP_ELU: - case GGML_UNARY_OP_NEG: - case GGML_UNARY_OP_ABS: - case GGML_UNARY_OP_SGN: - case GGML_UNARY_OP_STEP: - case GGML_UNARY_OP_HARDSWISH: - case GGML_UNARY_OP_HARDSIGMOID: - case GGML_UNARY_OP_EXP: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; - default: - return false; - } - case GGML_OP_GLU: - switch (ggml_get_glu_op(op)) { - case GGML_GLU_OP_REGLU: - case GGML_GLU_OP_GEGLU: - case GGML_GLU_OP_SWIGLU: - case GGML_GLU_OP_SWIGLU_OAI: - case GGML_GLU_OP_GEGLU_ERF: - case GGML_GLU_OP_GEGLU_QUICK: - return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; - default: - return false; - } - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - case GGML_OP_CONCAT: - return true; - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_ADD_ID: - return op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_ACC: - case GGML_OP_REPEAT: - case GGML_OP_SCALE: - case GGML_OP_CONV_TRANSPOSE_1D: - return true; - case GGML_OP_CLAMP: - return op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SIN: - case GGML_OP_COS: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_LOG: - return false; // TODO: implement - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - case GGML_OP_SOFT_MAX: - case GGML_OP_GROUP_NORM: - return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); - case GGML_OP_RMS_NORM: - case GGML_OP_L2_NORM: - return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); - case GGML_OP_ARGMAX: - return true; - case GGML_OP_NORM: - return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); - case GGML_OP_ROPE: - return true; - case GGML_OP_IM2COL: - return op->src[0]->type == GGML_TYPE_F16; - case GGML_OP_POOL_1D: - return false; - case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; - case GGML_OP_POOL_2D: - case GGML_OP_PAD: - case GGML_OP_PAD_REFLECT_1D: - case GGML_OP_TIMESTEP_EMBEDDING: - case GGML_OP_ARGSORT: - case GGML_OP_LEAKY_RELU: - return op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_ARANGE: - return true; - case GGML_OP_FLASH_ATTN_EXT: - if (op->src[0]->ne[0] == 32) { - // head size == 32 (e.g. bert-bge-small) - // TODO: not sure if it is worth adding kernels for this size - return false; - } - if (op->src[0]->ne[0] == 576) { - // DeepSeek sizes - // TODO: disabled for now, until optmized - return false; - } - if (op->src[1]->type != op->src[2]->type) { - return false; - } - return has_simdgroup_mm; // TODO: over-restricted for vec-kernels - case GGML_OP_SSM_CONV: - case GGML_OP_SSM_SCAN: - case GGML_OP_RWKV_WKV6: - case GGML_OP_RWKV_WKV7: - return true; - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - return has_simdgroup_reduction && - (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32); - case GGML_OP_CPY: - case GGML_OP_DUP: - case GGML_OP_CONT: - { - switch (op->src[0]->type) { - case GGML_TYPE_F32: - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_BF16: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_IQ4_NL: - return true; - default: - return false; - } - case GGML_TYPE_F16: - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - return true; - default: - return false; - } - case GGML_TYPE_BF16: - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_BF16: - return true; - default: - return false; - } - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - return true; - default: - return false; - } - default: - return false; - }; - } - case GGML_OP_SET: - { - switch (op->src[0]->type) { - case GGML_TYPE_F32: - case GGML_TYPE_I32: - return true; - default: - return false; - }; - } - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_GET_ROWS: - { - return op->ne[3] == 1; - } - case GGML_OP_SET_ROWS: - { - if (op->src[0]->type != GGML_TYPE_F32) { - return false; - } - - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_BF16: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_IQ4_NL: - return true; - default: - return false; - }; - } - default: - return false; - } -} - -static int ggml_metal_encode_node( - ggml_backend_t backend, - int idx, - int idx_end, - id encoder, - struct ggml_metal_mem_pool * mem_pool) { - struct ggml_backend_metal_context * ctx = backend->context; - struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; - - struct ggml_cgraph * gf = ctx->gf; - - enum ggml_op ops[8]; - - struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx; - struct ggml_tensor * node = nodes[0]; - - //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op)); - - struct ggml_tensor * src0 = node->src[0]; - struct ggml_tensor * src1 = node->src[1]; - struct ggml_tensor * src2 = node->src[2]; - struct ggml_tensor * dst = node; - - if (ggml_is_empty(dst)) { - return 1; - } - - switch (dst->op) { - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - { - // noop -> next node - } return 1; - default: - { - } break; - } - - if (!ggml_metal_supports_op(ctx_dev, dst)) { - GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); - GGML_ABORT("unsupported op"); - } - - ggml_metal_mem_pool_clear(mem_pool); - - const int64_t ne00 = src0 ? src0->ne[0] : 0; - const int64_t ne01 = src0 ? src0->ne[1] : 0; - const int64_t ne02 = src0 ? src0->ne[2] : 0; - const int64_t ne03 = src0 ? src0->ne[3] : 0; - - const uint64_t nb00 = src0 ? src0->nb[0] : 0; - const uint64_t nb01 = src0 ? src0->nb[1] : 0; - const uint64_t nb02 = src0 ? src0->nb[2] : 0; - const uint64_t nb03 = src0 ? src0->nb[3] : 0; - - const int64_t ne10 = src1 ? src1->ne[0] : 0; - const int64_t ne11 = src1 ? src1->ne[1] : 0; - const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; - - const uint64_t nb10 = src1 ? src1->nb[0] : 0; - const uint64_t nb11 = src1 ? src1->nb[1] : 0; - const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; - - const int64_t ne20 = src2 ? src2->ne[0] : 0; - const int64_t ne21 = src2 ? src2->ne[1] : 0; - const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); - const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - - const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23); - - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; - - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; - - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; - const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; - const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; - - size_t offs_src0 = 0; - size_t offs_src1 = 0; - size_t offs_src2 = 0; - size_t offs_dst = 0; - - id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; - id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; - id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; - id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; - - int n_fuse = 1; - -#if 0 - GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); - if (src0) { - GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, - ggml_is_contiguous(src0), src0->name); - } - if (src1) { - GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, - ggml_is_contiguous(src1), src1->name); - } - if (dst) { - GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, - dst->name); - } -#endif - - id device = ctx_dev->mtl_device; - - switch (dst->op) { - case GGML_OP_CONCAT: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; - - const int32_t dim = ((const int32_t *) dst->op_params)[0]; - - ggml_metal_kargs_concat args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.dim =*/ dim, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous_rows(src0)); - GGML_ASSERT(ggml_is_contiguous_rows(src1)); - - const size_t offs = 0; - - bool bcast_row = false; - - id pipeline = nil; - - ggml_metal_kargs_bin args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.offs =*/ offs, - /*.o1 =*/ { offs_src1 }, - }; - - // c[0] = add(a, b[0]) - // c[1] = add(c[0], b[1]) - // c[2] = add(c[1], b[2]) - // ... - if (ctx_dev->use_fusion) { - ops[0] = GGML_OP_ADD; - ops[1] = GGML_OP_ADD; - ops[2] = GGML_OP_ADD; - ops[3] = GGML_OP_ADD; - ops[4] = GGML_OP_ADD; - ops[5] = GGML_OP_ADD; - ops[6] = GGML_OP_ADD; - ops[7] = GGML_OP_ADD; - - size_t offs_fuse; - id id_fuse; - - // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes - // across splits. idx_end indicates the last node in the current split - for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) { - if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) { - break; - } - - if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) { - break; - } - - // b[0] === b[1] === ... - if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) { - break; - } - - // only fuse nodes if src1 is in the same Metal buffer - id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse); - if (id_fuse != id_src1) { - break; - } - - ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++; - - args.o1[n_fuse + 1] = offs_fuse; - } - - ++n_fuse; - - if (ctx_dev->debug_fusion > 1 && n_fuse > 1) { - GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse); - } - } - - if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(src0)); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - switch (dst->op) { - case GGML_OP_ADD: - { - switch (n_fuse) { - case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break; - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break; - case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break; - case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break; - case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - bcast_row = true; - } else { - switch (dst->op) { - case GGML_OP_ADD: - { - switch (n_fuse) { - case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break; - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break; - case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break; - case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break; - case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } - - if (n_fuse > 1) { - id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:0 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - if (bcast_row) { - const int64_t n = ggml_nelements(dst)/4; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } else { - int nth = 32; - - while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { - nth *= 2; - } - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - } break; - case GGML_OP_ADD_ID: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - GGML_ASSERT(src2t == GGML_TYPE_I32); - GGML_ASSERT(dstt == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous_rows(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ID].pipeline; - - ggml_metal_kargs_add_id args = { - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb11 =*/ nb11, - /*.nb21 =*/ nb21, - - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_REPEAT: - { - id pipeline; - - switch (src0t) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; - case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; - case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - ggml_metal_kargs_repeat args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ACC: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - GGML_ASSERT(dstt == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - const size_t pnb1 = ((const int32_t *) dst->op_params)[0]; - const size_t pnb2 = ((const int32_t *) dst->op_params)[1]; - const size_t pnb3 = ((const int32_t *) dst->op_params)[2]; - const size_t offs = ((const int32_t *) dst->op_params)[3]; - - const bool inplace = (bool) ((const int32_t *) dst->op_params)[4]; - - if (!inplace) { - // run a separete kernel to cpy src->dst - // not sure how to avoid this - // TODO: make a simpler cpy_bytes kernel - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; - - ggml_metal_kargs_cpy args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; - - ggml_metal_kargs_bin args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ pnb1, - /*.nb02 =*/ pnb2, - /*.nb03 =*/ pnb3, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ pnb1, - /*.nb2 =*/ pnb2, - /*.nb3 =*/ pnb3, - /*.offs =*/ offs, - /*.o1 =*/ { offs_src1}, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:0 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_SCALE: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - float scale; - float bias; - memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float)); - - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - n /= 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - [encoder setBytes:&bias length:sizeof(bias) atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_CLAMP: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; - - float min; - float max; - memcpy(&min, ((const int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float)); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&min length:sizeof(min) atIndex:2]; - [encoder setBytes:&max length:sizeof(max) atIndex:3]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(node)) { - // we are not taking into account the strides, so for now require contiguous tensors - GGML_ASSERT(ggml_is_contiguous(src0)); - - case GGML_UNARY_OP_TANH: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_RELU: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SIGMOID: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU_ERF: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU_QUICK: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SILU: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_ELU: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ELU].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_NEG: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_ABS: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SGN: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_STEP: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_HARDSWISH: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_HARDSIGMOID: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_EXP: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - default: - { - GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_GLU: - { - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - if (src1) { - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - } - - id pipeline = nil; - - switch (ggml_get_glu_op(node)) { - case GGML_GLU_OP_REGLU: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline; - break; - case GGML_GLU_OP_GEGLU: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline; - break; - case GGML_GLU_OP_SWIGLU: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; - break; - case GGML_GLU_OP_SWIGLU_OAI: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_OAI].pipeline; - break; - case GGML_GLU_OP_GEGLU_ERF: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline; - break; - case GGML_GLU_OP_GEGLU_QUICK: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline; - break; - default: - GGML_ABORT("fatal error"); - } - - const int32_t swp = ggml_get_op_params_i32(dst, 1); - const float alpha = ggml_get_op_params_f32(dst, 2); - const float limit = ggml_get_op_params_f32(dst, 3); - - const int32_t i00 = swp ? ne0 : 0; - const int32_t i10 = swp ? 0 : ne0; - - ggml_metal_kargs_glu args = { - /*.ne00 =*/ ne00, - /*.nb01 =*/ nb01, - /*.ne10 =*/ src1 ? ne10 : ne00, - /*.nb11 =*/ src1 ? nb11 : nb01, - /*.ne0 =*/ ne0, - /*.nb1 =*/ nb1, - /*.i00 =*/ src1 ? 0 : i00, - /*.i10 =*/ src1 ? 0 : i10, - /*.alpha=*/ alpha, - /*.limit=*/ limit - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - if (src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; - - const int64_t nrows = ggml_nrows(src0); - - const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_SQR: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SQRT: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SIN: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_COS: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - { - GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - - id pipeline = nil; - - switch (dst->op) { - case GGML_OP_SUM_ROWS: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; - break; - case GGML_OP_MEAN: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline; - break; - default: - GGML_ABORT("fatal error"); - } - - int nth = 32; // SIMD width - - while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { - nth *= 2; - } - - nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); - nth = MIN(nth, ne00); - - ggml_metal_kargs_sum_rows args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_SOFT_MAX: - { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - - int nth = 32; // SIMD width - - id pipeline = nil; - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - - if (ne00%4 == 0) { - while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; - } - } else { - while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; - } - } - - float scale; - float max_bias; - - memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); - - const uint32_t n_head = src0->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - -// use this branch to test the ggml_metal_mem_pool functionality -#if 0 - // cpy to tmp buffer in MTLHeap - - id h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0)); - if (!h_src0) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0)); - return 0; - } - - offs_src0 = 0; - - ggml_metal_kargs_cpy args_cpy = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne00, - /*.ne1 =*/ ne01, - /*.ne2 =*/ ne02, - /*.ne3 =*/ ne03, - /*.nb0 =*/ nb00, - /*.nb1 =*/ nb01, - /*.nb2 =*/ nb02, - /*.nb3 =*/ nb03, - }; - - if (src0->type == GGML_TYPE_F16) { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline]; - } else { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline]; - } - [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:h_src0 offset:0 atIndex:2]; - - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type)); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)]; - -#else - id h_src0 = id_src0; -#endif - // softmax - - ggml_metal_kargs_soft_max args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.scale =*/ scale, - /*.max_bias =*/ max_bias, - /*.m0 =*/ m0, - /*.m1 =*/ m1, - /*.n_head_log2 =*/ n_head_log2, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:h_src0 offset:offs_src0 atIndex:0]; - if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1]; - } - if (id_src2) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - } else { - [encoder setBuffer:h_src0 offset:offs_src0 atIndex:2]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&args length:sizeof(args) atIndex:4]; - - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_DIAG_MASK_INF: - { - const int n_past = ((const int32_t *)(dst->op_params))[0]; - - id pipeline = nil; - - if (ne00%8 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; - } - - ggml_metal_kargs_diag_mask_inf args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.n_past =*/ n_past, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - - if (ne00%8 == 0) { - [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - else { - [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - } break; - case GGML_OP_SSM_CONV: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; - - ggml_metal_kargs_ssm_conv args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SSM_SCAN: - { - struct ggml_tensor * src3 = node->src[3]; - struct ggml_tensor * src4 = node->src[4]; - struct ggml_tensor * src5 = node->src[5]; - struct ggml_tensor * src6 = node->src[6]; - - GGML_ASSERT(src3); - GGML_ASSERT(src4); - GGML_ASSERT(src5); - GGML_ASSERT(src6); - - size_t offs_src3 = 0; - size_t offs_src4 = 0; - size_t offs_src5 = 0; - size_t offs_src6 = 0; - - id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; - id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; - id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; - id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; - - const int64_t ne30 = src3->ne[0]; - const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); - - const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30); - const uint64_t nb31 = src3->nb[1]; - - const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); - const int64_t ne41 = src4->ne[1]; - const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); - const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); - - const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40); - const uint64_t nb41 = src4->nb[1]; - const uint64_t nb42 = src4->nb[2]; - const uint64_t nb43 = src4->nb[3]; - - const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); - const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); - const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); - const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); - - const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50); - const uint64_t nb51 = src5->nb[1]; - const uint64_t nb52 = src5->nb[2]; - const uint64_t nb53 = src5->nb[3]; - - const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); - - const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); - - const int64_t d_state = ne00; - const int64_t d_inner = ne01; - const int64_t n_head = ne02; - const int64_t n_group = ne41; - const int64_t n_seq_tokens = ne12; - const int64_t n_seqs = ne13; - - id pipeline = nil; - - if (ne30 == 1) { - // Mamba-2 - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; - } - - ggml_metal_kargs_ssm_scan args = { - /*.d_state =*/ d_state, - /*.d_inner =*/ d_inner, - /*.n_head =*/ n_head, - /*.n_group =*/ n_group, - /*.n_seq_tokens =*/ n_seq_tokens, - /*.n_seqs =*/ n_seqs, - /*.s_off =*/ ggml_nelements(src1) * sizeof(float), - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb21 =*/ nb21, - /*.nb22 =*/ nb22, - /*.nb31 =*/ nb31, - /*.nb41 =*/ nb41, - /*.nb42 =*/ nb42, - /*.nb43 =*/ nb43, - /*.nb51 =*/ nb51, - /*.nb52 =*/ nb52, - /*.nb53 =*/ nb53, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; - [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; - [encoder setBytes:&args length:sizeof(args) atIndex:8]; - - // One shared memory bucket for each simd group in the threadgroup - // NOTE: Metal kernels require the buffer size to be multiple of 16 bytes - // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength - if (d_state >= 32) { - GGML_ASSERT((int64_t)(d_state / 32) <= 32); - const int64_t shmem_size = 32; - GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup); - [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0]; - } - - if (ne30 == 1) { - // Mamba-2 - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; - } else { - GGML_ASSERT(d_inner == 1); - [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; - } - } break; - case GGML_OP_RWKV_WKV6: - { - const int64_t B = dst->src[5]->ne[1]; - const int64_t T = dst->src[0]->ne[2]; - const int64_t C = dst->ne[0]; - const int64_t H = dst->src[0]->ne[1]; - - GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); - GGML_ASSERT(C % H == 0); - GGML_ASSERT(C / H == 64); - - size_t offs_src3 = 0; - size_t offs_src4 = 0; - size_t offs_src5 = 0; - - id id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil; - id id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil; - id id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; - [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - - [encoder setBytes:&B length:sizeof(B) atIndex:7]; - [encoder setBytes:&T length:sizeof(T) atIndex:8]; - [encoder setBytes:&C length:sizeof(C) atIndex:9]; - [encoder setBytes:&H length:sizeof(H) atIndex:10]; - - [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)]; - } break; - case GGML_OP_RWKV_WKV7: - { - const int64_t B = dst->src[6]->ne[1]; - const int64_t T = dst->src[0]->ne[2]; - const int64_t C = dst->ne[0]; - const int64_t H = dst->src[0]->ne[1]; - - GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32); - GGML_ASSERT(C % H == 0); - GGML_ASSERT(C / H == 64); - - size_t offs_src3 = 0; - size_t offs_src4 = 0; - size_t offs_src5 = 0; - size_t offs_src6 = 0; - - id id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil; - id id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil; - id id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil; - id id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; - [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; - - [encoder setBytes:&B length:sizeof(B) atIndex:8]; - [encoder setBytes:&T length:sizeof(T) atIndex:9]; - [encoder setBytes:&C length:sizeof(C) atIndex:10]; - [encoder setBytes:&H length:sizeof(H) atIndex:11]; - - [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)]; - } break; - case GGML_OP_MUL_MAT: - { - GGML_ASSERT(ne00 == ne10); - - GGML_ASSERT(ne12 % ne02 == 0); - GGML_ASSERT(ne13 % ne03 == 0); - - const uint32_t r2 = ne12/ne02; - const uint32_t r3 = ne13/ne03; - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - const int ne11_mm_min = 4; - - // first try to use small-batch mat-mv kernels - // these should be efficient for BS [2, ~8] - if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) && - ( - ( - ( - src0t == GGML_TYPE_F16 || // TODO: helper function - src0t == GGML_TYPE_Q4_0 || - src0t == GGML_TYPE_Q4_1 || - src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || - src0t == GGML_TYPE_Q8_0 || - src0t == GGML_TYPE_MXFP4 || - src0t == GGML_TYPE_IQ4_NL || - false) && (ne11 >= 2 && ne11 <= 8) - ) || - ( - ( - src0t == GGML_TYPE_Q4_K || - src0t == GGML_TYPE_Q5_K || - src0t == GGML_TYPE_Q6_K || - false) && (ne11 >= 4 && ne11 <= 8) - ) - ) - ) { - // TODO: determine the optimal parameters based on grid utilization - // I still don't know why we should not always use the maximum available threads: - // - // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32 - // - // my current hypothesis is that the work grid is not evenly divisible for different nsg - // values and there can be some tail effects when nsg is high. need to confirm this - // - const int nsg = 2; // num simdgroups per threadgroup - const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup - const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time) - const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup - int r1ptg = 4; // num src1 rows per threadgroup - - // note: not sure how optimal are those across all different hardware. there might be someting cleverer - switch (ne11) { - case 2: - r1ptg = 2; break; - case 3: - case 6: - r1ptg = 3; break; - case 4: - case 7: - case 8: - r1ptg = 4; break; - case 5: - r1ptg = 5; break; - }; - - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F16: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q4_0: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q4_1: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q5_0: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q5_1: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q8_0: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_MXFP4: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q4_K: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q5_K: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q6_K: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_IQ4_NL: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - default: GGML_ABORT("not implemented"); - } - - ggml_metal_kargs_mul_mv_ext args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.r2 =*/ r2, - /*.r3 =*/ r3, - /*.nsg =*/ nsg, - /*.nxpsg =*/ nxpsg, - /*.r1ptg =*/ r1ptg, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg); - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } else - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if ([device supportsFamily:MTLGPUFamilyApple7] && - !ggml_is_transposed(src0) && - !ggml_is_transposed(src1) && - src1t == GGML_TYPE_F32 && - ne00 % 32 == 0 && ne00 >= 64 && - (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { - //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - - // some Metal matrix data types require aligned pointers - // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (src0->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; - case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; - default: break; - } - - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; - default: GGML_ABORT("MUL MAT-MAT not implemented"); - } - - ggml_metal_kargs_mul_mm args = { - /*.ne00 =*/ ne00, - /*.ne02 =*/ ne02, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne12 =*/ ne12, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.r2 =*/ r2, - /*.r3 =*/ r3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } else { - id pipeline = nil; - - int nsg = 0; // number of simdgroups - int nr0 = 0; // number of src0 rows per simdgroup - int nr1 = 1; // number of src1 rows per threadgroup - - size_t smem = 0; // shared memory - - // use custom matrix x vector kernel - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nsg = 1; - nr0 = 1; - nr1 = 4; - if (ne00 == 4) { - nr0 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; - } - } break; - case GGML_TYPE_F16: - { - nsg = 1; - nr0 = 1; - if (src1t == GGML_TYPE_F32) { - if (ne00 == 4) { - nr0 = 32; - nr1 = 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline; - } else if (ne11 * ne12 < 4) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; - nr1 = ne11; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; - nr1 = 4; - } - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; - nr1 = 4; - } - } break; - case GGML_TYPE_BF16: - { - nsg = 1; - nr0 = 1; - if (src1t == GGML_TYPE_F32) { - if (ne00 == 4) { - nr0 = 32; - nr1 = 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline; - } else if (ne11 * ne12 < 4) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; - nr1 = ne11; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; - nr1 = 4; - } - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; - nr1 = 4; - } - } break; - case GGML_TYPE_Q4_0: - { - nsg = N_SG_Q4_0; - nr0 = N_R0_Q4_0; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; - } break; - case GGML_TYPE_Q4_1: - { - nsg = N_SG_Q4_1; - nr0 = N_R0_Q4_1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; - } break; - case GGML_TYPE_Q5_0: - { - nsg = N_SG_Q5_0; - nr0 = N_R0_Q5_0; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; - } break; - case GGML_TYPE_Q5_1: - { - nsg = N_SG_Q5_1; - nr0 = N_R0_Q5_1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; - } break; - case GGML_TYPE_Q8_0: - { - nsg = N_SG_Q8_0; - nr0 = N_R0_Q8_0; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; - } break; - case GGML_TYPE_MXFP4: - { - nsg = N_SG_MXFP4; - nr0 = N_R0_MXFP4; - smem = 32*sizeof(float); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline; - } break; - case GGML_TYPE_Q2_K: - { - nsg = N_SG_Q2_K; - nr0 = N_R0_Q2_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; - } break; - case GGML_TYPE_Q3_K: - { - nsg = N_SG_Q3_K; - nr0 = N_R0_Q3_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; - } break; - case GGML_TYPE_Q4_K: - { - nsg = N_SG_Q4_K; - nr0 = N_R0_Q4_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; - } break; - case GGML_TYPE_Q5_K: - { - nsg = N_SG_Q5_K; - nr0 = N_R0_Q5_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; - } break; - case GGML_TYPE_Q6_K: - { - nsg = N_SG_Q6_K; - nr0 = N_R0_Q6_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XXS: - { - nsg = N_SG_IQ2_XXS; - nr0 = N_R0_IQ2_XXS; - smem = 256*8+128; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XS: - { - nsg = N_SG_IQ2_XS; - nr0 = N_R0_IQ2_XS; - smem = 512*8+128; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_XXS: - { - nsg = N_SG_IQ3_XXS; - nr0 = N_R0_IQ3_XXS; - smem = 256*4+128; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_S: - { - nsg = N_SG_IQ3_S; - nr0 = N_R0_IQ3_S; - smem = 512*4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; - } break; - case GGML_TYPE_IQ2_S: - { - nsg = N_SG_IQ2_S; - nr0 = N_R0_IQ2_S; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_S: - { - nsg = N_SG_IQ1_S; - nr0 = N_R0_IQ1_S; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_M: - { - nsg = N_SG_IQ1_M; - nr0 = N_R0_IQ1_M; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; - } break; - case GGML_TYPE_IQ4_NL: - { - nsg = N_SG_IQ4_NL; - nr0 = N_R0_IQ4_NL; - smem = 32*sizeof(float); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; - } break; - case GGML_TYPE_IQ4_XS: - { - nsg = N_SG_IQ4_XS; - nr0 = N_R0_IQ4_XS; - smem = 32*sizeof(float); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; - } break; - default: - { - GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t); - GGML_ABORT("not implemented"); - } - }; - - ggml_metal_kargs_mul_mv args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.r2 =*/ r2, - /*.r3 =*/ r3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - if (smem > 0) { - [encoder setThreadgroupMemoryLength:smem atIndex:0]; - } - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } - } break; - case GGML_OP_MUL_MAT_ID: - { - // src2 = ids - GGML_ASSERT(src2t == GGML_TYPE_I32); - - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - - GGML_ASSERT(src1t == GGML_TYPE_F32); - - GGML_ASSERT(ne03 == 1); - GGML_ASSERT(ne13 == 1); - - const uint32_t r2 = 1; - const uint32_t r3 = 1; - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - // ne20 = n_used_experts - // ne21 = n_rows (batch size) - const int ne21_mm_id_min = 32; - - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if ([device supportsFamily:MTLGPUFamilyApple7] && - ne00 % 32 == 0 && ne00 >= 64 && - (ne21 >= ne21_mm_id_min)) { - GGML_ASSERT(ne00 % 4 == 0); - - // some Metal matrix data types require aligned pointers - // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (src0->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; - case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; - default: break; - } - - const int64_t neh10 = ne10; // n_embd - const int64_t neh11 = ne21; // n_tokens - const int64_t neh12 = ne02; // n_expert - - const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16); - const uint64_t nbh11 = nbh10*neh10; - const uint64_t nbh12 = nbh11*neh11; - const uint64_t nbh13 = nbh12*neh12; - - const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12; - id h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1); - if (!h_src1) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1); - return 0; - } - - const int64_t neh0 = ne0; - const int64_t neh1 = ne21; - const int64_t neh2 = ne02; - - const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32); - const uint64_t nbh1 = nbh0*neh0; - const uint64_t nbh2 = nbh1*neh1; - //const uint64_t nbh3 = nbh2*neh2; - - const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2; - id h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst); - if (!h_dst) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst); - return 0; - } - - // tokens per expert - const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02; - id h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe); - if (!h_tpe) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe); - return 0; - } - - // id map - // [n_expert_used, n_tokens] - const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21; - id h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids); - if (!h_ids) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids); - return 0; - } - - { - const int nth = MIN(1024, ne10/4); - - ggml_metal_kargs_mul_mm_id_map0 args = { - ne10, - ne11, // n_expert_used (bcast) - nb11, - nb12, - neh11, // n_tokens - nbh11, - ne20, // n_expert_used - nb21, - }; - - id pipeline = nil; - - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer: h_src1 offset:0 atIndex:3]; - [encoder setBuffer: h_tpe offset:0 atIndex:4]; - [encoder setBuffer: h_ids offset:0 atIndex:5]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - - { - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break; - case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break; - default: GGML_ABORT("MUL_MAT_ID not implemented"); - } - - ggml_metal_kargs_mul_mm_id args = { - /*.ne00 =*/ ne00, - /*.ne02 =*/ ne02, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.neh12 =*/ neh12, - /*.nbh10 =*/ nbh10, - /*.nbh11 =*/ nbh11, - /*.nbh12 =*/ nbh12, - /*.nbh13 =*/ nbh13, - /*.neh0 =*/ neh0, - /*.neh1 =*/ neh1, - /*.r2 =*/ r2, - /*.r3 =*/ r3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer: h_src1 offset:0 atIndex:2]; - [encoder setBuffer: h_tpe offset:0 atIndex:3]; - [encoder setBuffer: h_dst offset:0 atIndex:4]; - - [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } - - { - GGML_ASSERT(ne0 % 4 == 0); - - const int nth = MIN(1024, ne0/4); - - ggml_metal_kargs_mul_mm_id_map1 args = { - ne20, // n_expert_used - neh0, - neh1, - nbh1, - nbh2, - ne0, - nb1, - nb2, - }; - - id pipeline = nil; - - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer: h_dst offset:0 atIndex:1]; - [encoder setBuffer: h_ids offset:0 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - } else { - id pipeline = nil; - - int nsg = 0; // number of simdgroups - int nr0 = 0; // number of src0 rows per simdgroup - int nr1 = 1; // number of src1 rows per threadgroup - - size_t smem = 0; // shared memory - - // use custom matrix x vector kernel - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nsg = 1; - nr0 = 1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nsg = 1; - nr0 = 1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; - } break; - case GGML_TYPE_BF16: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nsg = 1; - nr0 = 1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; - } break; - case GGML_TYPE_Q4_0: - { - nsg = N_SG_Q4_0; - nr0 = N_R0_Q4_0; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; - } break; - case GGML_TYPE_Q4_1: - { - nsg = N_SG_Q4_1; - nr0 = N_R0_Q4_1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; - } break; - case GGML_TYPE_Q5_0: - { - nsg = N_SG_Q5_0; - nr0 = N_R0_Q5_0; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; - } break; - case GGML_TYPE_Q5_1: - { - nsg = N_SG_Q5_1; - nr0 = N_R0_Q5_1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; - } break; - case GGML_TYPE_Q8_0: - { - nsg = N_SG_Q8_0; - nr0 = N_R0_Q8_0; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; - } break; - case GGML_TYPE_MXFP4: - { - nsg = N_SG_MXFP4; - nr0 = N_R0_MXFP4; - smem = 32*sizeof(float); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline; - } break; - case GGML_TYPE_Q2_K: - { - nsg = N_SG_Q2_K; - nr0 = N_R0_Q2_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; - } break; - case GGML_TYPE_Q3_K: - { - nsg = N_SG_Q3_K; - nr0 = N_R0_Q3_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; - } break; - case GGML_TYPE_Q4_K: - { - nsg = N_SG_Q4_K; - nr0 = N_R0_Q4_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; - } break; - case GGML_TYPE_Q5_K: - { - nsg = N_SG_Q5_K; - nr0 = N_R0_Q5_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; - } break; - case GGML_TYPE_Q6_K: - { - nsg = N_SG_Q6_K; - nr0 = N_R0_Q6_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XXS: - { - nsg = N_SG_IQ2_XXS; - nr0 = N_R0_IQ2_XXS; - smem = 256*8+128; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XS: - { - nsg = N_SG_IQ2_XS; - nr0 = N_R0_IQ2_XS; - smem = 512*8+128; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_XXS: - { - nsg = N_SG_IQ3_XXS; - nr0 = N_R0_IQ3_XXS; - smem = 256*4+128; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_S: - { - nsg = N_SG_IQ3_S; - nr0 = N_R0_IQ3_S; - smem = 512*4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; - } break; - case GGML_TYPE_IQ2_S: - { - nsg = N_SG_IQ2_S; - nr0 = N_R0_IQ2_S; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_S: - { - nsg = N_SG_IQ1_S; - nr0 = N_R0_IQ1_S; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_M: - { - nsg = N_SG_IQ1_M; - nr0 = N_R0_IQ1_M; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; - } break; - case GGML_TYPE_IQ4_NL: - { - nsg = N_SG_IQ4_NL; - nr0 = N_R0_IQ4_NL; - smem = 32*sizeof(float); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; - } break; - case GGML_TYPE_IQ4_XS: - { - nsg = N_SG_IQ4_XS; - nr0 = N_R0_IQ4_XS; - smem = 32*sizeof(float); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; - } break; - default: - { - GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t); - GGML_ABORT("not implemented"); - } - }; - - if (ggml_is_quantized(src0t)) { - GGML_ASSERT(ne00 >= nsg*nr0); - } - - ggml_metal_kargs_mul_mv_id args = { - /*.nei0 =*/ ne20, - /*.nei1 =*/ ne21, - /*.nbi1 =*/ nb21, - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.nb1 =*/ nb1, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; - - const int64_t _ne1 = 1; - const int64_t ne123 = ne20*ne21; - - if (smem > 0) { - [encoder setThreadgroupMemoryLength:smem atIndex:0]; - } - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } - } break; - case GGML_OP_GET_ROWS: - { - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; - case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; - case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; - default: GGML_ABORT("not implemented"); - } - - ggml_metal_kargs_get_rows args = { - /*.ne00 =*/ ne00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne10 =*/ ne10, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; - } break; - case GGML_OP_SET_ROWS: - { - id pipeline = nil; - - switch (dst->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break; - default: GGML_ABORT("not implemented"); - } - - const int32_t nk0 = ne0/ggml_blck_size(dst->type); - - int nth = 32; // SIMD width - - while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { - nth *= 2; - } - - int nrptg = 1; - if (nth > nk0) { - nrptg = (nth + nk0 - 1)/nk0; - nth = nk0; - - if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) { - nrptg--; - } - } - - nth = MIN(nth, nk0); - - ggml_metal_kargs_set_rows args = { - /*.nk0 =*/ nk0, - /*.ne01 =*/ ne01, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)]; - } break; - case GGML_OP_RMS_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_rows(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - ggml_metal_kargs_rms_norm args = { - /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.eps =*/ eps, - /*.nef1 =*/ { ne01 }, - /*.nef2 =*/ { ne02 }, - /*.nef3 =*/ { ne03 }, - /*.nbf1 =*/ { nb01 }, - /*.nbf2 =*/ { nb02 }, - /*.nbf3 =*/ { nb03 }, - }; - - size_t offs_fuse[2] = { 0, 0 }; - id id_fuse[2] = { id_src0, id_src0 }; - - // d[0] = rms_norm(a) - // d[1] = mul(d[0], b) - // d[2] = add(d[1], c) - if (ctx_dev->use_fusion) { - ops[0] = GGML_OP_RMS_NORM; - ops[1] = GGML_OP_MUL; - ops[2] = GGML_OP_ADD; - - for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) { - if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) { - break; - } - - if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) { - break; - } - - if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) { - break; - } - - if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) { - break; - } - - if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) { - break; - } - - ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++; - - id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]); - - args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1]; - args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2]; - args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3]; - - args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1]; - args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2]; - args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3]; - } - - ++n_fuse; - - if (ctx_dev->debug_fusion > 1 && n_fuse > 1) { - if (n_fuse == 2) { - GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__); - } - if (n_fuse == 3) { - GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__); - } - } - } - - if (n_fuse > 1) { - id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst); - } - - id pipeline; - - switch (n_fuse) { - case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break; - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break; - default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse); - } - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { - nth *= 2; - } - - nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); - nth = MIN(nth, ne00/4); - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2]; - [encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_L2_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline; - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { - nth *= 2; - } - - nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); - nth = MIN(nth, ne00/4); - - ggml_metal_kargs_l2_norm args = { - /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb01 =*/ nb01, - /*.eps =*/ eps, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_GROUP_NORM: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - float eps; - memcpy(&eps, dst->op_params + 1, sizeof(float)); - - const int32_t n_groups = ((const int32_t *) dst->op_params)[0]; - - int nth = 32; // SIMD width - - //while (nth < ne00/4 && nth < 1024) { - // nth *= 2; - //} - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; - - ggml_metal_kargs_group_norm args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.n_groups =*/ n_groups, - /*.eps =*/ eps, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { - nth *= 2; - } - - nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); - nth = MIN(nth, ne00/4); - - ggml_metal_kargs_norm args = { - /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb01 =*/ nb01, - /*.eps =*/ eps, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ROPE: - { - - // make sure we have one or more position id(ne10) per token(ne02) - GGML_ASSERT(ne10 % ne02 == 0); - GGML_ASSERT(ne10 >= ne02); - - const int nth = MIN(1024, ne00); - - const int n_past = ((const int32_t *) dst->op_params)[0]; - const int n_dims = ((const int32_t *) dst->op_params)[1]; - const int mode = ((const int32_t *) dst->op_params)[2]; - // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal - const int n_ctx_orig = ((const int32_t *) dst->op_params)[4]; - - float freq_base; - float freq_scale; - float ext_factor; - float attn_factor; - float beta_fast; - float beta_slow; - - memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (const int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float)); - - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; - const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; - const bool is_vision = mode == GGML_ROPE_TYPE_VISION; - - // mrope - const int sect_0 = ((const int32_t *) dst->op_params)[11]; - const int sect_1 = ((const int32_t *) dst->op_params)[12]; - const int sect_2 = ((const int32_t *) dst->op_params)[13]; - const int sect_3 = ((const int32_t *) dst->op_params)[14]; - - id pipeline = nil; - - if (is_neox) { - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } else if (is_mrope && !is_vision) { - GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } else if (is_vision) { - GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } else { - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } - - ggml_metal_kargs_rope args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.n_past =*/ n_past, - /*.n_dims =*/ n_dims, - /*.n_ctx_orig =*/ n_ctx_orig, - /*.freq_base =*/ freq_base, - /*.freq_scale =*/ freq_scale, - /*.ext_factor =*/ ext_factor, - /*.attn_factor =*/ attn_factor, - /*.beta_fast =*/ beta_fast, - /*.beta_slow =*/ beta_slow, - /* sect_0 =*/ sect_0, - /* sect_1 =*/ sect_1, - /* sect_2 =*/ sect_2, - /* sect_3 =*/ sect_3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_IM2COL: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; - - const int32_t N = src1->ne[is_2D ? 3 : 2]; - const int32_t IC = src1->ne[is_2D ? 2 : 1]; - const int32_t IH = is_2D ? src1->ne[1] : 1; - const int32_t IW = src1->ne[0]; - - const int32_t KH = is_2D ? src0->ne[1] : 1; - const int32_t KW = src0->ne[0]; - - const int32_t OH = is_2D ? dst->ne[2] : 1; - const int32_t OW = dst->ne[1]; - - const int32_t CHW = IC * KH * KW; - - const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; - const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; - - const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup; - - switch (dst->type) { - case GGML_TYPE_F32: { - pipeline = (is_gt_mttpt ? - ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline - : - ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline); - } break; - case GGML_TYPE_F16: { - pipeline = (is_gt_mttpt ? - ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline - : - ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline); - } break; - default: GGML_ABORT("fatal error"); - }; - - ggml_metal_kargs_im2col args = { - /*.ofs0 =*/ ofs0, - /*.ofs1 =*/ ofs1, - /*.IW =*/ IW, - /*.IH =*/ IH, - /*.CHW =*/ CHW, - /*.s0 =*/ s0, - /*.s1 =*/ s1, - /*.p0 =*/ p0, - /*.p1 =*/ p1, - /*.d0 =*/ d0, - /*.d1 =*/ d1, - /*.N =*/ N, - /*.KH =*/ KH, - /*.KW =*/ KW, - /*.KHW =*/ KH * KW, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - - if (is_gt_mttpt) { - const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N); - - const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); - - [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; - } else { - [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; - } - } break; - case GGML_OP_CONV_TRANSPOSE_1D: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - - const int32_t IC = src1->ne[1]; - const int32_t IL = src1->ne[0]; - - const int32_t K = src0->ne[0]; - - const int32_t OL = dst->ne[0]; - const int32_t OC = dst->ne[1]; - - id pipeline; - - switch (src0->type) { - case GGML_TYPE_F32: { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline; - } break; - case GGML_TYPE_F16: { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline; - } break; - default: GGML_ABORT("fatal error"); - }; - - ggml_metal_kargs_conv_transpose_1d args = { - /*.IC =*/ IC, - /*.IL =*/ IL, - /*.K =*/ K, - /*.s0 =*/ s0, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_UPSCALE: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const float sf0 = (float)ne0/src0->ne[0]; - const float sf1 = (float)ne1/src0->ne[1]; - const float sf2 = (float)ne2/src0->ne[2]; - const float sf3 = (float)ne3/src0->ne[3]; - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; - - ggml_metal_kargs_upscale args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.sf0 =*/ sf0, - /*.sf1 =*/ sf1, - /*.sf2 =*/ sf2, - /*.sf3 =*/ sf3 - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_PAD: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; - - ggml_metal_kargs_pad args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3 - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_PAD_REFLECT_1D: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const int32_t p0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[1]; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline; - - ggml_metal_kargs_pad_reflect_1d args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.p0 =*/ p0, - /*.p1 =*/ p1 - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ARANGE: - { - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - float start; - float step; - - memcpy(&start, ((const int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&step, ((const int32_t *) dst->op_params) + 2, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; - - ggml_metal_kargs_arange args = { - /*.ne0 =*/ ne0, - /*.start =*/ start, - /*.step =*/ step - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; - [encoder setBytes:&args length:sizeof(args) atIndex:1]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_TIMESTEP_EMBEDDING: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const int dim = dst->op_params[0]; - const int max_period = dst->op_params[1]; - - const int half = dim / 2; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; - - ggml_metal_kargs_timestep_embedding args = { - /*.nb1 =*/ nb1, - /*.dim =*/ dim, - /*.max_period =*/ max_period - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - - const int nth = MIN(1024, half); - - [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ARGSORT: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_I32); - - const int nrows = ggml_nrows(src0); - - enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - - // bitonic sort requires the number of elements to be power of 2 - int64_t ne00_padded = 1; - while (ne00_padded < ne00) { - ne00_padded *= 2; - } - - // Metal kernels require the buffer size to be multiple of 16 bytes - // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength - const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); - - id pipeline = nil; - - switch (order) { - case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; - case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - - ggml_metal_kargs_argsort args = { - /*.ncols =*/ ne00, - /*.ncols_pad =*/ ne00_padded - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; - } break; - case GGML_OP_LEAKY_RELU: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - float slope; - memcpy(&slope, dst->op_params, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; - - ggml_metal_kargs_leaky_relu args = { - /*.slope =*/ slope - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_FLASH_ATTN_EXT: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ne11 % 32 == 0); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == src2->type); - - //GGML_ASSERT(ggml_are_same_shape (src1, src2)); - GGML_ASSERT(ne11 == ne21); - GGML_ASSERT(ne12 == ne22); - - struct ggml_tensor * src3 = node->src[3]; // mask - struct ggml_tensor * src4 = node->src[4]; // sinks - - size_t offs_src3 = 0; - size_t offs_src4 = 0; - - id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; - id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; - - GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); - GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && - "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); - - const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); - //const int64_t ne31 = src3 ? src3->ne[1] : 0; - const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); - const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); - - const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); - const uint64_t nb31 = src3 ? src3->nb[1] : 0; - const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); - const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); - - float scale; - float max_bias; - float logit_softcap; - memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); - memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); - - if (logit_softcap != 0.0f) { - scale /= logit_softcap; - } - - const uint32_t n_head = src0->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - id pipeline = nil; - - bool use_vec_kernel = false; - - // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) - // for now avoiding mainly to keep the number of templates/kernels a bit lower - // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612 - if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) { - switch (src1->type) { - case GGML_TYPE_F16: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_BF16: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q4_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q4_1: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q5_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q5_1: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q8_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - use_vec_kernel = true; - - switch (ne00) { - case 64: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 96: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 128: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 192: - { - if (ne20 == 128) { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } - } break; - case 256: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 576: - { - if (ne20 == 512) { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - GGML_LOG_ERROR("unsupported size: %lld\n", ne20); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - - ggml_metal_kargs_flash_attn_ext args = { - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne11 =*/ ne11, - /*.ne_12_2 =*/ ne12, - /*.ne_12_3 =*/ ne13, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb21 =*/ nb21, - /*.nb22 =*/ nb22, - /*.nb23 =*/ nb23, - /*.ne32 =*/ ne32, - /*.ne33 =*/ ne33, - /*.nb31 =*/ nb31, - /*.nb32 =*/ nb32, - /*.nb33 =*/ nb33, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.scale =*/ scale, - /*.max_bias =*/ max_bias, - /*.m0 =*/ m0, - /*.m1 =*/ m1, - /*.n_head_log2 =*/ n_head_log2, - /*.logit_softcap =*/ logit_softcap, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; - } - if (id_src4) { - [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - - if (!use_vec_kernel) { - // half8x8 kernel - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); - - const int is_q = ggml_is_quantized(src1->type) ? 1 : 0; - - // 2*(2*ncpsg + nqptg)*(nsg) - // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) - // - // 16*32*(nsg) - // the shared memory needed for the simdgroups to load the KV cache - // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG - // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) - - int64_t nsgmax = 2; - - while (true) { - const size_t smem = FATTN_SMEM(nsgmax); - if (smem > device.maxThreadgroupMemoryLength) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; - - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - - const size_t smem = FATTN_SMEM(nsg); - - //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); - GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:smem atIndex:0]; -#undef FATTN_SMEM - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } else { - // half4x4 kernel - const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 1 == 0); - GGML_ASSERT(ncpsg % 32 == 0); - - // ne00 + 2*ncpsg*(nsg) - // for each query, we load it as f16 in shared memory (ne00) - // and store the soft_max values and the mask - // - // ne00*(nsg) - // each simdgroup has a full f32 head vector in shared mem to accumulate results - // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16)) - - int64_t nsgmax = 2; - while (true) { - const size_t smem = FATTN_SMEM(nsgmax); - if (smem > device.maxThreadgroupMemoryLength) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; - - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); - - int64_t nsg = 1; - while (nsg <= nsgt) { - nsg *= 2; - } - nsg /= 2; - - const size_t smem = FATTN_SMEM(nsg); - - //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); - GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:smem atIndex:0]; -#undef FATTN_SMEM - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } - } break; - case GGML_OP_DUP: - case GGML_OP_CPY: - case GGML_OP_CONT: - { - id pipeline = nil; - - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); - - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_F16: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_BF16: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_Q4_0: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_Q4_1: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_Q5_0: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_Q5_1: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_Q8_0: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - default: GGML_ABORT("not implemented"); - } - - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - - // TODO: support - //const int32_t nk00 = ne00/ggml_blck_size(dst->type); - const int32_t nk00 = ne00; - - int nth = 32; // SIMD width - - while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { - nth *= 2; - } - - nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); - - // when rows are small, we can batch them together in a single threadgroup - int nrptg = 1; - - // TODO: relax this constraint in the future - if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) { - if (nth > nk00) { - nrptg = (nth + nk00 - 1)/nk00; - nth = nk00; - - if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) { - nrptg--; - } - } - } - - nth = MIN(nth, nk00); - - ggml_metal_kargs_cpy args = { - /*.ne00 =*/ nk00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)]; - } break; - case GGML_OP_SET: - { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - - // src0 and dst as viewed during set - const size_t dst_nb0 = ggml_element_size(src0); - - const size_t dst_nb1 = ((int32_t *) dst->op_params)[0]; - const size_t dst_nb2 = ((int32_t *) dst->op_params)[1]; - const size_t dst_nb3 = ((int32_t *) dst->op_params)[2]; - const size_t offset = ((int32_t *) dst->op_params)[3]; - const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace) { - memcpy(((char *) dst->data), ((char *) src0->data), ggml_nbytes(dst)); - } - - const int im0 = (ne10 == 0 ? 0 : ne10-1); - const int im1 = (ne11 == 0 ? 0 : ne11-1); - const int im2 = (ne12 == 0 ? 0 : ne12-1); - const int im3 = (ne13 == 0 ? 0 : ne13-1); - - GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= ggml_nbytes(dst)); - - id pipeline = nil; - - switch (src0t) { - case GGML_TYPE_F32: - GGML_ASSERT(nb10 == sizeof(float)); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break; - case GGML_TYPE_I32: - GGML_ASSERT(nb10 == sizeof(int32_t)); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - ggml_metal_kargs_set args = { - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb1 =*/ dst_nb1, - /*.nb2 =*/ dst_nb2, - /*.nb3 =*/ dst_nb3, - /*.offs =*/ offset, - /*.inplace =*/ inplace, - }; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10); - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_POOL_2D: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt); - - const int32_t * opts = dst->op_params; - enum ggml_op_pool op = opts[0]; - - id pipeline = nil; - switch (src0t) { - case GGML_TYPE_F32: { - switch(op) { - case GGML_OP_POOL_AVG: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break; - case GGML_OP_POOL_MAX: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break; - default: GGML_ASSERT(false && "not implemented"); - } - } break; - default: GGML_ASSERT(false && "not implemented"); - } - - const int32_t k0 = opts[1]; - const int32_t k1 = opts[2]; - const int32_t s0 = opts[3]; - const int32_t s1 = opts[4]; - const int32_t p0 = opts[5]; - const int32_t p1 = opts[6]; - - const int64_t IH = src0->ne[1]; - const int64_t IW = src0->ne[0]; - - const int64_t N = dst->ne[3]; - const int64_t OC = dst->ne[2]; - const int64_t OH = dst->ne[1]; - const int64_t OW = dst->ne[0]; - - const int64_t parallel_elements = N * OC * OH * OW; - const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); - const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; - - ggml_metal_kargs_pool_2d args_pool_2d = { - /* .k0 = */ k0, - /* .k1 = */ k1, - /* .s0 = */ s0, - /* .s1 = */ s1, - /* .p0 = */ p0, - /* .p1 = */ p1, - /* .IH = */ IH, - /* .IW = */ IW, - /* .OH = */ OH, - /* .OW = */ OW, - /* .parallel_elements = */ parallel_elements - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; - } break; - case GGML_OP_ARGMAX: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - GGML_ASSERT(nb00 == ggml_type_size(src0->type)); - - const int64_t nrows = ggml_nrows(src0); - - int nth = 32; // SIMD width - while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - [encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - default: - { - GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); - } - } - - return n_fuse; -} - -static enum ggml_status ggml_metal_graph_compute( - ggml_backend_t backend, - struct ggml_cgraph * gf) { - struct ggml_backend_metal_context * ctx = backend->context; - struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; - - // number of nodes encoded by the main thread (empirically determined) - const int n_main = 128; - - // number of threads in addition to the main thread - const int n_cb = ctx->n_cb; - - // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them - // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread - // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes - // each thread creates it's own command buffer and enqueues the ops in parallel - // - // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2 - - @autoreleasepool { - ctx->gf = gf; - - ctx->n_nodes_0 = MIN(n_main, gf->n_nodes); - ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0; - - ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; - - const bool should_capture = ctx->capture_next_compute; - if (should_capture) { - ctx->capture_next_compute = false; - - if (!ctx->capture_started) { - // create capture scope - ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device]; - - MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; - descriptor.captureObject = ctx->capture_scope; - descriptor.destination = MTLCaptureDestinationGPUTraceDocument; - descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; - - NSError * error = nil; - if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { - GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); - } else { - [ctx->capture_scope beginScope]; - ctx->capture_started = true; - } - } - } - - // the main thread commits the first few commands immediately - // cmd_buf[n_cb] - { - id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; - ctx->cmd_bufs[n_cb].obj = cmd_buf; - - [cmd_buf enqueue]; - ctx->encode_async(n_cb); - } - - // prepare the rest of the command buffers asynchronously - // cmd_buf[0.. n_cb) - for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; - ctx->cmd_bufs[cb_idx].obj = cmd_buf; - - // always enqueue the first two command buffers - // enqueue all of the command buffers if we don't need to abort - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [cmd_buf enqueue]; - } - } - - dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); - - // wait for completion and check status of each command buffer - // needed to detect if the device ran out-of-memory for example (#1881) - { - id cmd_buf = ctx->cmd_bufs[n_cb].obj; - [cmd_buf waitUntilCompleted]; - - MTLCommandBufferStatus status = [cmd_buf status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); - if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); - } - - return GGML_STATUS_FAILED; - } - } - - for (int i = 0; i < n_cb; ++i) { - id cmd_buf = ctx->cmd_bufs[i].obj; - [cmd_buf waitUntilCompleted]; - - MTLCommandBufferStatus status = [cmd_buf status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); - if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); - } - - return GGML_STATUS_FAILED; - } - - id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); - if (!next_buffer) { - continue; - } - - const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); - if (next_queued) { - continue; - } - - if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { - GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); - return GGML_STATUS_ABORTED; - } - - [next_buffer commit]; - } - - if (!should_capture && ctx->capture_started) { - [ctx->capture_scope endScope]; - [[MTLCaptureManager sharedCaptureManager] stopCapture]; - } - } - - return GGML_STATUS_SUCCESS; -} - -//////////////////////////////////////////////////////////////////////////////// - -// backend interface - -static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) { - struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; - - for (int i = 0; i < ctx->n_buffers; i++) { - [ctx->buffers[i].metal release]; - } - - ggml_backend_metal_buffer_rset_free(ctx); - - if (ctx->owned) { -#if TARGET_OS_OSX - vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size); -#else - free(ctx->all_data); -#endif - } - - free(ctx); - free(buffer); -} - -static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) { - struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; - - return ctx->all_data; -} - -static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { - memset((char *)tensor->data + offset, value, size); - - GGML_UNUSED(buffer); -} - -static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - memcpy((char *)tensor->data + offset, data, size); - - GGML_UNUSED(buffer); -} - -static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - memcpy(data, (const char *)tensor->data + offset, size); - - GGML_UNUSED(buffer); -} - -static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { - if (ggml_backend_buffer_is_host(src->buffer)) { - memcpy(dst->data, src->data, ggml_nbytes(src)); - return true; - } - return false; - - GGML_UNUSED(buffer); -} - -static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; - - memset(ctx->all_data, value, ctx->all_size); -} - -static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = { - /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_get_base, - /* .init_tensor = */ NULL, - /* .memset_tensor = */ ggml_backend_metal_buffer_memset_tensor, - /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor, - /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor, - /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor, - /* .clear = */ ggml_backend_metal_buffer_clear, - /* .reset = */ NULL, -}; - -// default buffer type - -static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) { - return "Metal"; - - GGML_UNUSED(buft); -} - -static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { -#ifndef GGML_METAL_NDEBUG -#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) - if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_LOG_DEBUG("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\n", - __func__, - size_aligned / 1024.0 / 1024.0, - device.currentAllocatedSize / 1024.0 / 1024.0, - device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); - - if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) { - GGML_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); - } - } else { - GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", - __func__, - size_aligned / 1024.0 / 1024.0, - device.currentAllocatedSize / 1024.0 / 1024.0); - } -#endif -#endif - GGML_UNUSED(device); - GGML_UNUSED(size_aligned); -} - -static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); - - const size_t size_page = sysconf(_SC_PAGESIZE); - - size_t size_aligned = size; - if ((size_aligned % size_page) != 0) { - size_aligned += (size_page - (size_aligned % size_page)); - } - - struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context; - - GGML_ASSERT(ctx_dev->mtl_device != nil); - - id device = ctx_dev->mtl_device; - - ctx->all_data = ggml_metal_host_malloc(size_aligned); - ctx->all_size = size_aligned; - ctx->owned = true; - ctx->n_buffers = 1; - - if (ctx->all_data != NULL) { - ctx->buffers[0].data = ctx->all_data; - ctx->buffers[0].size = size; - ctx->buffers[0].metal = nil; - - if (size_aligned > 0) { - ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data - length:size_aligned - options:MTLResourceStorageModeShared - deallocator:nil]; - } - } - - if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) { - GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); - free(ctx); - return NULL; - } - - if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { - GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); - free(ctx); - return NULL; - } - - //ggml_backend_metal_log_allocated_size(device, size_aligned); - - return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); -} - -static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 32; - - GGML_UNUSED(buft); -} - -static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { - const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size; - - return max_size; -} - -static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) { - return true; - - GGML_UNUSED(buft); -} - -ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { - static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ ggml_backend_metal_buffer_type_is_host, - }, - /* .device = */ &g_ggml_backend_metal_device, - /* .context = */ NULL, - }; - - return &ggml_backend_buffer_type_metal; -} - -static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) { - return "Metal_Mapped"; - - GGML_UNUSED(buft); -} - -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) { - static struct ggml_backend_buffer_type ggml_backend_buffer_from_ptr_type_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_from_ptr_type_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ ggml_backend_metal_buffer_type_is_host, - }, - /* .device = */ &g_ggml_backend_metal_device, - /* .context = */ NULL, - }; - - return &ggml_backend_buffer_from_ptr_type_metal; -} - -// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr -ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) { - struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); - - ctx->all_data = data; - ctx->all_size = size; - ctx->owned = false; - ctx->n_buffers = 0; - - const size_t size_page = sysconf(_SC_PAGESIZE); - - // page-align the data ptr - { - const uintptr_t offs = (uintptr_t) data % size_page; - data = (void *) ((char *) data - offs); - size += offs; - } - - size_t size_aligned = size; - if ((size_aligned % size_page) != 0) { - size_aligned += (size_page - (size_aligned % size_page)); - } - - struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main; - - GGML_ASSERT(ctx_dev->mtl_device != nil); - - id device = ctx_dev->mtl_device; - - // the buffer fits into the max buffer size allowed by the device - if (size_aligned <= device.maxBufferLength) { - ctx->buffers[ctx->n_buffers].data = data; - ctx->buffers[ctx->n_buffers].size = size; - ctx->buffers[ctx->n_buffers].metal = nil; - - if (size_aligned > 0) { - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); - return false; - } - } - - ggml_backend_metal_log_allocated_size(device, size_aligned); - - ++ctx->n_buffers; - } else { - // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into - // one of the views - const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case - const size_t size_step = device.maxBufferLength - size_ovlp; - const size_t size_view = device.maxBufferLength; - - for (size_t i = 0; i < size; i += size_step) { - const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); - - ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i); - ctx->buffers[ctx->n_buffers].size = size_step_aligned; - ctx->buffers[ctx->n_buffers].metal = nil; - - if (size_step_aligned > 0) { - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); - return false; - } - } - - ggml_backend_metal_log_allocated_size(device, size_step_aligned); - - if (i + size_step < size) { - GGML_LOG_INFO("\n"); - } - - ++ctx->n_buffers; - } - } - - if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { - GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); - free(ctx); - return NULL; - } - - return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); -} - -// backend - -static const char * ggml_backend_metal_name(ggml_backend_t backend) { - return "Metal"; - - GGML_UNUSED(backend); -} - -static void ggml_backend_metal_free(ggml_backend_t backend) { - struct ggml_backend_metal_context * ctx = backend->context; - - ggml_metal_free(ctx); - - free(backend); -} - -static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - return ggml_metal_graph_compute(backend, cgraph); -} - -static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { - GGML_ASSERT(ggml_backend_is_metal(backend)); - - struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; - - if (ctx->n_cb != n_cb) { - ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); - - if (ctx->n_cb > 2) { - GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb); - } - } - - if (ctx->encode_async) { - Block_release(ctx->encode_async); - } - - ctx->encode_async = Block_copy(^(size_t iter) { - const int cb_idx = iter; - const int n_cb_l = ctx->n_cb; - - const int n_nodes_0 = ctx->n_nodes_0; - const int n_nodes_1 = ctx->n_nodes_1; - - const int n_nodes_per_cb = ctx->n_nodes_per_cb; - - id cmd_buf = ctx->cmd_bufs[cb_idx].obj; - - id encoder = [cmd_buf computeCommandEncoder]; - - int node_start = 0; - int node_end = n_nodes_0; - - if (cb_idx < n_cb_l) { - node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); - node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); - } - - const bool should_capture = ctx->capture_next_compute; - - struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; - ggml_metal_mem_pool_reset(mem_pool); - - for (int idx = node_start; idx < node_end;) { - if (should_capture) { - [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; - } - - const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool); - if (idx + res > node_end) { - GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s", - "https://github.com/ggml-org/llama.cpp/pull/14849"); - } - - if (should_capture) { - [encoder popDebugGroup]; - } - - if (res == 0) { - break; - } - - idx += res; - } - - [encoder endEncoding]; - - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [cmd_buf commit]; - } - }); -} - -static struct ggml_backend_i ggml_backend_metal_i = { - /* .get_name = */ ggml_backend_metal_name, - /* .free = */ ggml_backend_metal_free, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_metal_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, -}; - -static ggml_guid_t ggml_backend_metal_guid(void) { - static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 }; - return &guid; -} - -// TODO: remove in the future -ggml_backend_t ggml_backend_metal_init(void) { - ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0); - - struct ggml_backend_metal_context * ctx = ggml_metal_init(dev); - if (ctx == NULL) { - GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); - return NULL; - } - - ggml_backend_t backend = malloc(sizeof(struct ggml_backend)); - - *backend = (struct ggml_backend) { - /* .guid = */ ggml_backend_metal_guid(), - /* .interface = */ ggml_backend_metal_i, - /* .device = */ dev, - /* .context = */ ctx, - }; - - ggml_backend_metal_set_n_cb(backend, 1); - - return backend; -} - -bool ggml_backend_is_metal(ggml_backend_t backend) { - return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid()); -} - -void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) { - GGML_ASSERT(ggml_backend_is_metal(backend)); - - struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; - - ctx->abort_callback = abort_callback; - ctx->abort_callback_data = user_data; -} - -bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) { - GGML_ASSERT(ggml_backend_is_metal(backend)); - - struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; - - GGML_ASSERT(ctx_dev->mtl_device != nil); - - return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; -} - -void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { - GGML_ASSERT(ggml_backend_is_metal(backend)); - - struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; - ctx->capture_next_compute = true; -} - -// backend device - -static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) { - return "Metal"; - - GGML_UNUSED(dev); -} - -static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) { - struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; - - return ctx_dev->name; -} - -static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - if (@available(macOS 10.12, iOS 16.0, *)) { - struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; - id device = ctx_dev->mtl_device; - - *total = device.recommendedMaxWorkingSetSize; - *free = *total - device.currentAllocatedSize; - } else { - *free = 1; - *total = 1; - } -} - -static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) { - return GGML_BACKEND_DEVICE_TYPE_GPU; - - GGML_UNUSED(dev); -} - -static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { - props->name = ggml_backend_metal_device_get_name(dev); - props->description = ggml_backend_metal_device_get_description(dev); - props->id = "0"; - props->type = ggml_backend_metal_device_get_type(dev); - ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); - props->caps = (struct ggml_backend_dev_caps) { - /* .async = */ false, - /* .host_buffer = */ false, - /* .buffer_from_host_ptr = */ true, - /* .events = */ false, - }; -} - -static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) { - struct ggml_backend_metal_context * ctx = ggml_metal_init(dev); - if (ctx == NULL) { - GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); - return NULL; - } - - ggml_backend_t backend = malloc(sizeof(struct ggml_backend)); - - *backend = (struct ggml_backend) { - /* .guid = */ ggml_backend_metal_guid(), - /* .interface = */ ggml_backend_metal_i, - /* .device = */ dev, - /* .context = */ ctx, - }; - - ggml_backend_metal_set_n_cb(backend, 1); - - return backend; - - GGML_UNUSED(params); -} - -static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) { - return ggml_backend_metal_buffer_type(); - - GGML_UNUSED(dev); -} - -static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { - struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); - - ctx->all_data = ptr; - ctx->all_size = size; - ctx->owned = false; - ctx->n_buffers = 0; - - const size_t size_page = sysconf(_SC_PAGESIZE); - - // page-align the data ptr - { - const uintptr_t offs = (uintptr_t) ptr % size_page; - ptr = (void *) ((char *) ptr - offs); - size += offs; - } - - size_t size_aligned = size; - if ((size_aligned % size_page) != 0) { - size_aligned += (size_page - (size_aligned % size_page)); - } - - struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; - - GGML_ASSERT(ctx_dev->mtl_device != nil); - - id device = ctx_dev->mtl_device; - - // the buffer fits into the max buffer size allowed by the device - if (size_aligned <= device.maxBufferLength) { - ctx->buffers[ctx->n_buffers].data = ptr; - ctx->buffers[ctx->n_buffers].size = size; - ctx->buffers[ctx->n_buffers].metal = nil; - - if (size_aligned > 0) { - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); - return false; - } - } - - ggml_backend_metal_log_allocated_size(device, size_aligned); - - ++ctx->n_buffers; - } else { - // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into - // one of the views - const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case - const size_t size_step = device.maxBufferLength - size_ovlp; - const size_t size_view = device.maxBufferLength; - - for (size_t i = 0; i < size; i += size_step) { - const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); - - ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) ptr + i); - ctx->buffers[ctx->n_buffers].size = size_step_aligned; - ctx->buffers[ctx->n_buffers].metal = nil; - - if (size_step_aligned > 0) { - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); - return false; - } - } - - ggml_backend_metal_log_allocated_size(device, size_step_aligned); - - if (i + size_step < size) { - GGML_LOG_INFO("\n"); - } - - ++ctx->n_buffers; - } - } - - if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { - GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); - free(ctx); - return NULL; - } - - return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); -} - -static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { - struct ggml_backend_metal_device_context * ctx_dev = dev->context; - - return ggml_metal_supports_op(ctx_dev, op); -} - -static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { - return - buft->iface.get_name == ggml_backend_metal_buffer_type_get_name || - buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name; - - GGML_UNUSED(dev); -} - -static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { - return false; - - GGML_UNUSED(dev); - GGML_UNUSED(op); -} - -static struct ggml_backend_device_i ggml_backend_metal_device_i = { - /* .get_name = */ ggml_backend_metal_device_get_name, - /* .get_description = */ ggml_backend_metal_device_get_description, - /* .get_memory = */ ggml_backend_metal_device_get_memory, - /* .get_type = */ ggml_backend_metal_device_get_type, - /* .get_props = */ ggml_backend_metal_device_get_props, - /* .init_backend = */ ggml_backend_metal_device_init, - /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type, - /* .get_host_buffer_type = */ NULL, - /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr, - /* .supports_op = */ ggml_backend_metal_device_supports_op, - /* .supports_buft = */ ggml_backend_metal_device_supports_buft, - /* .offload_op = */ ggml_backend_metal_device_offload_op, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, -}; - -// backend registry - -static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) { - return "Metal"; - - GGML_UNUSED(reg); -} - -static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) { - return 1; - - GGML_UNUSED(reg); -} - -static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) { - GGML_ASSERT(index == 0); - - return &g_ggml_backend_metal_device; - - GGML_UNUSED(reg); - GGML_UNUSED(index); -} - -static struct ggml_backend_feature g_ggml_backend_metal_features[] = { -#if defined(GGML_METAL_EMBED_LIBRARY) - { "EMBED_LIBRARY", "1" }, -#endif -#if defined(GGML_METAL_USE_BF16) - { "BF16", "1" }, -#endif - { nil, nil }, -}; - -static struct ggml_backend_feature * ggml_backend_metal_get_features(ggml_backend_reg_t reg) { - return g_ggml_backend_metal_features; - - GGML_UNUSED(reg); -} - -static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) { - if (strcmp(name, "ggml_backend_get_features") == 0) { - return (void *)ggml_backend_metal_get_features; - } - - return NULL; - - GGML_UNUSED(reg); -} -static struct ggml_backend_reg_i ggml_backend_metal_reg_i = { - /* .get_name = */ ggml_backend_metal_reg_get_name, - /* .device_count = */ ggml_backend_metal_reg_device_count, - /* .device_get = */ ggml_backend_metal_reg_device_get, - /* .get_proc_address = */ ggml_backend_metal_get_proc_address, -}; - -// called upon program exit -static void ggml_metal_cleanup(void) { - ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main); -} - -// TODO: make thread-safe -ggml_backend_reg_t ggml_backend_metal_reg(void) { - ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main); - - // register cleanup callback - // TODO: not ideal, but not sure if there is a better way to do this in Objective-C - atexit(ggml_metal_cleanup); - - { - g_ggml_backend_metal_reg = (struct ggml_backend_reg) { - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_metal_reg_i, - /* .context = */ NULL, - }; - - g_ggml_backend_metal_device = (struct ggml_backend_device) { - /* .iface = */ ggml_backend_metal_device_i, - /* .reg = */ &g_ggml_backend_metal_reg, - /* .context = */ &g_ggml_ctx_dev_main, - }; - } - - return &g_ggml_backend_metal_reg; -} - -GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg) diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index b35a3bbd..375a0c7f 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -15,6 +15,10 @@ using namespace metal; #define MIN(x, y) ((x) < (y) ? (x) : (y)) #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } +#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1)) + +#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x) + #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf @@ -23,12 +27,13 @@ using namespace metal; // .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/ggml-metal.metal // .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal // -#if __METAL_VERSION__ < 310 && defined(GGML_METAL_USE_BF16) -#undef GGML_METAL_USE_BF16 +#if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16) +#undef GGML_METAL_HAS_BF16 #endif -#if defined(GGML_METAL_USE_BF16) +#if defined(GGML_METAL_HAS_BF16) typedef matrix bfloat4x4; +typedef matrix bfloat2x4; #endif constexpr constant static float kvalues_iq4nl_f[16] = { @@ -62,12 +67,21 @@ static inline float e8m0_to_fp32(uint8_t x) { return as_type(bits); } +static inline float dot(float x, float y) { + return x*y; +} + // NOTE: this is not dequantizing - we are simply fitting the template template void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { reg = (type4x4)(*src); } +template +void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) { + reg = (type4)(*src); +} + template void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { reg = (type4x4)(*src); @@ -78,7 +92,7 @@ void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { reg = (type4)(*(src)); } -#if defined(GGML_METAL_USE_BF16) +#if defined(GGML_METAL_HAS_BF16) template void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { reg = (type4x4)(*src); @@ -919,7 +933,7 @@ kernel void kernel_add_fuse_impl( typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; -template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; +template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; @@ -928,7 +942,7 @@ template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_ template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; -kernel void kernel_sub( +kernel void kernel_sub_fuse_1( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -954,7 +968,7 @@ kernel void kernel_sub( } } -kernel void kernel_mul( +kernel void kernel_mul_fuse_1( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -974,13 +988,20 @@ kernel void kernel_mul( device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + if (args.ne10 == 1) { + const float x = *((device float *)(src1_ptr)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; + } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + } } } -kernel void kernel_div( +kernel void kernel_div_fuse_1( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -1000,9 +1021,16 @@ kernel void kernel_div( device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); + if (args.ne10 == 1) { + const float x = 1.0f / *((device float *)(src1_ptr)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; + } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); + } } } @@ -1073,23 +1101,17 @@ kernel void kernel_add_row_c4_fuse_impl( device const char * src1, device char * dst, uint tpig[[thread_position_in_grid]]) { - const uint nb = args.ne00/4; const uint i = tpig % nb; device const float4 * src0_row = (device const float4 *) (src0); device float4 * dst_row = (device float4 *) (dst); - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - float4 res = src0_row[tpig]; #pragma unroll(F) for (short j = 0; j < F; ++j) { - res += src1_row[j][i]; + res += ((device const float4 *) (src1 + args.o1[j]))[i]; } dst_row[tpig] = res; @@ -1097,7 +1119,7 @@ kernel void kernel_add_row_c4_fuse_impl( typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t; -template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; +template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>; template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>; template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>; @@ -1137,7 +1159,7 @@ kernel void kernel_sub_row_c4_fuse_impl( typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t; -template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; +template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; template kernel void kernel_mul_row_c4_fuse_impl( @@ -1170,7 +1192,7 @@ kernel void kernel_mul_row_c4_fuse_impl( typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t; -template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; +template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; template kernel void kernel_div_row_c4_fuse_impl( @@ -1203,55 +1225,80 @@ kernel void kernel_div_row_c4_fuse_impl( typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t; -template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; +template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; -kernel void kernel_scale( +kernel void kernel_scale_f32( + constant ggml_metal_kargs_scale & args, device const float * src0, device float * dst, - constant float & scale, - constant float & bias, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale + bias; + dst[tpig] = src0[tpig] * args.scale + args.bias; } -kernel void kernel_scale_4( +kernel void kernel_scale_f32_4( + constant ggml_metal_kargs_scale & args, device const float4 * src0, device float4 * dst, - constant float & scale, - constant float & bias, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale + bias; + dst[tpig] = src0[tpig] * args.scale + args.bias; } -kernel void kernel_clamp( +kernel void kernel_clamp_f32( + constant ggml_metal_kargs_clamp & args, device const float * src0, device float * dst, - constant float & min, - constant float & max, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); + dst[tpig] = clamp(src0[tpig], args.min, args.max); } -kernel void kernel_relu( +kernel void kernel_clamp_f32_4( + constant ggml_metal_kargs_clamp & args, + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = clamp(src0[tpig], args.min, args.max); +} + +kernel void kernel_relu_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = max(0.0f, src0[tpig]); } -kernel void kernel_sigmoid( +kernel void kernel_relu_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]); +} + +kernel void kernel_sigmoid_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); } -kernel void kernel_tanh( +kernel void kernel_sigmoid_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); +} + +kernel void kernel_tanh_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = precise::tanh(x); + dst[tpig] = precise::tanh(src0[tpig]); +} + +kernel void kernel_tanh_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = precise::tanh(src0[tpig]); } constant float GELU_COEF_A = 0.044715f; @@ -1259,7 +1306,7 @@ constant float GELU_QUICK_COEF = -1.702f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; constant float SQRT_2_INV = 0.70710678118654752440084436210484f; -kernel void kernel_gelu( +kernel void kernel_gelu_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { @@ -1268,7 +1315,7 @@ kernel void kernel_gelu( dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } -kernel void kernel_gelu_4( +kernel void kernel_gelu_f32_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -1281,7 +1328,7 @@ kernel void kernel_gelu_4( dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } -kernel void kernel_gelu_quick( +kernel void kernel_gelu_quick_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { @@ -1290,7 +1337,7 @@ kernel void kernel_gelu_quick( dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); } -kernel void kernel_gelu_quick_4( +kernel void kernel_gelu_quick_f32_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -1317,7 +1364,7 @@ T erf_approx(T x) { return sign_x * y; } -kernel void kernel_gelu_erf( +kernel void kernel_gelu_erf_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { @@ -1326,7 +1373,7 @@ kernel void kernel_gelu_erf( dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); } -kernel void kernel_gelu_erf_4( +kernel void kernel_gelu_erf_f32_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -1335,7 +1382,7 @@ kernel void kernel_gelu_erf_4( dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); } -kernel void kernel_silu( +kernel void kernel_silu_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { @@ -1343,7 +1390,7 @@ kernel void kernel_silu( dst[tpig] = x / (1.0f + exp(-x)); } -kernel void kernel_silu_4( +kernel void kernel_silu_f32_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -1351,99 +1398,202 @@ kernel void kernel_silu_4( dst[tpig] = x / (1.0f + exp(-x)); } -kernel void kernel_elu( +kernel void kernel_elu_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; + const float x = src0[tpig]; dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f); } -kernel void kernel_sqr( +kernel void kernel_elu_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const float4 x = src0[tpig]; + dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f); + dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f); + dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f); + dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f); +} + +kernel void kernel_sqr_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * src0[tpig]; } -kernel void kernel_sqrt( +kernel void kernel_sqr_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src0[tpig]; +} + +kernel void kernel_sqrt_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = sqrt(src0[tpig]); } -kernel void kernel_sin( +kernel void kernel_sqrt_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sqrt(src0[tpig]); +} + +kernel void kernel_sin_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = sin(src0[tpig]); } -kernel void kernel_cos( +kernel void kernel_sin_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sin(src0[tpig]); +} + +kernel void kernel_cos_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = cos(src0[tpig]); } -kernel void kernel_neg( +kernel void kernel_cos_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = cos(src0[tpig]); +} + +kernel void kernel_log_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = log(src0[tpig]); +} + +kernel void kernel_log_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = log(src0[tpig]); +} + +kernel void kernel_neg_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = -src0[tpig]; } -kernel void kernel_abs( +kernel void kernel_neg_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = -src0[tpig]; +} + +kernel void kernel_abs_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = fabs(src0[tpig]); } -kernel void kernel_sgn( - device const float * src0, - device float * dst, +kernel void kernel_abs_f32_4( + device const float4 * src0, + device float4 * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f); + dst[tpig] = fabs(src0[tpig]); } -kernel void kernel_step( +kernel void kernel_sgn_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f; + dst[tpig] = sign(src0[tpig]); } -kernel void kernel_hardswish( +kernel void kernel_sgn_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sign(src0[tpig]); +} + +kernel void kernel_step_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; + dst[tpig] = step(0.0f, src0[tpig]); +} + +kernel void kernel_step_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = step(0.0f, src0[tpig]); +} + +kernel void kernel_hardswish_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + const float x = src0[tpig]; dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); } -kernel void kernel_hardsigmoid( +kernel void kernel_hardswish_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const float4 x = src0[tpig]; + dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); +} + +kernel void kernel_hardsigmoid_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; + const float x = src0[tpig]; dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); } -kernel void kernel_exp( +kernel void kernel_hardsigmoid_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const float4 x = src0[tpig]; + dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); +} + +kernel void kernel_exp_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = exp(src0[tpig]); } -kernel void kernel_reglu( +kernel void kernel_exp_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]); +} + +kernel void kernel_reglu_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -1459,11 +1609,11 @@ kernel void kernel_reglu( } } -kernel void kernel_geglu( +kernel void kernel_geglu_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -1481,11 +1631,11 @@ kernel void kernel_geglu( } } -kernel void kernel_swiglu( +kernel void kernel_swiglu_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -1503,11 +1653,11 @@ kernel void kernel_swiglu( } } -kernel void kernel_swiglu_oai( +kernel void kernel_swiglu_oai_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -1529,11 +1679,11 @@ kernel void kernel_swiglu_oai( } } -kernel void kernel_geglu_erf( +kernel void kernel_geglu_erf_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -1551,11 +1701,11 @@ kernel void kernel_geglu_erf( } } -kernel void kernel_geglu_quick( +kernel void kernel_geglu_quick_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -1573,6 +1723,24 @@ kernel void kernel_geglu_quick( } } +kernel void kernel_op_sum_f32( + constant ggml_metal_kargs_sum & args, + device const float * src0, + device float * dst, + ushort tiitg[[thread_index_in_threadgroup]]) { + + if (tiitg != 0) { + return; + } + + float acc = 0.0f; + for (ulong i = 0; i < args.np; ++i) { + acc += src0[i]; + } + + dst[0] = acc; +} + template kernel void kernel_sum_rows( constant ggml_metal_kargs_sum_rows & args, @@ -1625,16 +1793,16 @@ kernel void kernel_sum_rows( typedef decltype(kernel_sum_rows) kernel_sum_rows_t; -template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows; -template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows; +template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; +template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; template kernel void kernel_soft_max( + constant ggml_metal_kargs_soft_max & args, device const char * src0, device const char * src1, device const char * src2, device char * dst, - constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -1736,11 +1904,11 @@ kernel void kernel_soft_max( template kernel void kernel_soft_max_4( + constant ggml_metal_kargs_soft_max & args, device const char * src0, device const char * src1, device const char * src2, device char * dst, - constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -1850,53 +2018,12 @@ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kerne template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; -kernel void kernel_diag_mask_inf( - device const float * src0, - device float * dst, - constant ggml_metal_kargs_diag_mask_inf & args, - uint3 tpig[[thread_position_in_grid]]) { - const int64_t i02 = tpig[2]; - const int64_t i01 = tpig[1]; - const int64_t i00 = tpig[0]; - - if (i00 > args.n_past + i01) { - dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY; - } else { - dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00]; - } -} - -kernel void kernel_diag_mask_inf_8( - device const float4 * src0, - device float4 * dst, - constant ggml_metal_kargs_diag_mask_inf & args, - uint3 tpig[[thread_position_in_grid]]) { - - const int64_t i = 2*tpig[0]; - - dst[i+0] = src0[i+0]; - dst[i+1] = src0[i+1]; - int64_t i4 = 4*i; - const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01; - const int64_t i01 = i4/(args.ne00); i4 -= i01*args.ne00; - const int64_t i00 = i4; - for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= args.n_past + i01) { - break; - } - dst[i+1][k] = -INFINITY; - if (i00 + k > args.n_past + i01) { - dst[i][k] = -INFINITY; - } - } -} - // ref: ggml.c:ggml_compute_forward_ssm_conv_f32 -kernel void kernel_ssm_conv_f32( +kernel void kernel_ssm_conv_f32_f32( + constant ggml_metal_kargs_ssm_conv & args, device const void * src0, device const void * src1, device float * dst, - constant ggml_metal_kargs_ssm_conv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -1923,123 +2050,40 @@ kernel void kernel_ssm_conv_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part -kernel void kernel_ssm_scan_f32( - device const void * src0, - device const void * src1, - device const void * src2, - device const void * src3, - device const void * src4, - device const void * src5, - device const void * src6, - device float * dst, - threadgroup float * shared [[threadgroup(0)]], - constant ggml_metal_kargs_ssm_scan & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { +kernel void kernel_ssm_conv_f32_f32_4( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; - const int64_t i0 = tpitg.x; - const int64_t i1 = 0; - const int64_t ir = tgpig.x; // current head - const int64_t i3 = tgpig.y; // current seq + const int64_t nc = args.ne10; + //const int64_t ncs = args.ne00; + //const int64_t nr = args.ne01; + //const int64_t n_t = args.ne1; + //const int64_t n_s = args.ne2; - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); + device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11); + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; + float sumf = 0.0f; - const int64_t s_off = args.s_off; - - device const int32_t * ids = (device const int32_t *) src6; - - device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; - float s0 = s0_buff[i]; - float s = s_buff[i]; - - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); - - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} - - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; - - const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); - s = state; - - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - if (sgptg > 1) { - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. - threadgroup_barrier(mem_flags::mem_threadgroup); - - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); - if (tiisg == 0) { - y[0] = sumf; - } - } - } else if (tiisg == 0) { - y[0] = sumf; - } - - // recurse - s0 = s; + for (int64_t i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); } - // Assign the final state to the output buffer - s_buff[i] = s; + x[0] = sumf; } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -kernel void kernel_ssm_scan_f32_group( +kernel void kernel_ssm_scan_f32( + constant ggml_metal_kargs_ssm_scan & args, device const void * src0, device const void * src1, device const void * src2, @@ -2049,103 +2093,88 @@ kernel void kernel_ssm_scan_f32_group( device const void * src6, device float * dst, threadgroup float * shared [[threadgroup(0)]], - constant ggml_metal_kargs_ssm_scan & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + constexpr short NW = N_SIMDWIDTH; - const int64_t i0 = tpitg.x; - const int64_t i1 = tgpig.x; - const int64_t ir = tgpig.y; // current head - const int64_t i3 = tgpig.z; // current seq + shared[tpitg.x] = 0.0f; - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); + const int32_t i0 = tpitg.x; + const int32_t i1 = tgpig.x; + const int32_t ir = tgpig.y; // current head + const int32_t i3 = tgpig.z; // current seq - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; + const int32_t nc = args.d_state; + const int32_t nr = args.d_inner; + const int32_t nh = args.n_head; + const int32_t ng = args.n_group; + const int32_t n_t = args.n_seq_tokens; - const int64_t s_off = args.s_off; + const int32_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; + + const int32_t i = i0 + i1*nc; + const int32_t g = ir / (nh / ng); // repeat_interleave + float s0 = s0_buff[i]; - float s = s_buff[i]; + float s = 0.0f; - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh} - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} + const float A0 = A[i0%args.ne30]; - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; - const float dA = exp(dt_soft_plus * A[0]); + device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns} + device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns} - const float state = (s0 * dA) + (B[i0] * x_dt); - s = state; + device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns} - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. + for (int i2 = 0; i2 < n_t; i2 += sgptg) { threadgroup_barrier(mem_flags::mem_threadgroup); - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float dt0 = dt[0]; + const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; + const float x_dt = x[0] * dtsp; + const float dA = exp(dtsp * A0); + + s = (s0 * dA) + (B[i0] * x_dt); + + const float sumf = simd_sum(s * C[i0]); + if (tiisg == 0) { - y[0] = sumf; + shared[t*NW + sgitg] = sumf; } + + // recurse + s0 = s; + + x += args.ns12; + dt += args.ns21; + B += args.ns42; + C += args.ns52; } - // recurse - s0 = s; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float sumf = simd_sum(shared[sgitg*NW + tiisg]); + + if (tiisg == 0 && i2 + sgitg < n_t) { + y[sgitg*nh*nr] = sumf; + } + + y += sgptg*nh*nr; } - // Assign the final state to the output buffer s_buff[i] = s; } @@ -2327,24 +2356,22 @@ kernel void kernel_rwkv_wkv7_f32( } } -kernel void kernel_argmax( - device const void * x, - device int32_t * dst, - constant int64_t & ncols, - constant uint64_t & nb01, - threadgroup float * shared_maxval [[threadgroup(0)]], - threadgroup int32_t * shared_argmax [[threadgroup(1)]], +kernel void kernel_argmax_f32( + constant ggml_metal_kargs_argmax & args, + device const char * src0, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01); + device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01); float lmax = -INFINITY; int32_t larg = -1; - for (int i00 = tpitg; i00 < ncols; i00 += ntg) { + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { if (x_row[i00] > lmax) { lmax = x_row[i00]; larg = i00; @@ -2355,6 +2382,11 @@ kernel void kernel_argmax( float max_val = simd_max(lmax); int32_t arg_val = simd_max(select(-1, larg, lmax == max_val)); + device int32_t * dst_i32 = (device int32_t *) dst; + + threadgroup float * shared_maxval = (threadgroup float *) shmem; + threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH; + if (ntg > N_SIMDWIDTH) { if (sgitg == 0) { shared_maxval[tiisg] = -INFINITY; @@ -2376,38 +2408,51 @@ kernel void kernel_argmax( float max_val_reduced = simd_max(max_val); int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced)); - dst[tgpig] = arg_val_reduced; + dst_i32[tgpig] = arg_val_reduced; return; } - dst[tgpig] = arg_val; + dst_i32[tgpig] = arg_val; } -kernel void kernel_norm( +// F == 1 : norm (no fuse) +// F == 2 : norm + mul +// F == 3 : norm + mul + add +template +kernel void kernel_norm_fuse_impl( constant ggml_metal_kargs_norm & args, device const char * src0, + device const char * src1_0, + device const char * src1_1, device char * dst, threadgroup float * shmem_f32 [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - ushort tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { if (sgitg == 0) { shmem_f32[tiisg] = 0.0f; } - device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + const int i01 = tgpig.x; + const int i02 = tgpig.y; + const int i03 = tgpig.z; - float4 sumf4(0.0f); + device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); + + device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); + device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); + + T sumft(0.0f); float sumf = 0.0f; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { - sumf4 += x[i00]; + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { + sumft += x[i00]; } - sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3]; + sumf = dot(sumft, T(1.0f)); sumf = simd_sum(sumf); threadgroup_barrier(mem_flags::mem_threadgroup); @@ -2423,10 +2468,10 @@ kernel void kernel_norm( const float mean = sumf/args.ne00; - device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); sumf = 0.0f; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { y[i00] = x[i00] - mean; sumf += dot(y[i00], y[i00]); } @@ -2446,17 +2491,35 @@ kernel void kernel_norm( const float variance = sumf/args.ne00; const float scale = 1.0f/sqrt(variance + args.eps); - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { - y[i00] = y[i00] * scale; + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { + if (F == 1) { + y[i00] = (y[i00]*scale); + } + if (F == 2) { + y[i00] = (y[i00]*scale)*f0[i00]; + } + if (F == 3) { + y[i00] = (y[i00]*scale)*f0[i00] + f1[i00]; + } } } +typedef decltype(kernel_norm_fuse_impl) kernel_norm_fuse_t; + +template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; +template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; +template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; + +template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; +template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; +template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl; + // F == 1 : rms_norm (no fuse) // F == 2 : rms_norm + mul // F == 3 : rms_norm + mul + add -template +template kernel void kernel_rms_norm_fuse_impl( - constant ggml_metal_kargs_rms_norm & args, + constant ggml_metal_kargs_norm & args, device const char * src0, device const char * src1_0, device const char * src1_1, @@ -2475,15 +2538,15 @@ kernel void kernel_rms_norm_fuse_impl( const int i02 = tgpig.y; const int i03 = tgpig.z; - device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); + device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); - device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); - device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); + device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); + device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); float sumf = 0.0f; // parallel sum - for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { sumf += dot(x[i00], x[i00]); } sumf = simd_sum(sumf); @@ -2502,8 +2565,8 @@ kernel void kernel_rms_norm_fuse_impl( const float mean = sumf/args.ne00; const float scale = 1.0f/sqrt(mean + args.eps); - device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); - for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { + device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); + for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { if (F == 1) { y[i00] = (x[i00]*scale); } @@ -2516,13 +2579,17 @@ kernel void kernel_rms_norm_fuse_impl( } } -typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t; +typedef decltype(kernel_rms_norm_fuse_impl) kernel_rms_norm_fuse_t; -template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>; -template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>; -template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>; +template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; +template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; +template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; -kernel void kernel_l2_norm( +template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; +template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; +template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; + +kernel void kernel_l2_norm_f32( constant ggml_metal_kargs_l2_norm & args, device const char * src0, device char * dst, @@ -2565,10 +2632,10 @@ kernel void kernel_l2_norm( } } -kernel void kernel_group_norm( +kernel void kernel_group_norm_f32( + constant ggml_metal_kargs_group_norm & args, device const float * src0, device float * dst, - constant ggml_metal_kargs_group_norm & args, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], @@ -2576,7 +2643,7 @@ kernel void kernel_group_norm( uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { const int64_t ne = args.ne00*args.ne01*args.ne02; - const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups); + const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp); int start = tgpig * gs; int end = start + gs; @@ -2734,7 +2801,52 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } -template +template +static inline void helper_mv_reduce_and_write( + device float * dst_f32, + float sumf[NR0], + const int r0, + const int ne01, + ushort tiisg, + ushort sgitg, + threadgroup char * shmem) { + constexpr short NW = N_SIMDWIDTH; + + threadgroup float * shmem_f32[NR0]; + + for (short row = 0; row < NR0; ++row) { + shmem_f32[row] = (threadgroup float *) shmem + NW*row; + + if (sgitg == 0) { + shmem_f32[row][tiisg] = 0.0f; + } + + sumf[row] = simd_sum(sumf[row]); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short row = 0; row < NR0; ++row) { + if (tiisg == 0) { + shmem_f32[row][sgitg] = sumf[row]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short row = 0; row < NR0 && r0 + row < ne01; ++row) { + float tot = simd_sum(shmem_f32[row][tiisg]); + + if (tiisg == 0 && sgitg == 0) { + dst_f32[r0 + row] = tot; + } + } +} + +constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]]; +constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]]; + +template void mul_vec_q_n_f32_impl( args_t args, device const char * src0, @@ -2744,45 +2856,54 @@ void mul_vec_q_n_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + constexpr short NW = N_SIMDWIDTH; + constexpr short NQ = 16; + const int nb = args.ne00/QK4_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * nsg + sgitg) * nr0; + const int r0 = (tgpig.x*NSG + sgitg)*NR0; + //const int r0 = tgpig.x*NR0; + const int r1 = tgpig.y; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q_type * ax[nr0]; - for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + device const block_q_type * ax[NR0]; + FOR_UNROLL (int row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } + float sumf[NR0] = {0.f}; + + const short ix = (tiisg/(NW/NQ)); + const short il = (tiisg%(NW/NQ))*8; + + //const int ib0 = sgitg*NQ + ix; + const int ib0 = ix; + float yl[16]; // src1 vector cache - float sumf[nr0] = {0.f}; - const short ix = (tiisg/2); - const short il = (tiisg%2)*8; - - device const float * yb = y + ix*QK4_0 + il; + //device const float * yb = y + ix*QK4_0 + il; + device const float * yb = y + ib0*QK4_0 + il; // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { + //for (int ib = ib0; ib < nb; ib += NSG*NQ) { + for (int ib = ib0; ib < nb; ib += NQ) { float sumy[2] = { 0.f, 0.f }; -#pragma unroll - for (short i = 0; i < 8; i += 2) { + FOR_UNROLL (short i = 0; i < 8; i += 2) { sumy[0] += yb[i + 0] + yb[i + 1]; yl[i + 0] = yb[i + 0]; yl[i + 1] = yb[i + 1]/256.f; @@ -2792,21 +2913,23 @@ void mul_vec_q_n_f32_impl( yl[i + 9] = yb[i + 17]/4096.f; } -#pragma unroll - for (short row = 0; row < nr0; row++) { + FOR_UNROLL (short row = 0; row < NR0; row++) { sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); } yb += QK4_0 * 16; + //yb += NSG*NQ*QK4_0; } device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; - for (int row = 0; row < nr0; ++row) { + //helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); + + for (int row = 0; row < NR0; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < args.ne01) { - dst_f32[first_row + row] = tot; + if (tiisg == 0 && r0 + row < args.ne01) { + dst_f32[r0 + row] = tot; } } } @@ -2816,10 +2939,11 @@ kernel void kernel_mul_mv_q4_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -2827,10 +2951,11 @@ kernel void kernel_mul_mv_q4_1_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -2838,10 +2963,11 @@ kernel void kernel_mul_mv_q5_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -2849,15 +2975,14 @@ kernel void kernel_mul_mv_q5_1_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -#define NB_Q8_0 8 - -template +template void kernel_mul_mv_q8_0_f32_impl( args_t args, device const char * src0, @@ -2867,66 +2992,68 @@ void kernel_mul_mv_q8_0_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + constexpr short NW = N_SIMDWIDTH; + constexpr short NQ = 8; + const int nb = args.ne00/QK8_0; - const int r0 = tgpig.x; + const int r0 = tgpig.x*NR0; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; - const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q8_0 * ax[nr0]; - for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + device const block_q8_0 * ax[NR0]; + FOR_UNROLL (short row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } - float yl[NB_Q8_0]; - float sumf[nr0] = { 0.f }; + float sumf[NR0] = { 0.f }; - const short ix = tiisg/4; - const short il = tiisg%4; + const short ix = tiisg/(NW/NQ); + const short il = tiisg%(NW/NQ); - device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; + const int ib0 = sgitg*NQ + ix; - // each thread in a SIMD group deals with NB_Q8_0 quants at a time - for (int ib = ix; ib < nb; ib += nw/4) { - for (short i = 0; i < NB_Q8_0; ++i) { + float yl[NQ]; + + device const float * yb = y + ib0*QK8_0 + il*NQ; + + // each thread in a SIMD group deals with NQ quants at a time + for (int ib = ib0; ib < nb; ib += NSG*NQ) { + for (short i = 0; i < NQ; ++i) { yl[i] = yb[i]; } - for (short row = 0; row < nr0; row++) { - device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; + for (short row = 0; row < NR0; row++) { + device const int8_t * qs = ax[row][ib].qs + il*NQ; + float sumq = 0.f; - for (short iq = 0; iq < NB_Q8_0; ++iq) { - sumq += qs[iq] * yl[iq]; + FOR_UNROLL (short i = 0; i < NQ; ++i) { + sumq += qs[i] * yl[i]; } + sumf[row] += sumq*ax[row][ib].d; } - yb += nw*NB_Q8_0; + yb += NSG*NQ*QK8_0; } device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0; ++row) { - const float tot = simd_sum(sumf[row]); - - if (tiisg == 0 && first_row + row < args.ne01) { - dst_f32[first_row + row] = tot; - } - } + helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } [[host_name("kernel_mul_mv_q8_0_f32")]] @@ -2935,15 +3062,16 @@ kernel void kernel_mul_mv_q8_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } // mat-vec kernel processing in chunks of float4 // chpb - chunks per quantization block -template +template void kernel_mul_mv_ext_q4_f32_impl( constant ggml_metal_kargs_mul_mv_ext & args, device const char * src0, @@ -2952,6 +3080,9 @@ void kernel_mul_mv_ext_q4_f32_impl( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short NSG = FC_mul_mv_nsg; + const short nxpsg = FC_mul_mv_nxpsg; + const short chpt = 4; // chunks per thread //const short nxpsg = (32); @@ -2960,7 +3091,7 @@ void kernel_mul_mv_ext_q4_f32_impl( const short tx = tiisg%nxpsg; const short ty = tiisg/nxpsg; - const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty; const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; @@ -3001,7 +3132,6 @@ void kernel_mul_mv_ext_q4_f32_impl( #pragma unroll(r1ptg) for (short ir1 = 0; ir1 < r1ptg; ++ir1) { sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]); - } } @@ -3044,7 +3174,7 @@ void kernel_mul_mv_ext_q4_f32_impl( } // mat-vec kernel processing in chunks of float4x4 -template +template void kernel_mul_mv_ext_q4x4_f32_impl( constant ggml_metal_kargs_mul_mv_ext & args, device const char * src0, @@ -3053,6 +3183,9 @@ void kernel_mul_mv_ext_q4x4_f32_impl( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short NSG = FC_mul_mv_nsg; + const short nxpsg = FC_mul_mv_nxpsg; + const short chpt = 1; //const short nxpsg = (32); @@ -3061,7 +3194,7 @@ void kernel_mul_mv_ext_q4x4_f32_impl( const short tx = tiisg%nxpsg; const short ty = tiisg/nxpsg; - const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty; const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; @@ -3158,12 +3291,7 @@ kernel void kernel_mul_mv_ext_q4_f32_disp( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - switch (args.nxpsg) { - case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - } + kernel_mul_mv_ext_q4_f32_impl(args, src0, src1, dst, tgpig, tiisg, sgitg); } template @@ -3175,17 +3303,17 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - switch (args.nxpsg) { - case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - } + kernel_mul_mv_ext_q4x4_f32_impl(args, src0, src1, dst, tgpig, tiisg, sgitg); } typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t; typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>; + template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>; template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>; template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>; @@ -3241,106 +3369,253 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4 template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>; template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>; -#define N_MV_T_T 4 - -template -void kernel_mul_mv_impl( +template +void kernel_mul_mv_t_t_impl( args_t args, device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem, uint3 tgpig, - ushort tiisg) { - const int r0 = tgpig.x; - const int rb = tgpig.y*N_MV_T_T; + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + constexpr short NW = N_SIMDWIDTH; + constexpr short NB = 32; + constexpr short NF = 8; + + const int nb = args.ne00/NB; + + const int r0 = tgpig.x*NR0; + const int r1 = tgpig.y; const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const T0 * x = (device const T0 *) (src0 + offset0); + //device const T0 * x = (device const T0 *) (src0 + offset0); + device const T1 * y = (device const T1 *) (src1 + offset1); - device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + // pointers to src0 rows + device const T0 * ax [NR0]; + FOR_UNROLL (short row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - if (args.ne00 < 128) { - for (int row = 0; row < N_MV_T_T; ++row) { - int r1 = rb + row; - if (r1 >= args.ne11) { - break; - } + ax[row] = (device const T0 *) ((device char *) src0 + offset0); + } - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + float sumf[NR0] = { 0.f }; - device const T1 * y = (device const T1 *) (src1 + offset1); + const short ix = tiisg/(NW/NF); + const short il = tiisg%(NW/NF); - float sumf = 0; - for (int i = tiisg; i < args.ne00; i += 32) { - sumf += (T0) x[i] * (T1) y[i]; - } + const int ib0 = sgitg*NF + ix; - float sum_all = simd_sum(sumf); - if (tiisg == 0) { - dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; - } + T1 yl[NF]; + + device const T1 * yb = y + (ib0*NB + il*NF); + + for (int ib = ib0; ib < nb; ib += NSG*NF) { + for (short i = 0; i < NF; ++i) { + yl[i] = yb[i]; } - } else { - device const T04 * x4 = (device const T04 *) x; - for (int row = 0; row < N_MV_T_T; ++row) { - int r1 = rb + row; - if (r1 >= args.ne11) { - break; + + for (short row = 0; row < NR0; row++) { + device const T0 * xb = ax[row] + (ib*NB + il*NF); + + float sumq = 0.f; + FOR_UNROLL (short i = 0; i < NF; ++i) { + sumq += xb[i] * yl[i]; } - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - - device const T1 * y = (device const T1 *) (src1 + offset1); - device const T14 * y4 = (device const T14 *) y; - - float sumf = 0; - for (int i = tiisg; i < args.ne00/4; i += 32) { - sumf += dot((float4) x4[i], (float4) y4[i]); - } - - float sum_all = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); - dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; - } + sumf[row] += sumq; } + + yb += NSG*NF*NW; + } + + for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) { + for (short row = 0; row < NR0; row++) { + sumf[row] += ax[row][i] * y[i]; + } + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); +} + +template +void kernel_mul_mv_t_t_disp( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + switch (args.nr0) { + //case 1: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + case 2: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 3: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 4: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; } } -template -kernel void kernel_mul_mv( +template +kernel void kernel_mul_mv_t_t( constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_impl( - args, - src0, - src1, - dst, - tgpig, - tiisg); + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_t_t_disp(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(kernel_mul_mv) mul_mv_t; +typedef decltype(kernel_mul_mv_t_t) mul_mv_t_t; -template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv; -template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv; -template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv; -template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; #endif -template -void kernel_mul_mv_c4_impl( +template +void kernel_mul_mv_t_t_4_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + constexpr short NW = N_SIMDWIDTH; + constexpr short NB = 32; + constexpr short NF = 16; + constexpr short NF4 = NF/4; + + const int nb = args.ne00/NB; + + const int r0 = tgpig.x*NR0; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); + device const T14 * y4 = (device const T14 *) (src1 + offset1); + + // pointers to src0 rows + device const T0 * ax [NR0]; + device const T04 * ax4[NR0]; + FOR_UNROLL (short row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + ax [row] = (device const T0 *) ((device char *) src0 + offset0); + ax4[row] = (device const T04 *) ((device char *) src0 + offset0); + } + + float sumf[NR0] = { 0.f }; + + const short ix = tiisg/(NW/NF); + const short il = tiisg%(NW/NF); + + const int ib0 = sgitg*NF + ix; + + T14 yl4[NF4]; + + device const T14 * yb4 = y4 + (ib0*NB + il*NF)/4; + + for (int ib = ib0; ib < nb; ib += NSG*NF) { + for (short i = 0; i < NF4; ++i) { + yl4[i] = yb4[i]; + } + + for (short row = 0; row < NR0; row++) { + device const T04 * xb4 = ax4[row] + (ib*NB + il*NF)/4; + + float sumq = 0.f; + FOR_UNROLL (short i = 0; i < NF4; ++i) { + sumq += dot(float4(xb4[i]), float4(yl4[i])); + } + + sumf[row] += sumq; + } + + yb4 += NSG*NF*NW/4; + } + + for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) { + for (short row = 0; row < NR0; row++) { + sumf[row] += ax[row][i] * y[i]; + } + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); +} + +template +void kernel_mul_mv_t_t_4_disp( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + switch (args.nr0) { + //case 1: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + case 2: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 3: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 4: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + }; +} + +template +kernel void kernel_mul_mv_t_t_4( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_t_t_4_disp(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +typedef decltype(kernel_mul_mv_t_t_4) mul_mv_t_t_4; + +template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +#endif + +template +void kernel_mul_mv_t_t_short_impl( args_t args, device const char * src0, device const char * src1, @@ -3348,7 +3623,7 @@ void kernel_mul_mv_c4_impl( uint3 tgpig, ushort tiisg) { const int r0 = tgpig.x*32 + tiisg; - const int rb = tgpig.y*N_MV_T_T; + const int r1 = tgpig.y; const int im = tgpig.z; if (r0 >= args.ne01) { @@ -3360,33 +3635,32 @@ void kernel_mul_mv_c4_impl( const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - device const T04 * x = (device const T04 *) (src0 + offset0); + device const T0 * x = (device const T0 *) (src0 + offset0); device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; - for (int row = 0; row < N_MV_T_T; ++row) { - int r1 = rb + row; - if (r1 >= args.ne11) { - break; - } + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + device const T1 * y = (device const T1 *) (src1 + offset1); - device const T14 * y = (device const T14 *) (src1 + offset1); + float res = 0.0f; - dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]); + for (int i = 0; i < args.ne00; ++i) { + res += (float) x[i] * (float) y[i]; } + + dst_f32[(uint64_t)r1*args.ne0 + r0] = res; } -template -kernel void kernel_mul_mv_c4( +template +kernel void kernel_mul_mv_t_t_short( constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_c4_impl( + kernel_mul_mv_t_t_short_impl( args, src0, src1, @@ -3395,116 +3669,14 @@ kernel void kernel_mul_mv_c4( tiisg); } -typedef decltype(kernel_mul_mv_c4) mul_mv_c4_t; +typedef decltype(kernel_mul_mv_t_t_short) mul_mv_t_t_short_t; -template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -#endif - -template -kernel void kernel_mul_mv_1row( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]]) { - - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; - - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - - device const T * x = (device const T *) (src0 + offset0); - device const float * y = (device const float *) (src1 + offset1); - - device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - - float sumf = 0; - if (args.ne00 < 128) { - for (int i = tiisg; i < args.ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - float sum_all = simd_sum(sumf); - if (tiisg == 0) { - dst_f32[r0] = sum_all; - } - } else { - device const T4 * x4 = (device const T4 *) x; - device const float4 * y4 = (device const float4 *) y; - - for (int i = tiisg; i < args.ne00/4; i += 32) { - sumf += dot((float4) x4[i], y4[i]); - } - - float sum_all = simd_sum(sumf); - - if (tiisg == 0) { - for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); - dst_f32[r0] = sum_all; - } - } -} - -typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t; - -template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; -#endif - -// Assumes row size (ne00) is a multiple of 4 -template -kernel void kernel_mul_mv_l4( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]]) { - - const int nrows = args.ne11; - const int r0 = tgpig.x; - const int im = tgpig.z; - - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; - - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - - device const T4 * x4 = (device const T4 *) (src0 + offset0); - - device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; - - for (int r1 = 0; r1 < nrows; ++r1) { - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - - device const float4 * y4 = (device const float4 *) (src1 + offset1); - - float sumf = 0; - for (int i = tiisg; i < args.ne00/4; i += 32) { - sumf += dot((float4) x4[i], y4[i]); - } - - float sum_all = simd_sum(sumf); - if (tiisg == 0) { - dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; - } - } -} - -typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t; - -template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; +template [[host_name("kernel_mul_mv_f32_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_f16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_f16_f16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; #endif static float rope_yarn_ramp(const float low, const float high, const int i0) { @@ -3807,9 +3979,9 @@ template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t ker template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision; typedef void (im2col_t)( + constant ggml_metal_kargs_im2col & args, device const float * x, device char * dst, - constant ggml_metal_kargs_im2col & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -3817,9 +3989,9 @@ typedef void (im2col_t)( template kernel void kernel_im2col( + constant ggml_metal_kargs_im2col & args, device const float * x, device char * dst, - constant ggml_metal_kargs_im2col & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -3828,11 +4000,10 @@ kernel void kernel_im2col( const int64_t OH = tgpg[1]; const int64_t OW = tgpg[2]; -// const int64_t N = ntg[0]; const int64_t KH = ntg[1]; const int64_t KW = ntg[2]; - const int64_t in = tpitg[0]; + int64_t in = tpitg[0]; const int64_t ikh = tpitg[1]; const int64_t ikw = tpitg[2]; @@ -3843,88 +4014,102 @@ kernel void kernel_im2col( const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0; const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1; - const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); + int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); device T * pdst = (device T *) (dst); if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { - pdst[offset_dst] = 0.0f; + while (in < args.N) { + pdst[offset_dst] = 0.0f; + offset_dst += ntg[0]*args.CHW*OH*OW; + + in += ntg[0]; + } } else { - const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; - pdst[offset_dst] = x[offset_src]; + int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; + + while (in < args.N) { + pdst[offset_dst] = x[offset_src]; + + offset_dst += ntg[0]*args.CHW*OH*OW; + offset_src += ntg[0]*args.ofs0; + + in += ntg[0]; + } } } template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; -typedef void (im2col_ext_t)( - device const float * x, - device char * dst, - constant ggml_metal_kargs_im2col & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]); - -template -kernel void kernel_im2col_ext( - device const float * x, - device char * dst, - constant ggml_metal_kargs_im2col & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] - const int64_t KHW = (int64_t)args.KHW; - - const int64_t d = tgpig[0] / args.CHW; - const int64_t chw = tgpig[0] % args.CHW; - const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) - const int64_t HW = tgpig[0] % KHW; - - const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; - if (tpitg_0 >= args.N) { - return; - } - - const int64_t tpitg_1 = HW / args.KW; - const int64_t tpitg_2 = HW % args.KW; - - const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; - const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; - - const int64_t offset_dst = - (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + - (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); - - device T * pdst = (device T *) (dst); - - if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { - pdst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; - pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; - } -} - -template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; -template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; +// TODO: obolete -- remove +//typedef void (im2col_ext_t)( +// constant ggml_metal_kargs_im2col & args, +// device const float * x, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// uint3 tgpg[[threadgroups_per_grid]], +// uint3 tpitg[[thread_position_in_threadgroup]], +// uint3 ntg[[threads_per_threadgroup]]); +// +//template +//kernel void kernel_im2col_ext( +// constant ggml_metal_kargs_im2col & args, +// device const float * x, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW +// uint3 tpitg[[thread_position_in_threadgroup]], +// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] +// const int64_t KHW = (int64_t)args.KHW; +// +// const int64_t d = tgpig[0] / args.CHW; +// const int64_t chw = tgpig[0] % args.CHW; +// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) +// const int64_t HW = tgpig[0] % KHW; +// +// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; +// if (tpitg_0 >= args.N) { +// return; +// } +// +// const int64_t tpitg_1 = HW / args.KW; +// const int64_t tpitg_2 = HW % args.KW; +// +// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; +// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; +// +// const int64_t offset_dst = +// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + +// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); +// +// device T * pdst = (device T *) (dst); +// +// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { +// pdst[offset_dst] = 0.0f; +// } else { +// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; +// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; +// } +//} +// +//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; +//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; typedef void (conv_transpose_1d_t)( + constant ggml_metal_kargs_conv_transpose_1d & args, device const float * src0, device const float * src1, device char * dst, - constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); template kernel void kernel_conv_transpose_1d( + constant ggml_metal_kargs_conv_transpose_1d & args, device const T * src0, device const float * src1, device char * dst, - constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]) { @@ -3948,26 +4133,26 @@ kernel void kernel_conv_transpose_1d( template [[host_name("kernel_conv_transpose_1d_f32_f32")]] kernel void kernel_conv_transpose_1d( + constant ggml_metal_kargs_conv_transpose_1d & args, device const float * src0, device const float * src1, device char * dst, - constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); template [[host_name("kernel_conv_transpose_1d_f16_f32")]] kernel void kernel_conv_transpose_1d( + constant ggml_metal_kargs_conv_transpose_1d & args, device const half * src0, device const float * src1, device char * dst, - constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); kernel void kernel_upscale_f32( + constant ggml_metal_kargs_upscale & args, device const char * src0, device char * dst, - constant ggml_metal_kargs_upscale & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -3991,9 +4176,9 @@ kernel void kernel_upscale_f32( } kernel void kernel_pad_f32( + constant ggml_metal_kargs_pad & args, device const char * src0, device char * dst, - constant ggml_metal_kargs_pad & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -4027,9 +4212,9 @@ kernel void kernel_pad_f32( } kernel void kernel_pad_reflect_1d_f32( + constant ggml_metal_kargs_pad_reflect_1d & args, device const char * src0, device char * dst, - constant ggml_metal_kargs_pad_reflect_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -4060,8 +4245,8 @@ kernel void kernel_pad_reflect_1d_f32( } kernel void kernel_arange_f32( - device char * dst, constant ggml_metal_kargs_arange & args, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -4074,9 +4259,9 @@ kernel void kernel_arange_f32( } kernel void kernel_timestep_embedding_f32( + constant ggml_metal_kargs_timestep_embedding & args, device const char * src0, device char * dst, - constant ggml_metal_kargs_timestep_embedding & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -4094,25 +4279,25 @@ kernel void kernel_timestep_embedding_f32( } if (args.dim % 2 != 0 && tpitg.x == 0) { - embed_data[args.dim] = 0.f; + embed_data[2 * half_] = 0.f; } } // bitonic sort implementation following the CUDA kernels as reference typedef void (argsort_t)( - device const float * x, - device int32_t * dst, constant ggml_metal_kargs_argsort & args, + device const float * x, + device int32_t * dst, threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]); template kernel void kernel_argsort_f32_i32( - device const float * x, - device int32_t * dst, constant ggml_metal_kargs_argsort & args, - threadgroup int32_t * shared_values [[threadgroup(0)]], + device const float * x, + device int32_t * dst, + threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]) { // bitonic sort @@ -4161,17 +4346,236 @@ kernel void kernel_argsort_f32_i32( } } +typedef void (i32_argsort_t)( + constant ggml_metal_kargs_argsort & args, + device const int32_t * x, + device int32_t * dst, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +template +kernel void kernel_argsort_i32_i32( + constant ggml_metal_kargs_argsort & args, + device const int32_t * x, + device int32_t * dst, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + // bitonic sort + int col = tpitg[0]; + int row = tgpig[1]; + + if (col >= args.ncols_pad) return; + + device const int32_t * x_row = x + row * args.ncols; + threadgroup int32_t * dst_row = shared_values; + + // initialize indices + dst_row[col] = col; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 2; k <= args.ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= args.ncols || + (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= args.ncols || + (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + // copy the result to dst without the padding + if (col < args.ncols) { + dst[row * args.ncols + col] = dst_row[col]; + } +} + template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_i32_i32_asc")]] kernel i32_argsort_t kernel_argsort_i32_i32; +template [[host_name("kernel_argsort_i32_i32_desc")]] kernel i32_argsort_t kernel_argsort_i32_i32; kernel void kernel_leaky_relu_f32( + constant ggml_metal_kargs_leaky_relu & args, device const float * src0, device float * dst, - constant ggml_metal_kargs_leaky_relu & args, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope; + const float x = src0[tpig]; + dst[tpig] = x > 0.0f ? x : x * args.slope; } +kernel void kernel_leaky_relu_f32_4( + constant ggml_metal_kargs_leaky_relu & args, + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const float4 x = src0[tpig]; + dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope); +} + +constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; + +constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]]; + +// pad the last chunk of C elements of k and v into a an extra pad buffer +kernel void kernel_flash_attn_ext_pad( + constant ggml_metal_kargs_flash_attn_ext_pad & args, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int32_t C = FC_flash_attn_ext_pad_ncpsg; + + device char * k_pad = dst; + device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3; + device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const int32_t icp = args.ne11 % C; + const int32_t ic0 = args.ne11 - icp; + + const int32_t i1 = tgpig[0]; + const int32_t i2 = tgpig[1]; + const int32_t i3 = tgpig[2]; + + if (i2 < args.ne_12_2 && i3 < args.ne_12_3) { + device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3; + device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3; + + device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3; + device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3; + + if (i1 >= icp) { + // here it is not important the exact value that will be used as we rely on masking out the scores in the attention + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = 0; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = 0; + } + } else { + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = k_src[i]; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = v_src[i]; + } + } + } + + if (FC_flash_attn_ext_pad_has_mask) { + if (i2 < args.ne32 && i3 < args.ne33) { + for (int ib = i1; ib < args.ne31; ib += C) { + device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0; + device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3; + + for (int i = tiitg; i < C; i += ntg.x) { + if (i >= icp) { + mask_dst[i] = -MAXHALF; + } else { + mask_dst[i] = mask_src[i]; + } + } + } + } + } +} + +constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]]; +constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]]; + +// scan the blocks of the mask that are not masked +// 0 - masked (i.e. full of -INF, skip) +// 1 - not masked (i.e. at least one element of the mask is not -INF) +kernel void kernel_flash_attn_ext_blk( + constant ggml_metal_kargs_flash_attn_ext_blk & args, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + // block size C x Q + const int32_t Q = FC_flash_attn_ext_blk_nqptg; + const int32_t C = FC_flash_attn_ext_blk_ncpsg; + + constexpr short NW = N_SIMDWIDTH; + + const int32_t i3 = tgpig[2]/args.ne32; + const int32_t i2 = tgpig[2]%args.ne32; + const int32_t i1 = tgpig[1]; + const int32_t i0 = tgpig[0]; + + char res = i0*C + C > args.ne30 ? 1 : 0; + + device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg; + + // fast route + if (res == 0) { + if (simd_max(*mask_src) > -MAXHALF/2) { + res = 1; + } + } + + // detailed check of the elements of the block + if ((C > NW || Q > 1) && res == 0) { + half m = -MAXHALF; + + FOR_UNROLL (short j = 0; j < Q; ++j) { + FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) { + m = max(m, mask_src[ii*NW]); + } + + mask_src += args.nb31/2; + } + + if (simd_max(m) > -MAXHALF/2) { + res = 1; + } + } + + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne30 + C - 1)/C); + + if (tiisg == 0) { + dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res; + } +} + +constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; +constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; +constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; +constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]]; +constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]]; + +constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; + +//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; +//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]]; +//constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]]; + +constant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]]; +constant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]]; +constant int32_t FC_flash_attn_ext_nsg [[function_constant(FC_FLASH_ATTN_EXT + 22)]]; + // ref: https://arxiv.org/pdf/2307.08691.pdf template< typename q_t, // query types in shared memory @@ -4186,6 +4590,7 @@ template< typename qk_t, // Q*K types typename qk8x8_t, typename s_t, // soft-max types + typename s2_t, typename s8x8_t, typename o_t, // attention accumulation types typename o4_t, @@ -4196,59 +4601,107 @@ template< typename vd4x4_t, // value type in device memory short nl_v, void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), - short DK, // K head size - short DV, // V head size - short Q = 8, // queries per threadgroup - short KV = 8, // key/value processed per each simdgroup - short C = 32> // cache items per threadgroup -kernel void kernel_flash_attn_ext( + short DK, // K head size + short DV, // V head size + short Q, // queries per threadgroup + short C, // cache items per threadgroup + short NSG> // number of simd groups +void kernel_flash_attn_ext_impl( constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, device const char * k, device const char * v, device const char * mask, device const char * sinks, + device const char * pad, + device const char * blk, device char * dst, - threadgroup half * shmem_f16 [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups + threadgroup half * shmem_f16, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const ushort iq3 = tgpig[2]; + const ushort iq2 = tgpig[1]; + const ushort iq1 = tgpig[0]*Q; - const int iq3 = tgpig[2]; - const int iq2 = tgpig[1]; - const int iq1 = tgpig[0]*Q; +#define NS10 (FC_flash_attn_ext_ns10) +#define NS20 (FC_flash_attn_ext_ns20) + + // note: I had some concerns that using this instead of the ugly macros above was affecting performance + // need to re-check carefully and if no regressions are observerd - remove the macros + // the concerns is that maybe using const variables requires extra registers? but not sure if the compiler + // is clever enough to avoid this. unfortunately, using constexpr is not possible with FC + //const short NS10 = FC_flash_attn_ext_ns10; + //const short NS20 = FC_flash_attn_ext_ns20; + + constexpr short KV = 8; constexpr short DK4 = DK/4; constexpr short DK8 = DK/8; constexpr short DK16 = DK/16; constexpr short DV4 = DV/4; - constexpr short DV8 = DV/8; + //constexpr short DV8 = DV/8; constexpr short DV16 = DV/16; + constexpr short PV = PAD2(DV, 64); + constexpr short PV4 = PV/4; + constexpr short PV8 = PV/8; + //constexpr short PV16 = PV/16; + constexpr short NW = N_SIMDWIDTH; - constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) + constexpr short NQ = Q/NSG; + constexpr short SH = 2*C; // shared memory per simdgroup (s_t == float) - const short TS = nsg*SH; // shared memory size per query in (s_t == float) - const short T = 2*DK + 2*TS; // shared memory size per query in (half) + constexpr short TS = 2*SH; + constexpr short T = DK + 2*PV; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*T); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper) + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK); + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix + threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t - threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory - threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t - threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory - threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - o8x8_t lo[DV8]; + // mask storage in shared mem + threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C); + + // per-query mask pointers + device const half2 * pm2[NQ]; + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); + } + + { + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne11 + C - 1)/C); + + blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0; + } + + { + q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += ikv2*args.nb12 + ikv3*args.nb13; + v += ikv2*args.nb22 + ikv3*args.nb23; + } // load heads from Q to shared memory - for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01); for (short i = tiisg; i < DK4; i += NW) { if (iq1 + j < args.ne01) { @@ -4259,43 +4712,30 @@ kernel void kernel_flash_attn_ext( } } - // zero out lo - for (short i = 0; i < DV8; ++i) { - lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f); - } + // zero out + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] = 0; + } - // zero out shared memory SH - for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*TS + i] = 0.0f; + ss[j*SH + i] = 0.0f; } } threadgroup_barrier(mem_flags::mem_threadgroup); + float S[NQ] = { [0 ... NQ-1] = 0.0f }; + { - float S[Q] = { [0 ... Q-1] = 0.0f }; - float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 }; - - // thread indices inside the simdgroup - // TODO: see if we can utilize quad-group functions for better performance - // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3) - const short tx = tiisg%4; - const short ty = tiisg/4; - - // broadcast kv - //const short rk2 = args.ne02/args.ne12; - //const short rk3 = args.ne03/args.ne13; - - const short ikv2 = iq2/(args.ne02/args.ne_12_2); - const short ikv3 = iq3/(args.ne03/args.ne_12_3); - - const bool has_mask = mask != q; + float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 }; float slope = 1.0f; // ALiBi - if (args.max_bias > 0.0f) { + if (FC_flash_attn_ext_has_bias) { const short h = iq2; const float base = h < args.n_head_log2 ? args.m0 : args.m1; @@ -4306,177 +4746,354 @@ kernel void kernel_flash_attn_ext( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { - const int ic = ic0 + C*sgitg; + for (int ic0 = 0; ; ++ic0) { + int ic = ic0*C; if (ic >= args.ne11) { break; } - if (has_mask) { - // used to detect blocks full of -INF - float smax = -INFINITY; + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; - // load the mask in shared memory - #pragma unroll(Q) - for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); - const float m = pm[ic + tiisg]; + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; - ss[j*TS + C + tiisg] = m; - smax = max(smax, m); + if (!FC_flash_attn_ext_has_mask) { + threadgroup half * sm = (threadgroup half *) (sm2); + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + for (short i = tiisg; i < C; i += NW) { + if (ic + i >= args.ne11) { + sm[2*j*SH + i] = -MAXHALF; + } + } + } + } else { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + pm2[jj] = (device const half2 *) ((device const half *) mask + + (iq1 + j)*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32)); + } } - smax = simd_max(smax); + ic = 0; + } + + // read the mask into shared mem + if (FC_flash_attn_ext_has_mask) { + if (blk[ic0] == 0) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + pm2[jj] += NW; + } - if (smax == -INFINITY) { continue; } + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + if (FC_flash_attn_ext_bc_mask) { + sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); + } else { + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + } + + pm2[jj] += NW; + } + +#if 0 + // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // used to detect blocks full of -INF + // skip only when the entire threadgroup is masked + half2 smax2(-MAXHALF/2, -MAXHALF/2); + + FOR_UNROLL (short j = 0; j < Q; ++j) { + smax2 = max(smax2, sm2[j*SH + tiisg]); + } + + smax2 = simd_max(smax2); + + if (max(smax2[0], smax2[1]) <= -MAXHALF/2) { + // this barrier is important + threadgroup_barrier(mem_flags::mem_threadgroup); + + continue; + } +#endif } // Q*K^T - { - for (short cc = 0; cc < C/8; ++cc) { + // this is compile-time check, so it does not have runtime overhead + if (is_same::value) { + // we can read directly from global memory + device const k_t * pk = (device const k_t *) (k + ic*args.nb11); + threadgroup const q_t * pq = sq; + threadgroup s_t * ps = ss; + + pk += sgitg*(8*NS10); + ps += sgitg*(8*1); + + static_assert((C/8) % NSG == 0, ""); + + constexpr short NC = (C/8)/NSG; + + // note: do not unroll for large heads + #pragma unroll (DK <= 64 ? NC : 1) + for (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - // this is compile-time check, so it does not have runtime overhead - if (is_same::value) { - // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + if (DK % 16 != 0) { + k8x8_t mk; + q8x8_t mq; - #pragma unroll(DK8) - for (short i = 0; i < DK8; ++i) { - k8x8_t mk; - simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + FOR_UNROLL (short i = 0; i < DK8; ++i) { + simdgroup_barrier(mem_flags::mem_none); + + simdgroup_load(mk, pk + 8*i, NS10, 0, true); + simdgroup_load(mq, pq + 8*i, DK); + + simdgroup_barrier(mem_flags::mem_none); - q8x8_t mq; - simdgroup_load(mq, sq + i*8, DK); simdgroup_multiply_accumulate(mqk, mq, mk, mqk); } } else { - for (short ii = 0; ii < DK16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + k8x8_t mk[2]; + q8x8_t mq[2]; - if (DK16%4 == 0) { - // the head is evenly divisible by 4*16 = 64, so no need for bound checks - { - k4x4_t tmp; - deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); - sk4x4[4*ty + tx] = tmp; - } + FOR_UNROLL (short i = 0; i < DK8/2; ++i) { + simdgroup_barrier(mem_flags::mem_none); - simdgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); + simdgroup_load(mq[1], pq + 1*8 + 16*i, DK); - #pragma unroll(4) - for (short k = 0; k < 4; ++k) { - k8x8_t mk; - q8x8_t mq; + simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true); + simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true); - simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - } - } else { - if (ii + tx < DK16) { - k4x4_t tmp; - deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); - sk4x4[4*ty + tx] = tmp; - } + simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); + simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); + } + } - simdgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_store(mqk, ps, SH, 0, false); - for (short k = 0; k < 4 && ii + k < DK16; ++k) { - k8x8_t mk; - q8x8_t mq; + pk += 8*(NSG*NS10); + ps += 8*(NSG); + } + } else { + // TODO: this is the quantized K cache branch - not optimized yet + for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) { + const short cc = ccc*NSG + sgitg; - simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + const short tx = tiisg%4; + const short ty = tiisg/4; - simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - } + qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); + + for (short ii = 0; ii < DK16; ii += 4) { + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11)); + + if (DK16%4 == 0) { + // the head is evenly divisible by 4*16 = 64, so no need for bound checks + { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short k = 0; k < 4; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + if (ii + tx < DK16) { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < DK16; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); } } } - // cast qk_t -> s_t - //s8x8_t mqks(1.0f); - //simdgroup_multiply(mqks, mqk, mqks); - //simdgroup_store(mqks, ss + 8*cc, TS, 0, false); - - simdgroup_store(mqk, ss + 8*cc, TS, 0, false); + simdgroup_store(mqk, ss + 8*cc, SH, 0, false); } } + threadgroup_barrier(mem_flags::mem_threadgroup); + // online softmax - { - for (ushort j = 0; j < Q; ++j) { - const float m = M[j]; + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - // scale and apply the logitcap / mask - float s = ss[j*TS + tiisg]*args.scale; + const float m = M[jj]; - if (args.logit_softcap != 0.0f) { - s = args.logit_softcap*precise::tanh(s); + // scale and apply the logitcap / mask + float2 s2 = ss2[j*SH/2 + tiisg]*args.scale; + + if (FC_flash_attn_ext_has_scap) { + s2 = args.logit_softcap*precise::tanh(s2); + } + + // mqk = mqk + slope*mask + if (FC_flash_attn_ext_has_bias) { + s2 += s2_t(sm2[j*SH + tiisg])*slope; + } else { + s2 += s2_t(sm2[j*SH + tiisg]); + } + + M[jj] = simd_max(max(M[jj], max(s2[0], s2[1]))); + + const float ms = exp(m - M[jj]); + const float2 vs2 = exp(s2 - M[jj]); + + S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]); + + // the P matrix from the paper (Q rows, C columns) + ss2[j*SH/2 + tiisg] = vs2; + + if (DV4 % NW == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { + const short i = ii*NW + tiisg; + + so4[j*PV4 + i] *= ms; } - - // mqk = mqk + mask*slope - s += slope*ss[j*TS + C + tiisg]; - - M[j] = simd_max(max(M[j], s)); - - const float ms = exp(m - M[j]); - const float vs = exp(s - M[j]); - - S[j] = S[j]*ms + simd_sum(vs); - - // the P matrix from the paper (Q rows, C columns) - ss[j*TS + tiisg] = vs; - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*TS + 2*C + j] = ms; + } else { + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] *= ms; } } } - // O = diag(ms)*O - { - s8x8_t ms; - simdgroup_load(ms, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_multiply(lo[i], ms, lo[i]); - } - } + threadgroup_barrier(mem_flags::mem_threadgroup); // O = O + (Q*K^T)*V { - for (short cc = 0; cc < C/8; ++cc) { - s8x8_t vs; - simdgroup_load(vs, ss + 8*cc, TS, 0, false); + // we can read directly from global memory + if (is_same::value) { + static_assert(PV8 % NSG == 0, ""); - if (is_same::value) { - // we can read directly from global memory - device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + constexpr short NO = PV8/NSG; - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - v8x8_t mv; - simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 + o8x8_t lo[NO]; - simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]); + { + auto sot = so + 8*sgitg; + + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + simdgroup_load(lo[ii], sot, PV, 0, false); + + sot += 8*NSG; } - } else { - for (short ii = 0; ii < DV16; ii += 4) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + } + + { + device const v_t * pv = (device const v_t *) (v + ic*args.nb21); + + pv += 8*sgitg; + + if (DV <= 64) { + FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, SH, 0, false); + + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[2]; + + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false); + + simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]); + } + + pv += 8*NS20; + } + } else { + FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) { + s8x8_t vs[2]; + + simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); + simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false); + + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[4]; + + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]); + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]); + } + + pv += 2*8*NS20; + } + } + } + + { + auto sot = so + 8*sgitg; + + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + simdgroup_store(lo[ii], sot, PV, 0, false); + + sot += 8*NSG; + } + } + } else { + // TODO: this is the quantized V cache branch - not optimized yet + + const short tx = tiisg%4; + const short ty = tiisg/4; + + for (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, SH, 0, false); + + for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) { + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21)); if (DV16%4 == 0) { // no need for bound checks @@ -4488,15 +5105,20 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); - #pragma unroll(4) - for (short k = 0; k < 4; ++k) { - v8x8_t mv; + FOR_UNROLL (short k = 0; k < 4; ++k) { + v8x8_t mv[2]; + o8x8_t lo[2]; - simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); + simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); - simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); + simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]); + simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]); + + simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); } } else { if (ii + tx < DV16) { @@ -4508,236 +5130,252 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); for (short k = 0; k < 4 && ii + k < DV16; ++k) { - v8x8_t mv; + v8x8_t mv[2]; + o8x8_t lo[2]; - simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); + simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); - simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); + simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]); + simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]); + + simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); } } } } } } + + threadgroup_barrier(mem_flags::mem_threadgroup); } - if (sinks != q && sgitg == 0) { - for (ushort j = 0; j < Q; ++j) { - const float m = M[j]; + if (FC_flash_attn_ext_has_sinks) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + const float m = M[jj]; const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; - M[j] = simd_max(max(M[j], s)); + M[jj] = simd_max(max(M[jj], s)); - const float ms = exp(m - M[j]); - const float vs = exp(s - M[j]); + const float ms = exp(m - M[jj]); + const float vs = exp(s - M[jj]); - S[j] = S[j]*ms + simd_sum(vs); + S[jj] = S[jj]*ms + simd_sum(vs); - if (tiisg == j) { - ss[j*TS + 2*C + j] = ms; - } - } - - // O = diag(ms)*O - { - s8x8_t ms; - simdgroup_load(ms, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_multiply(lo[i], ms, lo[i]); + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] *= ms; } } } - - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (short j = tiisg; j < Q; j += NW) { - ss[j*TS + 0] = S[j]; - ss[j*TS + 1] = M[j]; - } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation - threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK); - - // store result to shared memory in F32 - if (sgitg == 0) { - for (short i = 0; i < DV8; ++i) { - //simdgroup_store(lo[i], so + i*8, DV, 0, false); - simdgroup_float8x8 t(1.0f); - simdgroup_multiply(t, lo[i], t); - simdgroup_store(t, so + i*8, DV, 0, false); + // store to global memory + for (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + if (iq1 + j >= args.ne01) { + break; } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // reduce the warps sequentially - for (ushort sg = 1; sg < nsg; ++sg) { - if (sgitg == sg) { - for (short j = tiisg; j < Q; j += NW) { - const float S0 = ss[j*TS - 1*SH + 0]; - const float S1 = ss[j*TS + 0]; - - const float M0 = ss[j*TS - 1*SH + 1]; - const float M1 = ss[j*TS + 1]; - - const float M = max(M0, M1); - - float ms0 = exp(M0 - M); - float ms1 = exp(M1 - M); - - const float S = S0*ms0 + S1*ms1; - - ss[j*TS + 0] = S; - ss[j*TS + 1] = M; - - ss[j*TS + 2*C + j - 1*SH] = ms0; - ss[j*TS + 2*C + j ] = ms1; - } - - //simdgroup_barrier(mem_flags::mem_threadgroup); - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - { - s8x8_t ms0; - s8x8_t ms1; - - simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false); - simdgroup_load(ms1, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_float8x8 t; - - simdgroup_load (t, so + i*8, DV, 0, false); - simdgroup_multiply(t, ms0, t); - - simdgroup_multiply_accumulate(t, ms1, lo[i], t); - simdgroup_store(t, so + i*8, DV, 0, false); - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK); - - // final rescale with 1/S and store to global memory - for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { - const float S = 1.0f/sf[j*TS + 0]; device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; - for (short i = tiisg; i < DV4; i += NW) { - dst4[i] = (float4) so4[j*DV4 + i]*S; + const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj]; + + if (DV4 % NW == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { + const short i = ii*NW + tiisg; + + dst4[i] = (float4) so4[j*PV4 + i]*scale; + } + } else { + for (short i = tiisg; i < DV4; i += NW) { + dst4[i] = (float4) so4[j*PV4 + i]*scale; + } } } + +#undef NS10 +#undef NS20 +} + +template< + typename q_t, // query types in shared memory + typename q4_t, + typename q8x8_t, + typename k_t, // key types in shared memory + typename k4x4_t, + typename k8x8_t, + typename v_t, // value types in shared memory + typename v4x4_t, + typename v8x8_t, + typename qk_t, // Q*K types + typename qk8x8_t, + typename s_t, // soft-max types + typename s2_t, + typename s8x8_t, + typename o_t, // attention accumulation types + typename o4_t, + typename o8x8_t, + typename kd4x4_t, // key type in device memory + short nl_k, + void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), + typename vd4x4_t, // value type in device memory + short nl_v, + void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), + short DK, // K head size + short DV, // V head size + short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup +kernel void kernel_flash_attn_ext( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device const char * sinks, + device const char * pad, + device const char * blk, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C +#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg + switch (FC_flash_attn_ext_nsg) { + // note: disabled cases to reduce library load time + //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; + //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; + case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; + } +#undef FWD_TMPL +#undef FWD_ARGS } // TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as // template to be able to explore different combinations // #define FA_TYPES \ - float, float4, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ - half, half4, simdgroup_half8x8 - //float, float4, simdgroup_float8x8 + float, float2, simdgroup_float8x8, \ + float, float4, simdgroup_float8x8 + //half, half4, simdgroup_half8x8 #define FA_TYPES_BF \ bfloat, bfloat4, simdgroup_bfloat8x8, \ bfloat, bfloat4x4, simdgroup_bfloat8x8, \ bfloat, bfloat4x4, simdgroup_bfloat8x8, \ float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ + float, float2, simdgroup_float8x8, \ half, half4, simdgroup_half8x8 //float, float4, simdgroup_float8x8 typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif -template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES #undef FA_TYPES_BF +constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]]; +constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; +constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]]; +constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]]; +constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]]; + +//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]]; +//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]]; +//constant float FC_flash_attn_ext_vec_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 12)]]; + +constant int32_t FC_flash_attn_ext_vec_ns10 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 20)]]; +constant int32_t FC_flash_attn_ext_vec_ns20 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 21)]]; +constant int32_t FC_flash_attn_ext_vec_nsg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 22)]]; +constant int32_t FC_flash_attn_ext_vec_nwg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 23)]]; + template< typename q4_t, // query types in shared memory typename k4_t, // key types in shared memory @@ -4754,60 +5392,89 @@ template< void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), short DK, // K head size short DV, // V head size - short NE = 4, // head elements per thread - short Q = 1, // queries per threadgroup - short C = 32> // cache items per threadgroup -kernel void kernel_flash_attn_ext_vec( - constant ggml_metal_kargs_flash_attn_ext & args, + short NE, // head elements per thread + short Q, // queries per threadgroup + short C, // cache items per threadgroup + short NSG> // number of simd groups +void kernel_flash_attn_ext_vec_impl( + constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, device const char * k, device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups + static_assert(DK % 32 == 0, "DK must be divisible by 32"); + static_assert(DV % 32 == 0, "DV must be divisible by 32"); - const int iq3 = tgpig[2]; - const int iq2 = tgpig[1]; - const int iq1 = tgpig[0]; +#define NWG (FC_flash_attn_ext_vec_nwg) + +#define NS10 (FC_flash_attn_ext_vec_ns10) +#define NS20 (FC_flash_attn_ext_vec_ns20) + + const short iwg = tgpig[2]%NWG; + + const ushort iq3 = tgpig[2]/NWG; + const ushort iq2 = tgpig[1]; + const ushort iq1 = tgpig[0]; constexpr short DK4 = DK/4; constexpr short DV4 = DV/4; + + constexpr short PK = PAD2(DK, 128); + constexpr short PK4 = PK/4; + + constexpr short PV = PAD2(DV, 128); + constexpr short PV4 = PV/4; + constexpr short NW = N_SIMDWIDTH; constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads constexpr short SH = 4*C; // shared memory per simdgroup - const short T = DK + nsg*SH; // shared memory size per query in (half) + static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); + static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t - threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask - threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results + const short T = PK + NSG*SH; // shared memory size per query in (half) - // store the result for all queries in local memory (the O matrix from the paper) - o4_t lo[DV4/NL]; + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results + + // store the result for all queries in shared memory (the O matrix from the paper) + so4 += tiisg; + + { + q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += ikv2*args.nb12 + ikv3*args.nb13; + v += ikv2*args.nb22 + ikv3*args.nb23; + } // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q); - for (short i = tiisg; i < DK4; i += NW) { - if (iq1 < args.ne01) { + for (short i = tiisg; i < PK4; i += NW) { + if (iq1 < args.ne01 && i < DK4) { sq4[i] = (q4_t) q4[i]; } else { sq4[i] = (q4_t) 0.0f; } } - // zero out lo + // zero out so for (short i = 0; i < DV4/NL; ++i) { - lo[i] = (o4_t) 0.0f; + so4[i*NL] = (o4_t) 0.0f; } // zero out shared memory SH @@ -4819,28 +5486,19 @@ kernel void kernel_flash_attn_ext_vec( { float S = 0.0f; - float M = -__FLT_MAX__/2; + float M = -FLT_MAX/2; // thread indices inside the simdgroup const short tx = tiisg%NL; const short ty = tiisg/NL; - // broadcast kv - //const short rk2 = args.ne02/args.ne12; - //const short rk3 = args.ne03/args.ne13; - - const short ikv2 = iq2/(args.ne02/args.ne_12_2); - const short ikv3 = iq3/(args.ne03/args.ne_12_3); - - const bool has_mask = mask != q; - // pointer to the mask device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); float slope = 1.0f; // ALiBi - if (args.max_bias > 0.0f) { + if (FC_flash_attn_ext_vec_has_bias) { const short h = iq2; const float base = h < args.n_head_log2 ? args.m0 : args.m1; @@ -4851,13 +5509,39 @@ kernel void kernel_flash_attn_ext_vec( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { - const int ic = ic0 + C*sgitg; + for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) { + int ic = ic0*C; if (ic >= args.ne11) { break; } - if (has_mask) { + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; + + if (!FC_flash_attn_ext_vec_has_mask) { + if (ic + tiisg >= args.ne11) { + sm[tiisg] = -MAXHALF; + } + } else { + pm = (device const half *) (mask) + + iq1*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32); + } + + ic = 0; + } + + if (FC_flash_attn_ext_vec_has_mask) { sm[tiisg] = pm[ic + tiisg]; } @@ -4868,70 +5552,82 @@ kernel void kernel_flash_attn_ext_vec( // Q*K^T { - // each simdgroup processes 1 query and NE (NW/NL) head elements - for (short cc = 0; cc < C/NE; ++cc) { - qk_t mqk = 0.0f; + device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11); + threadgroup const q4_t * pq4 = sq4; - device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + pk4 += ty*NS10/4 + tx; + pq4 += tx; - #pragma unroll(DK4/NL) - for (short ii = 0; ii < DK4; ii += NL) { - const short i = ii + tx; + qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f }; + + // each simdgroup processes 1 query and NE (NW/NL) cache elements + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + if (is_same::value) { + FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { + mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); + } + } else { + device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11)); k4_t mk; - deq_k_t4(pk + i/nl_k, i%nl_k, mk); - // note: this is less precise than the version below - //mqka[0] += dot(mq[0], mk[0]); - //mqka[1] += dot(mq[1], mk[1]); - //mqka[2] += dot(mq[2], mk[2]); - //mqka[3] += dot(mq[3], mk[3]); + FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { + const short i = ii*NL + tx; - //q4x4_t mq = sq4x4[i]; - //mqka[0] += dot((float4) mq[0], (float4) mk[0]); - //mqka[1] += dot((float4) mq[1], (float4) mk[1]); - //mqka[2] += dot((float4) mq[2], (float4) mk[2]); - //mqka[3] += dot((float4) mq[3], (float4) mk[3]); + deq_k_t4(pk + i/nl_k, i%nl_k, mk); - mqk += dot((float4) mk, (float4) sq4[i]); + mqk[cc] += dot((float4) mk, (float4) sq4[i]); + } } - static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails - - // simdgroup reduce (NE = 4) - // [ 0 .. 7] -> [ 0] - // [ 8 .. 15] -> [ 8] - // [16 .. 23] -> [16] - // [24 .. 31] -> [24] - if (NE <= 1) { - mqk += simd_shuffle_down(mqk, 16); - } - if (NE <= 2) { - mqk += simd_shuffle_down(mqk, 8); - } - if (NE <= 4) { - mqk += simd_shuffle_down(mqk, 4); - } - if (NE <= 8) { - mqk += simd_shuffle_down(mqk, 2); - } - if (NE <= 16) { - mqk += simd_shuffle_down(mqk, 1); - } - - // mqk = mqk*scale + mask*slope - if (tx == 0) { - mqk *= args.scale; - - if (args.logit_softcap != 0.0f) { - mqk = args.logit_softcap*precise::tanh(mqk); + if (NE == 1) { + mqk[cc] = simd_sum(mqk[cc]); + } else { + // simdgroup reduce (NE = 4) + // [ 0 .. 7] -> [ 0] + // [ 8 .. 15] -> [ 8] + // [16 .. 23] -> [16] + // [24 .. 31] -> [24] + if (NE <= 1) { + mqk[cc] += simd_shuffle_down(mqk[cc], 16); + } + if (NE <= 2) { + mqk[cc] += simd_shuffle_down(mqk[cc], 8); + } + if (NE <= 4) { + mqk[cc] += simd_shuffle_down(mqk[cc], 4); + } + if (NE <= 8) { + mqk[cc] += simd_shuffle_down(mqk[cc], 2); + } + if (NE <= 16) { + mqk[cc] += simd_shuffle_down(mqk[cc], 1); } - mqk += sm[NE*cc + ty]*slope; - - ss[NE*cc + ty] = mqk; + // broadcast + mqk[cc] = simd_shuffle(mqk[cc], NL*ty); } } + + if (FC_flash_attn_ext_vec_has_mask && + !FC_flash_attn_ext_vec_has_scap && + !FC_flash_attn_ext_vec_has_bias) { + ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]); + } else { + mqk[tx] *= args.scale; + + if (FC_flash_attn_ext_vec_has_scap) { + mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]); + } + + if (FC_flash_attn_ext_vec_has_bias) { + mqk[tx] += (qk_t) sm[NE*tx + ty]*slope; + } else { + mqk[tx] += (qk_t) sm[NE*tx + ty]; + } + + ss[NE*tx + ty] = mqk[tx]; + } } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -4952,9 +5648,10 @@ kernel void kernel_flash_attn_ext_vec( ss[tiisg] = vs; // O = diag(ms)*O - #pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - lo[ii/NL] *= ms; + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] *= ms; + } } } @@ -4962,26 +5659,84 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { - //#pragma unroll(C/NE) - for (short cc = 0; cc < C/NE; ++cc) { - device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + o4_t lo[DV4/NL]; + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + lo[ii] = 0.0f; + } - const s4_t ms(ss[NE*cc + ty]); + if (is_same::value) { + device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21); - #pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - const short i = ii + tx; + pv4 += ty*NS20/4 + tx; - v4_t mv; - deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + const auto sst = ss + ty; - lo[ii/NL] += o4_t(float4(mv)*float4(ms)); + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE])); + } + } + } else { + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21)); + + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + const short i = ii*NL + tx; + + v4_t mv; + deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + + lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty])); + } + } + } + + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + if (NE > 1) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 16); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 16); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 16); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 16); + } + + if (NE > 2) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 8); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 8); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 8); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 8); + } + + if (NE > 4) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 4); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 4); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 4); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 4); + } + + if (NE > 8) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 2); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 2); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 2); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 2); + } + + if (NE > 16) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 1); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 1); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 1); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 1); + } + } + + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] += lo[ii]; } } } } - if (sinks != q && sgitg == 0) { + if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) { const float m = M; const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; @@ -4992,9 +5747,10 @@ kernel void kernel_flash_attn_ext_vec( S = S*ms + simd_sum(vs); -#pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - lo[ii/NL] *= ms; + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] *= ms; + } } } @@ -5005,63 +5761,12 @@ kernel void kernel_flash_attn_ext_vec( } } - // simdgroup reduce (NE = 4) - // [ 0, 8, 16, 24] -> [ 0] - // [ 1, 9, 17, 25] -> [ 1] - // [ 2, 10, 18, 26] -> [ 2] - // [ 3, 11, 19, 27] -> [ 3] - // [ 4, 12, 20, 28] -> [ 4] - // [ 5, 13, 21, 29] -> [ 5] - // [ 6, 14, 22, 30] -> [ 6] - // [ 7, 15, 23, 31] -> [ 7] - for (short ii = 0; ii < DV4; ii += NL) { - if (NE > 1) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); - } - - if (NE > 2) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); - } - - if (NE > 4) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); - } - - if (NE > 8) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); - } - - if (NE > 16) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // store results to shared memory - for (short i = tiisg; i < DV4; i += NL) { - sr4[i] = lo[i/NL]; - } + so4 -= tiisg; threadgroup_barrier(mem_flags::mem_threadgroup); // parallel reduce - for (short r = nsg/2; r > 0; r >>= 1) { + for (short r = NSG/2; r > 0; r >>= 1) { if (sgitg < r) { const float S0 = ss[ 0]; const float S1 = ss[r*(SH/2) + 0]; @@ -5083,23 +5788,87 @@ kernel void kernel_flash_attn_ext_vec( // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 for (short i = tiisg; i < DV4; i += NW) { - sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1; + so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1; } } threadgroup_barrier(mem_flags::mem_threadgroup); } - device float4 * dst4 = (device float4 *) dst; - // final rescale with 1/S and store to global memory if (sgitg == 0) { - const float S = ss[0]; + const int64_t nrows = args.ne3*args.ne2*args.ne1; + const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1; + device float4 * dst4 = (device float4 *) dst; + device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results + + const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f; + + // interleave the workgroup data for (short i = tiisg; i < DV4; i += NW) { - dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S; + dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S; + } + + // store S and M + if (NWG > 1) { + if (tiisg == 0) { + dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0]; + dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1]; + } } } + +#undef NWG +#undef NS10 +#undef NS20 +} + +template< + typename q4_t, // query types in shared memory + typename k4_t, // key types in shared memory + typename v4_t, // value types in shared memory + typename qk_t, // Q*K types + typename s_t, // soft-max types + typename s4_t, + typename o4_t, // attention accumulation types + typename kd4_t, // key type in device memory + short nl_k, + void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), + typename vd4_t, // value type in device memory + short nl_v, + void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), + short DK, // K head size + short DV, // V head size + short NE = 4, // head elements per thread + short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( + constant ggml_metal_kargs_flash_attn_ext_vec & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device const char * sinks, + device const char * pad, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C +#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg + switch (FC_flash_attn_ext_vec_nsg) { + // note: disabled cases to reduce library load time + case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + case 2: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + case 4: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 8: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 16: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 32: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + } +#undef FWD_TMPL +#undef FWD_ARGS } // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem @@ -5115,126 +5884,135 @@ kernel void kernel_flash_attn_ext_vec( typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; -template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #undef FA_TYPES -template -kernel void kernel_set( - constant ggml_metal_kargs_set & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i13 = tgpig[2]; - const int i12 = tgpig[1]; - const int i11 = tgpig[0]; +constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]]; +constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]]; - const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10; +kernel void kernel_flash_attn_ext_vec_reduce( + constant ggml_metal_kargs_flash_attn_ext_vec_reduce & args, + device const char * htmp, + device char * dst, + uint tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define NWG (FC_flash_attn_ext_vec_reduce_NWG) +#define DV (FC_flash_attn_ext_vec_reduce_DV) - const int64_t i3 = n / (args.ne12*args.ne11*args.ne10); - const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10); - const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10; + const uint64_t rid = tgpig; - device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs); + const short iwg = tiisg; - for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) { - device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10); - dst_data[i10] = (T) src[0]; + device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*NWG; + + float S = ss[rid*(2*NWG) + 2*iwg + 0]; + float M = ss[rid*(2*NWG) + 2*iwg + 1]; + + const float m = simd_max(M); + const float ms = exp(M - m); + + S = simd_sum(S*ms); + S = S == 0.0f ? 0.0f : 1.0f/S; + + const short DV4 = DV/4; + + device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG; + device float4 * dst4 = (device float4 *) dst + rid*DV4; + + for (short i = sgitg; i < DV4; i += NWG) { + const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms); + + if (iwg == 0) { + dst4[i] = v*S; + } } + +#undef NWG +#undef DV } -typedef decltype(kernel_set) kernel_set_t; - -template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set; -template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set; - template -kernel void kernel_cpy( +kernel void kernel_cpy_t_t( constant ggml_metal_kargs_cpy & args, device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 tptg[[threads_per_threadgroup]]) { + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x; - - if (i01 >= args.ne01) { - return; - } + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -5245,188 +6023,70 @@ kernel void kernel_cpy( device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; + break; } } -typedef decltype(kernel_cpy) kernel_cpy_t; +typedef decltype(kernel_cpy_t_t) kernel_cpy_t; -template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -// TODO: templetify these kernels -kernel void kernel_cpy_f32_q8_0( +template +kernel void kernel_cpy_f32_q( constant ggml_metal_kargs_cpy & args, device const char * src0, - device char * dst, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; - device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); - quantize_q8_0(src, dst_data[i00/QK8_0]); + quantize_func(src, dst_data[i00]); + + break; } } -kernel void kernel_cpy_f32_q4_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; +typedef decltype(kernel_cpy_f32_q) cpy_f_q_t; - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0; - - device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_0(src, dst_data[i00/QK4_0]); - } -} - -kernel void kernel_cpy_f32_q4_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1; - - device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_1(src, dst_data[i00/QK4_1]); - } -} - -kernel void kernel_cpy_f32_q5_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0; - - device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_0(src, dst_data[i00/QK5_0]); - } -} - -kernel void kernel_cpy_f32_q5_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1; - - device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_1(src, dst_data[i00/QK5_1]); - } -} - -kernel void kernel_cpy_f32_iq4_nl( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL; - - device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_iq4_nl(src, dst_data[i00/QK4_NL]); - } -} +template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q; template kernel void kernel_cpy_q_f32( @@ -5434,11 +6094,12 @@ kernel void kernel_cpy_q_f32( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -5450,10 +6111,12 @@ kernel void kernel_cpy_q_f32( device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { T4x4 temp; dequantize_func(src_data + i00/nl, i00%nl, temp); dst_data[i00] = temp; + + break; } } @@ -5502,7 +6165,7 @@ kernel void kernel_concat( } } -template +template void kernel_mul_mv_q2_K_f32_impl( args_t args, device const char * src0, @@ -5512,13 +6175,15 @@ void kernel_mul_mv_q2_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5602,10 +6267,10 @@ kernel void kernel_mul_mv_q2_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q3_K_f32_impl( args_t args, device const char * src0, @@ -5615,6 +6280,7 @@ void kernel_mul_mv_q3_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; @@ -5622,7 +6288,7 @@ void kernel_mul_mv_q3_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5766,10 +6432,10 @@ kernel void kernel_mul_mv_q3_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q4_K_f32_impl( args_t args, device const char * src0, @@ -5779,9 +6445,11 @@ void kernel_mul_mv_q4_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; + const short NSG = FC_mul_mv_nsg; + + constexpr uint16_t kmask1 = 0x3f3f; + constexpr uint16_t kmask2 = 0x0f0f; + constexpr uint16_t kmask3 = 0xc0c0; const short ix = tiisg/8; // 0...3 const short it = tiisg%8; // 0...7 @@ -5794,7 +6462,7 @@ void kernel_mul_mv_q4_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5840,7 +6508,7 @@ void kernel_mul_mv_q4_K_f32_impl( float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (short i = 0; i < 4; ++i) { + FOR_UNROLL (short i = 0; i < 4; ++i) { acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F); acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00); acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0); @@ -5851,14 +6519,11 @@ void kernel_mul_mv_q4_K_f32_impl( acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000); } - float dall = dh[0]; - float dmin = dh[1]; - - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + - (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + - (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + - (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - + dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); q1 += args.nb01/2; sc += args.nb01/2; @@ -5888,10 +6553,10 @@ kernel void kernel_mul_mv_q4_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q5_K_f32_impl( args_t args, device const char * src0, @@ -5901,6 +6566,7 @@ void kernel_mul_mv_q5_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; @@ -5908,7 +6574,7 @@ void kernel_mul_mv_q5_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5923,9 +6589,9 @@ void kernel_mul_mv_q5_K_f32_impl( float yl[16], yh[16]; - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; + constexpr uint16_t kmask1 = 0x3f3f; + constexpr uint16_t kmask2 = 0x0f0f; + constexpr uint16_t kmask3 = 0xc0c0; const short tid = tiisg/4; const short ix = tiisg%4; @@ -5971,7 +6637,7 @@ void kernel_mul_mv_q5_K_f32_impl( float4 acc1 = {0.f}; float4 acc2 = {0.f}; - for (short l = 0; l < 8; ++l) { + FOR_UNROLL (short l = 0; l < 8; ++l) { uint8_t h = qh[l]; acc1[0] += yl[l+0] * (q1[l] & 0x0F); acc1[1] += yl[l+8] * (q1[l] & 0xF0); @@ -5982,13 +6648,12 @@ void kernel_mul_mv_q5_K_f32_impl( acc2[2] += h & hm3 ? yh[l+0] : 0.f; acc2[3] += h & hm4 ? yh[l+8] : 0.f; } - const float dall = dh[0]; - const float dmin = dh[1]; - sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + - sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + - sc8[4] * (acc1[2] + 16.f*acc2[2]) + - sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + sumf[row] += dh[0] * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + + sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + + sc8[4] * (acc1[2] + 16.f*acc2[2]) + + sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - + dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); q1 += args.nb01; qh += args.nb01; @@ -6019,10 +6684,10 @@ kernel void kernel_mul_mv_q5_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q6_K_f32_impl( args_t args, device const char * src0, @@ -6032,11 +6697,12 @@ void kernel_mul_mv_q6_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; - const uint8_t kmask1 = 0x03; - const uint8_t kmask2 = 0x0C; - const uint8_t kmask3 = 0x30; - const uint8_t kmask4 = 0xC0; + constexpr uint8_t kmask1 = 0x03; + constexpr uint8_t kmask2 = 0x0C; + constexpr uint8_t kmask3 = 0x30; + constexpr uint8_t kmask4 = 0xC0; const int nb = args.ne00/QK_K; @@ -6044,7 +6710,7 @@ void kernel_mul_mv_q6_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6087,18 +6753,16 @@ void kernel_mul_mv_q6_K_f32_impl( } for (short row = 0; row < nr0; ++row) { - const float dall = dh[0]; - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (short l = 0; l < 4; ++l) { + FOR_UNROLL (short l = 0; l < 4; ++l) { sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); } - sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); q1 += args.nb01; q2 += args.nb01; @@ -6128,12 +6792,12 @@ kernel void kernel_mul_mv_q6_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit -template +template void kernel_mul_mv_iq2_xxs_f32_impl( args_t args, device const char * src0, @@ -6143,13 +6807,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6236,10 +6902,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_xs_f32_impl( args_t args, device const char * src0, @@ -6249,13 +6915,15 @@ void kernel_mul_mv_iq2_xs_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6353,10 +7021,10 @@ kernel void kernel_mul_mv_iq2_xs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_xxs_f32_impl( args_t args, device const char * src0, @@ -6366,13 +7034,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6463,10 +7133,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_s_f32_impl( args_t args, device const char * src0, @@ -6476,13 +7146,15 @@ void kernel_mul_mv_iq3_s_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6573,10 +7245,10 @@ kernel void kernel_mul_mv_iq3_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_s_f32_impl( args_t args, device const char * src0, @@ -6586,13 +7258,15 @@ void kernel_mul_mv_iq2_s_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6684,10 +7358,10 @@ kernel void kernel_mul_mv_iq2_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq1_s_f32_impl( args_t args, device const char * src0, @@ -6697,13 +7371,15 @@ void kernel_mul_mv_iq1_s_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6781,10 +7457,10 @@ kernel void kernel_mul_mv_iq1_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq1_m_f32_impl( args_t args, device const char * src0, @@ -6794,6 +7470,7 @@ void kernel_mul_mv_iq1_m_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; @@ -6801,7 +7478,7 @@ void kernel_mul_mv_iq1_m_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6889,10 +7566,10 @@ kernel void kernel_mul_mv_iq1_m_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, device const char * src0, @@ -6902,15 +7579,15 @@ void kernel_mul_mv_iq4_nl_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6921,6 +7598,9 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK4_NL; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -6928,24 +7608,25 @@ void kernel_mul_mv_iq4_nl_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK4_NL + it * 8; + device const float * yb = y + ix*QK4_NL + it*8; uint32_t aux32[2]; thread const uint8_t * q8 = (thread const uint8_t *)aux32; float4 qf1, qf2; - for (int ib = ix; ib < nb; ib += 16) { + // [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; row++) { - device const block_iq4_nl & xb = x[row*nb + ib]; + for (short row = 0; row < NR0; row++) { + device const block_iq4_nl & xb = x[row*ns01 + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -6976,7 +7657,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -6995,10 +7676,10 @@ kernel void kernel_mul_mv_iq4_nl_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, device const char * src0, @@ -7008,13 +7689,14 @@ void kernel_mul_mv_iq4_xs_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7025,6 +7707,9 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_K; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/16; // 0 or 1 const short it = tiisg%16; // 0...15 const short ib = it/2; @@ -7034,7 +7719,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -7043,15 +7728,16 @@ void kernel_mul_mv_iq4_xs_f32_impl( float4 qf1, qf2; - for (int ibl = ix; ibl < nb; ibl += 2) { + // [TAG_MUL_MV_WEIRD] + for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; ++row) { - device const block_iq4_xs & xb = x[row*nb + ibl]; + for (short row = 0; row < NR0; ++row) { + device const block_iq4_xs & xb = x[row*ns01 + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -7081,7 +7767,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -7100,10 +7786,10 @@ kernel void kernel_mul_mv_iq4_xs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_mxfp4_f32_impl( args_t args, device const char * src0, @@ -7113,15 +7799,15 @@ void kernel_mul_mv_mxfp4_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_MXFP4; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7132,6 +7818,9 @@ void kernel_mul_mv_mxfp4_f32_impl( device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_MXFP4; + const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -7139,20 +7828,22 @@ void kernel_mul_mv_mxfp4_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK_MXFP4 + it * 8; + device const float * yb = y + ix*QK_MXFP4 + it*8; + + // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster + // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { + device const float4 * y4 = (device const float4 *) yb; - for (int ib = ix; ib < nb; ib += 16) { - device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; -#pragma unroll(nr0) - for (short row = 0; row < nr0; row++) { - device const block_mxfp4 & xb = x[row*nb + ib]; + FOR_UNROLL (short row = 0; row < NR0; row++) { + device const block_mxfp4 & xb = x[row*ns01 + ib]; device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it); float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]); @@ -7170,7 +7861,7 @@ void kernel_mul_mv_mxfp4_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -7189,76 +7880,70 @@ kernel void kernel_mul_mv_mxfp4_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template kernel void kernel_get_rows_q( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) { + auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); + + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { float4x4 temp; - dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp; + dequantize_func(psrc + ind/nl, ind%nl, temp); + pdst[ind] = temp; + + break; } } -template +template kernel void kernel_get_rows_f( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; + auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); + + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { + pdst[ind] = psrc[ind]; + + break; } } -kernel void kernel_get_rows_i32( - constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device int32_t * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; - } -} - -template +template kernel void kernel_set_rows_q32( constant ggml_metal_kargs_set_rows & args, device const void * src0, @@ -7279,7 +7964,7 @@ kernel void kernel_set_rows_q32( } const int32_t i10 = i01; - const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; + const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); @@ -7289,7 +7974,7 @@ kernel void kernel_set_rows_q32( } } -template +template kernel void kernel_set_rows_f( constant ggml_metal_kargs_set_rows & args, device const void * src0, @@ -7310,9 +7995,9 @@ kernel void kernel_set_rows_f( } const int32_t i10 = i01; - const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; + const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; - device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); + device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) { @@ -7320,6 +8005,9 @@ kernel void kernel_set_rows_f( } } +constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; +constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; + #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B #define BLOCK_SIZE_K 32 @@ -7332,7 +8020,7 @@ kernel void kernel_set_rows_f( #define SG_MAT_ROW 8 // each block_q contains 16*nl weights -template +template kernel void kernel_mul_mm( constant ggml_metal_kargs_mul_mm & args, device const char * src0, @@ -7343,8 +8031,8 @@ kernel void kernel_mul_mm( ushort tiitg[[thread_index_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup T * sa = (threadgroup T *)(shmem); - threadgroup float * sb = (threadgroup float *)(shmem + 4096); + threadgroup S0 * sa = (threadgroup S0 *)(shmem); + threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); const int r0 = tgpig.y; const int r1 = tgpig.x; @@ -7358,8 +8046,9 @@ kernel void kernel_mul_mm( const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_T8x8 ma[4]; - simdgroup_float8x8 mb[2]; + S0_8x8 ma[4]; + S1_8x8 mb[2]; + simdgroup_float8x8 mc[8]; for (short i = 0; i < 8; i++){ @@ -7377,27 +8066,45 @@ kernel void kernel_mul_mm( device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; - device const float * y = (device const float *)(src1 + const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)); + + device const T1 * y = (device const T1 *)(src1 + args.nb13*i13 + args.nb12*i12 + args.nb11*(r1*BLOCK_SIZE_N + thread_col) - + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + args.nb10*iy); for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory - T4x4 temp_a; - dequantize_func(x, il, temp_a); + if (is_same::value && FC_mul_mm_bc_inp) { + threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup_barrier(mem_flags::mem_threadgroup); + // no need for dequantization + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0; + } + } else { + S0_4x4 temp_a; + dequantize_func(x, il, temp_a); - #pragma unroll(16) - for (short i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ - + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ - + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + } } - *(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y); + if (FC_mul_mm_bc_inp) { + for (short i = 0; i < 8; ++i) { + sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0; + } + } else { + *(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y)); + } il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; @@ -7406,23 +8113,25 @@ kernel void kernel_mul_mm( threadgroup_barrier(mem_flags::mem_threadgroup); // load matrices from threadgroup memory and conduct outer products - threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); - threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); + threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); #pragma unroll(4) for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(4) for (short i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } - simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) for (short i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(8) for (short i = 0; i < 8; i++){ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); @@ -7433,7 +8142,8 @@ kernel void kernel_mul_mm( } } - if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) { + if (!FC_mul_mm_bc_out || ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1)) { + // if no bounds checks on the output are needed, we can directly write to device memory device float * C = (device float *) dst + (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; @@ -7474,124 +8184,111 @@ kernel void kernel_mul_mm( } } -template +template // n_expert_used kernel void kernel_mul_mm_id_map0( constant ggml_metal_kargs_mul_mm_id_map0 & args, - device const char * src1, device const char * src2, - device char * hsrc1, device char * htpe, device char * hids, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int ide = tgpig[0]; // expert id + threadgroup char * shmem [[threadgroup(0)]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort ntg[[threads_per_threadgroup]]) { + const short ide = tpitg; // expert id - int n_all = 0; + uint32_t n_all = 0; - device int32_t * ids_i32 = (device int32_t *) (hids); + device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21; - for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens - device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21); + for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens + if (i21 + tpitg < args.ne21) { + device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21); - for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used - if (src2_i32[i20] != ide) { - continue; + threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20; + + #pragma unroll(ne20) + for (short i20 = 0; i20 < ne20; i20++) { + sids[i20] = src2_i32[i20]; } - - device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11); - device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11); - - for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) { - hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]); - } - - if (tpitg.x == 0) { - ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all; - } - - ++n_all; } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short t = 0; t < ntg; t++) { + if (i21 + t >= args.ne21) { + break; + } + + threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20; + + short sel = 0; + #pragma unroll(ne20) + for (short i20 = 0; i20 < ne20; i20++) { + sel += (sids[i20] == ide)*(i20 + 1); + } + + ids_i32[n_all] = (i21 + t)*ne20 + sel - 1; + + n_all += sel > 0; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); } - if (tpitg.x == 0) { - device int32_t * tpe_i32 = (device int32_t *) (htpe); - tpe_i32[ide] = n_all; - } + device uint32_t * tpe_u32 = (device uint32_t *) (htpe); + tpe_u32[ide] = n_all; } -typedef decltype(kernel_mul_mm_id_map0) kernel_mul_mm_id_map0_t; +typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t; -template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0; +template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>; +template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>; +template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>; +template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>; +template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; +template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; +template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>; -template -kernel void kernel_mul_mm_id_map1( - constant ggml_metal_kargs_mul_mm_id_map1 & args, - device const char * hdst, - device const char * hids, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i20 = tgpig[0]; // used expert - const int i21 = tgpig[1]; // token - - device const int32_t * ids_i32 = (device const int32_t *) (hids); - device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2); - - const int id = ids_i32[i21*args.ne20 + i20]; - - const int ide = id / args.neh1; - const int idt = id % args.neh1; - - device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2); - - for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) { - dst_f32x4[i0] = hdst_f32x4[i0]; - } -} - -typedef decltype(kernel_mul_mm_id_map1) kernel_mul_mm_id_map1_t; - -template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1; - -template +template kernel void kernel_mul_mm_id( constant ggml_metal_kargs_mul_mm_id & args, device const char * src0, device const char * src1, - device const char * tpe, + device const char * htpe, + device const char * hids, device char * dst, threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup T * sa = (threadgroup T *)(shmem); - threadgroup half * sb = (threadgroup half *)(shmem + 4096); + threadgroup S0 * sa = (threadgroup S0 *)(shmem); + threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); const int r0 = tgpig.y; const int r1 = tgpig.x; - const int im = tgpig.z; + const int im = tgpig.z; // expert - device const int32_t * tpe_i32 = (device const int32_t *) (tpe); + device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe); + device const int32_t * ids_i32 = (device const int32_t *) (hids); - const int neh1 = tpe_i32[im]; + const int32_t neh1 = tpe_u32[im]; if (r1*BLOCK_SIZE_N >= neh1) { return; } // if this block is of 64x32 shape or smaller - const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; - const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; + const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; // a thread shouldn't load data outside of the matrix const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_T8x8 ma[4]; - simdgroup_half8x8 mb[2]; + S0_8x8 ma[4]; + S1_8x8 mb[2]; + simdgroup_float8x8 mc[8]; for (short i = 0; i < 8; i++){ @@ -7600,36 +8297,57 @@ kernel void kernel_mul_mm_id( short il = (tiitg % THREAD_PER_ROW); - const int i12 = im%args.neh12; - const int i13 = im/args.neh12; + const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col]; - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const short i11 = (id % args.ne20) % args.ne11; + const short i12 = (id / args.ne20); + const short i13 = 0; + + const uint64_t offset0 = im*args.nb02 + i13*args.nb03; const short offset1 = il/nl; device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; - device const half * y = (device const half *)(src1 - + args.nbh13*i13 - + args.nbh12*i12 - + args.nbh11*(r1*BLOCK_SIZE_N + thread_col) - + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)); + + device const T1 * y = (device const T1 *)(src1 + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*i11 + + args.nb10*iy); for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory - T4x4 temp_a; - dequantize_func(x, il, temp_a); + if (is_same::value && FC_mul_mm_bc_inp) { + threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup_barrier(mem_flags::mem_threadgroup); + // no need for dequantization + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0; + } + } else { + S0_4x4 temp_a; + dequantize_func(x, il, temp_a); - #pragma unroll(16) - for (short i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ - + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ - + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + } } - *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y); + if (FC_mul_mm_bc_inp) { + for (short i = 0; i < 8; ++i) { + sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0; + } + } else { + *(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y)); + } il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; @@ -7638,8 +8356,8 @@ kernel void kernel_mul_mm_id( threadgroup_barrier(mem_flags::mem_threadgroup); // load matrices from threadgroup memory and conduct outer products - threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); - threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); + threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); #pragma unroll(4) for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { @@ -7665,43 +8383,38 @@ kernel void kernel_mul_mm_id( } } - if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) { - device float * C = (device float *) dst + - (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ - (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0; + threadgroup_barrier(mem_flags::mem_threadgroup); - for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0); - } - } else { - // block is smaller than 64x32, we should avoid writing data outside of the matrix - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *) shmem) \ - + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; - for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; + + #pragma unroll(8) + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short j = sgitg; j < n_cols; j += 4) { + const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j]; + + const short ide = id % args.ne20; + const short idt = id / args.ne20; + + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = tiisg; + for (; i < n_rows/4; i += 32) { + *(D4 + i) = *(C4 + i); } - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (sgitg == 0) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0; - device float4 * D4 = (device float4 *) D; - - threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); - threadgroup float4 * C4 = (threadgroup float4 *) C; - - int i = 0; - for (; i < n_rows/4; i++) { - *(D4 + i) = *(C4 + i); - } - - i *= 4; - for (; i < n_rows; i++) { - *(D + i) = *(C + i); - } - } + i = (4*(n_rows/4)) + tiisg; + for (; i < n_rows; i += 32) { + *(D + i) = *(C + i); } } } @@ -7712,12 +8425,13 @@ kernel void kernel_mul_mm_id( // get rows // -typedef decltype(kernel_get_rows_f) get_rows_f_t; +typedef decltype(kernel_get_rows_f) get_rows_f_t; -template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; #endif typedef decltype(kernel_get_rows_q) get_rows_q_t; @@ -7747,93 +8461,153 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get // set rows // -typedef decltype(kernel_set_rows_f) set_rows_f_t; +typedef decltype(kernel_set_rows_f) set_rows_f_t; -template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f; -template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f; +template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f; +template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f; +template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f; +template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f; +template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f; #endif -typedef decltype(kernel_set_rows_q32) set_rows_q32_t; +typedef decltype(kernel_set_rows_q32) set_rows_q32_t; -template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32; -template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32; -template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32; -template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32; -template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32; -template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32; // // matrix-matrix multiplication // -typedef decltype(kernel_mul_mm) mul_mm_t; +typedef decltype(kernel_mul_mm) mul_mm_t; -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; #endif -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm; + +template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm; +#endif +template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication // -typedef decltype(kernel_mul_mm_id) mul_mm_id; +typedef decltype(kernel_mul_mm_id) mul_mm_id; -template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id; #endif -template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +#endif +template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; // // matrix-vector multiplication // -typedef void (kernel_mul_mv_impl_t)( +typedef void (kernel_mul_mv_disp_t)( ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, @@ -7841,7 +8615,7 @@ typedef void (kernel_mul_mv_impl_t)( uint3 tgpig, ushort tiisg); -typedef void (kernel_mul_mv2_impl_t)( +typedef void (kernel_mul_mv2_disp_t)( ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, @@ -7851,7 +8625,7 @@ typedef void (kernel_mul_mv2_impl_t)( ushort tiisg, ushort sgitg); -template +template void mmv_fn( ggml_metal_kargs_mul_mv args, device const char * src0, @@ -7862,10 +8636,10 @@ void mmv_fn( ushort tiitg, ushort tiisg, ushort sgitg) { - impl_fn(args, src0, src1, dst, tgpig, tiisg); + disp_fn(args, src0, src1, dst, tgpig, tiisg); } -template +template void mmv_fn( ggml_metal_kargs_mul_mv args, device const char * src0, @@ -7876,12 +8650,12 @@ void mmv_fn( ushort tiitg, ushort tiisg, ushort sgitg) { - impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(mmv_fn>) mul_mv_impl_fn_t; +typedef decltype(mmv_fn>) mul_mv_disp_fn_t; -template +template kernel void kernel_mul_mv_id( constant ggml_metal_kargs_mul_mv_id & args, device const char * src0s, @@ -7928,11 +8702,12 @@ kernel void kernel_mul_mv_id( /*.nb13 =*/ args.nb12, // ne12 == 1 /*.ne0 =*/ args.ne0, /*.ne1 =*/ 1, // args.ne1, + /*.nr0 =*/ args.nr0, /*.r2 =*/ 1, /*.r3 =*/ 1, }; - impl_fn( + disp_fn( args0, /* src0 */ src0_cur, /* src1 */ src1_cur, @@ -7944,44 +8719,52 @@ kernel void kernel_mul_mv_id( sgitg); } -typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; -template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_4_t; + +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +#endif +template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; #endif -template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; kernel void kernel_pool_2d_max_f32( + constant ggml_metal_kargs_pool_2d & args, device const float * src0, device float * dst, - constant ggml_metal_kargs_pool_2d & args, uint gid[[thread_position_in_grid]]) { - if (gid >= args.parallel_elements) { + if (gid >= args.np) { return; } @@ -8014,12 +8797,12 @@ kernel void kernel_pool_2d_max_f32( } kernel void kernel_pool_2d_avg_f32( + constant ggml_metal_kargs_pool_2d & args, device const float * src0, device float * dst, - constant ggml_metal_kargs_pool_2d & args, uint gid[[thread_position_in_grid]]) { - if (gid >= args.parallel_elements) { + if (gid >= args.np) { return; } @@ -8053,3 +8836,51 @@ kernel void kernel_pool_2d_avg_f32( o_ptr[cur_oh * args.OW + cur_ow] = res; } + +kernel void kernel_opt_step_adamw_f32( + constant ggml_metal_kargs_opt_step_adamw & args, + device float * x, + device const float * g, + device float * g_m, + device float * g_v, + device const float * pars, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.np) { + return; + } + + const float alpha = pars[0]; + const float beta1 = pars[1]; + const float beta2 = pars[2]; + const float eps = pars[3]; + const float wd = pars[4]; + const float beta1h = pars[5]; + const float beta2h = pars[6]; + + const float gi = g[gid]; + const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1); + const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2); + + g_m[gid] = gmi; + g_v[gid] = gvi; + + const float mh = gmi * beta1h; + const float vh = sqrt(gvi * beta2h) + eps; + + x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh; +} + +kernel void kernel_opt_step_sgd_f32( + constant ggml_metal_kargs_opt_step_sgd & args, + device float * x, + device const float * g, + device const float * pars, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.np) { + return; + } + + x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid]; +} diff --git a/ml/backend/ggml/ggml/src/ggml-metal/metal.go b/ml/backend/ggml/ggml/src/ggml-metal/metal.go index bf20ab7f..48fa75b9 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/metal.go +++ b/ml/backend/ggml/ggml/src/ggml-metal/metal.go @@ -4,6 +4,7 @@ package metal //go:generate sh -c "{ echo // Code generated by 'go generate'. DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal" -// #cgo CPPFLAGS: -DGGML_METAL_NDEBUG -DGGML_METAL_EMBED_LIBRARY -DGGML_METAL_USE_BF16 -I.. -I../../include +// #cgo CXXFLAGS: -std=c++17 +// #cgo CPPFLAGS: -DGGML_METAL_NDEBUG -DGGML_METAL_EMBED_LIBRARY -DGGML_METAL_HAS_BF16 -I.. -I../../include // #cgo LDFLAGS: -framework Metal -framework MetalKit import "C" diff --git a/ml/backend/ggml/ggml/src/ggml-opt.cpp b/ml/backend/ggml/ggml/src/ggml-opt.cpp index a3c82d67..e078ad14 100644 --- a/ml/backend/ggml/ggml/src/ggml-opt.cpp +++ b/ml/backend/ggml/ggml/src/ggml-opt.cpp @@ -64,9 +64,11 @@ struct ggml_opt_context { int32_t opt_i = 0; bool loss_per_datapoint = false; - ggml_opt_get_optimizer_params get_opt_pars = nullptr; - void * get_opt_pars_ud = nullptr; - struct ggml_tensor * adamw_params = nullptr; + ggml_opt_get_optimizer_params get_opt_pars = nullptr; + void * get_opt_pars_ud = nullptr; + struct ggml_tensor * opt_step_params = nullptr; // Stores output of get_opt_pars. + + enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW; }; struct ggml_opt_result { @@ -229,9 +231,13 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us result.adamw.eps = 1e-8f; result.adamw.wd = 0.0f; + result.sgd.alpha = 1e-3f; + result.sgd.wd = 0.0f; + return result; } + struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) { return *((struct ggml_opt_optimizer_params *) userdata); } @@ -249,6 +255,7 @@ struct ggml_opt_params ggml_opt_default_params( /*opt_period =*/ 1, /*get_opt_pars =*/ ggml_opt_get_default_optimizer_params, /*get_opt_pars_ud =*/ nullptr, + /*optimizer =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW, }; } @@ -316,9 +323,14 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) { GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc"); GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically"); + const enum ggml_opt_optimizer_type optimizer = opt_ctx->optimizer; + const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD && !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1); + const bool need_momenta = opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && + opt_ctx->optimizer == GGML_OPT_OPTIMIZER_TYPE_ADAMW; + ggml_set_input(opt_ctx->inputs); ggml_set_output(opt_ctx->outputs); @@ -340,8 +352,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) { // - pred (if using static graphs) // - ncorrect (if using static graphs, 2 tensors). constexpr size_t n_loss = 1; - const size_t tensors_per_param = (accumulate ? 1 : 0) + - (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0); + const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0); const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0; const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead(); struct ggml_init_params params = { @@ -458,7 +469,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) { } } - if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) { + if (need_momenta && opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) { opt_ctx->grad_m.resize(n_nodes); opt_ctx->grad_v.resize(n_nodes); for (int i = 0; i < n_nodes; ++i) { @@ -492,23 +503,36 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) { // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true); - opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7); - ggml_set_input(opt_ctx->adamw_params); - ggml_set_name(opt_ctx->adamw_params, "adamw_params"); - + opt_ctx->opt_step_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, need_momenta ? 7 : 2); + ggml_tensor * adamw_params = opt_ctx->opt_step_params; + ggml_set_input(adamw_params); + const char * optimizer_name = ggml_opt_optimizer_name(opt_ctx->optimizer); + ggml_format_name(adamw_params, "%s_params", optimizer_name); for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) { struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i]; struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node); if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) { - struct ggml_tensor * m = opt_ctx->grad_m[i]; - struct ggml_tensor * v = opt_ctx->grad_v[i]; - struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params); - - ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str()); - ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str()); - ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str()); - + struct ggml_tensor * m = nullptr; + struct ggml_tensor * v = nullptr; + if (need_momenta) { + m = opt_ctx->grad_m[i]; + v = opt_ctx->grad_v[i]; + ggml_format_name(m, "AdamW m for %s", node->name); + ggml_format_name(v, "AdamW v for %s", node->name); + } + struct ggml_tensor * opt_step; + switch (optimizer) { + case GGML_OPT_OPTIMIZER_TYPE_ADAMW: + opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params); + break; + case GGML_OPT_OPTIMIZER_TYPE_SGD: + opt_step = ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params); + break; + default: + GGML_ABORT("fatal error"); + } + ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name); ggml_build_forward_expand(opt_ctx->gb_opt, opt_step); } } @@ -534,6 +558,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { result->opt_period = params.opt_period; result->get_opt_pars = params.get_opt_pars; result->get_opt_pars_ud = params.get_opt_pars_ud; + result->optimizer = params.optimizer; GGML_ASSERT(result->opt_period >= 1); @@ -756,29 +781,43 @@ void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) { void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) { GGML_ASSERT(opt_ctx->eval_ready); if (opt_ctx->allocated_graph == opt_ctx->gb_opt) { - struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud); + const ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud); - GGML_ASSERT(opt_pars.adamw.alpha > 0.0f); - GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f); - GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f); - GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f); - GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f); - GGML_ASSERT(opt_pars.adamw.eps >= 0.0f); - GGML_ASSERT(opt_pars.adamw.wd >= 0.0f); - GGML_ASSERT(opt_pars.adamw.wd <= 1.0f); + switch (opt_ctx->optimizer) { + case GGML_OPT_OPTIMIZER_TYPE_ADAMW: { + GGML_ASSERT(opt_pars.adamw.alpha > 0.0f); + GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f); + GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f); + GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f); + GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f); + GGML_ASSERT(opt_pars.adamw.eps >= 0.0f); + GGML_ASSERT(opt_pars.adamw.wd >= 0.0f); + GGML_ASSERT(opt_pars.adamw.wd <= 1.0f); - // beta1, beta2 after applying warmup - const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter)); - const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter)); + // beta1, beta2 after applying warmup + const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter)); + const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter)); - float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params); - adamw_par_data[0] = opt_pars.adamw.alpha; - adamw_par_data[1] = opt_pars.adamw.beta1; - adamw_par_data[2] = opt_pars.adamw.beta2; - adamw_par_data[3] = opt_pars.adamw.eps; - adamw_par_data[4] = opt_pars.adamw.wd; - adamw_par_data[5] = beta1h; - adamw_par_data[6] = beta2h; + float * adamw_par_data = ggml_get_data_f32(opt_ctx->opt_step_params); + adamw_par_data[0] = opt_pars.adamw.alpha; + adamw_par_data[1] = opt_pars.adamw.beta1; + adamw_par_data[2] = opt_pars.adamw.beta2; + adamw_par_data[3] = opt_pars.adamw.eps; + adamw_par_data[4] = opt_pars.adamw.wd; + adamw_par_data[5] = beta1h; + adamw_par_data[6] = beta2h; + } break; + case GGML_OPT_OPTIMIZER_TYPE_SGD: { + GGML_ASSERT(opt_pars.sgd.alpha > 0.0f); + GGML_ASSERT(opt_pars.sgd.wd >= 0.0f); + GGML_ASSERT(opt_pars.sgd.wd <= 1.0f); + float * sgd = ggml_get_data_f32(opt_ctx->opt_step_params); + sgd[0] = opt_pars.sgd.alpha; + sgd[1] = opt_pars.sgd.wd; + } break; + default: + GGML_ABORT("fatal error"); + } } ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); @@ -963,6 +1002,7 @@ void ggml_opt_fit( ggml_tensor * outputs, ggml_opt_dataset_t dataset, enum ggml_opt_loss_type loss_type, + enum ggml_opt_optimizer_type optimizer, ggml_opt_get_optimizer_params get_opt_pars, int64_t nepoch, int64_t nbatch_logical, @@ -993,6 +1033,7 @@ void ggml_opt_fit( params.opt_period = opt_period; params.get_opt_pars = get_opt_pars; params.get_opt_pars_ud = &epoch; + params.optimizer = optimizer; ggml_opt_context_t opt_ctx = ggml_opt_init(params); // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch. @@ -1035,3 +1076,18 @@ void ggml_opt_fit( ggml_opt_result_free(result_train); ggml_opt_result_free(result_val); } + +enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t c) { + return c->optimizer; +} + +GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type o) { + switch (o) { + case GGML_OPT_OPTIMIZER_TYPE_ADAMW: + return "adamw"; + case GGML_OPT_OPTIMIZER_TYPE_SGD: + return "sgd"; + default: + return "undefined"; + }; +} diff --git a/ml/backend/ggml/ggml/src/ggml-quants.c b/ml/backend/ggml/ggml/src/ggml-quants.c index 94f6405c..de5cbd75 100644 --- a/ml/backend/ggml/ggml/src/ggml-quants.c +++ b/ml/backend/ggml/ggml/src/ggml-quants.c @@ -566,7 +566,7 @@ static float make_q3_quants(int n, int nmax, const float * GGML_RESTRICT x, int8 for (int i = 0; i < n; ++i) { L[i] += nmax; } - return sumlx / suml2; + return suml2 > 0.0f ? sumlx / suml2 : 0.0f; } for (int i = 0; i < n; ++i) { int l = nearest_int(iscale * x[i]); @@ -901,7 +901,7 @@ static float make_qp_quants(int n, int nmax, const float * GGML_RESTRICT x, uint for (int i = 0; i < n; ++i) { max = MAX(max, x[i]); } - if (!max) { // all zero + if (max < GROUP_MAX_EPS) { // all zero for (int i = 0; i < n; ++i) { L[i] = 0; } return 0.f; } @@ -966,7 +966,7 @@ static float make_qp_quants(int n, int nmax, const float * GGML_RESTRICT x, uint break; } } - return sumlx/suml2; + return suml2 > 0.0f ? sumlx / suml2 : 0.0f; } static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int k, const float * GGML_RESTRICT quant_weights) { @@ -3721,6 +3721,7 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * GGML_RESTRICT } float best = 0; float scale = max/(2*kMaxQ-1); + for (int k = 0; k < 8; ++k) is_on_grid[k] = true; for (int is = -15; is <= 15; ++is) { float id = (2*kMaxQ-1+is*0.2f)/max; float this_scale = 1/id; @@ -4266,7 +4267,7 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R sumw[j+1] = sumw[j] + weight[i]; } } - float best_score = -FLT_MIN, scale = max; + float best_score = -FLT_MAX, scale = max; int besti1 = -1, besti2 = -1, best_shift = 0; for (int i1 = 0; i1 <= block_size; ++i1) { for (int i2 = i1; i2 <= block_size; ++i2) { @@ -4442,7 +4443,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R idx[2*j] = j; } qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); - float best_score = -FLT_MIN, scale = max; + float best_score = -FLT_MAX, scale = max; int besti1 = -1, besti2 = -1, best_k = -1; // 0: +, + // 1: +, - diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt new file mode 100644 index 00000000..83a83887 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt @@ -0,0 +1,211 @@ +cmake_minimum_required(VERSION 3.19) +cmake_policy(SET CMP0114 NEW) +cmake_policy(SET CMP0116 NEW) + +find_package(Vulkan COMPONENTS glslc REQUIRED) + +function(detect_host_compiler) + if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH) + find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH) + else() + find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH) + find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH) + endif() + set(HOST_C_COMPILER "${HOST_C_COMPILER}" PARENT_SCOPE) + set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE) +endfunction() + +# Function to test shader extension support +# Parameters: +# EXTENSION_NAME - Name of the extension to test (e.g., "GL_EXT_integer_dot_product") +# TEST_SHADER_FILE - Path to the test shader file +# RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result +function(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE) + execute_process( + COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${TEST_SHADER_FILE}" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error + ) + + if (${glslc_error} MATCHES ".*extension not supported: ${EXTENSION_NAME}.*") + message(STATUS "${EXTENSION_NAME} not supported by glslc") + set(${RESULT_VARIABLE} OFF PARENT_SCOPE) + else() + message(STATUS "${EXTENSION_NAME} supported by glslc") + set(${RESULT_VARIABLE} ON PARENT_SCOPE) + add_compile_definitions(${RESULT_VARIABLE}) + + # Ensure the extension support is forwarded to vulkan-shaders-gen + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON) + set(VULKAN_SHADER_GEN_CMAKE_ARGS "${VULKAN_SHADER_GEN_CMAKE_ARGS}" PARENT_SCOPE) + endif() +endfunction() + +if (Vulkan_FOUND) + message(STATUS "Vulkan found") + + ggml_add_backend_library(ggml-vulkan + ggml-vulkan.cpp + ../../include/ggml-vulkan.h + ) + + set(VULKAN_SHADER_GEN_CMAKE_ARGS "") + + # Test all shader extensions + test_shader_extension_support( + "GL_KHR_cooperative_matrix" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat.comp" + "GGML_VULKAN_COOPMAT_GLSLC_SUPPORT" + ) + + test_shader_extension_support( + "GL_NV_cooperative_matrix2" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2.comp" + "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" + ) + + test_shader_extension_support( + "GL_EXT_integer_dot_product" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp" + "GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT" + ) + + test_shader_extension_support( + "GL_EXT_bfloat16" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/bfloat16.comp" + "GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT" + ) + + target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) + target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + + # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build + # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector + if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0) + endif() + + if (GGML_VULKAN_CHECK_RESULTS) + add_compile_definitions(GGML_VULKAN_CHECK_RESULTS) + endif() + + if (GGML_VULKAN_DEBUG) + add_compile_definitions(GGML_VULKAN_DEBUG) + endif() + + if (GGML_VULKAN_MEMORY_DEBUG) + add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG) + endif() + + if (GGML_VULKAN_SHADER_DEBUG_INFO) + add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DGGML_VULKAN_SHADER_DEBUG_INFO=ON) + endif() + + if (GGML_VULKAN_VALIDATE) + add_compile_definitions(GGML_VULKAN_VALIDATE) + endif() + + if (GGML_VULKAN_RUN_TESTS) + add_compile_definitions(GGML_VULKAN_RUN_TESTS) + endif() + + # Set up toolchain for host compilation whether cross-compiling or not + if (CMAKE_CROSSCOMPILING) + if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN) + set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN}) + else() + detect_host_compiler() + if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER) + message(FATAL_ERROR "Host compiler not found") + else() + message(STATUS "Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}") + endif() + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY) + set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake) + endif() + else() + # For non-cross-compiling, use empty toolchain (use host compiler) + set(HOST_CMAKE_TOOLCHAIN_FILE "") + endif() + + include(ExternalProject) + + if (CMAKE_CROSSCOMPILING) + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}) + message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") + endif() + + ExternalProject_Add( + vulkan-shaders-gen + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/$ + -DCMAKE_INSTALL_BINDIR=. + -DCMAKE_BUILD_TYPE=$ + ${VULKAN_SHADER_GEN_CMAKE_ARGS} + + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config $ + BUILD_ALWAYS TRUE + + # NOTE: When DESTDIR is set using Makefile generators and + # "make install" triggers the build step, vulkan-shaders-gen + # would be installed into the DESTDIR prefix, so it is unset + # to ensure that does not happen. + + INSTALL_COMMAND ${CMAKE_COMMAND} -E env --unset=DESTDIR + ${CMAKE_COMMAND} --install . --config $ + ) + + set (_ggml_vk_host_suffix $,.exe,>) + set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$") + set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}") + set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp") + set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders") + set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv") + + file(GLOB _ggml_vk_shader_files CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.comp") + + # Because external projects do not provide source-level tracking, + # the vulkan-shaders-gen sources need to be explicitly added to + # ensure that changes will cascade into shader re-generation. + + file(GLOB _ggml_vk_shaders_gen_sources + CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.cpp" + "${_ggml_vk_input_dir}/*.h") + + add_custom_command( + OUTPUT ${_ggml_vk_header} + COMMAND ${_ggml_vk_genshaders_cmd} + --output-dir ${_ggml_vk_output_dir} + --target-hpp ${_ggml_vk_header} + DEPENDS ${_ggml_vk_shaders_gen_sources} + vulkan-shaders-gen + COMMENT "Generate vulkan shaders header" + ) + target_sources(ggml-vulkan PRIVATE ${_ggml_vk_header}) + + foreach (file_full ${_ggml_vk_shader_files}) + get_filename_component(file ${file_full} NAME) + set (_ggml_vk_target_cpp "${CMAKE_CURRENT_BINARY_DIR}/${file}.cpp") + + add_custom_command( + OUTPUT ${_ggml_vk_target_cpp} + DEPFILE ${_ggml_vk_target_cpp}.d + COMMAND ${_ggml_vk_genshaders_cmd} + --glslc ${Vulkan_GLSLC_EXECUTABLE} + --source ${file_full} + --output-dir ${_ggml_vk_output_dir} + --target-hpp ${_ggml_vk_header} + --target-cpp ${_ggml_vk_target_cpp} + DEPENDS ${file_full} + ${_ggml_vk_shaders_gen_sources} + vulkan-shaders-gen + COMMENT "Generate vulkan shaders for ${file}" + ) + target_sources(ggml-vulkan PRIVATE ${_ggml_vk_target_cpp}) + endforeach() + +else() + message(WARNING "Vulkan not found") +endif() diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp new file mode 100644 index 00000000..564bc4a7 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -0,0 +1,13904 @@ +#include "ggml-vulkan.h" +#include +#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS) +#include +#include "ggml-cpu.h" +#endif + +// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- +#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1 +// We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE +// to avoid conflicts with applications or other libraries who might use it. +#if VK_HEADER_VERSION >= 301 +namespace vk::detail { class DispatchLoaderDynamic; } +using vk::detail::DispatchLoaderDynamic; +#else +namespace vk { class DispatchLoaderDynamic; } +using vk::DispatchLoaderDynamic; +#endif +DispatchLoaderDynamic & ggml_vk_default_dispatcher(); +#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher() + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +# define NOMINMAX 1 +# include +# define YIELD() YieldProcessor() +#elif defined(__clang__) || defined(__GNUC__) +# if defined(__x86_64__) ||defined(__i386__) +# include +# define YIELD() _mm_pause() +# elif defined(__arm__) || defined(__aarch64__) +# if defined(__clang__) +# include +# define YIELD() __yield() +# else +# define YIELD() asm volatile("yield") +# endif +# endif +#endif + +#if !defined(YIELD) +#define YIELD() +#endif + +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include "ggml-vulkan-shaders.hpp" + +// remove this once it's more widely available in the SDK +#if !defined(VK_KHR_shader_bfloat16) + +#define VK_KHR_shader_bfloat16 1 +#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1 +#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16" +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000) +#define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000) + +typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { + VkStructureType sType; + void* pNext; + VkBool32 shaderBFloat16Type; + VkBool32 shaderBFloat16DotProduct; + VkBool32 shaderBFloat16CooperativeMatrix; +} VkPhysicalDeviceShaderBfloat16FeaturesKHR; +#endif + +#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) +static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } + +#define VK_VENDOR_ID_AMD 0x1002 +#define VK_VENDOR_ID_APPLE 0x106b +#define VK_VENDOR_ID_INTEL 0x8086 +#define VK_VENDOR_ID_NVIDIA 0x10de + +#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 + +#define GGML_VK_MAX_NODES 8192 + +#define MAX_VK_BUFFERS 256 + +#define VK_CHECK(err, msg) \ + do { \ + vk::Result err_ = (err); \ + if (err_ != vk::Result::eSuccess) { \ + fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n", \ + #err, to_string(err_).c_str(), __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + +#ifdef GGML_VULKAN_DEBUG +#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl +#else +#define VK_LOG_DEBUG(msg) ((void) 0) +#endif // GGML_VULKAN_DEBUG + +struct ggml_backend_vk_context; + +#define MAX_PARAMETER_COUNT 12 +// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT. +#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3) + +struct vk_pipeline_struct { + std::string name; + vk::ShaderModule shader_module; + vk::PipelineLayout layout; + vk::Pipeline pipeline; + uint32_t push_constant_size; + uint32_t parameter_count; + std::array wg_denoms; + uint32_t align; + // true if fields have been set by ggml_vk_create_pipeline + bool initialized {}; + // set to true to request the pipeline is compiled after the dryrun + bool needed {}; + // set to true when the shader has been compiled + bool compiled {}; + // number of registers used, extracted from pipeline executable properties + uint32_t register_count {}; +}; + +typedef std::shared_ptr vk_pipeline; +typedef std::weak_ptr vk_pipeline_ref; + +static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline); + +struct vk_matmul_pipeline_struct { + vk_pipeline l, m, s; + vk_pipeline a_l, a_m, a_s; +}; + +typedef std::shared_ptr vk_matmul_pipeline; + +struct vk_matmul_pipeline2 { + vk_matmul_pipeline2() { + f16acc = std::make_shared(); + f32acc = std::make_shared(); + } + vk_matmul_pipeline f32acc; + vk_matmul_pipeline f16acc; +}; + +struct vk_device_struct; +typedef std::shared_ptr vk_device; +typedef std::weak_ptr vk_device_ref; + +struct vk_buffer_struct; +typedef std::shared_ptr vk_buffer; +typedef std::weak_ptr vk_buffer_ref; + +struct ggml_backend_vk_buffer_type_context { + std::string name; + vk_device device; +}; + +struct vk_queue; + +// Stores command pool/buffers. There's an instance of this +// for each (context,queue) pair and for each (device,queue) pair. +struct vk_command_pool { + void init(vk_device& device, vk_queue *q_); + void destroy(vk::Device& device); + + vk::CommandPool pool; + uint32_t cmd_buffer_idx; + std::vector cmd_buffers; + + vk_queue *q; +}; + +// Prevent simultaneous submissions to the same queue. +// This could be per vk_queue if we stopped having two vk_queue structures +// sharing the same vk::Queue. +static std::mutex queue_mutex; + +struct vk_queue { + uint32_t queue_family_index; + vk::Queue queue; + + vk_command_pool cmd_pool; + + vk::PipelineStageFlags stage_flags; + + bool transfer_only; + + // copy everything except the cmd_pool + void copyFrom(vk_queue &other) { + queue_family_index = other.queue_family_index; + queue = other.queue; + stage_flags = other.stage_flags; + transfer_only = other.transfer_only; + } +}; + +static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft); +static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); +static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft); +static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft); +static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor); +static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { + /* .get_name = */ ggml_backend_vk_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_vk_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_vk_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; + +#ifdef GGML_VULKAN_MEMORY_DEBUG +class vk_memory_logger; +#endif +class vk_perf_logger; +static void ggml_vk_destroy_buffer(vk_buffer& buf); + +static constexpr uint32_t mul_mat_vec_max_cols = 8; +static constexpr uint32_t p021_max_gqa_ratio = 8; + +enum vk_device_architecture { + OTHER, + AMD_GCN, + AMD_RDNA1, + AMD_RDNA2, + AMD_RDNA3, + INTEL_XE2, + NVIDIA_PRE_TURING, +}; + +static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { + vk::PhysicalDeviceProperties props = device.getProperties(); + + if (props.vendorID == VK_VENDOR_ID_AMD) { + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool amd_shader_core_properties = false; + bool integer_dot_product = false; + bool subgroup_size_control = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) { + amd_shader_core_properties = true; + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) { + integer_dot_product = true; + } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + subgroup_size_control = true; + } + } + + if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) { + return vk_device_architecture::OTHER; + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + + props2.pNext = &shader_core_props_amd; + shader_core_props_amd.pNext = &integer_dot_props; + integer_dot_props.pNext = &subgroup_size_control_props; + + device.getProperties2(&props2); + + if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) { + return vk_device_architecture::AMD_GCN; + } + if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) { + // RDNA + if (shader_core_props_amd.wavefrontsPerSimd == 20) { + return vk_device_architecture::AMD_RDNA1; + } + if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) { + return vk_device_architecture::AMD_RDNA3; + } + return vk_device_architecture::AMD_RDNA2; + } + } else if (props.vendorID == VK_VENDOR_ID_INTEL) { + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool subgroup_size_control = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + subgroup_size_control = true; + } + } + + if (!subgroup_size_control) { + return vk_device_architecture::OTHER; + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + + props2.pNext = &subgroup_size_control_props; + device.getProperties2(&props2); + + if (subgroup_size_control_props.minSubgroupSize == 16) { + // Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8. + // Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value. + // https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html + // https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html + return vk_device_architecture::INTEL_XE2; + } + } else if (props.vendorID == VK_VENDOR_ID_NVIDIA) { + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool cooperative_matrix = false; + + // Detect "pre-turing" based on lack of coopmat support. + for (const auto& properties : ext_props) { + if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) { + cooperative_matrix = true; + break; + } + } + + if (!cooperative_matrix) { + return vk_device_architecture::NVIDIA_PRE_TURING; + } + } + return vk_device_architecture::OTHER; +} + +enum vk_conv_shapes { + CONV_SHAPE_128x128, + CONV_SHAPE_64x32, + CONV_SHAPE_32x256, + CONV_SHAPE_COUNT, +}; + +enum dmmv_wg_sizes { + DMMV_WG_SIZE_SUBGROUP, + DMMV_WG_SIZE_LARGE, + DMMV_WG_SIZE_COUNT, +}; + +enum FaCodePath { + FA_SCALAR, + FA_COOPMAT1, + FA_COOPMAT2, +}; + +struct vk_fa_pipeline_state { + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc) + : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {} + + uint32_t HSK, HSV; + bool small_rows; + FaCodePath path; + bool aligned; + bool f32acc; + + bool operator<(const vk_fa_pipeline_state &b) const { + return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) < + std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc); + } +}; + +enum shader_reduction_mode { + SHADER_REDUCTION_MODE_SHMEM, + SHADER_REDUCTION_MODE_HYBRID, + SHADER_REDUCTION_MODE_SUBGROUP, + SHADER_REDUCTION_MODE_COUNT, +}; + +static constexpr uint32_t num_argsort_pipelines = 11; +static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); + +struct vk_device_struct { + std::recursive_mutex mutex; + + vk::PhysicalDevice physical_device; + vk::PhysicalDeviceProperties properties; + std::string name; + uint64_t max_memory_allocation_size; + uint64_t max_buffer_size; + uint64_t suballocation_block_size; + bool fp16; + bool bf16; + bool pipeline_robustness; + vk::Device device; + uint32_t vendor_id; + vk::DriverId driver_id; + vk_device_architecture architecture; + vk_queue compute_queue; + vk_queue transfer_queue; + bool single_queue; + uint32_t subgroup_size; + uint32_t shader_core_count; + bool uma; + bool prefer_host_memory; + bool float_controls_rte_fp16; + bool subgroup_arithmetic; + bool subgroup_shuffle; + bool subgroup_ballot; + bool subgroup_clustered; + bool multi_add; + bool shader_int64; + bool buffer_device_address; + + bool add_rms_fusion; + uint32_t partials_binding_alignment; + + bool integer_dot_product; + // 0: default, 1: force mmvq, -1: disable mmvq + int32_t mmvq_mode; + + bool subgroup_size_control; + uint32_t subgroup_min_size; + uint32_t subgroup_max_size; + bool subgroup_require_full_support; + + bool coopmat_support; + bool coopmat_acc_f32_support {}; + bool coopmat_acc_f16_support {}; + bool coopmat_bf16_support {}; + bool coopmat_support_16x16x16_f16acc {}; + bool coopmat_support_16x16x16_f32acc {}; + bool coopmat1_fa_support {}; + uint32_t coopmat_m; + uint32_t coopmat_n; + uint32_t coopmat_k; + + bool coopmat_int_support; + uint32_t coopmat_int_m; + uint32_t coopmat_int_n; + uint32_t coopmat_int_k; + + bool coopmat2; + + bool pipeline_executable_properties_support {}; + + size_t idx; + + bool mul_mat_l[GGML_TYPE_COUNT]; + bool mul_mat_m[GGML_TYPE_COUNT]; + bool mul_mat_s[GGML_TYPE_COUNT]; + bool mul_mat_id_l[GGML_TYPE_COUNT]; + bool mul_mat_id_m[GGML_TYPE_COUNT]; + bool mul_mat_id_s[GGML_TYPE_COUNT]; + + // set to true to indicate that some shaders need to be compiled after the dryrun + bool need_compiles {}; + + vk::DescriptorSetLayout dsl; + + vk_matmul_pipeline pipeline_matmul_f32 {}; + vk_matmul_pipeline pipeline_matmul_f32_f16 {}; + vk_matmul_pipeline pipeline_matmul_bf16 {}; + vk_matmul_pipeline2 pipeline_matmul_f16; + vk_matmul_pipeline2 pipeline_matmul_f16_f32; + + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT]; + + vk_matmul_pipeline pipeline_matmul_id_f32 {}; + vk_matmul_pipeline pipeline_matmul_id_bf16 {}; + vk_matmul_pipeline2 pipeline_matmul_id_f16; + vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; + + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; + + vk_pipeline pipeline_matmul_split_k_reduce; + vk_pipeline pipeline_quantize_q8_1; + vk_pipeline pipeline_quantize_q8_1_x4; + + vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; + vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; + + vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + + vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; + vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; + vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; + vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_acc_f32; + + // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16] + vk_pipeline pipeline_add[2][2][2]; + vk_pipeline pipeline_add_norepeat[2][2][2]; + vk_pipeline pipeline_sub[2][2][2]; + vk_pipeline pipeline_sub_norepeat[2][2][2]; + vk_pipeline pipeline_mul[2][2][2]; + vk_pipeline pipeline_mul_norepeat[2][2][2]; + vk_pipeline pipeline_div[2][2][2]; + vk_pipeline pipeline_div_norepeat[2][2][2]; + vk_pipeline pipeline_add_rms[2][2][2]; + vk_pipeline pipeline_add_rms_norepeat[2][2][2]; + + // indexed by num_additional_fused_ops == num_adds - 1 + vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS]; + vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS]; + + vk_pipeline pipeline_add_id_f32; + + vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; + vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32; + vk_pipeline pipeline_scale_f32; + vk_pipeline pipeline_sqr_f32; + vk_pipeline pipeline_sqrt_f32; + vk_pipeline pipeline_sin_f32; + vk_pipeline pipeline_cos_f32; + vk_pipeline pipeline_clamp_f32; + vk_pipeline pipeline_pad_f32; + vk_pipeline pipeline_roll_f32; + vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; + vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; + vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT]; + vk_pipeline pipeline_norm_f32; + vk_pipeline pipeline_group_norm_f32; + vk_pipeline pipeline_rms_norm_f32; + vk_pipeline pipeline_rms_norm_mul_f32; + vk_pipeline pipeline_rms_norm_partials_f32; + vk_pipeline pipeline_rms_norm_mul_partials_f32; + vk_pipeline pipeline_rms_norm_back_f32; + vk_pipeline pipeline_l2_norm_f32; + + // [src/dst 0=fp32,1=fp16] + vk_pipeline pipeline_exp[2]; + vk_pipeline pipeline_gelu[2]; + vk_pipeline pipeline_gelu_erf[2]; + vk_pipeline pipeline_gelu_quick[2]; + vk_pipeline pipeline_silu[2]; + vk_pipeline pipeline_relu[2]; + vk_pipeline pipeline_tanh[2]; + vk_pipeline pipeline_sigmoid[2]; + vk_pipeline pipeline_hardsigmoid[2]; + vk_pipeline pipeline_hardswish[2]; + + vk_pipeline pipeline_geglu[2]; + vk_pipeline pipeline_reglu[2]; + vk_pipeline pipeline_swiglu[2]; + vk_pipeline pipeline_swiglu_oai[2]; + vk_pipeline pipeline_geglu_erf[2]; + vk_pipeline pipeline_geglu_quick[2]; + + vk_pipeline pipeline_leaky_relu_f32; + vk_pipeline pipeline_silu_back_f32; + vk_pipeline pipeline_diag_mask_inf_f32; + vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; + vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; + vk_pipeline pipeline_soft_max_back_f32; + vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; + vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; + vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; + vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; + vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; + vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_argmax_f32; + vk_pipeline pipeline_count_equal_i32; + vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; + vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; + vk_pipeline pipeline_timestep_embedding_f32; + vk_pipeline pipeline_conv_transpose_1d_f32; + vk_pipeline pipeline_pool2d_f32; + vk_pipeline pipeline_rwkv_wkv6_f32; + vk_pipeline pipeline_rwkv_wkv7_f32; + vk_pipeline pipeline_opt_step_adamw_f32; + vk_pipeline pipeline_opt_step_sgd_f32; + vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT]; + vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; + vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT]; + vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT]; + vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; + vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; + + std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT]; + + vk_pipeline pipeline_flash_attn_split_k_reduce; + + std::vector all_pipelines; + + std::vector> pinned_memory; + + vk::Fence fence; + vk_buffer sync_staging; + + ggml_backend_buffer_type buffer_type; + + bool disable_fusion; + bool disable_host_visible_vidmem; + bool allow_sysmem_fallback; + bool disable_graph_optimize; + +#ifdef GGML_VULKAN_MEMORY_DEBUG + std::unique_ptr memory_logger; +#endif + + // for GGML_VK_PERF_LOGGER + std::unique_ptr perf_logger; + vk::QueryPool query_pool; + int32_t num_queries; + + ~vk_device_struct() { + VK_LOG_DEBUG("destroy device " << name); + + device.destroyFence(fence); + + ggml_vk_destroy_buffer(sync_staging); + + compute_queue.cmd_pool.destroy(device); + transfer_queue.cmd_pool.destroy(device); + + for (auto& pipeline : all_pipelines) { + if (pipeline.expired()) { + continue; + } + + vk_pipeline pl = pipeline.lock(); + ggml_vk_destroy_pipeline(device, pl); + } + all_pipelines.clear(); + + device.destroyDescriptorSetLayout(dsl); + + device.destroy(); + } +}; + +void vk_command_pool::init(vk_device& device, vk_queue *q_) { + cmd_buffer_idx = 0; + q = q_; + + vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index); + pool = device->device.createCommandPool(command_pool_create_info); +} + +void vk_command_pool::destroy(vk::Device& device) { + device.destroyCommandPool(pool); + pool = nullptr; + cmd_buffers.clear(); +} + +struct vk_buffer_struct { + vk::Buffer buffer = VK_NULL_HANDLE; + vk::DeviceMemory device_memory = VK_NULL_HANDLE; + vk::MemoryPropertyFlags memory_property_flags; + void * ptr; + size_t size = 0; + vk::DeviceAddress bda_addr {}; + + vk_device device; + + ~vk_buffer_struct() { + if (size == 0) { + return; + } + VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")"); + + device->device.freeMemory(device_memory); + device->device.destroyBuffer(buffer); + } +}; + +struct vk_subbuffer { + vk_buffer buffer; + uint64_t offset; + uint64_t size; + + operator vk::DescriptorBufferInfo() const { + return { buffer->buffer, offset, size }; + } +}; + +struct vk_semaphore { + vk::Semaphore s; + uint64_t value; +}; + +struct vk_submission { + vk::CommandBuffer buffer; + std::vector wait_semaphores; + std::vector signal_semaphores; +}; + +typedef std::vector vk_sequence; + +struct vk_mat_mat_push_constants { + uint32_t M; uint32_t N; uint32_t K; + uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t k_split; + uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; + uint32_t padded_N; +}; +struct vk_mat_vec_push_constants { + uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; +}; + +struct vk_mat_mat_id_push_constants { + uint32_t M; uint32_t N; uint32_t K; + uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11; + uint32_t padded_N; +}; +struct vk_mat_vec_id_push_constants { + uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t nei0; uint32_t ne11; +}; + +struct vk_flash_attn_push_constants { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + uint32_t nem2; + uint32_t nem3; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask_n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +}; +static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128"); + +struct vk_op_push_constants { + uint32_t KX; + uint32_t KY; + float param1; + float param2; +}; + +struct vk_op_glu_push_constants { + uint32_t N; + uint32_t ne00; + uint32_t ne20; + uint32_t mode; // 0: default, 1: swapped, 2: split + float alpha; // for swiglu_oai + float limit; +}; + +struct vk_op_unary_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t misalign_offsets; + float param1; float param2; + uint32_t ne0_012mp; uint32_t ne0_012L; + uint32_t ne0_01mp; uint32_t ne0_01L; + uint32_t ne0_0mp; uint32_t ne0_0L; + uint32_t ne1_012mp; uint32_t ne1_012L; + uint32_t ne1_01mp; uint32_t ne1_01L; + uint32_t ne1_0mp; uint32_t ne1_0L; +}; +static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128"); + +static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) { + GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst))); + ne = ne != 0 ? ne : ggml_nelements(dst); + GGML_ASSERT(ne <= (int64_t)std::numeric_limits::max()); + + vk_op_unary_push_constants p{}; + p.ne = (uint32_t)ne; + + size_t src0_tsize = ggml_type_size(src0->type); + p.ne00 = (uint32_t)src0->ne[0]; + p.ne01 = (uint32_t)src0->ne[1]; + p.ne02 = (uint32_t)src0->ne[2]; + p.ne03 = (uint32_t)src0->ne[3]; + p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize); + p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize); + p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize); + p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize); + + size_t dst_tsize = ggml_type_size(dst->type); + p.ne10 = (uint32_t)dst->ne[0]; + p.ne11 = (uint32_t)dst->ne[1]; + p.ne12 = (uint32_t)dst->ne[2]; + p.ne13 = (uint32_t)dst->ne[3]; + p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize); + p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize); + p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); + p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); + + return p; // offsets are initialized later in ggml_vk_op +} + +struct vk_op_pad_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t misalign_offsets; + + uint32_t lp0; uint32_t rp0; + uint32_t lp1; uint32_t rp1; + uint32_t lp2; uint32_t rp2; + uint32_t lp3; uint32_t rp3; +}; + +static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) { + int64_t ne = ggml_nelements(dst); + GGML_ASSERT(ne <= (int64_t)std::numeric_limits::max()); + + vk_op_pad_push_constants p{}; + p.ne = (uint32_t)ne; + + size_t src0_tsize = ggml_type_size(src0->type); + p.ne00 = (uint32_t)src0->ne[0]; + p.ne01 = (uint32_t)src0->ne[1]; + p.ne02 = (uint32_t)src0->ne[2]; + p.ne03 = (uint32_t)src0->ne[3]; + p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize); + p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize); + p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize); + p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize); + + size_t dst_tsize = ggml_type_size(dst->type); + p.ne10 = (uint32_t)dst->ne[0]; + p.ne11 = (uint32_t)dst->ne[1]; + p.ne12 = (uint32_t)dst->ne[2]; + p.ne13 = (uint32_t)dst->ne[3]; + p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize); + p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize); + p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); + p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); + + p.lp0 = dst->op_params[0]; + p.rp0 = dst->op_params[1]; + p.lp1 = dst->op_params[2]; + p.rp1 = dst->op_params[3]; + p.lp2 = dst->op_params[4]; + p.rp2 = dst->op_params[5]; + p.lp3 = dst->op_params[6]; + p.rp3 = dst->op_params[7]; + + return p; // fastdiv values and offsets are initialized later in ggml_vk_op +} + +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L) +{ + // compute L = ceil(log2(d)); + L = 0; + while (L < 32 && (uint32_t{1} << L) < d) { + L++; + } + + mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1); +} + +template void init_pushconst_fastdiv(T &p) { + GGML_UNUSED(p); + static_assert(!std::is_const::value, "unexpected type"); +} + +template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) { + // Compute magic values to divide by these six numbers. + init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, p.ne0_012L); + init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, p.ne0_01L); + init_fastdiv_values(p.ne00, p.ne0_0mp, p.ne0_0L); + init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, p.ne1_012L); + init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, p.ne1_01L); + init_fastdiv_values(p.ne10, p.ne1_0mp, p.ne1_0L); +} + +struct vk_op_binary_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23; + uint32_t misalign_offsets; + float param1; float param2; int32_t param3; +}; + +struct vk_op_multi_add_push_constants { + // shape for dst + uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; + + // strides for srcs+dst + uint32_t nb[MAX_PARAMETER_COUNT][4]; + + uint32_t rms_partials; +}; +// update multi_add.comp if this changes +static_assert(MAX_PARAMETER_COUNT == 12); +static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); + +struct vk_op_add_id_push_constants { + uint32_t ne0; + uint32_t ne1; + uint32_t s01; + uint32_t s02; + uint32_t s11; + uint32_t s21; +}; + +struct vk_op_diag_mask_push_constants { + uint32_t ncols; + uint32_t rows_per_channel; + int32_t n_past; +}; + +struct vk_op_rope_push_constants { + uint32_t ncols; + uint32_t n_dims; + float freq_scale; + uint32_t p_delta_rows; + float freq_base; + float ext_factor; + float attn_factor; + float corr_dims[2]; + float theta_scale; + uint32_t has_ff; + uint32_t ne02; + uint32_t s1; + uint32_t s2; + int32_t sections[4]; + uint32_t is_back; +}; + +struct vk_op_soft_max_push_constants { + uint32_t KX; + uint32_t KY; + uint32_t ne00; + uint32_t ne01; + uint32_t ne02; + uint32_t ne12; + uint32_t ne13; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + float scale; + float max_bias; + float m0; + float m1; + uint32_t n_head_log2; + uint32_t nrows_x; + uint32_t has_sinks; +}; + +struct vk_op_argsort_push_constants { + uint32_t ncols; + int32_t order; +}; + +struct vk_op_im2col_push_constants { + uint64_t dst_addr; + uint32_t batch_offset; uint32_t offset_delta; + uint32_t IC; + uint32_t IW; uint32_t IH; + uint32_t OW; uint32_t OH; + uint32_t KW; uint32_t KH; + uint32_t pelements; + uint32_t CHW; + int32_t s0; int32_t s1; + int32_t p0; int32_t p1; + int32_t d0; int32_t d1; +}; + +struct vk_op_im2col_3d_push_constants { + uint64_t dst_addr; + uint32_t nb10; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t s0; + uint32_t s1; + uint32_t s2; + uint32_t p0; + uint32_t p1; + uint32_t p2; + uint32_t d0; + uint32_t d1; + uint32_t d2; + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t IC; + uint32_t KW; + uint32_t OH; + uint32_t KD_KH_KW; + uint32_t KH_KW; + uint32_t IC_KD_KH_KW; + uint32_t N_OD_OH; + uint32_t OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW; + uint32_t misalign_offsets; +}; + +struct vk_op_timestep_embedding_push_constants { + uint32_t nb1; + uint32_t dim; + uint32_t max_period; +}; + +struct vk_op_conv_transpose_1d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t K; + uint32_t L; + uint32_t KL; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb11; + uint32_t nb1; + + int32_t s0; +}; + +struct vk_op_pool2d_push_constants { + uint32_t IW; uint32_t IH; + uint32_t OW; uint32_t OH; + uint32_t OC; + uint32_t pelements; + uint32_t op; + int32_t k0; int32_t k1; + int32_t s0; int32_t s1; + int32_t p0; int32_t p1; +}; + +struct vk_op_rwkv_wkv6_push_constants { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t H; +}; + +struct vk_op_rwkv_wkv7_push_constants { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t H; +}; + +struct vk_op_conv2d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH + uint32_t KWmp; uint32_t KWL; + uint32_t KWKHmp; uint32_t KWKHL; + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; +}; + +template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) { + // Compute magic values to divide by KW, KW*KH, OW, OW*OH + init_fastdiv_values(p.KW, p.KWmp, p.KWL); + init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL); + init_fastdiv_values(p.OW, p.OWmp, p.OWL); + init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); +} + +struct vk_op_conv_transpose_2d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1 + uint32_t KWmp; uint32_t KWL; + uint32_t KWKHmp; uint32_t KWKHL; + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; + uint32_t s0mp; uint32_t s0L; + uint32_t s1mp; uint32_t s1L; +}; + +template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) { + // Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1 + init_fastdiv_values(p.KW, p.KWmp, p.KWL); + init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL); + init_fastdiv_values(p.OW, p.OWmp, p.OWL); + init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); + init_fastdiv_values(p.s0, p.s0mp, p.s0L); + init_fastdiv_values(p.s1, p.s1mp, p.s1L); +} + +struct vk_op_conv2d_dw_push_constants { + uint32_t ne; + uint32_t batches; + uint32_t channels; + uint32_t dst_w; + uint32_t dst_h; + uint32_t src_w; + uint32_t src_h; + uint32_t knl_w; + uint32_t knl_h; + int32_t stride_x; + int32_t stride_y; + int32_t pad_x; + int32_t pad_y; + int32_t dilation_x; + int32_t dilation_y; +}; + +struct vk_op_upscale_push_constants { + uint32_t ne; uint32_t a_offset; uint32_t d_offset; + uint32_t ne00; uint32_t ne01; + uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; + float sf0; float sf1; float sf2; float sf3; +}; + +struct vk_op_sum_rows_push_constants +{ + uint32_t n_cols; + uint32_t ne01, ne02; + uint32_t nb01, nb02, nb03; + uint32_t nb11, nb12, nb13; + float weight; + uint32_t misalign_offsets; + uint32_t ne0_12mp, ne0_12L; + uint32_t ne0_1mp, ne0_1L; +}; + +static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) { + uint32_t type_size = (uint32_t)ggml_type_size(src->type); + vk_op_sum_rows_push_constants p = {}; + p.n_cols = (uint32_t)n_cols; + p.ne01 = (uint32_t)src->ne[1]; + p.ne02 = (uint32_t)src->ne[2]; + p.nb01 = (uint32_t)src->nb[1] / type_size; + p.nb02 = (uint32_t)src->nb[2] / type_size; + p.nb03 = (uint32_t)src->nb[3] / type_size; + p.nb11 = (uint32_t)dst->nb[1] / type_size; + p.nb12 = (uint32_t)dst->nb[2] / type_size; + p.nb13 = (uint32_t)dst->nb[3] / type_size; + p.weight = 1.0f; + return p; +} + +template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) { + init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L); + init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L); +} + +// Allow pre-recording command buffers +struct vk_staging_memcpy { + vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} + + void * dst; + const void * src; + size_t n; +}; + +struct vk_staging_memset { + vk_staging_memset(void * _dst, uint32_t _val, size_t _n) : dst(_dst), val(_val), n(_n) {} + + void * dst; + uint32_t val; + size_t n; +}; + +struct vk_context_struct { + vk_submission * s; + std::vector seqs; + + int exit_tensor_idx; + + std::vector in_memcpys; + std::vector out_memcpys; + std::vector memsets; + + vk_command_pool * p {}; +}; +typedef std::shared_ptr vk_context; +typedef std::weak_ptr vk_context_ref; + +struct ggml_vk_garbage_collector { + std::vector tl_semaphores; + std::vector semaphores; + std::vector events; + std::vector temp_buffers; + std::vector contexts; +}; + +#if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG) +#define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl + +static std::string format_size(size_t size) { + const size_t kib = 1024; + const size_t mib = kib * 1024; + const size_t gib = mib * 1024; + + std::ostringstream oss; + oss << std::fixed << std::setprecision(2); + + if (size >= gib) { + oss << static_cast(size) / gib << " GiB"; + } else if (size >= mib) { + oss << static_cast(size) / mib << " MiB"; + } else if (size >= kib) { + oss << static_cast(size) / kib << " KiB"; + } else { + oss << size << " B"; + } + + return oss.str(); +} + +class vk_memory_logger { +public: + vk_memory_logger(): total_device(0), total_host(0) {} + void log_allocation(vk_buffer_ref buf_ref, size_t size); + void log_deallocation(vk_buffer_ref buf_ref); + +private: + std::map allocations; // Track allocations + size_t total_device; + size_t total_host; +}; +#else +#define VK_LOG_MEMORY(msg) ((void) 0) +#endif // GGML_VULKAN_MEMORY_DEBUG + +class vk_perf_logger { + public: + void print_timings() { + if (timings.empty()) { + return; + } + uint64_t total_all_op_times = 0; + std::cerr << "----------------\nVulkan Timings:" << std::endl; + for (const auto & t : timings) { + uint64_t total_op_times = 0; + for (const auto & time : t.second) { + total_op_times += time; + } + std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0) + << " us"; + + // If we have as many flops entries as timing entries for the op, then compute and log the flops/S. + auto it = flops.find(t.first); + if (it != flops.end() && (it->second).size() == t.second.size()) { + uint64_t total_op_flops = 0; + for (const auto & elem : it->second) { + total_op_flops += elem; + } + std::cerr << " (" + << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) / + (double(total_op_times) / (1000.0 * 1000.0 * 1000.0)) + << " GFLOPS/s)"; + } + + total_all_op_times += total_op_times; + + std::cerr << std::endl; + } + + if (timings.size() > 0) { + std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl; + } + + timings.clear(); + flops.clear(); + } + + void log_timing(const ggml_tensor * node, uint64_t time) { + if (node->op == GGML_OP_UNARY) { + timings[ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time); + return; + } + if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { + const uint64_t m = node->src[0]->ne[1]; + const uint64_t n = node->ne[1]; + const uint64_t k = node->src[1]->ne[0]; + const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3]; + std::string name = ggml_op_name(node->op); + if ((node->op == GGML_OP_MUL_MAT && n <= mul_mat_vec_max_cols) || + (node->op == GGML_OP_MUL_MAT_ID && node->src[2]->ne[1] == 1)) { + name += "_VEC"; + } + name += " "; + name += ggml_type_name(node->src[0]->type); + name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); + if (batch > 1) { + name += " batch=" + std::to_string(batch); + } + timings[name].push_back(time); + flops[name].push_back(m * n * (k + (k - 1)) * batch); + return; + } + if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) { + std::string name = ggml_op_name(node->op); + ggml_tensor * knl = node->src[0]; + uint64_t OW = node->ne[0]; + uint64_t OH = node->ne[1]; + uint64_t N = node->ne[3]; + uint64_t Cout = node->ne[2]; + uint64_t KW = knl->ne[0]; + uint64_t KH = knl->ne[1]; + uint64_t Cin = node->src[1]->ne[2]; + // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ + uint64_t size_M = Cout; + uint64_t size_K = Cin * KW * KH; + uint64_t size_N = N * OW * OH; + uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1)); + name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) + + ", N=N*OW*OH=" + std::to_string(size_N); + flops[name].push_back(n_flops); + timings[name].push_back(time); + return; + } + if (node->op == GGML_OP_RMS_NORM) { + std::string name = ggml_op_name(node->op); + name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")"; + timings[name].push_back(time); + return; + } + timings[ggml_op_name(node->op)].push_back(time); + } + private: + std::map> timings; + std::map> flops; +}; + +struct ggml_backend_vk_context { + std::string name; + + vk_device device; + + size_t semaphore_idx, event_idx; + ggml_vk_garbage_collector gc; + size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset; + vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials; + vk::Fence fence, almost_ready_fence; + bool almost_ready_fence_pending {}; + // Set before op_add and unset after op_rms_norm to indicate that the add should + // write partial sums to accumulate the square of the vector components + bool do_add_rms_partials; + + // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert. + vk_pipeline_struct * prealloc_y_last_pipeline_used {}; + const ggml_tensor * prealloc_y_last_tensor_used {}; + + // Track which nodes have been used since the last sync, and whether they were written to + std::vector unsynced_nodes_written; + std::vector unsynced_nodes_read; + // Track which prealloc buffers have pending reads that need to be synchronized. + // These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set), + // and set to true after the buffer contents are consumed. + bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync; + + vk_buffer buffer_pool[MAX_VK_BUFFERS]; + + vk_context_ref compute_ctx; + vk_context_ref transfer_ctx; + + std::vector tensor_ctxs; + + std::vector descriptor_pools; + std::vector descriptor_sets; + uint32_t descriptor_set_idx {}; + uint32_t pipeline_descriptor_set_requirements {}; + + vk_command_pool compute_cmd_pool; + vk_command_pool transfer_cmd_pool; + + // number of additional consecutive nodes that are being fused with the + // node currently being processed + int num_additional_fused_ops {}; +}; + +static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT + +static uint64_t vk_tensor_offset(const ggml_tensor * tensor) { + if (tensor->view_src) { + return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base; + } + return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base; +} + +struct ggml_backend_vk_buffer_context { + vk_device_ref device; + vk_buffer dev_buffer; + std::string name; + + ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) : + device(device), + dev_buffer(dev_buffer), + name(name) { + } + + ~ggml_backend_vk_buffer_context() { + ggml_vk_destroy_buffer(dev_buffer); + } +}; + +#ifdef GGML_VULKAN_MEMORY_DEBUG +static std::mutex log_mutex; + +void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) { + std::lock_guard guard(log_mutex); + vk_buffer buf = buf_ref.lock(); + const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); + const std::string type = device ? "device" : "host"; + allocations[buf->buffer] = size; + total_device += device ? size : 0; + total_host += device ? 0 : size; + VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); +} + +void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { + if (buf_ref.expired() || buf_ref.lock()->size == 0) { + return; + } + + std::lock_guard guard(log_mutex); + vk_buffer buf = buf_ref.lock(); + const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); + std::string type = device ? "device" : "host"; + auto it = allocations.find(buf->buffer); + total_device -= device ? it->second : 0; + total_host -= device ? 0 : it->second; + if (it != allocations.end()) { + VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); + allocations.erase(it); + } else { + VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer); + } +} +#endif // GGML_VULKAN_MEMORY_DEBUG + +struct vk_instance_t { + vk::Instance instance; + + bool debug_utils_support = false; // VK_EXT_debug_utils enabled + PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {}; + PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {}; + PFN_vkQueueEndDebugUtilsLabelEXT pfn_vkQueueEndDebugUtilsLabelEXT = {}; + PFN_vkCmdBeginDebugUtilsLabelEXT pfn_vkCmdBeginDebugUtilsLabelEXT = {}; + PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {}; + PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {}; + + std::vector device_indices; + std::vector device_supports_membudget; + vk_device devices[GGML_VK_MAX_DEVICES]; +}; + +static bool vk_instance_initialized = false; +static vk_instance_t vk_instance; + +static bool vk_perf_logger_enabled = false; + +#ifdef GGML_VULKAN_CHECK_RESULTS +static size_t vk_skip_checks; +static size_t vk_output_tensor; + +static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name); +static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx); +static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx); +#endif + +typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +static void ggml_backend_vk_free(ggml_backend_t backend); + +static VkDeviceSize ggml_vk_get_max_buffer_range(const ggml_backend_vk_context * ctx, const vk_buffer &buf, const VkDeviceSize offset) { + const VkDeviceSize range = std::min(VkDeviceSize{buf->size - offset}, + VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange}); + return range; +} + +// Wait for ctx->fence to be signaled. +static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { + // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep + // during this wait. + if (ctx->almost_ready_fence_pending) { + VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence"); + ctx->device->device.resetFences({ ctx->almost_ready_fence }); + ctx->almost_ready_fence_pending = false; + } + + // Spin (w/pause) waiting for the graph to finish executing. + vk::Result result; + while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) { + if (result != vk::Result::eNotReady) { + fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__); + exit(1); + } + for (uint32_t i = 0; i < 100; ++i) { + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + } + } + ctx->device->device.resetFences({ ctx->fence }); +} + +// variables to track number of compiles in progress +static uint32_t compile_count = 0; +static std::mutex compile_count_mutex; +static std::condition_variable compile_count_cond; + +static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint, + uint32_t parameter_count, std::array wg_denoms, std::vector specialization_constants, + bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { + VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << pipeline->name << ", " << entrypoint << ", " << parameter_count << + ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << + disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); + GGML_ASSERT(parameter_count > 0); + GGML_ASSERT(parameter_count <= MAX_PARAMETER_COUNT); + GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT + + vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); + pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); + + vk::PushConstantRange pcr( + vk::ShaderStageFlagBits::eCompute, + 0, + pipeline->push_constant_size + ); + + vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), device->dsl, pcr); + pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info); + + std::vector specialization_entries(specialization_constants.size()); + + for (size_t i = 0; i < specialization_constants.size(); i++) { + specialization_entries[i].constantID = i; + specialization_entries[i].offset = i * sizeof(uint32_t); + specialization_entries[i].size = sizeof(uint32_t); + } + + vk::SpecializationInfo specialization_info( + specialization_entries.size(), + specialization_entries.data(), + specialization_constants.size() * sizeof(uint32_t), + specialization_constants.data() + ); + + vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{}; + + if (device->subgroup_require_full_support && require_full_subgroups) { + pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT; + } + + vk::PipelineShaderStageCreateInfo pipeline_shader_create_info( + pipeline_shader_stage_create_flags, + vk::ShaderStageFlagBits::eCompute, + pipeline->shader_module, + entrypoint.c_str(), + &specialization_info); + + vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info; + pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size; + if (device->subgroup_size_control && required_subgroup_size > 0) { + GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size); + pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info); + } + + vk::ComputePipelineCreateInfo compute_pipeline_create_info( + device->pipeline_executable_properties_support ? + vk::PipelineCreateFlagBits::eCaptureStatisticsKHR : + vk::PipelineCreateFlags{}, + pipeline_shader_create_info, + pipeline->layout); + + vk::PipelineRobustnessCreateInfoEXT rci; + + if (device->pipeline_robustness && disable_robustness) { + rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; + rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; + compute_pipeline_create_info.setPNext(&rci); + } + + try { + pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Compute pipeline creation failed for " << pipeline->name << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } + pipeline->compiled = true; + + if (vk_instance.debug_utils_support) { + vk::DebugUtilsObjectNameInfoEXT duoni; + duoni.objectType = vk::ObjectType::ePipeline; + duoni.pObjectName = pipeline->name.c_str(); + duoni.objectHandle = /*reinterpret_cast*/(uint64_t)(static_cast(pipeline->pipeline)); + vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast(duoni)); + } + + if (device->pipeline_executable_properties_support) { + vk::PipelineExecutableInfoKHR executableInfo; + executableInfo.pipeline = pipeline->pipeline; + + auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo); + for (auto & s : statistics) { + // "Register Count" is reported by NVIDIA drivers. + if (strcmp(s.name, "Register Count") == 0) { + VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers"); + pipeline->register_count = (uint32_t)s.value.u64; + } + } + } + + { + std::lock_guard guard(device->mutex); + device->all_pipelines.push_back(pipeline); + } + + { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + } + compile_count_cond.notify_all(); +} + +static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { + VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")"); + device.destroyPipelineLayout(pipeline->layout); + + device.destroyShaderModule(pipeline->shader_module); + + device.destroyPipeline(pipeline->pipeline); +} + +static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) { + VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); + ctx->pipeline_descriptor_set_requirements += n; + if (!pipeline->compiled) { + pipeline->needed = true; + ctx->device->need_compiles = true; + } +} + +static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) { + + if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) { + // Enough descriptors are available + return; + } + + vk_device& device = ctx->device; + + uint32_t to_alloc = ctx->pipeline_descriptor_set_requirements - ctx->descriptor_sets.size(); + uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; + uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + while (to_alloc > 0) { + const uint32_t alloc_count = std::min(pool_remaining, to_alloc); + to_alloc -= alloc_count; + pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + if (pool_idx >= ctx->descriptor_pools.size()) { + vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE); + vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); + ctx->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); + } + + std::vector layouts(alloc_count); + for (uint32_t i = 0; i < alloc_count; i++) { + layouts[i] = device->dsl; + } + vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data()); + std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); + ctx->descriptor_sets.insert(ctx->descriptor_sets.end(), sets.begin(), sets.end()); + + pool_idx++; + } +} + +static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) { + VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()"); + + if (p.cmd_buffers.size() > p.cmd_buffer_idx) { + // Reuse command buffer + return p.cmd_buffers[p.cmd_buffer_idx++]; + } + + vk::CommandBufferAllocateInfo command_buffer_alloc_info( + p.pool, + vk::CommandBufferLevel::ePrimary, + 1); + const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); + auto buf = cmd_buffers.front(); + + p.cmd_buffers.push_back(buf); + p.cmd_buffer_idx++; + + return buf; +} + +static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { + if (ctx->seqs.empty()) { + if (fence) { + std::lock_guard guard(queue_mutex); + ctx->p->q->queue.submit({}, fence); + } + return; + } + VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")"); + + std::vector> tl_wait_vals; + std::vector> tl_signal_vals; + std::vector> tl_wait_semaphores; + std::vector> tl_signal_semaphores; + std::vector tl_submit_infos; + std::vector submit_infos; + int idx = -1; + std::vector> stage_flags; + + size_t reserve = 0; + + for (const auto& sequence : ctx->seqs) { + reserve += sequence.size(); + } + + // Pre-reserve vectors to prevent reallocation, which invalidates pointers + tl_wait_semaphores.reserve(reserve); + tl_wait_vals.reserve(reserve); + tl_signal_semaphores.reserve(reserve); + tl_signal_vals.reserve(reserve); + tl_submit_infos.reserve(reserve); + submit_infos.reserve(reserve); + stage_flags.reserve(reserve); + + for (const auto& sequence : ctx->seqs) { + for (const auto& submission : sequence) { + stage_flags.push_back({}); + idx++; + tl_wait_vals.push_back({}); + tl_wait_semaphores.push_back({}); + tl_signal_vals.push_back({}); + tl_signal_semaphores.push_back({}); + for (size_t i = 0; i < submission.wait_semaphores.size(); i++) { + stage_flags[idx].push_back(ctx->p->q->stage_flags); + tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value); + tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s); + } + for (size_t i = 0; i < submission.signal_semaphores.size(); i++) { + tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value); + tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s); + } + tl_submit_infos.push_back({ + (uint32_t) submission.wait_semaphores.size(), + tl_wait_vals[idx].data(), + (uint32_t) submission.signal_semaphores.size(), + tl_signal_vals[idx].data(), + }); + tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo; + tl_submit_infos[idx].pNext = nullptr; + vk::SubmitInfo si{ + (uint32_t) submission.wait_semaphores.size(), + tl_wait_semaphores[idx].data(), + stage_flags[idx].data(), + 1, + &submission.buffer, + (uint32_t) submission.signal_semaphores.size(), + tl_signal_semaphores[idx].data(), + }; + si.setPNext(&tl_submit_infos[idx]); + submit_infos.push_back(si); + } + } + + std::lock_guard guard(queue_mutex); + ctx->p->q->queue.submit(submit_infos, fence); + + ctx->seqs.clear(); +} + +static uint32_t ggml_vk_find_queue_family_index(std::vector& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) { + VK_LOG_DEBUG("ggml_vk_find_queue_family_index()"); + const uint32_t qfsize = queue_family_props.size(); + + // Try with avoid preferences first + for (uint32_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) { + return i; + } + } + + // Fall back to only required + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) { + return i; + } + } + + // Fall back to reusing compute queue + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) { + return i; + } + } + + // Fall back to ignoring min_num_queries + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueFlags & required) { + return i; + } + } + + // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations. + // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional. + if (compute_index >= 0) { + return compute_index; + } + + std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl; + + for(auto &q_family : queue_family_props) { + std::cerr << "Queue number: " + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl; + } + abort(); +} + +static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) { + VK_LOG_DEBUG("ggml_vk_create_queue()"); + std::lock_guard guard(device->mutex); + + q.queue_family_index = queue_family_index; + q.transfer_only = transfer_only; + + q.cmd_pool.init(device, &q); + + q.queue = device->device.getQueue(queue_family_index, queue_index); + + q.stage_flags = stage_flags; +} + +static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p) { + vk_context result = std::make_shared(); + VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")"); + ctx->gc.contexts.emplace_back(result); + result->p = &p; + return result; +} + +static vk_context ggml_vk_create_temporary_context(vk_command_pool& p) { + vk_context result = std::make_shared(); + VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")"); + result->p = &p; + return result; +} + +static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()"); + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci); + ctx->gc.semaphores.push_back({ semaphore, 0 }); + return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1]; +} + +static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()"); + if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) { + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci); + ctx->gc.tl_semaphores.push_back({ semaphore, 0 }); + } + return &ctx->gc.tl_semaphores[ctx->semaphore_idx++]; +} + +static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) { + if (ctx->event_idx >= ctx->gc.events.size()) { + ctx->gc.events.push_back(ctx->device->device.createEvent({})); + } + return ctx->gc.events[ctx->event_idx++]; +} + +static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) { + VK_LOG_DEBUG("ggml_vk_command_pool_cleanup()"); + + // Requires command buffers to be done + device->device.resetCommandPool(p.pool); + p.cmd_buffer_idx = 0; +} + +static void ggml_vk_queue_command_pools_cleanup(vk_device& device) { + VK_LOG_DEBUG("ggml_vk_queue_command_pools_cleanup()"); + + // Arbitrary frequency to cleanup/reuse command buffers + static constexpr uint32_t cleanup_frequency = 10; + + if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool); + } + if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool); + } +} + + +static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) { + for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) { + vk::MemoryType memory_type = mem_props->memoryTypes[i]; + if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) && + (flags & memory_type.propertyFlags) == flags && + mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) { + return static_cast(i); + } + } + return UINT32_MAX; +} + +static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) { + VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")"); + if (size > device->max_buffer_size) { + throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit"); + } + + vk_buffer buf = std::make_shared(); + + if (size == 0) { + buf->size = 0; + return buf; + } + + vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst; + vk::MemoryAllocateFlags mem_flags {}; + if (device->buffer_device_address) { + usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress; + mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress; + } + + vk::BufferCreateInfo buffer_create_info{ + vk::BufferCreateFlags(), + size, + usage_flags, + vk::SharingMode::eExclusive, + 0, + nullptr, + }; + + buf->buffer = device->device.createBuffer(buffer_create_info); + + vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer); + + vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); + + const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags }; + + for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { + const auto & req_flags = *it; + + uint32_t memory_type_index = find_properties(&mem_props, &mem_req, req_flags); + + if (memory_type_index == UINT32_MAX) { + continue; + } + buf->memory_property_flags = req_flags; + + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info }); + break; + } catch (const vk::SystemError& e) { + // loop and retry + // during last attempt throw the exception + if (it + 1 == req_flags_list.end()) { + device->device.destroyBuffer(buf->buffer); + throw e; + } + } + } + + if (!buf->device_memory) { + device->device.destroyBuffer(buf->buffer); + throw vk::OutOfDeviceMemoryError("No suitable memory type found"); + } + + buf->ptr = nullptr; + + if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + } + + device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0); + + buf->device = device; + buf->size = size; + + if (device->buffer_device_address) { + const vk::BufferDeviceAddressInfo addressInfo(buf->buffer); + buf->bda_addr = device->device.getBufferAddress(addressInfo); + } + +#ifdef GGML_VULKAN_MEMORY_DEBUG + device->memory_logger->log_allocation(buf, size); +#endif + + return buf; +} + +static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { + try { + return ggml_vk_create_buffer(device, size, {req_flags, fallback_flags}); + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } +} + +static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { + vk_buffer buf; + try { + if (device->prefer_host_memory) { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, + vk::MemoryPropertyFlagBits::eDeviceLocal}); + } else if (device->uma) { + // Fall back to host memory type + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); + } else if (device->disable_host_visible_vidmem) { + if (device->allow_sysmem_fallback) { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); + } else { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + } + } else { + // use rebar if available, otherwise fallback to device only visible memory + if (device->allow_sysmem_fallback) { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, + vk::MemoryPropertyFlagBits::eDeviceLocal, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); + } else { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, + vk::MemoryPropertyFlagBits::eDeviceLocal}); + } + } + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } + + return buf; +} + +static void ggml_vk_destroy_buffer(vk_buffer& buf) { + if (buf == nullptr) { + return; + } + +#ifdef GGML_VULKAN_MEMORY_DEBUG + if (buf->device != nullptr) { + buf->device->memory_logger->log_deallocation(buf); + } +#endif + + buf.reset(); +} + +static vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset = 0) { + return { buf, offset, ggml_vk_get_max_buffer_range(ctx, buf, offset) }; +} + +static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) { + VK_LOG_DEBUG("ggml_vk_sync_buffers()"); + + const bool transfer_queue = subctx->p->q->transfer_only; + + if (ctx) { + ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; + } + + subctx->s->buffer.pipelineBarrier( + subctx->p->q->stage_flags, + subctx->p->q->stage_flags, + {}, + { { + { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, + { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) } + } }, + {}, + {} + ); +} + +static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events) { + VK_LOG_DEBUG("ggml_vk_wait_events()"); + if (events.empty()) { + return; + } + + ctx->s->buffer.waitEvents( + events, + ctx->p->q->stage_flags, + ctx->p->q->stage_flags, + {}, + {}, + {} + ); +} + +// number of rows/cols for flash attention shader +static constexpr uint32_t flash_attention_num_small_rows = 32; +static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; + +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { + if (hsv >= 192) { + return 2; + } else { + return 8; + } +} + +// The FA coopmat1 shader assumes 16x16x16 matrix multiply support. +// 128 threads split into four subgroups, each subgroup does 1/4 +// of the Bc dimension. +static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; +static constexpr uint32_t scalar_flash_attention_Bc = 64; +static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; + +static uint32_t get_fa_num_small_rows(FaCodePath path) { + if (path == FA_COOPMAT2) { + return flash_attention_num_small_rows; + } else { + return scalar_flash_attention_num_small_rows; + } +} + +static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) { + GGML_UNUSED(clamp); + GGML_UNUSED(hsv); + + if (path == FA_SCALAR) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, 64}; + } else { + if ((hsv | hsk) & 8) { + // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter + // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. + return {get_fa_scalar_num_large_rows(hsv), 64}; + } else { + return {get_fa_scalar_num_large_rows(hsv), 32}; + } + } + } + + if (path == FA_COOPMAT1) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc}; + } else { + return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; + } + } + + // small rows, large cols + if (small_rows) { + return {get_fa_num_small_rows(FA_COOPMAT2), 32}; + } + + // small cols to reduce register count + if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) { + if (hsk >= 512 || hsv >= 512) { + return {32, 32}; + } else { + return {64, 32}; + } + } + return {64, 64}; +} + +static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) { + return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1]; +} + +static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { + + uint32_t lut_size = 0; + switch (src0_type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + lut_size = 2*2048; + break; + case GGML_TYPE_IQ2_XXS: + lut_size = 8*256; + break; + case GGML_TYPE_IQ2_XS: + lut_size = 8*512; + break; + case GGML_TYPE_IQ2_S: + lut_size = 8*1024; + break; + case GGML_TYPE_IQ3_XXS: + lut_size = 4*256; + break; + case GGML_TYPE_IQ3_S: + lut_size = 4*512; + break; + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: + lut_size = 4*16; + break; + default: + break; + } + + // Needs to be kept up to date on shader changes + const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; + const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const uint32_t warps = warptile[0] / warptile[10]; + + const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; + const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0; + const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; + const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0; + + const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " + "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported); + + return supported; +} + +struct GpuPipelineConfig { + // GPU architecture identifier. + // Example: vk_device_architecture::AMD_GCN + vk_device_architecture arch; + + // Mapping of pipeline names to their specific subgroup sizes. + // Example: {"soft_max_f32", 64} + std::unordered_map pipelines; + + // Default subgroup size for this GPU. + // Defaults to 0 if not explicitly provided. + uint32_t default_subgroup_size = 0; +}; + +// Pipeline configuration for RDNA1 GPUs. +static const std::unordered_map rdna1_pipelines = { + {"soft_max", 64}, {"im2col", 64}, + {"argmax", 64}, {"mul_mat_vec", 64}, + {"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32} +}; + +// Pipeline configuration for RDNA2 GPUs. +static const std::unordered_map rdna2_pipelines = { + {"soft_max", 64}, {"im2col", 64}, +}; + +static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32; + +// Define configurations for different GPUs. +static std::vector gpu_pipeline_configs = { + { + vk_device_architecture::AMD_RDNA1, + { + rdna1_pipelines, + }, + RDNA_DEFAULT_SUBGROUP_SIZE + }, + { + vk_device_architecture::AMD_RDNA2, + { + rdna2_pipelines, + }, + RDNA_DEFAULT_SUBGROUP_SIZE + }, +}; + +static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) { + for (const auto &config : gpu_pipeline_configs) { + if (config.arch == arch) { + auto pipIt = config.pipelines.find(pipeline_name); + if (pipIt != config.pipelines.end()) { + return pipIt->second; + } + std::vector> sorted_pipelines(config.pipelines.begin(), config.pipelines.end()); + std::sort(sorted_pipelines.begin(), sorted_pipelines.end(), + [](const auto &a, const auto &b) { return a.first.size() > b.first.size(); }); + for (const auto &entry : sorted_pipelines) { + if (pipeline_name.find(entry.first) != std::string::npos) { + return entry.second; + } + } + return config.default_subgroup_size; + } + } + return 0; // If no matching configuration is found +} + +static void ggml_vk_load_shaders(vk_device& device) { + VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); + + // some shaders have a minimum subgroup size + const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u); + const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); + const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); + + const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; + const uint32_t mul_mat_subgroup_size_8 = std::max(mul_mat_subgroup_size, 8u); + const uint32_t mul_mat_subgroup_size_16 = std::max(mul_mat_subgroup_size, 16u); + const uint32_t mul_mat_subgroup_size_32 = std::max(mul_mat_subgroup_size, 32u); + + const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) || + (device->subgroup_size_control && device->subgroup_max_size >= 16); + + // mulmat + std::vector l_warptile, m_warptile, s_warptile, + l_warptile_id, m_warptile_id, s_warptile_id, + l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, + l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int, + l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, + l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; + std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, + l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms, + l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, + l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; + + uint32_t l_align, m_align, s_align; + if (device->coopmat2) { + // spec constants and tile sizes for non-quant matmul/matmul_id + l_warptile = { 256, 128, 256, 64, 1 }; + m_warptile = { 256, 128, 128, 64, 0 }; + s_warptile = { 128, 64, 64, 64, 0 }; + l_wg_denoms = {128, 256, 1 }; + m_wg_denoms = {128, 128, 1 }; + s_wg_denoms = { 64, 64, 1 }; + + // spec constants and tile sizes for quant matmul (non-Qi_K) + l_warptile_mmq = { 256, 128, 256, 64, 1 }; + m_warptile_mmq = { 256, 128, 128, 64, 1 }; + s_warptile_mmq = { 256, 32, 64, 128, 0 }; + l_mmq_wg_denoms = { 128, 256, 1 }; + m_mmq_wg_denoms = { 128, 128, 1 }; + s_mmq_wg_denoms = { 32, 64, 1 }; + + // spec constants and tile sizes for quant matmul (Qi_K) + l_warptile_mmq_k = { 256, 128, 256, 64, 1 }; + m_warptile_mmq_k = { 256, 128, 128, 64, 1 }; + s_warptile_mmq_k = { 256, 32, 64, 128, 0 }; + l_mmq_wg_denoms_k = { 128, 256, 1 }; + m_mmq_wg_denoms_k = { 128, 128, 1 }; + s_mmq_wg_denoms_k = { 32, 64, 1 }; + + // spec constants and tile sizes for quant matmul_id + l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size }; + m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; + s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; + l_mmqid_wg_denoms = { 128, 128, 1 }; + m_mmqid_wg_denoms = { 128, 64, 1 }; + s_mmqid_wg_denoms = { 128, 64, 1 }; + + l_align = 128; + m_align = 64; + s_align = 32; + } else { + // Matrix cores require different warp group sizes + const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4; + const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4; + const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2; + const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4; + const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2; + const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2; + const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1; + const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; + const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; + + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + + l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + + l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + + l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; + m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; + s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 }; + + l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 }; + m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; + s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; + + // chip specific tuning + if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { + m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + } + + l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; + m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; + s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; + l_align = 128; + m_align = 64; + s_align = 32; + + for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { + ggml_type t = (ggml_type)i; + // Disable medium and large matrix multiplication if not enough shared memory is available + // Check mmq warptiles as the largest configuration + // Throw an error if not enough for any matrix multiplication is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) { + std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; + throw std::runtime_error("Shared memory size too small for matrix multiplication."); + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) { + device->mul_mat_m[i] = false; + device->mul_mat_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) { + device->mul_mat_l[i] = false; + } + + // Disable mul_mat_id if not enough shared memory is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) { + device->mul_mat_id_s[i] = false; + device->mul_mat_id_m[i] = false; + device->mul_mat_id_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) { + device->mul_mat_id_m[i] = false; + device->mul_mat_id_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) { + device->mul_mat_id_l[i] = false; + } + } + } + + if (!device->pipeline_matmul_f32) { + device->pipeline_matmul_f32 = std::make_shared(); + } + if (!device->pipeline_matmul_f32_f16) { + device->pipeline_matmul_f32_f16 = std::make_shared(); + } + if (!device->pipeline_matmul_id_f32) { + device->pipeline_matmul_id_f32 = std::make_shared(); + } + if (!device->pipeline_matmul_bf16) { + device->pipeline_matmul_bf16 = std::make_shared(); + } + if (!device->pipeline_matmul_id_bf16) { + device->pipeline_matmul_id_bf16 = std::make_shared(); + } + + std::vector> compiles; + auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, + uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, + uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + + if (!require_full_subgroups && required_subgroup_size == 0) { + required_subgroup_size = get_subgroup_size(name, device->architecture); + } + + if (!pipeline) { + pipeline = std::make_shared(); + } + if (!pipeline->initialized) { + pipeline->name = name; + pipeline->parameter_count = parameter_count; + pipeline->push_constant_size = push_constant_size; + pipeline->wg_denoms = wg_denoms; + pipeline->align = align; + pipeline->initialized = true; + } + + if (!pipeline->needed || pipeline->compiled) { + return; + } + { + // wait until fewer than N compiles are in progress + uint32_t N = std::max(1u, std::thread::hardware_concurrency()); + std::unique_lock guard(compile_count_mutex); + while (compile_count >= N) { + compile_count_cond.wait(guard); + } + compile_count++; + } + + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, + parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); + }; + + auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint, + uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, + uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + return ggml_vk_create_pipeline(device, pipeline, name.c_str(), spv_size, spv_data, entrypoint, + parameter_count, push_constant_size, wg_denoms, specialization_constants, + align, disable_robustness, require_full_subgroups, required_subgroup_size); + }; + + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1}; + }; + + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + // For large number of rows, 128 invocations seems to work best. + // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we + // can't use 256 for D==80. + // For scalar, use 128 (arbitrary) + // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. + const uint32_t D = (hsk|hsv); + uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) + ? scalar_flash_attention_workgroup_size + : ((small_rows && (D % 32) == 0) ? 256 : 128); + auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows); + + // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. + // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. + const uint32_t D_lsb = D ^ (D & (D-1)); + uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); + + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; + }; + +#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ + for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ + uint32_t HSK = fa.first.HSK; \ + uint32_t HSV = fa.first.HSV; \ + bool small_rows = fa.first.small_rows; \ + FaCodePath path = fa.first.path; \ + bool aligned = fa.first.aligned; \ + bool f32acc = fa.first.f32acc; \ + if (path == FAPATH) { \ + if (aligned) { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } \ + } else { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } \ + } \ + } \ + } + + CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->coopmat1_fa_support) { + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) + } +#endif +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) + } +#endif +#undef CREATE_FA + +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + + CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) + } +#endif + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K], matmul_q5_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K], matmul_q6_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S], matmul_iq1_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M], matmul_iq1_m_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S], matmul_iq2_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + + GGML_ASSERT(device->subgroup_ballot); + + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + } +#endif + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) +#undef CREATE_MM +#undef CREATE_MM2 + } else +#endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->coopmat_support) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->coopmat_acc_f16_support) { \ + CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + } \ + if (device->coopmat_acc_f32_support) { \ + CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + } \ + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ) + } +#endif + + if (device->coopmat_acc_f16_support) { + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } else { + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } + + GGML_ASSERT(device->subgroup_ballot); + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + } +#endif + + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); +#undef CREATE_MM2 +#undef CREATE_MM + } else +#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->fp16) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + } \ + if (device->mul_mat ## ID ## _m[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + } \ + if (device->mul_mat ## ID ## _s[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + } \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + } +#endif + + if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + } else { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); + + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + } +#undef CREATE_MM2 +#undef CREATE_MMQ +#undef CREATE_MM + } else { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + } +#endif + + if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + } else { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + } + } + // reusing CREATE_MM from the fp32 path + if ((device->coopmat2 || device->coopmat_support) +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + && !device->coopmat_bf16_support +#endif + ) { + // use scalar tile sizes + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 }; + + l_wg_denoms = {128, 128, 1 }; + m_wg_denoms = { 64, 64, 1 }; + s_wg_denoms = { 32, 32, 1 }; + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); + } +#undef CREATE_MM + + // mul mat vec + + // the number of rows computed per shader depends on GPU model and quant + uint32_t rm_stdq = 1; + uint32_t rm_kq = 2; + if (device->vendor_id == VK_VENDOR_ID_AMD) { + if (device->architecture == AMD_GCN) { + rm_stdq = 2; + rm_kq = 4; + } + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) + rm_stdq = 2; + uint32_t rm_iq = 2 * rm_kq; + + const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN; + // Ensure a subgroup size >= 16 is available + const bool use_subgroups16 = use_subgroups && subgroup_min_size_16; + + const uint32_t subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16) ? 16 : device->subgroup_size; + const uint32_t subgroup_size16 = std::max(subgroup_size, 16u); + + const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0; + const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0; + + for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) { + const uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4); + const uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size16 : (subgroup_size16 * 4); + + const shader_reduction_mode reduc = (use_subgroups && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP : + (use_subgroups && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID : + SHADER_REDUCTION_MODE_SHMEM; + + const shader_reduction_mode reduc16 = (use_subgroups16 && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP : + (use_subgroups16 && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID : + SHADER_REDUCTION_MODE_SHMEM; + + for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32", arr_dmmv_q5_k_f32_f32_len[reduc16], arr_dmmv_q5_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32", arr_dmmv_q6_k_f32_f32_len[reduc16], arr_dmmv_q6_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32", arr_dmmv_iq1_s_f32_f32_len[reduc16], arr_dmmv_iq1_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32", arr_dmmv_iq1_m_f32_f32_len[reduc16], arr_dmmv_iq1_m_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32", arr_dmmv_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_iq2_xxs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32", arr_dmmv_iq2_xs_f32_f32_len[reduc16], arr_dmmv_iq2_xs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32", arr_dmmv_iq2_s_f32_f32_len[reduc16], arr_dmmv_iq2_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32", arr_dmmv_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_iq3_xxs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32", arr_dmmv_iq3_s_f32_f32_len[reduc16], arr_dmmv_iq3_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32", arr_dmmv_q5_k_f16_f32_len[reduc16], arr_dmmv_q5_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32", arr_dmmv_q6_k_f16_f32_len[reduc16], arr_dmmv_q6_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32", arr_dmmv_iq1_s_f16_f32_len[reduc16], arr_dmmv_iq1_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32", arr_dmmv_iq1_m_f16_f32_len[reduc16], arr_dmmv_iq1_m_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32", arr_dmmv_iq2_xxs_f16_f32_len[reduc16], arr_dmmv_iq2_xxs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32", arr_dmmv_iq2_xs_f16_f32_len[reduc16], arr_dmmv_iq2_xs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32", arr_dmmv_iq2_s_f16_f32_len[reduc16], arr_dmmv_iq2_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32", arr_dmmv_iq3_xxs_f16_f32_len[reduc16], arr_dmmv_iq3_xxs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32", arr_dmmv_iq3_s_f16_f32_len[reduc16], arr_dmmv_iq3_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; + const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + } +#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT + } + } + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + + // dequant shaders + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S], "dequant_iq1_s", dequant_iq1_s_len, dequant_iq1_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M], "dequant_iq1_m", dequant_iq1_m_len, dequant_iq1_m_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + + // get_rows + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q2_K], "get_rows_q2_k", get_rows_q2_k_len, get_rows_q2_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q3_K], "get_rows_q3_k", get_rows_q3_k_len, get_rows_q3_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_K], "get_rows_q4_k", get_rows_q4_k_len, get_rows_q4_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_K], "get_rows_q5_k", get_rows_q5_k_len, get_rows_q5_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q6_K], "get_rows_q6_k", get_rows_q6_k_len, get_rows_q6_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q2_K], "get_rows_q2_k_f32", get_rows_q2_k_f32_len, get_rows_q2_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q3_K], "get_rows_q3_k_f32", get_rows_q3_k_f32_len, get_rows_q3_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_K], "get_rows_q4_k_f32", get_rows_q4_k_f32_len, get_rows_q4_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_K], "get_rows_q5_k_f32", get_rows_q5_k_f32_len, get_rows_q5_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q6_K], "get_rows_q6_k_f32", get_rows_q6_k_f32_len, get_rows_q6_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); + + if (device->subgroup_clustered && device->subgroup_require_full_support) { + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_subgroup_len, quantize_q8_1_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); + } + + for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { + ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true); + } else { + ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); + } + } + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 12 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); + + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + } + +#define SET_ROWS(itype, rte) \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + + if (device->float_controls_rte_fp16) { + SET_ROWS(_i32, _rte) + SET_ROWS(_i64, _rte) + } else { + SET_ROWS(_i32, ) + SET_ROWS(_i64, ) + } +#undef SET_ROWS + + + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; + s += std::string(src0_f16 ? "_f16" : "_f32"); + s += std::string(src1_f16 ? "_f16" : "_f32"); + s += std::string(dst_f16 ? "_f16" : "_f32"); + return s; + }; + + bool rte = device->float_controls_rte_fp16; +#define CREATE_BINARY(name, namemod, spec, bindings) \ + for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ + ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ + #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \ + "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); + + CREATE_BINARY(add, , {0}, 4) + CREATE_BINARY(add, _norepeat, {1}, 4) + CREATE_BINARY(sub, , {0}, 3) + CREATE_BINARY(sub, _norepeat, {1}, 3) + CREATE_BINARY(mul, , {0}, 3) + CREATE_BINARY(mul, _norepeat, {1}, 3) + CREATE_BINARY(div, , {0}, 3) + CREATE_BINARY(div, _norepeat, {1}, 3) + CREATE_BINARY(add_rms, , {0}, 4) + CREATE_BINARY(add_rms, _norepeat, {1}, 4) +#undef CREATE_BINARY + + if (device->multi_add) { + for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) { + ggml_vk_create_pipeline2(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1); + ggml_vk_create_pipeline2(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1); + } + } + + ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1); + ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1); + ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + +#define CREATE_UNARY(name) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + CREATE_UNARY(gelu) + CREATE_UNARY(gelu_erf) + CREATE_UNARY(gelu_quick) + CREATE_UNARY(silu) + CREATE_UNARY(relu) + CREATE_UNARY(tanh) + CREATE_UNARY(sigmoid) + CREATE_UNARY(hardsigmoid) + CREATE_UNARY(hardswish) +#undef CREATE_UNARY + +#define CREATE_UNARY_RTE(name) \ + if (device->float_controls_rte_fp16) { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + } else { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + } + CREATE_UNARY_RTE(exp) +#undef CREATE_UNARY_RTE + +#define CREATE_GLU(name) \ + if (device->float_controls_rte_fp16) { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + } else { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + } + + CREATE_GLU(geglu) + CREATE_GLU(reglu) + CREATE_GLU(swiglu) + CREATE_GLU(swiglu_oai) + CREATE_GLU(geglu_erf) + CREATE_GLU(geglu_quick) +#undef CREATE_GLU + + ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); + + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true); + + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + } + + for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { + ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + + ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + + ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); + +#define IM2COL(bda) \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ + if (device->float_controls_rte_fp16) { \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ + } else { \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ + } + if (device->shader_int64 && device->buffer_device_address) { + IM2COL(_bda) + } else { + IM2COL() + } + + ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + // conv2d, conv_transpose_2d + for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { + uint32_t conv2d_WG_SIZE = 256; + uint32_t conv2d_BS_K = 128; + uint32_t conv2d_BS_CRS = 16; + uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. + uint32_t conv2d_BS_NPQ = 128; + uint32_t conv2d_TS_K = 8; + uint32_t conv2d_SHMEM_PAD = 4; + bool conv2d_UNROLL = true; + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + conv2d_SHMEM_PAD = 8; // 8 float16_t + } +#endif + + if (device->vendor_id == VK_VENDOR_ID_INTEL) { + conv2d_SHMEM_PAD = 0; + conv2d_UNROLL = false; + } else if (device->vendor_id == VK_VENDOR_ID_AMD) { + conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4; + } + + switch (s) { + default: + case CONV_SHAPE_128x128: + conv2d_BS_K = 128; + conv2d_BS_NPQ = 128; + conv2d_BS_CRS = 16; + if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) { + conv2d_UNROLL = false; + } + break; + case CONV_SHAPE_64x32: + conv2d_BS_K = 64; + conv2d_BS_NPQ = 32; + conv2d_BS_CRS = 32; + conv2d_TS_K = 4; + break; + case CONV_SHAPE_32x256: + conv2d_BS_K = 32; + conv2d_BS_NPQ = 256; + conv2d_BS_CRS = 16; + break; + } + + // Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math. + bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA || + device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; + bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD || + device->architecture == vk_device_architecture::AMD_GCN; + + if (device->subgroup_shuffle && + device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316. + allow_collectives_nv && + allow_collectives_amd) { + use_collectives = 1; + conv2d_BS_CRS = std::min( + device->subgroup_size, + conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used. + } + + uint32_t conv2d_shmem_req = + (conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float); + if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) { + conv2d_BS_CRS = 8; + if (use_collectives) { + conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); + } + } + + std::array wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 }; + std::vector spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; + +#define CREATE_CONV(name, type_suffix, spv_suffix) \ + ggml_vk_create_pipeline( \ + device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \ + name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \ + sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); +#define CREATE_CONVS(spv_suffix) \ + CREATE_CONV(conv2d, _f32, spv_suffix) \ + CREATE_CONV(conv2d, _f16_f32, spv_suffix) \ + if (device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_conv_transpose_2d_push_constants)) { \ + CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \ + CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) \ + } +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + CREATE_CONVS(_cm2) + } else +#endif + if (conv2d_UNROLL) { + CREATE_CONVS(_unroll) + } else { + CREATE_CONVS( ) + } +#undef CREATE_CONV +#undef CREATE_CONVS + } + + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + + for (auto &c : compiles) { + c.wait(); + } + device->need_compiles = false; +} + +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch); + +static vk_device ggml_vk_get_device(size_t idx) { + VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); + + if (vk_instance.devices[idx] == nullptr) { + VK_LOG_DEBUG("Initializing new vk_device"); + vk_device device = std::make_shared(); + vk_instance.devices[idx] = device; + +#ifdef GGML_VULKAN_MEMORY_DEBUG + device->memory_logger = std::unique_ptr(new vk_memory_logger()); +#endif + if (vk_perf_logger_enabled) { + device->perf_logger = std::unique_ptr(new vk_perf_logger()); + } + + size_t dev_num = vk_instance.device_indices[idx]; + + std::vector physical_devices = vk_instance.instance.enumeratePhysicalDevices(); + + if (dev_num >= physical_devices.size()) { + std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; + throw std::runtime_error("Device not found"); + } + + device->physical_device = physical_devices[dev_num]; + const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties(); + + device->architecture = get_device_architecture(device->physical_device); + + const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY"); + device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr; + + const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv("GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM"); + device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr; + + const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK"); + device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr; + + const char* GGML_VK_DISABLE_GRAPH_OPTIMIZE = getenv("GGML_VK_DISABLE_GRAPH_OPTIMIZE"); + device->disable_graph_optimize = GGML_VK_DISABLE_GRAPH_OPTIMIZE != nullptr; + + bool fp16_storage = false; + bool fp16_compute = false; + bool maintenance4_support = false; + bool sm_builtins = false; + bool amd_shader_core_properties2 = false; + bool pipeline_robustness = false; + bool coopmat2_support = false; + bool pipeline_executable_properties_support = false; + device->coopmat_support = false; + device->integer_dot_product = false; + bool bfloat16_support = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { + maintenance4_support = true; + } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { + fp16_storage = true; + } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { + fp16_compute = true; + } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { + sm_builtins = true; + } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) { + amd_shader_core_properties2 = true; + } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) { + pipeline_robustness = true; + } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + device->subgroup_size_control = true; +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT")) { + device->coopmat_support = true; + device->coopmat_m = 0; + device->coopmat_n = 0; + device->coopmat_k = 0; +#endif +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2")) { + coopmat2_support = true; +#endif +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { + device->integer_dot_product = true; +#endif +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_BFLOAT16")) { + bfloat16_support = true; +#endif + } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) { + pipeline_executable_properties_support = true; + } + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceMaintenance3Properties props3; + vk::PhysicalDeviceMaintenance4Properties props4; + vk::PhysicalDeviceSubgroupProperties subgroup_props; + vk::PhysicalDeviceDriverProperties driver_props; + vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; + vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; + vk::PhysicalDeviceVulkan11Properties vk11_props; + vk::PhysicalDeviceVulkan12Properties vk12_props; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; + + props2.pNext = &props3; + props3.pNext = &subgroup_props; + subgroup_props.pNext = &driver_props; + driver_props.pNext = &vk11_props; + vk11_props.pNext = &vk12_props; + + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props; + + if (maintenance4_support) { + last_struct->pNext = (VkBaseOutStructure *)&props4; + last_struct = (VkBaseOutStructure *)&props4; + } + if (sm_builtins) { + last_struct->pNext = (VkBaseOutStructure *)&sm_props; + last_struct = (VkBaseOutStructure *)&sm_props; + } + if (amd_shader_core_properties2) { + last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props; + last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props; + } + if (device->subgroup_size_control) { + last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props; + last_struct = (VkBaseOutStructure *)&subgroup_size_control_props; + } + +#if defined(VK_NV_cooperative_matrix2) + vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props; + last_struct = (VkBaseOutStructure *)&coopmat2_props; + } +#endif + + if (device->integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; + } + + device->physical_device.getProperties2(&props2); + device->properties = props2.properties; + device->vendor_id = device->properties.vendorID; + device->driver_id = driver_props.driverID; + + const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); + + if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) { + device->max_memory_allocation_size = std::stoull(GGML_VK_FORCE_MAX_ALLOCATION_SIZE); + } else if (maintenance4_support) { + device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize); + } else { + device->max_memory_allocation_size = props3.maxMemoryAllocationSize; + } + + const char* GGML_VK_FORCE_MAX_BUFFER_SIZE = getenv("GGML_VK_FORCE_MAX_BUFFER_SIZE"); + + if (GGML_VK_FORCE_MAX_BUFFER_SIZE != nullptr) { + device->max_buffer_size = std::stoull(GGML_VK_FORCE_MAX_BUFFER_SIZE); + } else if (maintenance4_support) { + device->max_buffer_size = props4.maxBufferSize; + } else { + device->max_buffer_size = device->max_memory_allocation_size; + } + + const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv("GGML_VK_SUBALLOCATION_BLOCK_SIZE"); + + if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) { + device->suballocation_block_size = std::stoull(GGML_VK_SUBALLOCATION_BLOCK_SIZE); + } else { + // Limit batching of allocations to 1GB by default to avoid fragmentation issues + device->suballocation_block_size = 1024*1024*1024; + } + device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size); + + device->subgroup_size = subgroup_props.subgroupSize; + device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + if (sm_builtins) { + device->shader_core_count = sm_props.shaderSMCount; + } else if (amd_shader_core_properties2) { + device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount; + } else { + device->shader_core_count = 0; + } + device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + + device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); +#ifdef __APPLE__ + // Workaround for subgroup arithmetic failing on MoltenVK with AMD GPUs (issue 15846) + if (device->vendor_id == VK_VENDOR_ID_AMD) { + device->subgroup_arithmetic = false; + } +#endif + device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); + device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered); + + device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot); + + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; + + device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; + + if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) { + device->coopmat_support = false; + } + + device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + + std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); + + // Try to find a non-graphics compute queue and transfer-focused queues + const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1); + const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1); + + const float priorities[] = { 1.0f, 1.0f }; + device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1; + + std::vector device_queue_create_infos; + if (compute_queue_family_index != transfer_queue_family_index) { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1}); + } else if(!device->single_queue) { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities}); + } else { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); + } + vk::DeviceCreateInfo device_create_info; + std::vector device_extensions; + vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures(); + + VkPhysicalDeviceFeatures2 device_features2; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + device_features2.pNext = nullptr; + device_features2.features = (VkPhysicalDeviceFeatures)device_features; + + VkPhysicalDeviceVulkan11Features vk11_features; + vk11_features.pNext = nullptr; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + device_features2.pNext = &vk11_features; + + VkPhysicalDeviceVulkan12Features vk12_features; + vk12_features.pNext = nullptr; + vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; + vk11_features.pNext = &vk12_features; + + last_struct = (VkBaseOutStructure *)&vk12_features; + + VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features; + pl_robustness_features.pNext = nullptr; + pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT; + pl_robustness_features.pipelineRobustness = VK_FALSE; + + if (pipeline_robustness) { + last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features; + last_struct = (VkBaseOutStructure *)&pl_robustness_features; + device_extensions.push_back("VK_EXT_pipeline_robustness"); + } + + VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features; + subgroup_size_control_features.pNext = nullptr; + subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT; + subgroup_size_control_features.computeFullSubgroups = false; + subgroup_size_control_features.subgroupSizeControl = false; + + if (device->subgroup_size_control) { + last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features; + last_struct = (VkBaseOutStructure *)&subgroup_size_control_features; + } + +#if defined(VK_KHR_cooperative_matrix) + VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; + coopmat_features.pNext = nullptr; + coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + coopmat_features.cooperativeMatrix = VK_FALSE; + + if (device->coopmat_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; + last_struct = (VkBaseOutStructure *)&coopmat_features; + } +#endif + +#if defined(VK_NV_cooperative_matrix2) + VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; + coopmat2_features.pNext = nullptr; + coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features; + last_struct = (VkBaseOutStructure *)&coopmat2_features; + device_extensions.push_back("VK_NV_cooperative_matrix2"); + } +#endif + +#if defined(VK_KHR_shader_bfloat16) + VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; + bfloat16_features.pNext = nullptr; + bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR; + if (bfloat16_support) { + last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features; + last_struct = (VkBaseOutStructure *)&bfloat16_features; + device_extensions.push_back("VK_KHR_shader_bfloat16"); + } +#endif + + VkPhysicalDeviceMaintenance4Features maint4_features {}; + maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES; + if (maintenance4_support) { + last_struct->pNext = (VkBaseOutStructure *)&maint4_features; + last_struct = (VkBaseOutStructure *)&maint4_features; + device_extensions.push_back("VK_KHR_maintenance4"); + } + + VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {}; + shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR; + if (device->integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; + device_extensions.push_back("VK_KHR_shader_integer_dot_product"); + } + + VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {}; + pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR; + if (pipeline_executable_properties_support) { + last_struct->pNext = (VkBaseOutStructure *)&pep_features; + last_struct = (VkBaseOutStructure *)&pep_features; + device_extensions.push_back("VK_KHR_pipeline_executable_properties"); + } + + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); + + device->pipeline_executable_properties_support = pipeline_executable_properties_support; + + device->fp16 = device->fp16 && vk12_features.shaderFloat16; + +#if defined(VK_KHR_shader_bfloat16) + device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type; +#else + device->bf16 = false; +#endif + + device->pipeline_robustness = pl_robustness_features.pipelineRobustness; + + device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 && + device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) && + vk12_features.runtimeDescriptorArray && + device->vendor_id != VK_VENDOR_ID_INTEL && + getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr; + + device->shader_int64 = device_features2.features.shaderInt64; + device->buffer_device_address = vk12_features.bufferDeviceAddress; + + if (device->subgroup_size_control) { + device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; + device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize; + device_extensions.push_back("VK_EXT_subgroup_size_control"); + } + + device->subgroup_size_control = device->subgroup_size_control && + (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) && + subgroup_size_control_features.subgroupSizeControl; + + device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; + +#if defined(VK_KHR_cooperative_matrix) + device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; + + // coopmat1 fa shader currently assumes 32 invocations per subgroup + device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support && + device->subgroup_size_control && device->subgroup_min_size <= 32 && + device->subgroup_max_size >= 32; +#endif + + if (coopmat2_support) { +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (coopmat2_features.cooperativeMatrixWorkgroupScope && + coopmat2_features.cooperativeMatrixFlexibleDimensions && + coopmat2_features.cooperativeMatrixReductions && + coopmat2_features.cooperativeMatrixConversions && + coopmat2_features.cooperativeMatrixPerElementOperations && + coopmat2_features.cooperativeMatrixTensorAddressing && + coopmat2_features.cooperativeMatrixBlockLoads && + vk12_features.bufferDeviceAddress) { + + std::vector flexible_dimensions; + uint32_t count = 0; + + PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV = + (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV) + vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV"); + + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr); + + VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {}; + empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV; + flexible_dimensions.resize(count, empty_prop); + + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data()); + + bool found_fp16_128 = false, + found_fp16_256 = false, + found_fp32_128 = false, + found_fp32_256 = false; + // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 + // with 32x16x16 and 256 with 32x32x16. + for (auto &prop : flexible_dimensions) { + if (prop.saturatingAccumulation == VK_FALSE && + prop.scope == VK_SCOPE_WORKGROUP_KHR && + prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_128 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_128 = true; + } + } + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_256 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_256 = true; + } + } + } + } + if (found_fp16_128 && found_fp16_256 && + found_fp32_128 && found_fp32_256 && + coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { + device->coopmat2 = true; + } + } +#endif + } + + if (!vk11_features.storageBuffer16BitAccess) { + std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl; + throw std::runtime_error("Unsupported device"); + } + + device_extensions.push_back("VK_KHR_16bit_storage"); + +#ifdef GGML_VULKAN_VALIDATE + device_extensions.push_back("VK_KHR_shader_non_semantic_info"); +#endif + + if (device->fp16) { + device_extensions.push_back("VK_KHR_shader_float16_int8"); + } + +#if defined(VK_KHR_cooperative_matrix) + if (device->coopmat_support) { + // Query supported shapes + std::vector cm_props; + + PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR = + (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR"); + + uint32_t cm_props_num; + + pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr); + + cm_props.resize(cm_props_num); + + for (auto& prop : cm_props) { + prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR; + } + + pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data()); + + VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size()); + + for (auto& prop : cm_props) { + VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope)); + + if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 && + (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup + ) { + if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_acc_f32_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_acc_f32_support = true; + } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f32acc = true; + } + } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_acc_f16_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_acc_f16_support = true; + } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f16acc = true; + } + } + } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 && + (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 && + (vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup && + device->coopmat_int_m == 0 + ) { + device->coopmat_int_support = true; + device->coopmat_int_m = prop.MSize; + device->coopmat_int_n = prop.NSize; + device->coopmat_int_k = prop.KSize; + } +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup + ) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_bf16_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_bf16_support = true; + } + } +#endif + } + + if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) { + // No suitable matmul mode found + GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n"); + device->coopmat_support = false; + } + if (getenv("GGML_VK_DISABLE_BFLOAT16")) { + device->coopmat_bf16_support = false; + } + } + + if (device->coopmat_support) { + device_extensions.push_back("VK_KHR_cooperative_matrix"); + } +#if defined(VK_KHR_shader_bfloat16) + if (device->coopmat_bf16_support) { + device_extensions.push_back("VK_KHR_shader_bfloat16"); + } +#endif +#endif + device->name = GGML_VK_NAME + std::to_string(idx); + + device_create_info = { + vk::DeviceCreateFlags(), + device_queue_create_infos, + {}, + device_extensions + }; + device_create_info.setPNext(&device_features2); + device->device = device->physical_device.createDevice(device_create_info); + + // Queues + ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false); + + // Shaders + // Disable matmul tile sizes early if performance low or not supported + for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { + switch (device->vendor_id) { +#ifndef GGML_VULKAN_RUN_TESTS + case VK_VENDOR_ID_AMD: + case VK_VENDOR_ID_INTEL: + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; + case VK_VENDOR_ID_APPLE: + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = false; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = false; + break; +#endif + default: + device->mul_mat_l[i] = true; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_l[i] = true; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; + } + } + + + std::vector dsl_binding; + std::vector dsl_binding_flags; + for (uint32_t i = 0; i < MAX_PARAMETER_COUNT; i++) { + dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); + dsl_binding_flags.push_back({}); + } + + vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags }; + + vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( + {}, + dsl_binding); + descriptor_set_layout_create_info.setPNext(&dslbfci); + device->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info); + + ggml_vk_load_shaders(device); + + if (!device->single_queue) { + const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; + ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); + } else { + // TODO: Use pointer or reference to avoid copy + device->transfer_queue.copyFrom(device->compute_queue); + device->transfer_queue.cmd_pool.init(device, &device->transfer_queue); + } + + device->buffer_type = { + /* .iface = */ ggml_backend_vk_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx), + /* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device }, + }; + + device->fence = device->device.createFence({}); + + device->idx = idx; + + device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr; + + device->add_rms_fusion = !device->disable_fusion && + device->subgroup_arithmetic && + device->vendor_id != VK_VENDOR_ID_INTEL; + device->partials_binding_alignment = + std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment); + + device->mmvq_mode = 0; + if (getenv("GGML_VK_DISABLE_MMVQ")) { + device->mmvq_mode = -1; + } else if (getenv("GGML_VK_FORCE_MMVQ")) { + device->mmvq_mode = 1; + } + + return device; + } + + return vk_instance.devices[idx]; +} + +static void ggml_vk_print_gpu_info(size_t idx) { + GGML_ASSERT(idx < vk_instance.device_indices.size()); + size_t dev_num = vk_instance.device_indices[idx]; + VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")"); + GGML_ASSERT(vk_instance_initialized); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + if (dev_num >= devices.size()) { + std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; + throw std::runtime_error("Device not found"); + } + + vk::PhysicalDevice physical_device = devices[dev_num]; + std::vector ext_props = physical_device.enumerateDeviceExtensionProperties(); + + bool fp16_storage = false; + bool fp16_compute = false; + bool coopmat_support = false; + bool coopmat2_support = false; + bool integer_dot_product = false; + bool bfloat16_support = false; + + for (auto properties : ext_props) { + if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { + fp16_storage = true; + } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { + fp16_compute = true; +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT")) { + coopmat_support = true; +#endif +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2")) { + coopmat2_support = true; +#endif +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { + integer_dot_product = true; +#endif +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_BFLOAT16")) { + bfloat16_support = true; +#endif + } + } + + const vk_device_architecture device_architecture = get_device_architecture(physical_device); + + const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16"); + bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; + + bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute; + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceMaintenance3Properties props3; + vk::PhysicalDeviceSubgroupProperties subgroup_props; + vk::PhysicalDeviceDriverProperties driver_props; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; + props2.pNext = &props3; + props3.pNext = &subgroup_props; + subgroup_props.pNext = &driver_props; + + // Pointer to the last chain element + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props; + + if (integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; + } + + physical_device.getProperties2(&props2); + + VkPhysicalDeviceFeatures2 device_features2; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + device_features2.pNext = nullptr; + + VkPhysicalDeviceVulkan11Features vk11_features; + vk11_features.pNext = nullptr; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + device_features2.pNext = &vk11_features; + + VkPhysicalDeviceVulkan12Features vk12_features; + vk12_features.pNext = nullptr; + vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; + vk11_features.pNext = &vk12_features; + + // Pointer to the last chain element + last_struct = (VkBaseOutStructure *)&vk12_features; + +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; + coopmat_features.pNext = nullptr; + coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + coopmat_features.cooperativeMatrix = VK_FALSE; + + if (coopmat_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; + last_struct = (VkBaseOutStructure *)&coopmat_features; + } +#endif + + VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {}; + shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR; + if (integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; + } + +#if defined(VK_KHR_shader_bfloat16) + VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; + bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR; + if (bfloat16_support) { + last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features; + last_struct = (VkBaseOutStructure *)&bfloat16_features; + } +#endif + + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); + + fp16 = fp16 && vk12_features.shaderFloat16; + +#if defined(VK_KHR_shader_bfloat16) + bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type; +#else + bool bf16 = false; +#endif + + uint32_t default_subgroup_size = get_subgroup_size("", device_architecture); + const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize; + const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + + integer_dot_product = integer_dot_product + && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated + && shader_integer_dot_product_features.shaderIntegerDotProduct; + + coopmat_support = coopmat_support +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + && coopmat_features.cooperativeMatrix +#endif + && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture); + + std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; + + std::string device_name = props2.properties.deviceName.data(); + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size, + props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str()); + + if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { + GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); + } +} + +static bool ggml_vk_instance_validation_ext_available(); +static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); +static bool ggml_vk_instance_debug_utils_ext_available(const std::vector & instance_extensions); +static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev); + +static DispatchLoaderDynamic ggml_vk_default_dispatcher_instance; +DispatchLoaderDynamic & ggml_vk_default_dispatcher() { + return ggml_vk_default_dispatcher_instance; +} + +static void ggml_vk_instance_init() { + if (vk_instance_initialized) { + return; + } + VK_LOG_DEBUG("ggml_vk_instance_init()"); + + // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- + ggml_vk_default_dispatcher_instance.init(vkGetInstanceProcAddr); + + uint32_t api_version = vk::enumerateInstanceVersion(); + + if (api_version < VK_API_VERSION_1_2) { + std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl; + throw vk::SystemError(vk::Result::eErrorFeatureNotPresent, "Vulkan 1.2 required"); + } + + vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version }; + + const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); + const bool validation_ext = ggml_vk_instance_validation_ext_available(); +#ifdef __APPLE__ + const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); +#endif + const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr; + std::vector layers; + + if (validation_ext) { + layers.push_back("VK_LAYER_KHRONOS_validation"); + } + std::vector extensions; + if (validation_ext) { + extensions.push_back("VK_EXT_validation_features"); + } +#ifdef __APPLE__ + if (portability_enumeration_ext) { + extensions.push_back("VK_KHR_portability_enumeration"); + } +#endif + if (debug_utils_ext) { + extensions.push_back("VK_EXT_debug_utils"); + } + vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); +#ifdef __APPLE__ + if (portability_enumeration_ext) { + instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR; + } +#endif + + std::vector features_enable; + vk::ValidationFeaturesEXT validation_features; + + if (validation_ext) { + features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices }; + validation_features = { + features_enable, + {}, + }; + validation_features.setPNext(nullptr); + instance_create_info.setPNext(&validation_features); + GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n"); + } + vk_instance.instance = vk::createInstance(instance_create_info); + vk_instance_initialized = true; + + if (debug_utils_ext) { + vk_instance.debug_utils_support = true; + vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkSetDebugUtilsObjectNameEXT"); + vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueBeginDebugUtilsLabelEXT"); + vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueEndDebugUtilsLabelEXT"); + vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT"); + vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT"); + vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT"); + } + + vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; + + // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- + VULKAN_HPP_DEFAULT_DISPATCHER.init(vk_instance.instance); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan + char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES"); + if (devices_env != nullptr) { + size_t num_available_devices = devices.size(); + + std::string devices(devices_env); + std::replace(devices.begin(), devices.end(), ',', ' '); + + std::stringstream ss(devices); + size_t tmp; + while (ss >> tmp) { + if(tmp >= num_available_devices) { + std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl; + throw std::runtime_error("Invalid Vulkan device index"); + } + vk_instance.device_indices.push_back(tmp); + } + } else { + // If no vulkan devices are found, return early + if (devices.empty()) { + GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); + return; + } + + // Default to using all dedicated GPUs + for (size_t i = 0; i < devices.size(); i++) { + vk::PhysicalDeviceProperties2 new_props; + vk::PhysicalDeviceDriverProperties new_driver; + vk::PhysicalDeviceIDProperties new_id; + new_props.pNext = &new_driver; + new_driver.pNext = &new_id; + devices[i].getProperties2(&new_props); + + if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) { + // Check if there are two physical devices corresponding to the same GPU + auto old_device = std::find_if( + vk_instance.device_indices.begin(), + vk_instance.device_indices.end(), + [&devices, &new_id](const size_t k){ + vk::PhysicalDeviceProperties2 old_props; + vk::PhysicalDeviceIDProperties old_id; + old_props.pNext = &old_id; + devices[k].getProperties2(&old_props); + return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); + } + ); + if (old_device == vk_instance.device_indices.end()) { + vk_instance.device_indices.push_back(i); + } else { + // There can be two physical devices corresponding to the same GPU if there are 2 different drivers + // This can cause error when splitting layers aross the devices, need to keep only 1 + VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID"); + + vk::PhysicalDeviceProperties2 old_props; + vk::PhysicalDeviceDriverProperties old_driver; + old_props.pNext = &old_driver; + devices[*old_device].getProperties2(&old_props); + + std::map driver_priorities {}; + int old_priority = std::numeric_limits::max(); + int new_priority = std::numeric_limits::max(); + + // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id + // Smaller number -> higher priority + switch (old_props.properties.vendorID) { + case VK_VENDOR_ID_AMD: + driver_priorities[vk::DriverId::eMesaRadv] = 1; + driver_priorities[vk::DriverId::eAmdOpenSource] = 2; + driver_priorities[vk::DriverId::eAmdProprietary] = 3; + break; + case VK_VENDOR_ID_INTEL: + driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1; + driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2; + break; + case VK_VENDOR_ID_NVIDIA: + driver_priorities[vk::DriverId::eNvidiaProprietary] = 1; +#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235 + driver_priorities[vk::DriverId::eMesaNvk] = 2; +#endif + break; + } + + if (driver_priorities.count(old_driver.driverID)) { + old_priority = driver_priorities[old_driver.driverID]; + } + if (driver_priorities.count(new_driver.driverID)) { + new_priority = driver_priorities[new_driver.driverID]; + } + + if (new_priority < old_priority) { + auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device); + vk_instance.device_indices.erase(r, vk_instance.device_indices.end()); + vk_instance.device_indices.push_back(i); + + VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName); + } + else { + VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl); + } + } + } + } + + // If no GPUs found, fall back to the first non-CPU device. + // If only CPU devices are available, return without devices. + if (vk_instance.device_indices.empty()) { + for (size_t i = 0; i < devices.size(); i++) { + if (devices[i].getProperties().deviceType != vk::PhysicalDeviceType::eCpu) { + vk_instance.device_indices.push_back(i); + break; + } + } + } + + if (vk_instance.device_indices.empty()) { + GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); + return; + } + } + GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size()); + + for (size_t i = 0; i < vk_instance.device_indices.size(); i++) { + vk::PhysicalDevice vkdev = devices[vk_instance.device_indices[i]]; + std::vector extensionprops = vkdev.enumerateDeviceExtensionProperties(); + + bool membudget_supported = false; + for (const auto & ext : extensionprops) { + if (strcmp(VK_EXT_MEMORY_BUDGET_EXTENSION_NAME, ext.extensionName) == 0) { + membudget_supported = true; + break; + } + } + + vk_instance.device_supports_membudget.push_back(membudget_supported); + + ggml_vk_print_gpu_info(i); + } +} + +static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { + VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")"); + ggml_vk_instance_init(); + GGML_ASSERT(idx < vk_instance.device_indices.size()); + + ctx->name = GGML_VK_NAME + std::to_string(idx); + + ctx->device = ggml_vk_get_device(idx); + + ctx->semaphore_idx = 0; + ctx->event_idx = 0; + + ctx->prealloc_size_x = 0; + ctx->prealloc_size_y = 0; + ctx->prealloc_size_split_k = 0; + + ctx->fence = ctx->device->device.createFence({}); + ctx->almost_ready_fence = ctx->device->device.createFence({}); + + ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue); + ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); + +#ifdef GGML_VULKAN_CHECK_RESULTS + const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); + vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks)); + const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR"); + vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor)); +#endif +} + +static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) { + VK_LOG_DEBUG("ggml_vk_get_to_fp16()"); + switch (type) { + case GGML_TYPE_F32: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + break; + default: + return nullptr; + } + + return ctx->device->pipeline_dequant[type]; +} + +static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { + VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ", " << prec << ")"); + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f32; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f32_f16; + } + if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) { + return ctx->device->pipeline_matmul_bf16; + } + if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f16_f32.f16acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f16.f16acc; + } + } else { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f16_f32.f32acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f16.f32acc; + } + } + + // MMQ + if (src1_type == GGML_TYPE_Q8_1) { + vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; + + if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { + return nullptr; + } + + return pipelines; + } + + if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) { + return nullptr; + } + + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + break; + default: + return nullptr; + } + + if (ctx->device->coopmat2) { + assert(src1_type == GGML_TYPE_F16); + return prec == GGML_PREC_DEFAULT ? ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f32acc; + } + if (ctx->device->coopmat_support) { + return (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; + } + return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; +} + +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols, uint32_t m, uint32_t k) { + VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); + GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16 || b_type == GGML_TYPE_Q8_1); + GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols); + + if (b_type == GGML_TYPE_Q8_1) { + switch (a_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + break; + default: + return nullptr; + } + } + + switch (a_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + break; + default: + return nullptr; + } + + // heuristic to choose workgroup size + uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + // Prefer larger workgroups when M is small, to spread the work out more + // and keep more SMs busy. + // q6_k seems to prefer small workgroup size even for "medium" values of M. + if (a_type == GGML_TYPE_Q6_K) { + if (m < 4096 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } else { + if (m <= 8192 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } + } + + if (b_type == GGML_TYPE_Q8_1) { + if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + } + return ctx->device->pipeline_dequant_mul_mat_vec_q8_1_f32[dmmv_wg][a_type][num_cols-1]; + } + + return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[dmmv_wg][a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[dmmv_wg][a_type][num_cols-1]; +} + +static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { + VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()"); + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f32; + } + if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) { + return ctx->device->pipeline_matmul_id_bf16; + } + if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f16_f32.f16acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_id_f16.f16acc; + } + } else { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f16_f32.f32acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_id_f16.f32acc; + } + } + + GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); + + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + break; + default: + return nullptr; + } + + // XXX TODO 'prec' is not actually allowed in mul_mat_id. + bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/; + bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr; + bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr; + + if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) { + return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc; + } else { + GGML_ASSERT(support_fp32acc); + return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc; + } +} + +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { + VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()"); + GGML_ASSERT(b_type == GGML_TYPE_F32); + + switch (a_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + break; + default: + return nullptr; + } + + return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; +} + +static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) { + VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")"); + VK_LOG_MEMORY("ggml_vk_pool_malloc"); + + int best_i = -1; + size_t best_size = std::numeric_limits::max(); //smallest unused buffer that fits our needs + int worst_i = -1; + size_t worst_size = 0; //largest unused buffer seen so far + for (int i = 0; i < MAX_VK_BUFFERS; ++i) { + vk_buffer &b = ctx->buffer_pool[i]; + if (b != nullptr && b->size >= size && b->size < best_size) { + best_i = i; + best_size = b->size; + } + if (b != nullptr && b->size > worst_size) { + worst_i = i; + worst_size = b->size; + } + } + if(best_i != -1) { + //found the smallest buffer that fits our needs + vk_buffer b = ctx->buffer_pool[best_i]; + ctx->buffer_pool[best_i].reset(); + return b; + } + if(worst_i != -1) { + //no buffer that fits our needs, resize largest one to save memory + vk_buffer& b = ctx->buffer_pool[worst_i]; + ggml_vk_destroy_buffer(b); + } + + return ggml_vk_create_buffer_device(ctx->device, size); +} + +static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) { + VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")"); + for (int i = 0; i < MAX_VK_BUFFERS; ++i) { + vk_buffer& b = ctx->buffer_pool[i]; + if (b == nullptr) { + b = buffer; + return; + } + } + std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl; + ggml_vk_destroy_buffer(buffer); +} + +// Returns an available temporary buffer that may only be used temporarily, it will be reused +static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) { + // Try to find existing temp buffer with enough capacity + for (auto& buffer : ctx->gc.temp_buffers) { + if (buffer->size >= size) { + return buffer; + } + } + + VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")"); + + // Otherwise create new buffer + vk_buffer buf = ggml_vk_pool_malloc(ctx, size); + ctx->gc.temp_buffers.push_back(buf); + + return buf; +} + +static void * ggml_vk_host_malloc(vk_device& device, size_t size) { + VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")"); + vk_buffer buf = ggml_vk_create_buffer(device, size, + {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); + + if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) { + fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n", + size/1024.0/1024.0); + device->device.freeMemory(buf->device_memory); + device->device.destroyBuffer(buf->buffer); + return nullptr; + } + + std::lock_guard guard(device->mutex); + device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf)); + + return buf->ptr; +} + +static void ggml_vk_host_free(vk_device& device, void* ptr) { + if (ptr == nullptr) { + return; + } + VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")"); + std::lock_guard guard(device->mutex); + + vk_buffer buf; + size_t index; + for (size_t i = 0; i < device->pinned_memory.size(); i++) { + const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]); + const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]); + if (ptr >= addr && ptr < endr) { + buf = std::get<2>(device->pinned_memory[i]); + index = i; + break; + } + } + if (buf == nullptr) { + fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n"); + return; + } + + ggml_vk_destroy_buffer(buf); + + device->pinned_memory.erase(device->pinned_memory.begin() + index); +} + +static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { + std::lock_guard guard(device->mutex); + buf = nullptr; + buf_offset = 0; + for (size_t i = 0; i < device->pinned_memory.size(); i++) { + const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]); + const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]); + if (ptr >= addr && ptr < endr) { + buf = std::get<2>(device->pinned_memory[i]); + buf_offset = ((const uint8_t *)ptr) - addr; + break; + } + } +} + +static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) { + vk_submission s; + s.buffer = ggml_vk_create_cmd_buffer(device, p); + if (one_time) { + s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); + } else { + s.buffer.begin({ vk::CommandBufferUsageFlags{} }); + } + + return s; +} + +template size_t push_constant_size(const T &t) { + static_assert(std::is_class::value, "T must be a struct/class"); + GGML_UNUSED(t); + return sizeof(T); +} +template size_t push_constant_size(const std::vector &t) { + GGML_UNUSED(t); + return sizeof(T) * t.size(); +} +template size_t push_constant_size(const std::array &t) { + GGML_UNUSED(t); + return sizeof(T) * N; +} + +template const T *push_constant_data(const T &t) { + static_assert(std::is_class::value, "T must be a struct/class"); + return &t; +} +template const T *push_constant_data(const std::vector &t) { + return t.data(); +} +template const T *push_constant_data(const std::array &t) { + return t.data(); +} + +template +static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, const T &push_constants, std::array elements) { + const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); + const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); + const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]); + VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {"; + for (auto& buffer : descriptor_buffer_infos) { + std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; + } + std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); + GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); + GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); + GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size()); + + vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++]; + vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; + ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {}); + + subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants)); + subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline); + subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, + pipeline->layout, + 0, + { descriptor_set }, + {}); + subctx->s->buffer.dispatch(wg0, wg1, wg2); +} + +static void ggml_vk_end_submission(vk_submission& s, std::vector wait_semaphores, std::vector signal_semaphores) { + s.buffer.end(); + + s.wait_semaphores = std::move(wait_semaphores); + s.signal_semaphores = std::move(signal_semaphores); +} + +static void ggml_vk_ctx_end(vk_context& ctx) { + VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")"); + if (ctx->s == nullptr) { + return; + } + + ctx->s->buffer.end(); + ctx->s = nullptr; +} + +static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { + VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")"); + if (subctx->s != nullptr) { + ggml_vk_ctx_end(subctx); + } + + subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->p) }); + subctx->s = subctx->seqs[subctx->seqs.size() - 1].data(); +} + +static size_t ggml_vk_align_size(size_t width, size_t align) { + VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")"); + return CEIL_DIV(width, align) * align; +} + +static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector* memcpys = nullptr) { + if (memcpys == nullptr) { + memcpy(dst, src, size); + } else { + memcpys->emplace_back(dst, src, size); + } +} + +static void deferred_memset(void * dst, uint32_t val, size_t size, std::vector* memsets = nullptr) { + if (memsets == nullptr) { + memset(dst, val, size); + } else { + memsets->emplace_back(dst, val, size); + } +} + +static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) { + if (device->sync_staging == nullptr || device->sync_staging->size < size) { + VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")"); + ggml_vk_destroy_buffer(device->sync_staging); + device->sync_staging = ggml_vk_create_buffer_check(device, size, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + } +} + +static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")"); + GGML_ASSERT(!ggml_is_contiguous(tensor)); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl; + GGML_ABORT("fatal error"); + } + // Check if src is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset); + + const uint64_t ne0 = tensor->ne[0]; + const uint64_t ne1 = tensor->ne[1]; + const uint64_t ne2 = tensor->ne[2]; + const uint64_t ne3 = tensor->ne[3]; + const uint64_t nb0 = tensor->nb[0]; + const uint64_t nb1 = tensor->nb[1]; + const uint64_t nb2 = tensor->nb[2]; + const uint64_t nb3 = tensor->nb[3]; + const ggml_type type = tensor->type; + const uint64_t ts = ggml_type_size(type); + const uint64_t bs = ggml_blck_size(type); + + const uint64_t dstnb0 = ts; + const uint64_t dstnb1 = dstnb0*(ne0/bs); + const uint64_t dstnb2 = dstnb1*ne1; + const uint64_t dstnb3 = dstnb2*ne2; + + const uint64_t ne = ggml_nelements(tensor); + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + std::vector slices; + + for (uint64_t i3 = 0; i3 < ne3; i3++) { + for (uint64_t i2 = 0; i2 < ne2; i2++) { + // Find longest contiguous slice + if (ne1*nb1 == dstnb2) { + slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 }); + } else { + for (uint64_t i1 = 0; i1 < ne1; i1++) { + if (ne0*nb0/bs == dstnb1) { + slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 }); + } else { + const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; + const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1; + for (uint64_t i0 = 0; i0 < ne0; i0++) { + slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 }); + } + } + } + } + } + } + + ggml_vk_sync_buffers(ctx, subctx); + subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); + return; + } + + if (!sync_staging) { + GGML_ABORT("Asynchronous write to non-pinned memory not supported"); + } + + // Staging buffer required + vk_buffer& staging = ctx->device->sync_staging; + const uint64_t copy_size = ts*ne/bs; + ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size); + VkBufferCopy buf_copy{ 0, offset, copy_size }; + + ggml_vk_sync_buffers(ctx, subctx); + vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + + for (uint64_t i3 = 0; i3 < ne3; i3++) { + for (uint64_t i2 = 0; i2 < ne2; i2++) { + // Find longest contiguous slice + if (ne1*nb1 == dstnb2) { + deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys); + } else { + for (uint64_t i1 = 0; i1 < ne1; i1++) { + if (ne0*nb0/bs == dstnb1) { + deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys); + } else { + const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; + const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1; + for (uint64_t i0 = 0; i0 < ne0; i0++) { + deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys); + } + } + } + } + } + } +} + +static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")"); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl; + GGML_ABORT("fatal error"); + } + // Check if src is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(dst->device, src, buf, buf_offset); + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + std::vector slices(1); + if (width == spitch) { + // Only do single write if stride is equal + slices[0].srcOffset = buf_offset; + slices[0].dstOffset = offset; + slices[0].size = width * height; + } else { + slices.resize(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = buf_offset + i * spitch; + slices[i].dstOffset = offset + i * width; + slices[i].size = width; + } + } + + ggml_vk_sync_buffers(nullptr, subctx); + subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); + return; + } + VK_LOG_DEBUG("STAGING"); + + if (!sync_staging) { + GGML_ABORT("Asynchronous write to non-pinned memory not supported"); + } + + // Staging buffer required + const size_t copy_size = width*height; + ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size); + + vk_buffer& staging_buffer = dst->device->sync_staging; + + VkBufferCopy buf_copy = { + 0, + offset, + copy_size}; + + ggml_vk_sync_buffers(nullptr, subctx); + vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + + if (width == spitch) { + deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys); + } else { + for (size_t i = 0; i < height; i++) { + deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys); + } + } +} + +static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")"); + return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging); +} + +static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) { + VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")"); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); + + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); + } + } else { + std::lock_guard guard(dst->device->mutex); + + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); + ggml_vk_ctx_begin(dst->device, subctx); + ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); + ggml_vk_ctx_end(subctx); + + for (auto& cpy : subctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + for (auto& mset : subctx->memsets) { + memset(mset.dst, mset.val, mset.n); + } + + ggml_vk_submit(subctx, dst->device->fence); + VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences"); + dst->device->device.resetFences({ dst->device->fence }); + ggml_vk_queue_command_pools_cleanup(dst->device); + } +} + +static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")"); + ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1); +} + +static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")"); + GGML_ASSERT(width > 0); + GGML_ASSERT(height > 0); + GGML_ASSERT(src != nullptr); + + // TODO: staging_offset is not used + + // Check if dst is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(src->device, dst, buf, buf_offset); + + std::vector slices(1); + if (width == spitch && width == dpitch) { + // Only do single write if stride is equal + slices[0].srcOffset = offset; + slices[0].dstOffset = buf_offset; + slices[0].size = width * height; + } else { + slices.resize(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = offset + i * spitch; + slices[i].dstOffset = buf_offset + i * dpitch; + slices[i].size = width; + } + } + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + ggml_vk_sync_buffers(nullptr, subctx); + subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices); + + return; + } + VK_LOG_DEBUG("STAGING"); + + if (!sync_staging) { + GGML_ABORT("Asynchronous read from non-pinned memory not supported"); + } + + // Fall back to staging buffer + const size_t copy_size = dpitch * height; + ggml_vk_ensure_sync_staging_buffer(src->device, copy_size); + + vk_buffer& staging_buffer = src->device->sync_staging; + + ggml_vk_sync_buffers(nullptr, subctx); + subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices); + + deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); +} + +static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) { + return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging); +} + +static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); + + // If the device is not an UMA device the memory is host-accessible through rebar. While writing + // through PCIe is sufficient fast reading back data from PCIe is slower than going through + // the HW device to host copy path. + if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { + GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); + + memcpy(dst, (uint8_t *) src->ptr + offset, size); + } else { + std::lock_guard guard(src->device->mutex); + + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); + ggml_vk_ctx_begin(src->device, subctx); + ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); + ggml_vk_ctx_end(subctx); + + ggml_vk_submit(subctx, src->device->fence); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences"); + src->device->device.resetFences({ src->device->fence }); + ggml_vk_queue_command_pools_cleanup(src->device); + + for (auto& cpy : subctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + } +} + +static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")"); + // Make sure both buffers are on same device + GGML_ASSERT(src->device == dst->device); + + VkBufferCopy bc{ src_offset, dst_offset, size }; + + vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc); +} + +static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { + if (src->device == dst->device) { + std::lock_guard guard(src->device->mutex); + VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")"); + // Copy within the device + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); + ggml_vk_ctx_begin(src->device, subctx); + ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size); + ggml_vk_ctx_end(subctx); + ggml_vk_submit(subctx, src->device->fence); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences"); + src->device->device.resetFences({ src->device->fence }); + ggml_vk_queue_command_pools_cleanup(src->device); + } else { + VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")"); + // Copy device to device + ggml_vk_ensure_sync_staging_buffer(src->device, size); + ggml_vk_ensure_sync_staging_buffer(dst->device, size); + + // Copy to src staging buffer + ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size); + // memcpy to dst staging buffer + memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size); + // Copy to dst buffer + ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size); + } +} + +static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")"); + + if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && + dst->device->uma) { + deferred_memset((uint8_t*)dst->ptr + offset, c, size, &ctx->memsets); + return; + } + + // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers + ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); +} + +static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); + + if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && + dst->device->uma) { + memset((uint8_t*)dst->ptr + offset, c, size); + return; + } + + std::lock_guard guard(dst->device->mutex); + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); + ggml_vk_ctx_begin(dst->device, subctx); + subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); + ggml_vk_ctx_end(subctx); + + ggml_vk_submit(subctx, dst->device->fence); + VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences"); + dst->device->device.resetFences({ dst->device->fence }); + ggml_vk_queue_command_pools_cleanup(dst->device); +} + +static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) { + VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")"); + + if (disable_split_k) { + return 1; + } + + uint32_t split_k = 1; + if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) { + // If k is 'large' and the SMs will fill less than halfway, use split_k. + uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]); + uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]); + + if (k >= 2048) { + if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) { + split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); + } else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) { + split_k = 3; + } + // Cap the split at 8x. Unless k is huge this is a lot of overhead. + split_k = std::min(split_k, 8u); + + // ggml_vk_matmul will align the splits to be a multiple of 256. + // If this rounded up size would cause the last split to be empty, + // then reduce the split count. + while (true) { + if (split_k == 1) { + break; + } + uint32_t k_split = CEIL_DIV(k, split_k); + k_split = ROUNDUP_POW2(k_split, 256); + if (k_split * (split_k - 1) < k) { + break; + } + split_k--; + } + } + } + + return split_k; +} + +static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + + if (ctx->device->coopmat2) { + const uint32_t shader_core_count = ctx->device->shader_core_count; + const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]); + const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]); + + // Use large shader when the N dimension is greater than the medium shader's tile size + uint32_t crossover_large = mmp->m->wg_denoms[1]; + + // Prefer large over medium if either: + // - medium or large tiles would overfill the GPU + // - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not + // (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead) + bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count || + // split_k==3 with large tiles likely better than medium tiles with no split_k. + (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2); + + if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { + return aligned ? mmp->a_l : mmp->l; + } + // Use medium shader when the N dimension is greater than the small shader's tile size + uint32_t crossover_medium = mmp->s->wg_denoms[1]; + if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_s : mmp->s; + } + + if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { + return aligned ? mmp->a_s : mmp->s; + } + if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_l : mmp->l; + + GGML_UNUSED(src1_type); +} + +static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align; +} + +static void ggml_vk_matmul( + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, + vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, + uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, + uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, + uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, + uint32_t padded_n) { + VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); + if (split_k == 1) { + const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch }); + return; + } + + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + GGML_ASSERT(batch_stride_d == m * n); + + // Round the split size up to a multiple of 256 (k-quant alignment) + uint32_t k_split = CEIL_DIV(k, split_k); + k_split = ROUNDUP_POW2(k_split, 256); + + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; + // Make sure enough workgroups get assigned for split k to work + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); + ggml_vk_sync_buffers(ctx, subctx); + const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 }); + ctx->prealloc_split_k_need_sync = true; +} + +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); + + if (ctx->device->coopmat2) { + // Use large shader when the N dimension is greater than the medium shader's tile size + uint32_t crossover_large = mmp->m->wg_denoms[1]; + if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { + return aligned ? mmp->a_l : mmp->l; + } + // Use medium shader when the N dimension is greater than the small shader's tile size + uint32_t crossover_medium = mmp->s->wg_denoms[1]; + if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_s : mmp->s; + } + + if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { + return aligned ? mmp->a_s : mmp->s; + } + if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_l : mmp->l; +} + +static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); + return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align; +} + +static void ggml_vk_matmul_id( + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, + vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, + uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, + uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, + uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11, + uint32_t padded_n) { + VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " << + "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << + "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << + "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); + const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, + nei0, nei1, nbi1, ne11, padded_n }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as }); +} + +static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && + (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]); +} + +static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) { + + // Choose "contiguous copy" shader if src/dst are contiguous + bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst)); + + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f32; + } else { + return ctx->device->pipeline_cpy_f32_f32; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f16; + } else { + return ctx->device->pipeline_cpy_f32_f16; + } + } + if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f16; + } else { + return ctx->device->pipeline_cpy_f16_f16; + } + } + if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f32; + } else { + return ctx->device->pipeline_cpy_f16_f32; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_bf16; + } else { + return ctx->device->pipeline_cpy_f32_bf16; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_i32; + } else { + return ctx->device->pipeline_cpy_f32_i32; + } + } + if (src->type == GGML_TYPE_I32 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_i32_f32; + } else { + return ctx->device->pipeline_cpy_i32_f32; + } + } + if (src->type == GGML_TYPE_F32) { + switch (to) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return ctx->device->pipeline_cpy_f32_quant[to]; + default: + break; + } + } + + if (to == GGML_TYPE_F32) { + switch (src->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return ctx->device->pipeline_cpy_quant_f32[src->type]; + default: + break; + } + } + + if (src->type == to) { + // Copy two or four bytes at a time, depending on block size. + // For quantized types, we scale by block size/type size. But + // this path is also used for bf16->bf16 for example, where the + // type size must be exactly 2 or 4. + GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4); + if ((ggml_type_size(src->type) % 4) == 0) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f32; + } else { + return ctx->device->pipeline_cpy_f32_f32; + } + } else { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f16; + } else { + return ctx->device->pipeline_cpy_f16_f16; + } + } + } + + std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; + GGML_ABORT("fatal error"); +} + +static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) { + VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), "; + std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")"); + const int tensor_type_size = ggml_type_size(tensor->type); + + const uint32_t ne = ggml_nelements(tensor); + std::array elements; + + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + + vk_op_unary_push_constants pc = { + (uint32_t)ne, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]), + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }; + init_pushconst_fastdiv(pc); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); +} + +static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type, bool use_x4_blocks) { + switch(type) { + case GGML_TYPE_Q8_1: + return use_x4_blocks ? ctx->device->pipeline_quantize_q8_1_x4 : ctx->device->pipeline_quantize_q8_1; + default: + std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl; + GGML_ABORT("fatal error"); + } +} + +static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne, bool use_x4_blocks = false) { + VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")"); + + vk_pipeline pipeline = use_x4_blocks ? ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true) : ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false); + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 }); + ggml_vk_sync_buffers(ctx, subctx); +} + +static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t ne21 = dst->ne[1]; + const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type); + const uint32_t stride_batch_d = stride_d*ne21; + + const uint64_t r2 = ne12 / ne02; + const uint64_t r3 = ne13 / ne03; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + } + + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || + !ggml_vk_dim01_contiguous(src1); + + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; + + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0; + + // Check for mmq first + vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr; + + if (mmp == nullptr) { + // Fall back to f16 dequant mul mat + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); + quantize_y = false; + } + + const bool qx_needs_dequant = mmp == nullptr || x_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig); + + if (qx_needs_dequant) { + // Fall back to dequant + f16 mulmat + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]); + } + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); + const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; + + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); + + // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking + uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11; + const int x_ne = ne01 * ne00; + const int y_ne = padded_n * ne10; + const int d_ne = ne11 * ne01; + + const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline); + + const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; + const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + vk_pipeline to_q8_1 = nullptr; + + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); + } else { + to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true); + } + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + uint64_t y_sz_upd = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; + } + const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (split_k > 1 && split_k_size > ctx->device->properties.limits.maxStorageBufferRange)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) { + ctx->prealloc_size_split_k = split_k_size; + } + + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); + } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); + } + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1); + } + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if (!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); + } else if (quantize_y) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig || qx_needs_dequant) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + + if (x_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); + } else if (qx_needs_dequant) { + const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; + ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + ggml_vk_sync_buffers(ctx, subctx); + } + if (y_non_contig) { + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } + if (quantize_y) { + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } + + uint32_t stride_batch_x = ne00*ne01; + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + uint32_t y_sz_total = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; + } + + // compute + ggml_vk_matmul( + ctx, subctx, pipeline, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total }, + ggml_vk_subbuffer(ctx, d_D, d_buf_offset), { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, + ne01, ne11, ne10, + ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d, + split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n + ); // NOLINT + + if (x_non_contig || qx_needs_dequant) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig || quantize_y) { + ctx->prealloc_y_need_sync = true; + } +} + +// Device tuning +static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_t n, uint32_t k, ggml_type src0_type) { + if (device->mmvq_mode == 1) { + return true; + } else if (device->mmvq_mode == -1) { + return false; + } + + // MMVQ is generally good for batches + if (n > 1) { + return true; + } + + switch (device->vendor_id) { + case VK_VENDOR_ID_NVIDIA: + switch (src0_type) { + case GGML_TYPE_Q8_0: + return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; + default: + return true; + } + case VK_VENDOR_ID_AMD: + switch (src0_type) { + case GGML_TYPE_Q8_0: + return device->architecture == vk_device_architecture::AMD_GCN; + default: + return true; + } + case VK_VENDOR_ID_INTEL: + switch (src0_type) { + // From tests on A770 Linux, may need more tuning + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q5_1: + return false; + default: + return true; + } + default: + return true; + } + + GGML_UNUSED(m); + GGML_UNUSED(k); +} + +static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + const uint64_t r2 = ne12 / ne02; + const uint64_t r3 = ne13 / ne03; + + // batch_n indicates that we need to compute a few vector results, and this assumes + // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides. + GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1); + bool batch_n = ne11 > 1; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + + const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type); + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + + // Check for mmq first + vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11, ne20, ne00) : nullptr; + vk_pipeline to_q8_1 = nullptr; + + if (dmmv == nullptr) { + // Fall back to f16 dequant mul mat + dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11, ne20, ne00); + quantize_y = false; + } + + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true); + } + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + GGML_ASSERT(dmmv != nullptr); + + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne11 * ne01; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); + const uint64_t d_sz = sizeof(float) * d_ne; + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + uint64_t y_sz_upd = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; + } + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); + } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); + } + ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if(!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if(!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + } else if (quantize_y) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); + } + if (y_non_contig) { + GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } + if (quantize_y) { + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } + + // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride + uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01; + uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11); + uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21); + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + uint32_t groups_x = ne01; + uint32_t groups_z = 1; + + if (ne01 > max_groups_x) { + groups_z = 64; + groups_x = CEIL_DIV(groups_x, groups_z); + } + + // TODO: Clean up this whole sz * ne_2 * ne_3 thing, it hasn't been necessary for a long time + uint32_t y_sz_total = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; + } + + // compute + const vk_mat_vec_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + stride_batch_x, stride_batch_y, stride_batch_d, + (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, + }; + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz_total }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, + pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); + + if (x_non_contig) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig || quantize_y) { + ctx->prealloc_y_need_sync = true; + } +} + +static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); + GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT + GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + // const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + // const uint64_t ne13 = src1->ne[3]; + + GGML_ASSERT(ne11 == 1); + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src1_uma = d_Qy != nullptr; + } + + const uint64_t x_ne = ne00 * ne01 * ne02; + const uint64_t y_ne = ne10 * ne11 * ne12; + const uint64_t d_ne = ne01 * ne11 * ne12; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t d_sz = sizeof(float) * d_ne; + + // With grouped query attention there are > 1 Q matrices per K, V matrix. + uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02; + if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) { + gqa_ratio = 1; + } + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_Qx = src0_buf_ctx->dev_buffer; + const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + + const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; + + const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; + + // compute + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + + uint32_t workgroups_z = (uint32_t)ne12; + // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups + if (gqa_ratio > 1) { + workgroups_z /= gqa_ratio; + } + + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z }); +} + +static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(!ggml_is_permuted(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t nb01 = src0->nb[1]; + const uint64_t nb02 = src0->nb[2]; + + const uint64_t nb12 = src1->nb[2]; + + // const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + // const uint64_t ne13 = src1->ne[3]; + + const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t)); + const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float)); + const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float)); + + GGML_ASSERT(ne11 == 1); + GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src1_uma = d_Qy != nullptr; + } + + const uint64_t d_ne = ne01 * ne11 * ne12 * ne03; + + const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t); + const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); + const uint32_t channel_stride_y = nb12 / sizeof(float); + + const uint64_t qx_sz = ggml_nbytes(src0); + const uint64_t qy_sz = ggml_nbytes(src1); + const uint64_t d_sz = sizeof(float) * d_ne; + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_Qx = src0_buf_ctx->dev_buffer; + const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + + const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; + + const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; + + // compute + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), nb03, nb13, nb23 }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); +} + +static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")"); + + // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases + // where the M dimension is very large. + // Split_k doesn't work with M splitting. + const size_t nbytes = ggml_nbytes(src0); + const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange; + if (needs_split) { + // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets) + const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]); + uint32_t m_offset = 0; + while (m_offset < dst->ne[0]) { + const uint32_t cur_M_size = std::min(M_split, (uint32_t)(dst->ne[0] - m_offset)); + ggml_tensor dst2 = *dst; + ggml_tensor src02 = *src0; + + dst2.view_src = dst->view_src ? dst->view_src : dst; + src02.view_src = src0->view_src ? src0->view_src : src0; + + dst2.view_offs += m_offset * dst->nb[0]; + src02.view_offs += m_offset * src0->nb[1]; + dst2.ne[0] = cur_M_size; + src02.ne[1] = cur_M_size; + + ggml_vk_mul_mat_q_f16(ctx, subctx, &src02, src1, &dst2, true, dryrun); + + m_offset += cur_M_size; + } + } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && + // detect 0213 permutation, and batch size of 1 + src0->nb[0] <= src0->nb[2] && + src0->nb[2] <= src0->nb[1] && + src0->nb[1] <= src0->nb[3] && + src1->nb[0] <= src1->nb[2] && + src1->nb[2] <= src1->nb[1] && + src1->nb[1] <= src1->nb[3] && + src0->ne[3] == 1 && + src1->ne[3] == 1) { + ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun); + } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 && + !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { + ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun); + // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) + // when ne12 and ne13 are one. + } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) { + ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); + } else { + ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, false, dryrun); + } +} + +static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t nei0 = ids->ne[0]; + const uint64_t nei1 = ids->ne[1]; + + const uint32_t nbi1 = ids->nb[1]; + const uint32_t nbi2 = ids->nb[2]; + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + const uint64_t n_as = ne02; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + vk_buffer d_ids = nullptr; + size_t ids_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool ids_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + ids_uma = d_ids != nullptr; + } + + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || + !ggml_vk_dim01_contiguous(src1); + + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; + + vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); + + const bool qx_needs_dequant = mmp == nullptr || x_non_contig; + const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig; + + if (qx_needs_dequant) { + // Fall back to dequant + f16 mulmat + mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]); + } + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); + const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; + + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); + + // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking + uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = padded_n * ne10; + const uint64_t d_ne = ne21 * ne20; + + const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; + const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t ids_sz = nbi2; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); + } else { + to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); + } + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if (!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (!ids_uma) { + d_ids = ids_buf_ctx->dev_buffer; + ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; + GGML_ASSERT(d_ids != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig || qx_needs_dequant) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + + if (x_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); + } else if (qx_needs_dequant) { + const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; + ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + ggml_vk_sync_buffers(ctx, subctx); + } + if (y_non_contig) { + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } + + uint32_t stride_batch_x = ne00*ne01; + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + // compute + ggml_vk_matmul_id( + ctx, subctx, pipeline, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz }, + ne01, ne21, ne10, ne10, ne10, ne01, + stride_batch_x, stride_batch_y, ne20*ne21, + n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n + ); // NOLINT + + if (x_non_contig || qx_needs_dequant) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig) { + ctx->prealloc_y_need_sync = true; + } +} + +static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t nei0 = ids->ne[0]; + const uint64_t nei1 = ids->ne[1]; + + const uint64_t nbi2 = ids->nb[2]; + + GGML_ASSERT(nei1 == 1); + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + vk_buffer d_ids = nullptr; + size_t ids_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool ids_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + ids_uma = d_ids != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + + const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne21 * ne20; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t ids_sz = nbi2; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type); + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + GGML_ASSERT(dmmv != nullptr); + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); + } + ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if(!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if(!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if(!ids_uma) { + d_ids = ids_buf_ctx->dev_buffer; + ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; + GGML_ASSERT(d_ids != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + + if (x_non_contig) { + GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); + } + if (y_non_contig) { + GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } + + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + uint32_t groups_x = ne01; + uint32_t groups_z = 1; + + if (ne01 > max_groups_x) { + groups_z = 64; + groups_x = CEIL_DIV(groups_x, groups_z); + } + + // compute + const vk_mat_vec_id_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21), + (uint32_t)nei0, (uint32_t)ne11, + }; + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, + vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } }, + pc, { groups_x, (uint32_t)nei0, groups_z }); + + if (x_non_contig) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig) { + ctx->prealloc_y_need_sync = true; + } +} + +static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")"); + if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { + ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); + } else { + ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); + } +} + +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) { + // Needs to be kept up to date on shader changes + GGML_UNUSED(hsv); + const uint32_t wg_size = scalar_flash_attention_workgroup_size; + const uint32_t Br = get_fa_scalar_num_large_rows(hsv); + const uint32_t Bc = scalar_flash_attention_Bc; + + const uint32_t tmpsh = wg_size * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * sizeof(float); + + const uint32_t masksh = Bc * Br * sizeof(float); + + const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float); + + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); + + return supported; +} + +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) { + // Needs to be kept up to date on shader changes + GGML_UNUSED(hsv); + const uint32_t wg_size = scalar_flash_attention_workgroup_size; + const uint32_t Br = coopmat1_flash_attention_num_large_rows; + const uint32_t Bc = scalar_flash_attention_Bc; + + const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16); + + const uint32_t acctype = f32acc ? 4 : 2; + const uint32_t f16vec4 = 8; + + const uint32_t tmpsh = wg_size * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * acctype; + + const uint32_t qstride = hsk_pad / 4 + 2; + const uint32_t Qf = Br * qstride * f16vec4; + + const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; + const uint32_t sfsh = Bc * sfshstride * acctype; + + const uint32_t kshstride = hsk_pad / 4 + 2; + const uint32_t ksh = Bc * kshstride * f16vec4; + + const uint32_t slope = Br * sizeof(float); + + const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); + + return supported; +} + +static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; + std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; + std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + if (sinks) { + std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3]; + } + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const uint32_t nem1 = mask ? mask->ne[1] : 0; + const uint32_t nem2 = mask ? mask->ne[2] : 0; + const uint32_t nem3 = mask ? mask->ne[3] : 0; + + const uint32_t HSK = nek0; + const uint32_t HSV = nev0; + uint32_t N = neq1; + const uint32_t KV = nek1; + + GGML_ASSERT(ne0 == HSV); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == HSK); + + GGML_ASSERT(neq1 == N); + + GGML_ASSERT(nev1 == nek1); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + assert(dst->type == GGML_TYPE_F32); + assert(q->type == GGML_TYPE_F32); + assert(k->type == v->type); + + FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : + ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + + if (path == FA_COOPMAT1) { + const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || + (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); + + const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32); + + if (!coopmat_shape_supported || !coopmat_shmem_supported) { + path = FA_SCALAR; + } + } + + uint32_t gqa_ratio = 1; + uint32_t qk_ratio = neq2 / nek2; + uint32_t workgroups_x = (uint32_t)neq1; + uint32_t workgroups_y = (uint32_t)neq2; + uint32_t workgroups_z = (uint32_t)neq3; + + // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. + // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). + uint32_t max_gqa; + switch (path) { + case FA_SCALAR: + case FA_COOPMAT1: + // We may switch from coopmat1 to scalar, so use the scalar limit for both + max_gqa = get_fa_scalar_num_large_rows(HSV); + break; + case FA_COOPMAT2: + max_gqa = get_fa_num_small_rows(FA_COOPMAT2); + break; + default: + GGML_ASSERT(0); + } + + if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && + qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { + // grouped query attention - make the N dimension equal to gqa_ratio, reduce + // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 + // and change addressing calculations to index Q's dimension 2. + gqa_ratio = qk_ratio; + N = gqa_ratio; + workgroups_y /= N; + } + + bool small_rows = N <= get_fa_num_small_rows(path); + + // coopmat1 does not actually support "small rows" (it needs 16 rows). + // So use scalar instead. + if (small_rows && path == FA_COOPMAT1) { + path = FA_SCALAR; + } + + // scalar is faster than coopmat2 when N==1 + if (N == 1 && path == FA_COOPMAT2) { + path = FA_SCALAR; + } + + // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory + if (path == FA_SCALAR && + !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) { + small_rows = true; + } + + const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); + const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); + const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); + + uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows); + bool aligned = (KV % alignment) == 0 && + // the "aligned" shader variant will forcibly align strides, for performance + (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; + + // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned. + if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) { + aligned = false; + } + + bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc); + + vk_pipeline pipeline = nullptr; + + auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type]; + auto it = pipelines.find(fa_pipeline_state); + if (it != pipelines.end()) { + pipeline = it->second; + } else { + pipelines[fa_pipeline_state] = pipeline = std::make_shared(); + } + + assert(pipeline); + + uint32_t split_kv = KV; + uint32_t split_k = 1; + + // Use a placeholder core count if one isn't available. split_k is a big help for perf. + const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + + // Try to use split_k when KV is large enough to be worth the overhead + if (workgroups_x == 1 && shader_core_count > 0) { + // Try to run two workgroups per SM. + split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); + if (split_k > 1) { + // Try to evenly split KV into split_k chunks, but it needs to be a multiple + // of "align", so recompute split_k based on that. + split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); + split_k = CEIL_DIV(KV, split_kv); + workgroups_x = split_k; + } + } + + // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) + // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. + const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; + if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (ctx->prealloc_size_split_k < split_k_size) { + ctx->prealloc_size_split_k = split_k_size; + } + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + } + return; + } + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head_kv = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr; + size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0; + + bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); + ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset); + ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset); + Q_uma = d_Q != nullptr; + K_uma = d_K != nullptr; + V_uma = d_V != nullptr; + D_uma = d_D != nullptr; + if (mask) { + ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset); + M_uma = d_M != nullptr; + } + if (sinks) { + ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset); + S_uma = d_S != nullptr; + } + } + + + ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context; + ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; + ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; + + if (!Q_uma) { + d_Q = q_buf_ctx->dev_buffer; + q_buf_offset = vk_tensor_offset(q) + q->view_offs; + } + if (!K_uma) { + d_K = k_buf_ctx->dev_buffer; + k_buf_offset = vk_tensor_offset(k) + k->view_offs; + } + if (!V_uma) { + d_V = v_buf_ctx->dev_buffer; + v_buf_offset = vk_tensor_offset(v) + v->view_offs; + } + if (!D_uma) { + d_D = d_buf_ctx->dev_buffer; + d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + } + + if (!M_uma) { + d_M = d_Q; + m_buf_offset = q_buf_offset; + if (mask) { + ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context; + d_M = m_buf_ctx->dev_buffer; + m_buf_offset = vk_tensor_offset(mask) + mask->view_offs; + } + } + + if (!S_uma) { + d_S = d_Q; + s_buf_offset = q_buf_offset; + if (sinks) { + ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context; + d_S = s_buf_ctx->dev_buffer; + s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs; + } + } + + uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2; + + const vk_flash_attn_push_constants pc = { N, KV, + (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, + (uint32_t)neq2, (uint32_t)neq3, + (uint32_t)nek2, (uint32_t)nek3, + (uint32_t)nev2, (uint32_t)nev3, + nem1, nem2, nem3, + q_stride, (uint32_t)nbq2, (uint32_t)nbq3, + k_stride, (uint32_t)nbk2, (uint32_t)nbk3, + v_stride, (uint32_t)nbv2, (uint32_t)nbv3, + scale, max_bias, logit_softcap, + mask_n_head_log2, m0, m1, + gqa_ratio, split_kv, split_k }; + + if (split_k > 1) { + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + ggml_vk_subbuffer(ctx, d_Q, q_buf_offset), + ggml_vk_subbuffer(ctx, d_K, k_buf_offset), + ggml_vk_subbuffer(ctx, d_V, v_buf_offset), + ggml_vk_subbuffer(ctx, d_M, m_buf_offset), + ggml_vk_subbuffer(ctx, d_S, s_buf_offset), + ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0), + }, + // We only use split_k when group query attention is enabled, which means + // there's no more than one tile of rows (i.e. workgroups_x would have been + // one). We reuse workgroups_x to mean the number of splits, so we need to + // cancel out the divide by wg_denoms[0]. + pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + + ggml_vk_sync_buffers(ctx, subctx); + const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, + { + ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0), + ggml_vk_subbuffer(ctx, d_S, s_buf_offset), + ggml_vk_subbuffer(ctx, d_D, d_buf_offset), + }, + pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); + ctx->prealloc_split_k_need_sync = true; + } else { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + ggml_vk_subbuffer(ctx, d_Q, q_buf_offset), + ggml_vk_subbuffer(ctx, d_K, k_buf_offset), + ggml_vk_subbuffer(ctx, d_V, v_buf_offset), + ggml_vk_subbuffer(ctx, d_M, m_buf_offset), + ggml_vk_subbuffer(ctx, d_S, s_buf_offset), + ggml_vk_subbuffer(ctx, d_D, d_buf_offset), + }, + pc, { workgroups_x, workgroups_y, workgroups_z }); + } +} + +static std::array ggml_vk_get_conv_elements(const ggml_tensor *dst) { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + + // src0 - kernel: [KW, KH, Cin, Cout] + // src1 - input: [W, H, Cin, N] + // dst - result: [OW, OH, Cout, N] + + // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; + }; + // parallelize in {OW/BS_K, OH/BS_NPQ, 1} + int64_t W = src1->ne[0]; + int64_t H = src1->ne[1]; + int64_t KW = src0->ne[0]; + int64_t KH = src0->ne[1]; + int64_t Cout = src0->ne[3]; + int64_t N = src1->ne[3]; + int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]); + int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]); + int64_t NPQ = N * OW * OH; + + // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups + std::array elements = { static_cast(Cout), static_cast(NPQ), 1 }; + return elements; +} + +static std::array ggml_vk_get_conv_transpose_2d_elements(const ggml_tensor *dst) { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + + // src0 - kernel: [KW, KH, Cout, Cin] + // src1 - input: [W, H, Cin, N] + // dst - result: [OW, OH, Cout, N] + + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins - 1) * s - 2 * p + (ks - 1) * d + 1; + }; + // parallelize in {OW/BS_K, OH/BS_NPQ, 1} + int64_t W = src1->ne[0]; + int64_t H = src1->ne[1]; + int64_t KW = src0->ne[0]; + int64_t KH = src0->ne[1]; + int64_t Cout = src0->ne[2]; + int64_t N = src1->ne[3]; + int64_t OH = calc_conv_output_size(H, KH, dst->op_params[0], 0, 1); + int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], 0, 1); + int64_t NPQ = N * OW * OH; + + // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups + std::array elements = { static_cast(Cout), static_cast(NPQ), 1 }; + return elements; +} + +static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) { + switch (op) { + case GGML_OP_GET_ROWS: + GGML_ASSERT(src1->type == GGML_TYPE_I32); + if (dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_get_rows[src0->type]; + } + if (dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_get_rows_f32[src0->type]; + } + return nullptr; + case GGML_OP_ACC: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_acc_f32; + } + return nullptr; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) { + return nullptr; + } + switch (op) { + case GGML_OP_ADD: + { + if (ctx->num_additional_fused_ops > 0) { + if (ctx->do_add_rms_partials) { + return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops]; + } else { + return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops]; + } + } + if (ctx->do_add_rms_partials) { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } else { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + } + case GGML_OP_SUB: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_MUL: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_DIV: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + default: + break; + } + return nullptr; + case GGML_OP_ADD_ID: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_add_id_f32; + } + return nullptr; + case GGML_OP_CONCAT: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_concat_f32; + } + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_concat_f16; + } + if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_concat_i32; + } + return nullptr; + case GGML_OP_UPSCALE: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + int mode = ggml_get_op_params_i32(dst, 0); + switch (mode) { + case GGML_SCALE_MODE_NEAREST: + return ctx->device->pipeline_upscale_nearest_f32; + case GGML_SCALE_MODE_BILINEAR: + return ctx->device->pipeline_upscale_bilinear_f32; + case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS: + return ctx->device->pipeline_upscale_bilinear_ac_f32; + } + } + return nullptr; + case GGML_OP_SCALE: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_scale_f32; + } + return nullptr; + case GGML_OP_SQR: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sqr_f32; + } + return nullptr; + case GGML_OP_SQRT: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sqrt_f32; + } + return nullptr; + case GGML_OP_SIN: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sin_f32; + } + return nullptr; + case GGML_OP_COS: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_cos_f32; + } + return nullptr; + case GGML_OP_CLAMP: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_clamp_f32; + } + return nullptr; + case GGML_OP_PAD: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_pad_f32; + } + return nullptr; + case GGML_OP_ROLL: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_roll_f32; + } + return nullptr; + case GGML_OP_REPEAT: + if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { + return ctx->device->pipeline_repeat_f32; + } + return nullptr; + case GGML_OP_REPEAT_BACK: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_repeat_back_f32; + } + return nullptr; + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type); + case GGML_OP_SET_ROWS: + if (src1->type == GGML_TYPE_I64) { + return ctx->device->pipeline_set_rows_i64[dst->type]; + } else { + return ctx->device->pipeline_set_rows_i32[dst->type]; + } + case GGML_OP_SILU_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_silu_back_f32; + } + return nullptr; + case GGML_OP_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_norm_f32; + } + return nullptr; + case GGML_OP_GROUP_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_group_norm_f32; + } + return nullptr; + case GGML_OP_RMS_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (ctx->do_add_rms_partials) { + return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32; + } else { + return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32; + } + } + return nullptr; + case GGML_OP_RMS_NORM_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rms_norm_back_f32; + } + return nullptr; + case GGML_OP_L2_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_l2_norm_f32; + } + return nullptr; + case GGML_OP_UNARY: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || + (src0->type != dst->type)) { + return nullptr; + } + + switch (ggml_get_unary_op(dst)) { + case GGML_UNARY_OP_EXP: + return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_SILU: + return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_GELU: + return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_GELU_ERF: + return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_GELU_QUICK: + return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_RELU: + return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_TANH: + return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_SIGMOID: + return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_HARDSIGMOID: + return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_HARDSWISH: + return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16]; + default: + break; + } + return nullptr; + case GGML_OP_GLU: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || + (src0->type != dst->type)) { + return nullptr; + } + + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_GEGLU: + return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_REGLU: + return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_SWIGLU: + return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_SWIGLU_OAI: + return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_GEGLU_ERF: + return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_GEGLU_QUICK: + return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16]; + default: + break; + } + return nullptr; + case GGML_OP_DIAG_MASK_INF: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_diag_mask_inf_f32; + } + return nullptr; + case GGML_OP_SOFT_MAX: + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); + + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { + return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; + } + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16; + } + return nullptr; + case GGML_OP_SOFT_MAX_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_soft_max_back_f32; + } + return nullptr; + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + { + const int mode = ((const int32_t *) dst->op_params)[2]; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_neox) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_neox_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_neox_f16; + } + } else if (is_mrope && !is_vision) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_multi_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_multi_f16; + } + } else if (is_vision) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_vision_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_vision_f16; + } + } else { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_norm_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_norm_f16; + } + } + return nullptr; + } + case GGML_OP_ARGSORT: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { + uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); + return ctx->device->pipeline_argsort_f32[idx]; + } + return nullptr; + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sum_rows_f32; + } + return nullptr; + case GGML_OP_ARGMAX: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_argmax_f32; + } + return nullptr; + case GGML_OP_COUNT_EQUAL: + if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) { + return ctx->device->pipeline_count_equal_i32; + } + return nullptr; + case GGML_OP_IM2COL: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_im2col_f32; + } + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_im2col_f32_f16; + } + return nullptr; + case GGML_OP_IM2COL_3D: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_im2col_3d_f32; + } + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_im2col_3d_f32_f16; + } + return nullptr; + case GGML_OP_TIMESTEP_EMBEDDING: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_timestep_embedding_f32; + } + return nullptr; + case GGML_OP_CONV_TRANSPOSE_1D: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv_transpose_1d_f32; + } + return nullptr; + case GGML_OP_POOL_2D: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_pool2d_f32; + } + return nullptr; + case GGML_OP_RWKV_WKV6: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rwkv_wkv6_f32; + } + return nullptr; + case GGML_OP_RWKV_WKV7: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rwkv_wkv7_f32; + } + return nullptr; + case GGML_OP_OPT_STEP_ADAMW: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_opt_step_adamw_f32; + } + return nullptr; + case GGML_OP_OPT_STEP_SGD: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_opt_step_sgd_f32; + } + return nullptr; + case GGML_OP_LEAKY_RELU: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_leaky_relu_f32; + } + return nullptr; + case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && + ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + std::array elements; + if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst); + else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst); + vk_conv_shapes shape; + + uint32_t tiles[CONV_SHAPE_COUNT]; + for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) { + tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]); + } + + // We can't query number of shader cores on Intel, use 32 as a placeholder + // so small convolutions will still choose a smaller tile. + const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32; + + if (elements[0] > 64 && tiles[CONV_SHAPE_128x128] >= shader_core_count * 2) { + shape = CONV_SHAPE_128x128; + } else if (elements[0] <= 32 && tiles[CONV_SHAPE_32x256] >= shader_core_count * 2) { + shape = CONV_SHAPE_32x256; + } else { + shape = CONV_SHAPE_64x32; + } + + if (op == GGML_OP_CONV_2D) { + if (src0->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv2d_f32[shape]; + } else if (src0->type == GGML_TYPE_F16) { + return ctx->device->pipeline_conv2d_f16_f32[shape]; + } + } else if (op == GGML_OP_CONV_TRANSPOSE_2D) { + if (src0->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv_transpose_2d_f32[shape]; + } else if (src0->type == GGML_TYPE_F16) { + return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape]; + } + } + } + return nullptr; + case GGML_OP_CONV_2D_DW: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (ggml_is_contiguous(src1)) { + return ctx->device->pipeline_conv2d_dw_whcn_f32; + } else if (ggml_is_contiguous_channels(src1)) { + return ctx->device->pipeline_conv2d_dw_cwhn_f32; + } + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + if (ggml_is_contiguous(src1)) { + return ctx->device->pipeline_conv2d_dw_whcn_f16_f32; + } else if (ggml_is_contiguous_channels(src1)) { + return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32; + } + } + return nullptr; + default: + return nullptr; + } + + GGML_UNUSED(src2); +} + +static bool ggml_vk_op_supports_incontiguous(ggml_op op) { + switch (op) { + case GGML_OP_CPY: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_ADD_ID: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_ROPE: + case GGML_OP_RMS_NORM: + case GGML_OP_CONV_2D_DW: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: + case GGML_OP_SET_ROWS: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + return true; + default: + return false; + } +} + +static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_tensor * t) +{ + return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));; +} + +template void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + GGML_UNUSED(p); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(dst); + static_assert(!std::is_const::value, "unexpected type"); + GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0); + GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0); + GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0); + GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src0); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0)); + + p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset; + + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.a_offset = a_offset; + p.d_offset = d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template +static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + if (src1 != nullptr) { + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + } + if (src2 != nullptr) { + std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3]; + } + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT + GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT + GGML_ASSERT(dst->buffer != nullptr); + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + const uint64_t ne0 = ne00 * ne01; + + const bool use_src1 = src1 != nullptr; + const uint64_t ne10 = use_src1 ? src1->ne[0] : 0; + const uint64_t ne11 = use_src1 ? src1->ne[1] : 0; + const uint64_t ne12 = use_src1 ? src1->ne[2] : 0; + const uint64_t ne13 = use_src1 ? src1->ne[3] : 0; + const uint64_t ne1 = ne10 * ne11; + // const uint64_t nb10 = use_src1 ? src1->nb[0] : 0; + + const bool use_src2 = src2 != nullptr; + const uint64_t ne20 = use_src2 ? src2->ne[0] : 0; + const uint64_t ne21 = use_src2 ? src2->ne[1] : 0; + const uint64_t ne22 = use_src2 ? src2->ne[2] : 0; + const uint64_t ne23 = use_src2 ? src2->ne[3] : 0; + const uint64_t ne2 = ne20 * ne21; + + const uint64_t ned0 = dst->ne[0]; + const uint64_t ned1 = dst->ne[1]; + const uint64_t ned2 = dst->ne[2]; + const uint64_t ned3 = dst->ne[3]; + const uint64_t ned = ned0 * ned1; + + init_pushconst_fastdiv(pc); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op); + + if (pipeline == nullptr) { + std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type); + if (src1 != nullptr) { + std::cerr << " and " << ggml_type_name(src1->type); + } + std::cerr << " to " << ggml_type_name(dst->type) << std::endl; + GGML_ABORT("fatal error"); + } + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr; + ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr; + + vk_buffer d_X = nullptr; + size_t x_buf_offset = 0; + vk_buffer d_Y = nullptr; + size_t y_buf_offset = 0; + vk_buffer d_Z = nullptr; + size_t z_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool src2_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset); + src0_uma = d_X != nullptr; + if (use_src1) { + ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset); + src1_uma = d_Y != nullptr; + } + if (use_src2) { + ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset); + src2_uma = d_Z != nullptr; + } + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + + GGML_ASSERT(d_D != nullptr); + uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + if(!src0_uma) { + d_X = src0_buf_ctx->dev_buffer; + x_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_X != nullptr); + } + if (use_src1 && !src1_uma) { + d_Y = src1_buf_ctx->dev_buffer; + y_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Y != nullptr); + } + if (use_src2 && !src2_uma) { + d_Z = src2_buf_ctx->dev_buffer; + z_buf_offset = vk_tensor_offset(src2) + src2->view_offs; + GGML_ASSERT(d_Z != nullptr); + } + // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets. + init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst); + x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + + std::array elements; + + // Single call if dimension 2 is contiguous + GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); + + switch (op) { + case GGML_OP_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + { + const uint32_t nr = ggml_nrows(src0); + if (nr > 262144) { + elements = { 512, 512, CEIL_DIV(nr, 262144) }; + } else if (nr > 512) { + elements = { 512, CEIL_DIV(nr, 512), 1 }; + } else { + elements = { nr, 1, 1 }; + } + } break; + case GGML_OP_RMS_NORM: + if (ctx->do_add_rms_partials) { + // Run one element per thread, 128 threads per workgroup + elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 }; + } else { + elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; + } + break; + + case GGML_OP_SUM: + // We use GGML_OP_SUM_ROWS with 1 row. + elements = { 1, 1, 1 }; + break; + case GGML_OP_GROUP_NORM: + { + const uint32_t num_groups = dst->op_params[0]; + elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 }; + } break; + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; + break; + case GGML_OP_GET_ROWS: + elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + break; + case GGML_OP_ARGSORT: + elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; + break; + case GGML_OP_IM2COL: + { + const bool is_2D = dst->op_params[6] == 1; + + const uint32_t IC = src1->ne[is_2D ? 2 : 1]; + + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t KW = src0->ne[0]; + + const uint32_t OH = is_2D ? dst->ne[2] : 1; + const uint32_t OW = dst->ne[1]; + + const uint32_t batch = src1->ne[is_2D ? 3 : 2]; + + elements = { OW * KW * KH, OH, batch * IC }; + } break; + case GGML_OP_IM2COL_3D: + { + const uint32_t IC = ((const uint32_t *)(dst->op_params))[9]; + + const uint32_t N = ne13 / IC; + + const uint32_t KD = ne02; + const uint32_t KH = ne01; + const uint32_t KW = ne00; + + const uint32_t OD = ned3 / N; + const uint32_t OH = ned2; + const uint32_t OW = ned1; + + const uint32_t IC_KD_KH_KW = IC*KD*KH*KW; + const uint32_t N_OD_OH = N*OD*OH; + + elements = { IC_KD_KH_KW, OW, N_OD_OH }; + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + const uint32_t dim = dst->op_params[0]; + uint32_t half_ceil = (dim + 1) / 2; + elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1} + } break; + case GGML_OP_POOL_2D: + { + const uint32_t N = dst->ne[3]; + const uint32_t OC = dst->ne[2]; + const uint32_t OH = dst->ne[1]; + const uint32_t OW = dst->ne[0]; + elements = { N * OC * OH * OW, 1, 1}; + } break; + case GGML_OP_CONV_2D: + { + elements = ggml_vk_get_conv_elements(dst); + } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + elements = ggml_vk_get_conv_transpose_2d_elements(dst); + } break; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_DIV: + case GGML_OP_MUL: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_ROLL: + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_CPY: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_UNARY: + case GGML_OP_GLU: + case GGML_OP_CONV_2D_DW: + { + uint32_t ne = ggml_nelements(dst); + if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + // Convert from number of logical elements to 2- or 4-byte units. + ne /= ggml_blck_size(src0->type); + if ((ggml_type_size(src0->type) % 4) == 0) { + ne *= ggml_type_size(src0->type) / 4; + } else { + ne *= ggml_type_size(src0->type) / 2; + } + } + // copy_to_quant has block size of 32, and each thread does QUANT_K elements. + // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements. + // So divide by block size here before splitting into 512x512 groups. + if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + ne = CEIL_DIV(ne, ggml_blck_size(dst->type)); + } + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + } break; + case GGML_OP_ADD_ID: + { + elements = { (uint32_t)ne01, (uint32_t)ne02, 1 }; + } break; + case GGML_OP_SET_ROWS: + { + uint32_t ne = ggml_nelements(src0); + if (ggml_is_quantized(dst->type)) { + // quants run 32 threads each doing QUANT_K elements + ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type)); + } else { + // scalar types do one element per thread, running 512 threads + ne = CEIL_DIV(ne, 512); + } + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + } + break; + default: + elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; + break; + } + + uint64_t x_sz, y_sz, z_sz, d_sz; + + if (op_supports_incontiguous) { + x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0); + y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0; + z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0; + d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst); + + if (x_buf_offset + x_sz >= d_X->size) { + x_sz = ggml_vk_get_max_buffer_range(ctx, d_X, x_buf_offset); + } + if (use_src1 && y_buf_offset + y_sz >= d_Y->size) { + y_sz = ggml_vk_get_max_buffer_range(ctx, d_Y, y_buf_offset); + } + if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { + z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset); + } + if (d_buf_offset + d_sz >= d_D->size) { + d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset); + } + } else { + x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03; + y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0; + z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0; + d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3; + } + + if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) { + vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X; + size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { vk_subbuffer{ d_X, x_buf_offset, x_sz }, + vk_subbuffer{ d_Y, y_buf_offset, y_sz }, + vk_subbuffer{ d_D, d_buf_offset, d_sz }, + ggml_vk_subbuffer(ctx, d_A, a_buf_offset), + }, pc, elements); + } else if (op == GGML_OP_GLU) { + // Empty src1 is possible in glu, but the shader needs a buffer + vk_subbuffer subbuf_y; + if (use_src1) { + subbuf_y = { d_Y, y_buf_offset, y_sz }; + } else { + subbuf_y = { d_X, 0, x_sz }; + } + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_SOFT_MAX) { + // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer + vk_subbuffer subbuf_y; + if (use_src1) { + subbuf_y = { d_Y, y_buf_offset, y_sz }; + } else { + subbuf_y = { d_X, 0, x_sz }; + } + + vk_subbuffer subbuf_z; + if (use_src2) { + subbuf_z = { d_Z, z_buf_offset, z_sz }; + } else { + subbuf_z = { d_X, 0, x_sz }; + } + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { + // Empty src2 is possible in rope, but the shader needs a buffer + vk_subbuffer subbuf_z; + if (use_src2) { + subbuf_z = { d_Z, z_buf_offset, z_sz }; + } else { + subbuf_z = { d_X, 0, x_sz }; + } + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) { + if (ctx->device->shader_int64 && ctx->device->buffer_device_address) { + // buffer device address path doesn't use dst buffer + d_sz = 1; + } + // im2col uses only src1 and dst buffers + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_COUNT_EQUAL) { + // count_equal assumes that destination buffer is initialized with zeroes + ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_OPT_STEP_SGD) { + // OPT_STEP_SGD works on src0, it does not need dst + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements); + } else if (use_src2) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (use_src1) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } +} + +static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 + int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 + // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused + int offset = dst->op_params[3] / 4; // offset in bytes + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, offset, + }, dryrun); +} + +static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) { + const ggml_tensor *first_node = cgraph->nodes[node_idx]; + const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; + + // Make a list of all the tensors used by the op. + // Last element of the list is the dest tensor. + const ggml_tensor *tensors[MAX_PARAMETER_COUNT]; + uint32_t num_srcs = ctx->num_additional_fused_ops + 2; + uint32_t num_tensors = num_srcs + 1; + GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT); + + tensors[0] = first_node->src[0]; + tensors[1] = first_node->src[1]; + for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) { + // check whether the previous result is src[0] or src[1] + if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) { + tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1]; + } else { + tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0]; + } + } + tensors[num_srcs] = dst; + + vk_op_multi_add_push_constants pc; + pc.ne20 = (uint32_t)dst->ne[0]; + pc.ne21 = (uint32_t)dst->ne[1]; + pc.ne22 = (uint32_t)dst->ne[2]; + pc.ne23 = (uint32_t)dst->ne[3]; + + for (uint32_t i = 0; i < num_tensors; ++i) { + const ggml_tensor *t = tensors[i]; + pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float); + pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float); + pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float); + pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float); + } + pc.rms_partials = ctx->do_add_rms_partials; + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, tensors[0], tensors[1], nullptr, dst, dst->op); + + if (pipeline == nullptr) { + std::cerr << "ggml_vulkan: Error: Missing multi_add"; + GGML_ABORT("fatal error"); + } + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT]; + vk_buffer buf[MAX_PARAMETER_COUNT]; + size_t offset[MAX_PARAMETER_COUNT]; + bool uma[MAX_PARAMETER_COUNT]; + + for (uint32_t i = 0; i < num_tensors; ++i) { + buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context; + buf[i] = nullptr; + offset[i] = 0; + uma[i] = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]); + uma[i] = buf[i] != nullptr; + } + if (!uma[i]) { + buf[i] = buf_ctx[i]->dev_buffer; + offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs; + } + GGML_ASSERT(buf[i] != nullptr); + } + // If any remaining descriptors are unused, just point them at src[0] + for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) { + buf[i] = buf[0]; + offset[i] = 0; + } + if (ctx->do_add_rms_partials) { + buf[num_tensors] = ctx->prealloc_add_rms_partials; + offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset; + } + + std::array elements; + + uint32_t ne = ggml_nelements(dst); + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + + static_assert(MAX_PARAMETER_COUNT == 12); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + ggml_vk_subbuffer(ctx, buf[0], offset[0]), + ggml_vk_subbuffer(ctx, buf[1], offset[1]), + ggml_vk_subbuffer(ctx, buf[2], offset[2]), + ggml_vk_subbuffer(ctx, buf[3], offset[3]), + ggml_vk_subbuffer(ctx, buf[4], offset[4]), + ggml_vk_subbuffer(ctx, buf[5], offset[5]), + ggml_vk_subbuffer(ctx, buf[6], offset[6]), + ggml_vk_subbuffer(ctx, buf[7], offset[7]), + ggml_vk_subbuffer(ctx, buf[8], offset[8]), + ggml_vk_subbuffer(ctx, buf[9], offset[9]), + ggml_vk_subbuffer(ctx, buf[10], offset[10]), + ggml_vk_subbuffer(ctx, buf[11], offset[11]), + }, pc, elements); +} + +static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, ctx->do_add_rms_partials, + }, dryrun); +} + +static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t src2_type_size = ggml_type_size(src2->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, { + (uint32_t)dst->ne[0], + (uint32_t)dst->ne[1], + (uint32_t)src0->nb[1] / src0_type_size, + (uint32_t)src0->nb[2] / src0_type_size, + (uint32_t)src1->nb[1] / src1_type_size, + (uint32_t)src2->nb[1] / src2_type_size, + }, dryrun); +} + +static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) { + GGML_ASSERT(version == 6 || version == 7); + int num_srcs = version == 6 ? 6 : 7; + + for (int i = 0; i < num_srcs; i++) { + GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type)); + } + + GGML_ASSERT(dst->buffer != nullptr); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + for (int i = 0; i < num_srcs; i++) { + src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; + } + + vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 }; + bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false }; + + if (ctx->device->uma) { + for (int i = 0; i < num_srcs; i++) { + ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]); + srcs_uma[i] = d_srcs[i] != nullptr; + } + + ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); + dst_uma = d_D != nullptr; + } + + uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 }; + for (int i = 0; i < num_srcs; i++) { + src_sizes[i] = ggml_nbytes(dst->src[i]); + if (!srcs_uma[i]) { + d_srcs[i] = src_buf_ctxs[i]->dev_buffer; + src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs; + } + } + + const uint64_t dst_size = ggml_nbytes(dst); + if (!dst_uma) { + d_D = dst_buf_ctx->dev_buffer; + dst_offset = vk_tensor_offset(dst) + dst->view_offs; + } + + std::array elements = { + (uint32_t)(pc.B * pc.H), + 1, + 1 + }; + + if (version == 6) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, + vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, + vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, + vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, + vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, + vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, pc, elements); + } else if (version == 7) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, + vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, + vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, + vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, + vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, + vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, + vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, pc, elements); + } else { + // shouldn't happen + GGML_ASSERT(false); + } +} + +static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t seq_length = dst->src[0]->ne[2]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[1]; + const size_t n_seqs = dst->src[5]->ne[1]; + + ggml_vk_op_f32_wkv( + ctx, subctx, dst, + { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }, + 6, + dryrun + ); +} + +static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t seq_length = dst->src[0]->ne[2]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[1]; + const size_t n_seqs = dst->src[6]->ne[1]; + + ggml_vk_op_f32_wkv( + ctx, subctx, dst, + { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }, + 7, + dryrun + ); +} + +static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) { + const ggml_tensor * x = dst->src[0]; + const ggml_tensor * g = dst->src[1]; + const ggml_tensor * gm = dst->src[2]; + const ggml_tensor * gv = dst->src[3]; + const ggml_tensor * p = dst->src[4]; + + GGML_ASSERT(x->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(gm->type == GGML_TYPE_F32); + GGML_ASSERT(gv->type == GGML_TYPE_F32); + GGML_ASSERT(p->type == GGML_TYPE_F32); + GGML_ASSERT(dst->buffer != nullptr); + GGML_ASSERT(ggml_is_contiguous(x)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(gm)); + GGML_ASSERT(ggml_is_contiguous(gv)); + GGML_ASSERT(ggml_is_contiguous(p)); + GGML_ASSERT(ggml_are_same_shape(x, g)); + GGML_ASSERT(ggml_are_same_shape(x, gm)); + GGML_ASSERT(ggml_are_same_shape(x, gv)); + GGML_ASSERT(ggml_nelements(p) == 7); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context; + ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context; + ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context; + ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context; + ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context; + + vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr; + size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0; + bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, x->data, d_X, x_offset); + ggml_vk_host_get(ctx->device, g->data, d_G, g_offset); + ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset); + ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset); + ggml_vk_host_get(ctx->device, p->data, d_P, p_offset); + + X_uma = d_X != nullptr; + G_uma = d_G != nullptr; + GM_uma = d_GM != nullptr; + GV_uma = d_GV != nullptr; + P_uma = d_P != nullptr; + } + + if (!X_uma) { + d_X = x_buf_ctx->dev_buffer; + x_offset = vk_tensor_offset(x) + x->view_offs; + } + if (!G_uma) { + d_G = g_buf_ctx->dev_buffer; + g_offset = vk_tensor_offset(g) + g->view_offs; + } + if (!GM_uma) { + d_GM = gm_buf_ctx->dev_buffer; + gm_offset = vk_tensor_offset(gm) + gm->view_offs; + } + if (!GV_uma) { + d_GV = gv_buf_ctx->dev_buffer; + gv_offset = vk_tensor_offset(gv) + gv->view_offs; + } + if (!P_uma) { + d_P = p_buf_ctx->dev_buffer; + p_offset = vk_tensor_offset(p) + p->view_offs; + } + + const uint64_t x_size = ggml_nbytes(x); + const uint64_t g_size = ggml_nbytes(g); + const uint64_t gm_size = ggml_nbytes(gm); + const uint64_t gv_size = ggml_nbytes(gv); + const uint64_t p_size = ggml_nbytes(p); + + std::array elements = { (uint32_t)ggml_nelements(x), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_X, x_offset, x_size }, + vk_subbuffer{ d_G, g_offset, g_size }, + vk_subbuffer{ d_GM, gm_offset, gm_size }, + vk_subbuffer{ d_GV, gv_offset, gv_size }, + vk_subbuffer{ d_P, p_offset, p_size }, + }, pc, elements); +} + +static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t n = ggml_nelements(dst->src[0]); + + ggml_vk_op_f32_opt_step_adamw( + ctx, subctx, dst, + { (uint32_t)n, 0, 0.0f, 0.0f }, + dryrun + ); +} + +static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + const size_t n = ggml_nelements(dst->src[0]); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + int * op_params = (int *)dst->op_params; + + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, op_params[0], + }, dryrun); +} + +static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0); + + float sf0 = (float)dst->ne[0] / src0->ne[0]; + float sf1 = (float)dst->ne[1] / src0->ne[1]; + float sf2 = (float)dst->ne[2] / src0->ne[2]; + float sf3 = (float)dst->ne[3] / src0->ne[3]; + + if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) { + sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1); + sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1); + } + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { + (uint32_t)ggml_nelements(dst), 0, 0, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], + (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], + sf0, sf1, sf2, sf3, + }, dryrun); +} + +static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = ggml_get_op_params_f32(dst, 0); + p.param2 = ggml_get_op_params_f32(dst, 1); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun); +} + +static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun); +} + +static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun); +} + +static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun); +} + +static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun); +} + +static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = ggml_get_op_params_f32(dst, 0); + p.param2 = ggml_get_op_params_f32(dst, 1); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun); +} + +static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun); +} + +static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t s2 = ggml_get_op_params_i32(dst, 2); + const int32_t s3 = ggml_get_op_params_i32(dst, 3); + const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000); + const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000); + + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + memcpy(&p.param1, &s01_packed, sizeof(float)); + memcpy(&p.param2, &s23_packed, sizeof(float)); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun); +} + +static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun); +} + +static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun); +} + +static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + uint32_t ne = (uint32_t)ggml_nelements(src0); + if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + // Convert from number of logical elements to 2- or 4-byte units. + ne /= ggml_blck_size(src0->type); + if ((ggml_type_size(src0->type) % 4) == 0) { + ne *= ggml_type_size(src0->type) / 4; + } else { + ne *= ggml_type_size(src0->type) / 2; + } + } + + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun); +} + +static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + // Skip empty skip_rows operations. For most ops the empty check at the start + // of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst + // with empty srcs. + if (ggml_is_empty(src0) || ggml_is_empty(src1)) { + return; + } + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + +static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const int * int_op_params = (const int *)dst->op_params; + const float * float_op_params = (const float *)dst->op_params; + + const uint32_t num_groups = int_op_params[0]; + const float eps = float_op_params[1]; + const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); +} + +static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) { + const uint32_t ne = (uint32_t)node->ne[0]; + const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0]; + const uint32_t num_partials = CEIL_DIV(ne, denom); + return num_partials; +} + +static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) { + const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node); + const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment); + return num_bytes; +} + +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], 0.0f, (int32_t)param3, + }, dryrun); + + if (ctx->do_add_rms_partials) { + ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0); + ctx->do_add_rms_partials = false; + } +} + +static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + +static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + +static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const float * op_params_f = (const float *)dst->op_params; + + const bool swapped = (bool)dst->op_params[1]; + const bool split = src1 != nullptr; + const float alpha = op_params_f[2]; + const float limit = op_params_f[3]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + + if (!split) { + GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); + } else { + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + GGML_ASSERT(src0->ne[0] == dst->ne[0]); + GGML_ASSERT(src0->type == src1->type); + } + + const uint32_t mode = split ? 2 : (swapped ? 1 : 0); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, + { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], + (uint32_t)dst->ne[0], + mode, + alpha, + limit + }, dryrun); +} + +static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + int32_t * op_params = (int32_t *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); +} + +static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + + float scale = op_params[0]; + float max_bias = op_params[1]; + + const uint32_t ncols = (uint32_t)src0->ne[0]; + const uint32_t nrows_x = (uint32_t)ggml_nrows(src0); + const uint32_t nrows_y = (uint32_t)src0->ne[1]; + + const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u; + const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u; + const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u; + const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u; + const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u; + + const uint32_t n_head_kv = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, { + ncols, + src1 != nullptr ? nrows_y : (uint32_t)0, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], + ne12, ne13, + nb11, nb12, nb13, + scale, max_bias, + m0, m1, + n_head_log2, + nrows_x, + src2 != nullptr + }, dryrun); +} + +static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun); +} + +static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + // const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + const float freq_base = ((float *) dst->op_params)[5]; + const float freq_scale = ((float *) dst->op_params)[6]; + const float ext_factor = ((float *) dst->op_params)[7]; + const float attn_factor = ((float *) dst->op_params)[8]; + const float beta_fast = ((float *) dst->op_params)[9]; + const float beta_slow = ((float *) dst->op_params)[10]; + int sections[4] {}; + if (mode & GGML_ROPE_TYPE_MROPE) { + memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4); + } + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type); + uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { + (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], + freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, + src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, + { sections[0], sections[1], sections[2], sections[3] }, backprop + }, dryrun); +} + +static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + int32_t * op_params = (int32_t *)dst->op_params; + + uint32_t ncols = src0->ne[0]; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { + ncols, + op_params[0], + }, dryrun); +} + +static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0)); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun); +} + +static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun); +} + +static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + p.weight = 1.0f / (float)src0->ne[0]; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun); +} + +static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const int32_t s0 = dst->op_params[0]; + const int32_t s1 = dst->op_params[1]; + const int32_t p0 = dst->op_params[2]; + const int32_t p1 = dst->op_params[3]; + const int32_t d0 = dst->op_params[4]; + const int32_t d1 = dst->op_params[5]; + + const bool is_2D = dst->op_params[6] == 1; + + const uint32_t IC = src1->ne[is_2D ? 2 : 1]; + const uint32_t IH = is_2D ? src1->ne[1] : 1; + const uint32_t IW = src1->ne[0]; + + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t KW = src0->ne[0]; + + const uint32_t OH = is_2D ? dst->ne[2] : 1; + const uint32_t OW = dst->ne[1]; + + const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 + + const uint32_t pelements = OW * KW * KH; + + const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + const vk_buffer d_buf = d_buf_ctx->dev_buffer; + + const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, { + dst_addr, + batch_offset, offset_delta, + IC, IW, IH, OW, OH, KW, KH, + pelements, + IC * KH * KW, + s0, s1, p0, p1, d0, d1, + }, dryrun); +} + +static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + + const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + const vk_buffer d_buf = d_buf_ctx->dev_buffer; + + const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs; + + vk_op_im2col_3d_push_constants pc {}; + + pc.dst_addr = dst_addr; + pc.nb10 = nb10 / ggml_type_size(src1->type); + pc.nb11 = nb11 / ggml_type_size(src1->type); + pc.nb12 = nb12 / ggml_type_size(src1->type); + pc.nb13 = nb13 / ggml_type_size(src1->type); + pc.s0 = s0; + pc.s1 = s1; + pc.s2 = s2; + pc.p0 = p0; + pc.p1 = p1; + pc.p2 = p2; + pc.d0 = d0; + pc.d1 = d1; + pc.d2 = d2; + pc.IW = IW; + pc.IH = IH; + pc.ID = ID; + pc.IC = IC; + pc.KW = KW; + pc.OH = OH; + pc.KD_KH_KW = KD*KH*KW; + pc.KH_KW = KH*KW; + pc.IC_KD_KH_KW = IC*KD*KH*KW; + pc.N_OD_OH = N*OD*OH; + pc.OD_OH = OD*OH; + pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; + pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun); +} + +static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t dim = dst->op_params[0]; + const uint32_t max_period = dst->op_params[1]; + const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, { + nb1, dim, max_period, + }, dryrun); +} + +static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + // src0: (K, Cout, Cin, 1) -- kernel + // src1: (L, Cin, 1, 1) -- input + // dst: (*, Cout, 1, 1) + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + const int32_t s0 = dst->op_params[0]; + + vk_op_conv_transpose_1d_push_constants p{}; + p.Cout = static_cast(ne01); + p.Cin = static_cast(ne02); + p.K = static_cast(ne00); + p.L = static_cast(ne10); + p.KL = static_cast(ne0); + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb11 = static_cast(nb11 / nb10); + p.nb1 = static_cast(nb1 / nb0); + p.s0 = static_cast(s0); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun); +} + +static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + uint32_t op = static_cast(dst->op_params[0]); + const int32_t k1 = dst->op_params[1]; + const int32_t k0 = dst->op_params[2]; + const int32_t s1 = dst->op_params[3]; + const int32_t s0 = dst->op_params[4]; + const int32_t p1 = dst->op_params[5]; + const int32_t p0 = dst->op_params[6]; + + const uint32_t IH = src0->ne[1]; + const uint32_t IW = src0->ne[0]; + + const uint32_t N = dst->ne[3]; + + const uint32_t OC = dst->ne[2]; + const uint32_t OH = dst->ne[1]; + const uint32_t OW = dst->ne[0]; + + const uint32_t parallel_elements = N * OC * OH * OW; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, { + IW, IH, OW, OH, OC, + parallel_elements, + op, + k0, k1, s0, s1, p0, p1, + }, dryrun); +} + +static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + vk_op_conv2d_push_constants p{}; + p.Cout = static_cast(ne03); + p.Cin = static_cast(ne02); + p.N = static_cast(ne13); + + p.KW = static_cast(ne00); + p.KH = static_cast(ne01); + p.W = static_cast(ne10); + p.H = static_cast(ne11); + p.OW = static_cast(ne0); + p.OH = static_cast(ne1); + + p.s0 = static_cast(dst->op_params[0]); + p.s1 = static_cast(dst->op_params[1]); + p.p0 = static_cast(dst->op_params[2]); + p.p1 = static_cast(dst->op_params[3]); + p.d0 = static_cast(dst->op_params[4]); + p.d1 = static_cast(dst->op_params[5]); + + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb03 = static_cast(nb03 / nb00); + + p.nb11 = static_cast(nb11 / nb10); + p.nb12 = static_cast(nb12 / nb10); + p.nb13 = static_cast(nb13 / nb10); + + p.nb1 = static_cast(nb1 / nb0); + p.nb2 = static_cast(nb2 / nb0); + p.nb3 = static_cast(nb3 / nb0); + + GGML_ASSERT(ne03 == ne2); + GGML_ASSERT(ne02 == ne12); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun); +} + +static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + vk_op_conv_transpose_2d_push_constants p{}; + p.Cout = static_cast(ne02); + p.Cin = static_cast(ne03); + p.N = static_cast(ne13); + + p.KW = static_cast(ne00); + p.KH = static_cast(ne01); + p.W = static_cast(ne10); + p.H = static_cast(ne11); + p.OW = static_cast(ne0); + p.OH = static_cast(ne1); + + p.s0 = static_cast(dst->op_params[0]); + p.s1 = static_cast(dst->op_params[0]); + p.p0 = 0; + p.p1 = 0; + p.d0 = 1; + p.d1 = 1; + + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb03 = static_cast(nb03 / nb00); + + p.nb11 = static_cast(nb11 / nb10); + p.nb12 = static_cast(nb12 / nb10); + p.nb13 = static_cast(nb13 / nb10); + + p.nb1 = static_cast(nb1 / nb0); + p.nb2 = static_cast(nb2 / nb0); + p.nb3 = static_cast(nb3 / nb0); + + GGML_ASSERT(ne02 == ne2); + GGML_ASSERT(ne03 == ne12); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun); +} + +static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + vk_op_conv2d_dw_push_constants p{}; + p.ne = ggml_nelements(dst); + p.channels = dst->ne[2]; + p.batches = dst->ne[3]; + p.dst_w = dst->ne[0]; + p.dst_h = dst->ne[1]; + p.src_w = src1->ne[0]; + p.src_h = src1->ne[1]; + p.knl_w = src0->ne[0]; + p.knl_h = src0->ne[1]; + p.stride_x = dst->op_params[0]; + p.stride_y = dst->op_params[1]; + p.pad_x = dst->op_params[2]; + p.pad_y = dst->op_params[3]; + p.dilation_x = dst->op_params[4]; + p.dilation_y = dst->op_params[5]; + + GGML_ASSERT(src0->ne[3] == p.channels); + GGML_ASSERT(src1->ne[3] == p.batches); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun); +} + +static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const float * op_params = (const float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); +} + +#ifdef GGML_VULKAN_RUN_TESTS +static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) { + if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) { + float val; + if (type == GGML_TYPE_F32) { + val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0); + } else if (type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0)); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +template +static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) { + VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")"); + const size_t x_ne = m * k * batch; + const size_t y_ne = k * n * batch; + const size_t d_ne = m * n * batch; + + vk_pipeline p; + std::string shname; + if (shader_size == 0) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_s; + shname = "F32_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_s; + shname = "F32_F16_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s; + shname = "F16_F32_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_s; + shname = "F16_ALIGNED_S"; + } else { + GGML_ABORT("fatal error"); + } + } else if (shader_size == 1) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_m; + shname = "F32_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_m; + shname = "F32_F16_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m; + shname = "F16_F32_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_m; + shname = "F16_ALIGNED_M"; + } else { + GGML_ABORT("fatal error"); + } + } else if (shader_size == 2) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_l; + shname = "F32_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_l; + shname = "F32_F16_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l; + shname = "F16_F32_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_l; + shname = "F16_ALIGNED_L"; + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ASSERT(0); + } + + const size_t kpad = ggml_vk_align_size(k, p->align); + + if (k != kpad) { + if (shader_size == 0) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->s; + shname = "F32_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->s; + shname = "F32_F16_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->s; + shname = "F16_F32_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->s; + shname = "F16_S"; + } + } else if (shader_size == 1) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->m; + shname = "F32_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->m; + shname = "F32_F16_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->m; + shname = "F16_F32_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->m; + shname = "F16_M"; + } + } else if (shader_size == 2) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->l; + shname = "F32_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->l; + shname = "F32_F16_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->l; + shname = "F16_F32_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->l; + shname = "F16_L"; + } + } + } + + ggml_pipeline_request_descriptor_sets(ctx, p, num_it); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); + + if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + } + } + + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + + ggml_pipeline_allocate_descriptor_sets(ctx); + + vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + + X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne); + Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne); + float* d = (float *) malloc(sizeof(float) * d_ne); + + for (size_t i = 0; i < x_ne; i++) { + if (std::is_same()) { + x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // x[i] = 1.0f; + // x[i] = i + 1; + // x[i] = (i % k == i / k) ? 1.0f : 0.0f; + } else if (std::is_same()) { + x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); + // x[i] = ggml_fp32_to_fp16(1.0f); + // x[i] = ggml_fp32_to_fp16(i + 1); + // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); + } else { + GGML_ABORT("fatal error"); + } + } + for (size_t i = 0; i < y_ne; i++) { + if (std::is_same()) { + y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // y[i] = (i % k == i / k) ? 1.0f : 0.0f; + // y[i] = i + 1; + } else if (std::is_same()) { + y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); + // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); + // y[i] = ggml_fp32_to_fp16(i + 1); + } else { + GGML_ABORT("fatal error"); + } + } + + ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch); + ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ggml_vk_ctx_begin(ctx->device, subctx); + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul( + ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k), + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); + + auto end = std::chrono::high_resolution_clock::now(); + double time = std::chrono::duration_cast(end-begin).count() / 1000.0; + + // copy dst to host + ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne); + + float * d_chk = (float *) malloc(sizeof(float) * d_ne); + + ggml_init_params iparams = { + /*.mem_size =*/ 1024*1024*1024, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ggml_ctx = ggml_init(iparams); + + ggml_type src0_type; + ggml_type src1_type; + + if (std::is_same()) { + src0_type = GGML_TYPE_F32; + } else if (std::is_same()) { + src0_type = GGML_TYPE_F16; + } else { + GGML_ABORT("fatal error"); + } + if (std::is_same()) { + src1_type = GGML_TYPE_F32; + } else if (std::is_same()) { + src1_type = GGML_TYPE_F16; + } else { + GGML_ABORT("fatal error"); + } + + ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch); + ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch); + ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); + + src0_ggml->data = x; + src1_ggml->data = y; + tensor_ggml->data = d_chk; + + ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph, tensor_ggml); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); + + ggml_free(ggml_ctx); + + double avg_err = 0.0; + int first_err_n = -1; + int first_err_m = -1; + int first_err_b = -1; + + for (size_t i = 0; i < m*n*batch; i++) { + double err = std::fabs(d[i] - d_chk[i]); + avg_err += err; + + if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { + first_err_b = i / (m * n); + first_err_n = (i % (m * n)) / m; + first_err_m = (i % (m * n)) % m; + } + } + + avg_err /= m * n; + + double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; + + if (avg_err > 0.1 || std::isnan(avg_err)) { + std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + std::cerr << "Expected result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + if (split_k > 1) { + float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); + ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); + + std::cerr << "d_buf0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf2: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf3: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + free(split_k_buf); + } + } + + free(d_chk); + + ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); + + ggml_vk_destroy_buffer(d_X); + ggml_vk_destroy_buffer(d_Y); + ggml_vk_destroy_buffer(d_D); + + free(x); + free(y); + free(d); +} + +static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + i3 = std::max(i3, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { + float val; + if (tensor->type == GGML_TYPE_F32) { + val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) { + ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr); +} + +static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) { + if (quant == GGML_TYPE_F32) { + memcpy(to, from, sizeof(float) * ne); + return; + } + + const auto * tt = ggml_get_type_traits(quant); + + ggml_to_float_t dequant_fn = tt->to_float; + + dequant_fn(from, to, ne); +} + +static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { + VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")"); + const size_t x_sz = sizeof(float) * ne; + const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne; + const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); + float * x = (float *) malloc(x_sz); + void * qx = malloc(qx_sz); + vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + float * x_ref = (float *) malloc(x_sz); + ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16); + + for (size_t i = 0; i < ne; i++) { + x[i] = rand() / (float)RAND_MAX; + } + + vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant); + + ggml_vk_quantize_data(x, qx, ne, quant); + ggml_vk_dequantize_data(qx, x_ref, ne, quant); + + ggml_pipeline_request_descriptor_sets(ctx, p, 1); + + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + + ggml_pipeline_allocate_descriptor_sets(ctx); + + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ggml_vk_ctx_begin(ctx->device, subctx); + const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; + ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1}); + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); + + auto end = std::chrono::high_resolution_clock::now(); + + double ms_dequant = std::chrono::duration_cast(end-begin).count() / 1000.0; + ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16); + + int first_err = -1; + + double avg_err = 0.0; + for (size_t i = 0; i < ne; i++) { + double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i])); + avg_err += error; + + if (first_err < 0 && error > 0.05) { + first_err = i; + } + } + + avg_err /= ne; + + std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl; + + if (avg_err > 0.1) { + std::cerr << "first_error = " << first_err << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { + std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", "; + } + std::cerr << std::endl << "Expected result: " << std::endl << std::endl; + for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { + std::cerr << x_ref[i] << ", "; + } + std::cerr << std::endl; + } + + ggml_vk_destroy_buffer(x_buf); + ggml_vk_destroy_buffer(qx_buf); + + free(x); + free(qx); + free(x_ref); + free(x_chk); +} + +// This does not work without ggml q8_1 quantization support +// +// typedef uint16_t ggml_half; +// typedef uint32_t ggml_half2; +// +// #define QK8_1 32 +// typedef struct { +// union { +// struct { +// ggml_half d; // delta +// ggml_half s; // d * sum(qs[i]) +// } GGML_COMMON_AGGR_S; +// ggml_half2 ds; +// } GGML_COMMON_AGGR_U; +// int8_t qs[QK8_1]; // quants +// } block_q8_1; +// +// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { +// VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")"); +// GGML_ASSERT(quant == GGML_TYPE_Q8_1); +// +// const size_t x_sz = sizeof(float) * ne; +// const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); +// float * x = (float *) malloc(x_sz); +// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz); +// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz); +// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); +// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); +// +// for (size_t i = 0; i < ne; i++) { +// x[i] = rand() / (float)RAND_MAX; +// } +// +// vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant); +// +// ggml_pipeline_request_descriptor_sets(ctx, p, 1); +// +// if (ctx->device->need_compiles) { +// ggml_vk_load_shaders(ctx->device); +// } +// +// ggml_pipeline_allocate_descriptor_sets(ctx); +// +// ggml_vk_buffer_write(x_buf, 0, x, x_sz); +// +// vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); +// ggml_vk_ctx_begin(ctx->device, subctx); +// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, x_buf), ggml_vk_subbuffer(ctx, qx_buf), ne); +// ggml_vk_ctx_end(subctx); +// +// auto begin = std::chrono::high_resolution_clock::now(); +// +// ggml_vk_submit(subctx, ctx->fence); +// VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences"); +// ctx->device->device.resetFences({ ctx->fence }); +// ggml_vk_queue_command_pools_cleanup(ctx->device); +// +// auto end = std::chrono::high_resolution_clock::now(); +// +// double ms_quant = std::chrono::duration_cast(end-begin).count() / 1000.0; +// ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz); +// +// ggml_vk_quantize_data(x, qx_res, ne, quant); +// +// int first_err = -1; +// +// for (size_t i = 0; i < ne / 32; i++) { +// double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d)); +// +// if (first_err < 0 && error > 0.1) { +// first_err = i; +// } +// +// error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s)); +// +// if (first_err < 0 && error > 0.1) { +// first_err = i; +// } +// +// for (size_t j = 0; j < 32; j++) { +// uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]); +// +// if (first_err < 0 && error > 1) { +// first_err = i; +// } +// } +// } +// +// std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl; +// +// if (first_err != -1) { +// std::cerr << "first_error = " << first_err << std::endl; +// std::cerr << "Actual result: " << std::endl << std::endl; +// std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " "; +// for (size_t j = 0; j < 32; j++) { +// std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " "; +// } +// std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl; +// std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " "; +// for (size_t j = 0; j < 32; j++) { +// std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " "; +// } +// std::cerr << std::endl; +// } +// +// ggml_vk_destroy_buffer(x_buf); +// ggml_vk_destroy_buffer(qx_buf); +// +// free(x); +// free(qx); +// free(qx_res); +// } + +static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) { + VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")"); + const size_t x_ne = m * k * batch; + const size_t y_ne = k * n * batch; + const size_t d_ne = m * n * batch; + + vk_matmul_pipeline2 * pipelines; + + if (mmq) { + pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1; + } else { + pipelines = ctx->device->pipeline_dequant_mul_mat_mat; + } + + const bool fp16acc = ctx->device->fp16; + + vk_pipeline p; + std::string shname; + if (shader_size == 0) { + p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S"; + } else if (shader_size == 1) { + p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M"; + } else if (shader_size == 2) { + p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L"; + } else { + GGML_ASSERT(0); + } + + const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align); + + if (mmq || k != kpad) { + if (shader_size == 0) { + p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s; + shname = std::string(ggml_type_name(quant)) + "_S"; + } else if (shader_size == 1) { + p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m; + shname = std::string(ggml_type_name(quant)) + "_M"; + } else if (shader_size == 2) { + p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l; + shname = std::string(ggml_type_name(quant)) + "_L"; + } else { + GGML_ASSERT(0); + } + } + + if (p == nullptr) { + std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl; + return; + } + + const size_t x_sz = sizeof(float) * x_ne; + const size_t y_sz = sizeof(float) * y_ne; + const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant); + const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz; + const size_t d_sz = sizeof(float) * d_ne; + float * x = (float *) malloc(x_sz); + float * y = (float *) malloc(y_sz); + void * qx = malloc(qx_sz); + vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + float * d = (float *) malloc(d_sz); + float * d_chk = (float *) malloc(d_sz); + + for (size_t i = 0; i < x_ne; i++) { + x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // x[i] = (i % k == i / k) ? 1.0f : 0.0f; + // x[i] = i % k; + } + + ggml_vk_quantize_data(x, qx, x_ne, quant); + + for (size_t i = 0; i < y_ne; i++) { + y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // y[i] = (i % k == i / k) ? 1.0f : 0.0f; + // y[i] = i % k; + } + + ggml_pipeline_request_descriptor_sets(ctx, p, num_it); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); + + if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + } + } + if (mmq) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it); + } + + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + + ggml_pipeline_allocate_descriptor_sets(ctx); + + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); + ggml_vk_buffer_write(y_buf, 0, y, y_sz); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ggml_vk_ctx_begin(ctx->device, subctx); + if (mmq) { + for (size_t i = 0; i < num_it; i++) { + ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne); + ggml_vk_matmul( + ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } + } else { + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul( + ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } + } + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); + + auto end = std::chrono::high_resolution_clock::now(); + + double time_ms = std::chrono::duration_cast(end-begin).count() / 1000.0; + ggml_vk_buffer_read(d_buf, 0, d, d_sz); + + ggml_init_params iparams = { + /*.mem_size =*/ 1024*1024*1024, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ggml_ctx = ggml_init(iparams); + + ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch); + ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch); + ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); + + src0_ggml->data = qx; + src1_ggml->data = y; + tensor_ggml->data = d_chk; + + ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph, tensor_ggml); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); + + ggml_free(ggml_ctx); + + double avg_err = 0.0; + int first_err_n = -1; + int first_err_m = -1; + int first_err_b = -1; + + for (size_t i = 0; i < m*n*batch; i++) { + double err = std::fabs(d[i] - d_chk[i]); + avg_err += err; + + if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { + first_err_b = i / (m * n); + first_err_n = (i % (m * n)) / m; + first_err_m = (i % (m * n)) % m; + } + } + + avg_err /= m * n; + + double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST dequant matmul " << shname; + if (mmq) { + std::cerr << " mmq"; + } + std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; + + if (avg_err > 0.01 || std::isnan(avg_err)) { + std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + std::cerr << std::endl; + std::cerr << "Expected result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "src0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b); + std::cerr << std::endl; + std::cerr << "src1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b); + + if (split_k > 1) { + float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); + ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); + + std::cerr << "d_buf0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf2: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf3: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + free(split_k_buf); + } + } + + ggml_vk_destroy_buffer(qx_buf); + ggml_vk_destroy_buffer(y_buf); + ggml_vk_destroy_buffer(qy_buf); + ggml_vk_destroy_buffer(d_buf); + + free(x); + free(qx); + free(y); + free(d); + free(d_chk); +} +#endif + +static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { +#if defined(GGML_VULKAN_RUN_TESTS) + const std::vector vals { + 512, 512, 128, + 128, 512, 512, + 4096, 512, 4096, + 11008, 512, 4096, + 4096, 512, 11008, + 32000, 512, 4096, + 8, 8, 8, + 100, 46, 576, + 623, 111, 128, + 100, 46, 558, + 512, 1, 256, + 128, 110, 622, + 511, 511, 127, + 511, 511, 7, + 511, 511, 17, + 49, 49, 128, + 128, 49, 49, + 4096, 49, 4096, + }; + const size_t num_it = 100; + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true); + + abort(); + + for (size_t i = 0; i < vals.size(); i += 3) { + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2); + std::cerr << '\n'; + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2); + std::cerr << '\n'; + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2); + std::cerr << '\n' << std::endl; + + if (vals[i + 2] % 32 == 0) { + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0); + std::cerr << '\n' << std::endl; + } + + if (vals[i + 2] % 256 == 0) { + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K); + std::cerr << '\n' << std::endl; + } + } + + GGML_ABORT("fatal error"); +#endif + + if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")"); + // Resize buffer + if (ctx->prealloc_x != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_x); + } + ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x); + } + if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")"); + // Resize buffer + if (ctx->prealloc_y != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_y); + } + ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); + } + if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k); + } + if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")"); + // Resize buffer + if (ctx->prealloc_add_rms_partials != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials); + } + ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials); + } +} + +static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); + +// Returns true if node has enqueued work into the queue, false otherwise +// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. +static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){ + ggml_tensor * node = cgraph->nodes[node_idx]; + if (ggml_is_empty(node) || !node->buffer) { + return false; + } + + VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); + ctx->semaphore_idx = 0; + + ggml_tensor * src0 = node->src[0]; + ggml_tensor * src1 = node->src[1]; + ggml_tensor * src2 = node->src[2]; + ggml_tensor * src3 = node->src[3]; + + switch (node->op) { + // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NONE: + return false; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: + break; + default: + return false; + } + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + break; + default: + return false; + } + break; + case GGML_OP_ADD: + { + int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops; + if (next_node_idx < cgraph->n_nodes && + cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM && + cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] && + ggml_nrows(cgraph->nodes[next_node_idx]) == 1 && + ctx->device->add_rms_fusion) { + if (dryrun) { + ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]); + } + ctx->do_add_rms_partials = true; + } + } break; + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD_ID: + case GGML_OP_ACC: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_ROLL: + case GGML_OP_CPY: + case GGML_OP_SET_ROWS: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_SILU_BACK: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ARGSORT: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + break; + default: + std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; + GGML_ABORT("fatal error"); + } + + vk_context compute_ctx; + + if (!dryrun) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + } else { + switch (node->op) { + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_ACC: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_CPY: + case GGML_OP_SET_ROWS: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_SILU_BACK: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: + case GGML_OP_UNARY: + case GGML_OP_GLU: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_ARGSORT: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_LEAKY_RELU: + case GGML_OP_OPT_STEP_SGD: + { + // These operations all go through ggml_vk_op_f32, so short-circuit and + // do the only thing needed for the dryrun. + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + if (node->op == GGML_OP_RMS_NORM) { + ctx->do_add_rms_partials = false; + } + return false; + } + default: + break; + } + } + + if (!dryrun) { + // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers + // to synchronize them. This handles most "normal" synchronization when computing the graph, and when + // there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers + // outside of this logic. When a node uses one of the prealloc buffers for something like + // dequantization or split_k, additional synchronization is needed between those passes. + bool need_sync = false; + + // Check whether "node" requires synchronization. The node requires synchronization if it + // overlaps in memory with another unsynchronized node and at least one of them is a write. + // Destination nodes are checked against both the written/read lists. Source nodes are only + // checked against the written list. Two nodes overlap in memory if they come from the same + // buffer and the tensor or view ranges overlap. + auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector &unsynced_nodes) -> bool { + if (unsynced_nodes.size() == 0) { + return false; + } + auto n_base = vk_tensor_offset(node) + node->view_offs; + auto n_size = ggml_nbytes(node); + ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context; + vk_buffer a_buf = a_buf_ctx->dev_buffer; + for (auto &other : unsynced_nodes) { + ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context; + vk_buffer o_buf = o_buf_ctx->dev_buffer; + if (a_buf == o_buf) { + auto o_base = vk_tensor_offset(other) + other->view_offs; + auto o_size = ggml_nbytes(other); + + if ((o_base <= n_base && n_base < o_base + o_size) || + (n_base <= o_base && o_base < n_base + n_size)) { + return true; + } + } + } + return false; + }; + + // For all fused ops, check if the destination node or any of the source + // nodes require synchronization. + for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) { + const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; + if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) { + need_sync = true; + break; + } + for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { + if (!cur_node->src[j]) { + continue; + } + if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) { + need_sync = true; + break; + } + } + } + if (need_sync) { + ctx->unsynced_nodes_written.clear(); + ctx->unsynced_nodes_read.clear(); + ggml_vk_sync_buffers(ctx, compute_ctx); + } + // Add the last fused node and all fused source nodes to the unsynchronized list. + const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; + ctx->unsynced_nodes_written.push_back(last_node); + for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { + const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; + for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { + if (!cur_node->src[j]) { + continue; + } + ctx->unsynced_nodes_read.push_back(cur_node->src[j]); + } + } + } + + switch (node->op) { + case GGML_OP_REPEAT: + ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_REPEAT_BACK: + ggml_vk_repeat_back(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_ACC: + ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_GET_ROWS: + ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_ADD: + if (ctx->num_additional_fused_ops) { + ggml_vk_multi_add(ctx, compute_ctx, cgraph, node_idx, dryrun); + } else { + ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); + } + break; + case GGML_OP_SUB: + ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_MUL: + ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_DIV: + ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_ADD_ID: + ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); + + break; + case GGML_OP_CONCAT: + ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_UPSCALE: + ggml_vk_upscale(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SCALE: + ggml_vk_scale(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SQR: + ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SQRT: + ggml_vk_sqrt(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SIN: + ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_COS: + ggml_vk_cos(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CLAMP: + ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_PAD: + ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_ROLL: + ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SET_ROWS: + ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_SILU_BACK: + ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_NORM: + ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_GROUP_NORM: + ggml_vk_group_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_RMS_NORM: + if (ctx->num_additional_fused_ops > 0) { + // fused rms_norm + mul + ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0]; + ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun); + } else { + ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun); + } + break; + case GGML_OP_RMS_NORM_BACK: + ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_L2_NORM: + ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: + ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun); + break; + default: + return false; + } + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun); + break; + default: + return false; + } + break; + case GGML_OP_DIAG_MASK_INF: + ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SOFT_MAX: + ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun); + + break; + case GGML_OP_SOFT_MAX_BACK: + ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_ROPE: + ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun); + + break; + case GGML_OP_ROPE_BACK: + ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun); + + break; + case GGML_OP_ARGSORT: + ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SUM: + ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SUM_ROWS: + ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_MEAN: + ggml_vk_mean(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_ARGMAX: + ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_COUNT_EQUAL: + ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_IM2COL: + ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_IM2COL_3D: + ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_TIMESTEP_EMBEDDING: + ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CONV_TRANSPOSE_1D: + ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_POOL_2D: + ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CONV_2D: + ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_CONV_TRANSPOSE_2D: + ggml_vk_conv_transpose_2d(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_CONV_2D_DW: + ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_LEAKY_RELU: + ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_MUL_MAT: + ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_MUL_MAT_ID: + ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); + + break; + + case GGML_OP_FLASH_ATTN_EXT: + ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node, dryrun); + + break; + + case GGML_OP_RWKV_WKV6: + ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun); + + break; + + case GGML_OP_RWKV_WKV7: + ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun); + + break; + + case GGML_OP_OPT_STEP_ADAMW: + ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); + + break; + + case GGML_OP_OPT_STEP_SGD: + ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun); + + break; + default: + return false; + } + + if (dryrun) { + return false; + } + + ctx->tensor_ctxs[node_idx] = compute_ctx; + +#if defined(GGML_VULKAN_CHECK_RESULTS) + // Force context reset on each node so that each tensor ends up in its own context + // and can be run and compared to its CPU equivalent separately + last_node = true; +#endif + + if (submit || last_node) { + ggml_vk_ctx_end(compute_ctx); + + // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward + if (last_node) { + compute_ctx->exit_tensor_idx = node_idx_begin; + } + else { + compute_ctx->exit_tensor_idx = -1; + } + + ctx->compute_ctx.reset(); + + bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready); + if (!ok) { + if (node->op == GGML_OP_UNARY) { + std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; + } else if (node->op == GGML_OP_GLU) { + std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast(node->op_params[0])) << ")" << std::endl; + } else { + std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; + } + } + + } + return true; +} + +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { + GGML_UNUSED(cgraph); + ggml_backend_buffer * buf = nullptr; + + switch (tensor->op) { + case GGML_OP_ADD: + case GGML_OP_ACC: + case GGML_OP_GET_ROWS: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_ADD_ID: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_ROLL: + case GGML_OP_CPY: + case GGML_OP_SET_ROWS: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_SILU_BACK: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NONE: + case GGML_OP_ARGSORT: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + case GGML_OP_LEAKY_RELU: + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + buf = tensor->buffer; + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: + buf = tensor->buffer; + break; + default: + return false; + } + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(tensor)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + buf = tensor->buffer; + break; + default: + return false; + } + break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_FLASH_ATTN_EXT: + buf = tensor->buffer; + + break; + default: + return false; + } + + if (buf == nullptr) { + return false; + } + + VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")"); + + vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock(); + + // always wait for the GPU work to be done for the last submit + if (tensor_idx == subctx->exit_tensor_idx) { + use_fence = true; + } + + // Only run if ctx hasn't been submitted yet + if (!subctx->seqs.empty()) { +#ifdef GGML_VULKAN_CHECK_RESULTS + ggml_vk_check_results_0(ctx, cgraph, tensor_idx); + use_fence = true; +#endif + + // Do staging buffer copies + for (auto& cpy : subctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + for (auto& mset : subctx->memsets) { + memset(mset.dst, mset.val, mset.n); + } + + if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) { + ggml_vk_submit(subctx, ctx->almost_ready_fence); + ctx->almost_ready_fence_pending = true; + } else { + ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); + } + + if (use_fence) { + ggml_vk_wait_for_fence(ctx); + } +#ifdef GGML_VULKAN_CHECK_RESULTS + ggml_vk_check_results_1(ctx, cgraph, tensor_idx); +#endif + } + + if (tensor_idx == subctx->exit_tensor_idx) { + // Do staging buffer copies + for (auto& cpy : subctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + subctx->in_memcpys.clear(); + subctx->out_memcpys.clear(); + subctx->memsets.clear(); + } + + return true; +} + +// Clean up after graph processing is done +static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_graph_cleanup()"); + for (auto& buffer : ctx->gc.temp_buffers) { + ggml_vk_pool_free(ctx, buffer); + } + ctx->gc.temp_buffers.clear(); + ctx->prealloc_y_last_pipeline_used = {}; + + ctx->unsynced_nodes_written.clear(); + ctx->unsynced_nodes_read.clear(); + ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; + + ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); + + for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { + ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); + } + ctx->gc.semaphores.clear(); + + for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) { + ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s }); + } + ctx->gc.tl_semaphores.clear(); + ctx->semaphore_idx = 0; + + ctx->event_idx = 0; + + for (auto& event : ctx->gc.events) { + ctx->device->device.resetEvent(event); + } + + ctx->tensor_ctxs.clear(); + ctx->gc.contexts.clear(); + ctx->pipeline_descriptor_set_requirements = 0; + ctx->descriptor_set_idx = 0; +} + +// Clean up on backend free +static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")"); + ggml_vk_graph_cleanup(ctx); + + ggml_vk_destroy_buffer(ctx->prealloc_x); + ggml_vk_destroy_buffer(ctx->prealloc_y); + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + ctx->prealloc_y_last_pipeline_used = nullptr; + + for (auto& buffer : ctx->buffer_pool) { + ggml_vk_destroy_buffer(buffer); + } + + ctx->prealloc_size_x = 0; + ctx->prealloc_size_y = 0; + ctx->prealloc_size_split_k = 0; + + for (auto& event : ctx->gc.events) { + ctx->device->device.destroyEvent(event); + } + ctx->gc.events.clear(); + + ctx->device->device.destroyFence(ctx->fence); + ctx->device->device.destroyFence(ctx->almost_ready_fence); + + for (auto& pool : ctx->descriptor_pools) { + ctx->device->device.destroyDescriptorPool(pool); + } + ctx->descriptor_pools.clear(); + ctx->descriptor_sets.clear(); + + ctx->compute_cmd_pool.destroy(ctx->device->device); + ctx->transfer_cmd_pool.destroy(ctx->device->device); +} + +static int ggml_vk_get_device_count() { + ggml_vk_instance_init(); + + return vk_instance.device_indices.size(); +} + +static void ggml_vk_get_device_description(int device, char * description, size_t description_size) { + ggml_vk_instance_init(); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + vk::PhysicalDeviceProperties props; + devices[device].getProperties(&props); + + snprintf(description, description_size, "%s", props.deviceName.data()); +} + +static std::string ggml_vk_get_device_id(int device) { + ggml_vk_instance_init(); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + vk::PhysicalDeviceProperties2 props; + vk::PhysicalDeviceIDProperties deviceIDProps; + props.pNext = &deviceIDProps; + devices[device].getProperties2(&props); + + const auto& uuid = deviceIDProps.deviceUUID; + char id[64]; + snprintf(id, sizeof(id), + "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", + uuid[0], uuid[1], uuid[2], uuid[3], + uuid[4], uuid[5], + uuid[6], uuid[7], + uuid[8], uuid[9], + uuid[10], uuid[11], uuid[12], uuid[13], uuid[14], uuid[15] + ); + return std::string(id); +} + +// backend interface + +#define UNUSED GGML_UNUSED + +// device backend + +static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) { + return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name; +} + +static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { + VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()"); + ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; + ggml_vk_destroy_buffer(ctx->dev_buffer); + delete ctx; + delete buffer; +} + +static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { + return vk_ptr_base; + + UNUSED(buffer); +} + +static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")"); + if (tensor->view_src != nullptr) { + GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); + } + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + uint32_t val32 = (uint32_t)value * 0x01010101; + ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size); +} + +static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + if (ggml_backend_buffer_is_vk(src->buffer)) { + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + + vk_buffer src_buf = src_buf_ctx->dev_buffer; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; + + ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + + return true; + } + return false; + + UNUSED(buffer); +} + +static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; + + ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size); +} + +static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { + /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer, + /* .get_base = */ ggml_backend_vk_buffer_get_base, + /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, + /* .clear = */ ggml_backend_vk_buffer_clear, + /* .reset = */ NULL, +}; + +// vk buffer type +static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context; + + return ctx->name.c_str(); +} + +static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")"); + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + + vk_buffer dev_buffer = nullptr; + try { + dev_buffer = ggml_vk_create_buffer_device(ctx->device, size); + } catch (const vk::SystemError& e) { + return nullptr; + } + + ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name); + + return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size); +} + +static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + return ctx->device->properties.limits.minStorageBufferOffsetAlignment; +} + +static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + return ctx->device->suballocation_block_size; +} + +static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + return ggml_nbytes(tensor); + + UNUSED(buft); +} + +ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) { + ggml_vk_instance_init(); + + VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")"); + + vk_device dev = ggml_vk_get_device(dev_num); + + return &dev->buffer_type; +} + +// host buffer type + +static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) { + return GGML_VK_NAME "_Host"; + + UNUSED(buft); +} + +static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) { + return GGML_VK_NAME "_Host"; + + UNUSED(buffer); +} + +static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { + VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); + ggml_vk_host_free(vk_instance.devices[0], buffer->context); + delete buffer; +} + +static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")"); + + size += 32; // Behave like the CPU buffer type + void * ptr = nullptr; + try { + ptr = ggml_vk_host_malloc(vk_instance.devices[0], size); + } catch (vk::SystemError& e) { + GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what()); + // fallback to cpu buffer + return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + } + + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer; + + return buffer; + + UNUSED(buft); +} + +static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment; + + UNUSED(buft); +} + +static size_t ggml_backend_vk_host_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + return vk_instance.devices[0]->suballocation_block_size; + + UNUSED(buft); +} + +// Should be changed to return device-specific host buffer type +// but that probably requires changes in llama.cpp +ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { + static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = { + /* .iface = */ { + /* .get_name = */ ggml_backend_vk_host_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_vk_host_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, + /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0), + /* .context = */ nullptr, + }; + + // Make sure device 0 is initialized + ggml_vk_instance_init(); + ggml_vk_get_device(0); + + return &ggml_backend_vk_buffer_type_host; +} + + +// backend + +static const char * ggml_backend_vk_name(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + return ctx->name.c_str(); +} + +static void ggml_backend_vk_free(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")"); + + ggml_vk_cleanup(ctx); + + delete ctx; + delete backend; +} + +static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + return &ctx->device->buffer_type; +} + +static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_read_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { + VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) { + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer src_buf = src_buf_ctx->dev_buffer; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; + + ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + return true; + } + + return false; +} + +static void ggml_backend_vk_synchronize(ggml_backend_t backend) { + VK_LOG_DEBUG("ggml_backend_vk_synchronize()"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + if(ctx->transfer_ctx.expired()) { + return; + } + + vk_context transfer_ctx = ctx->transfer_ctx.lock(); + + ggml_vk_ctx_end(transfer_ctx); + + for (auto& cpy : transfer_ctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ggml_vk_submit(transfer_ctx, ctx->fence); + ggml_vk_wait_for_fence(ctx); + + for (auto& cpy : transfer_ctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ctx->transfer_ctx.reset(); +} + +static bool ggml_vk_is_empty(ggml_tensor * node) { + return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; +} + +static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + // additional constraints specific to this fusion + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && + !ggml_are_same_shape(mul->src[0], rms_norm)) { + return false; + } + // rms_norm shader assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + } + return true; +} + +static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { + + const ggml_tensor *first_node = cgraph->nodes[node_idx]; + if (first_node->op != GGML_OP_ADD) { + return 0; + } + + if (!ctx->device->multi_add) { + return 0; + } + + int32_t num_adds = 1; + while (node_idx + num_adds < cgraph->n_nodes && + cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD && + num_adds < MAX_FUSED_ADDS) { + num_adds++; + } + + // The shader currently requires same shapes (but different strides are allowed), + // everything f32, and no misalignment + for (int32_t i = 0; i < num_adds; ++i) { + const ggml_tensor *next_node = cgraph->nodes[node_idx + i]; + if (!ggml_are_same_shape(first_node, next_node->src[0]) || + !ggml_are_same_shape(first_node, next_node->src[1]) || + next_node->type != GGML_TYPE_F32 || + next_node->src[0]->type != GGML_TYPE_F32 || + next_node->src[1]->type != GGML_TYPE_F32 || + get_misalign_bytes(ctx, next_node) || + get_misalign_bytes(ctx, next_node->src[0]) || + get_misalign_bytes(ctx, next_node->src[1])) { + num_adds = i; + } + } + + // Verify we can fuse these + ggml_op adds[MAX_FUSED_ADDS]; + for (int32_t i = 0; i < num_adds; ++i) { + adds[i] = GGML_OP_ADD; + } + + // decrease num_adds if they can't all be fused + while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, adds, num_adds)) { + num_adds--; + } + + // a single add is not "fused", so just return zero + if (num_adds == 1) { + return 0; + } + return num_adds; +} + +static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + if (vk_instance.debug_utils_support) { + vk::DebugUtilsLabelEXT dul = {}; + dul.pLabelName = "ggml_backend_vk_graph_compute"; + dul.color = std::array{1.0f, 1.0f, 1.0f, 1.0f}; + vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast(&dul)); + } + + ctx->prealloc_size_add_rms_partials = 0; + ctx->prealloc_size_add_rms_partials_offset = 0; + ctx->do_add_rms_partials = false; + + uint64_t total_mat_mul_bytes = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + if (!ctx->device->disable_fusion) { + uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); + if (num_adds) { + ctx->num_additional_fused_ops = num_adds - 1; + } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; + } + } + ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); + if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { + total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); + } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D || cgraph->nodes[i]->op == GGML_OP_CONV_TRANSPOSE_2D) { + // Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode. + auto CRS_size = + cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[1]->ne[2]; + auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3]; + total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type); + } + i += ctx->num_additional_fused_ops; + ctx->num_additional_fused_ops = 0; + } + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + ggml_vk_preallocate_buffers(ctx); + ggml_pipeline_allocate_descriptor_sets(ctx); + + int last_node = cgraph->n_nodes - 1; + + // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly + while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) { + last_node -= 1; + } + + // Reserve tensor context space for all nodes + ctx->tensor_ctxs.resize(cgraph->n_nodes); + + bool first_node_in_batch = true; // true if next node will be first node in a batch + int submit_node_idx = 0; // index to first node in a batch + + vk_context compute_ctx; + if (vk_perf_logger_enabled) { + // allocate/resize the query pool + if (ctx->device->num_queries < cgraph->n_nodes + 1) { + if (ctx->device->query_pool) { + ctx->device->device.destroyQueryPool(ctx->device->query_pool); + } + vk::QueryPoolCreateInfo query_create_info; + query_create_info.queryType = vk::QueryType::eTimestamp; + query_create_info.queryCount = cgraph->n_nodes + 100; + ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info); + ctx->device->num_queries = query_create_info.queryCount; + } + + ctx->device->device.resetQueryPool(ctx->device->query_pool, 0, cgraph->n_nodes+1); + + GGML_ASSERT(ctx->compute_ctx.expired()); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0); + } + + ctx->prealloc_y_last_pipeline_used = nullptr; + ctx->prealloc_y_last_tensor_used = nullptr; + + if (ctx->prealloc_size_add_rms_partials) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + // initialize partial sums to zero. + ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials); + ggml_vk_sync_buffers(ctx, compute_ctx); + } + + // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. + // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB + // (and scaled down based on model size, so smaller models submit earlier). + // Also submit at least every 100 nodes, in case there are workloads without as much matmul. + int nodes_per_submit = 100; + int submitted_nodes = 0; + int submit_count = 0; + uint64_t mul_mat_bytes = 0; + uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u); + for (int i = 0; i < cgraph->n_nodes; i++) { + if (first_node_in_batch) { + submit_node_idx = i; + } + + if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { + mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); + } + + if (!ctx->device->disable_fusion) { + uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); + if (num_adds) { + ctx->num_additional_fused_ops = num_adds - 1; + } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; + } + } + + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) + bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; + bool submit = (submitted_nodes >= nodes_per_submit) || + (mul_mat_bytes >= mul_mat_bytes_per_submit) || + (i + ctx->num_additional_fused_ops == last_node) || + (almost_ready && !ctx->almost_ready_fence_pending); + + bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit); + + if (vk_perf_logger_enabled) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple + for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) { + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1); + } + } + + if (enqueued) { + ++submitted_nodes; + +#ifndef GGML_VULKAN_CHECK_RESULTS + if (first_node_in_batch) { + first_node_in_batch = false; + } +#endif + } + + if (submit && enqueued) { + first_node_in_batch = true; + submitted_nodes = 0; + mul_mat_bytes = 0; + if (submit_count < 3) { + mul_mat_bytes_per_submit *= 2; + } + submit_count++; + } + i += ctx->num_additional_fused_ops; + ctx->num_additional_fused_ops = 0; + } + + if (vk_perf_logger_enabled) { + // End the command buffer and submit/wait + GGML_ASSERT(!ctx->compute_ctx.expired()); + compute_ctx = ctx->compute_ctx.lock(); + ggml_vk_ctx_end(compute_ctx); + + ggml_vk_submit(compute_ctx, ctx->device->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences"); + ctx->device->device.resetFences({ ctx->device->fence }); + + // Get the results and pass them to the logger + std::vector timestamps(cgraph->n_nodes + 1); + VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results"); + for (int i = 0; i < cgraph->n_nodes; i++) { + if (!ggml_vk_is_empty(cgraph->nodes[i])) { + ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod)); + } + } + + ctx->device->perf_logger->print_timings(); + } + + ggml_vk_graph_cleanup(ctx); + + return GGML_STATUS_SUCCESS; + + UNUSED(backend); +} + +// Sort the graph for improved parallelism. +static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph) +{ + VK_LOG_DEBUG("ggml_vk_graph_optimize(" << graph->n_nodes << " nodes)"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + if (ctx->device->disable_graph_optimize) { + return; + } + + auto const &is_empty = [](ggml_tensor * node) -> bool { + return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; + }; + + auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + if (dst->src[s] == src) { + return true; + } + } + // implicit dependency if they view the same tensor + const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst; + const ggml_tensor *src2 = src->view_src ? src->view_src : src; + if (dst2 == src2) { + return true; + } + return false; + }; + + // This function tries to reorder the graph to allow nodes to run in parallel. + // This helps with small batches, but for large batches its a slowdown, probably + // due to cache contention. So only reorder if the majority of nodes have few rows. + int num_small_nodes = 0; + int num_counted_nodes = 0; + for (int i = 0; i < graph->n_nodes; ++i) { + if (!is_empty(graph->nodes[i]) && + graph->nodes[i]->op != GGML_OP_SET_ROWS) { + if (ggml_nrows(graph->nodes[i]) <= 8) { + num_small_nodes++; + } + num_counted_nodes++; + } + } + if (num_small_nodes < num_counted_nodes / 2) { + return; + } + + std::vector new_order; + std::vector used(graph->n_nodes, false); + int first_unused = 0; + while (first_unused < graph->n_nodes) { + std::vector current_set; + + // First, grab the next unused node. + current_set.push_back(first_unused); + + // Loop through the next N nodes. Grab any that don't depend on other nodes that + // haven't already been run. Nodes that have already been run have used[i] set + // to true. Allow nodes that depend on the previous node if it's a fusion pattern + // that we support (e.g. RMS_NORM + MUL). + // This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes. + // The goal is to not interleave real and view nodes in a way that breaks fusion. + const int NUM_TO_CHECK = 20; + for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (is_empty(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + if (!used[c] && + is_src_of(graph->nodes[j], graph->nodes[c]) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL)) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + // Second pass grabs view nodes. + // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add). + if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) { + for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (!is_empty(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end(); + // skip views whose srcs haven't been processed. + if (!used[c] && + is_src_of(graph->nodes[j], graph->nodes[c]) && + !c_in_current_set) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + } + + // Push the current set into new_order + for (auto c : current_set) { + new_order.push_back(graph->nodes[c]); + used[c] = true; + } + while (first_unused < graph->n_nodes && used[first_unused]) { + first_unused++; + } + } + // Replace the graph with the new order. + for (int i = 0; i < graph->n_nodes; ++i) { + graph->nodes[i] = new_order[i]; + } +} + +// TODO: enable async and synchronize +static ggml_backend_i ggml_backend_vk_interface = { + /* .get_name = */ ggml_backend_vk_name, + /* .free = */ ggml_backend_vk_free, + /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async, + /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async, + /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, + /* .synchronize = */ NULL, // ggml_backend_vk_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_vk_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ ggml_vk_graph_optimize, +}; + +static ggml_guid_t ggml_backend_vk_guid() { + static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b }; + return &guid; +} + +ggml_backend_t ggml_backend_vk_init(size_t dev_num) { + VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")"); + + ggml_backend_vk_context * ctx = new ggml_backend_vk_context; + ggml_vk_init(ctx, dev_num); + + ggml_backend_t vk_backend = new ggml_backend { + /* .guid = */ ggml_backend_vk_guid(), + /* .iface = */ ggml_backend_vk_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num), + /* .context = */ ctx, + }; + + return vk_backend; +} + +bool ggml_backend_is_vk(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid()); +} + +int ggml_backend_vk_get_device_count() { + return ggml_vk_get_device_count(); +} + +void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + int dev_idx = vk_instance.device_indices[device]; + ggml_vk_get_device_description(dev_idx, description, description_size); +} + +std::string ggml_backend_vk_get_device_id(int device) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + int dev_idx = vk_instance.device_indices[device]; + return ggml_vk_get_device_id(dev_idx); +} + +////////////////////////// + +struct ggml_backend_vk_device_context { + size_t device; + std::string name; + std::string description; + bool is_integrated_gpu; + // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function) + std::string pci_id; + std::string id; + std::string uuid; + int major; + int minor; + int driver_major; + int driver_minor; + int pci_bus_id; + int pci_device_id; + int pci_domain_id; +}; + +void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) { + GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size()); + GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size()); + + vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; + + vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); + vk::PhysicalDeviceProperties2 props2; + vkdev.getProperties2(&props2); + + if (!ctx->is_integrated_gpu) + { + // Use vendor specific management libraries for best VRAM reporting if available + switch (props2.properties.vendorID) { + case VK_VENDOR_ID_AMD: + if (ggml_hip_mgmt_init() == 0) { + int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_hip_mgmt_release(); + return; + } + ggml_hip_mgmt_release(); + } + break; + case VK_VENDOR_ID_NVIDIA: + if (ggml_nvml_init() == 0) { + int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_nvml_release(); + return; + } + ggml_nvml_release(); + } + break; + } + } + // else fallback to memory budget if supported + + *total = 0; + *free = 0; + vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props; + vk::PhysicalDeviceMemoryProperties2 memprops2; + memprops2.pNext = &mem_budget_props; + vkdev.getMemoryProperties2(&memprops2); + for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { + if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total += memprops2.memoryProperties.memoryHeaps[i].size; + } else if (ctx->is_integrated_gpu) { + // Include shared memory on iGPUs + *total += memprops2.memoryProperties.memoryHeaps[i].size; + } + } + for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { + if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *free += mem_budget_props.heapBudget[i]; + } else if (ctx->is_integrated_gpu) { + *free += mem_budget_props.heapBudget[i]; + } + } + if (*total > 0 && *free > 0) { + return; + } else if (*total > 0) { + *free = *total; + return; + } + + // else just report the physical memory + for (const vk::MemoryHeap& heap : memprops2.memoryProperties.memoryHeaps) { + if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total = heap.size; + *free = heap.size; + break; + } + } +} + +static vk::PhysicalDeviceType ggml_backend_vk_get_device_type(int device_idx) { + GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size()); + + vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]]; + + vk::PhysicalDeviceProperties2 props = {}; + device.getProperties2(&props); + + return props.properties.deviceType; +} + +static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { + GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size()); + + vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]]; + + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool ext_support = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_EXT_pci_bus_info", properties.extensionName) == 0) { + ext_support = true; + break; + } + } + + if (!ext_support) { + return ""; + } + + vk::PhysicalDeviceProperties2 props = {}; + vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_info = {}; + + props.pNext = &pci_bus_info; + + device.getProperties2(&props); + + const uint32_t pci_domain = pci_bus_info.pciDomain; + const uint32_t pci_bus = pci_bus_info.pciBus; + const uint32_t pci_device = pci_bus_info.pciDevice; + const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning + + char pci_bus_id[16] = {}; + snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function); + + return std::string(pci_bus_id); +} + +static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) { + if (id.empty()) return false; + unsigned int d = 0, b = 0, dev = 0, func = 0; + // Expected format: dddd:bb:dd.f (all hex) + int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func); + if (n < 4) return false; + if (domain) *domain = (int) d; + if (bus) *bus = (int) b; + if (device) *device = (int) dev; + return true; +} + +static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->name.c_str(); +} + +static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->description.c_str(); +} + +static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->id.c_str(); +} + +static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; + ggml_backend_vk_get_device_memory(ctx, free, total); +} + +static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ggml_backend_vk_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) { + UNUSED(dev); + return ggml_backend_vk_host_buffer_type(); +} + +static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + + return ctx->is_integrated_gpu ? GGML_BACKEND_DEVICE_TYPE_IGPU : GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + + props->name = ggml_backend_vk_device_get_name(dev); + props->description = ggml_backend_vk_device_get_description(dev); + props->id = ggml_backend_vk_device_get_id(dev); + props->type = ggml_backend_vk_device_get_type(dev); + props->device_id = ctx->pci_id.empty() ? nullptr : ctx->pci_id.c_str(); + ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ true, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; + + props->compute_major = ctx->major; + props->compute_minor = ctx->minor; + props->driver_major = ctx->driver_major; + props->driver_minor = ctx->driver_minor; + props->integrated = ctx->is_integrated_gpu; + props->pci_bus_id = ctx->pci_bus_id; + props->pci_device_id = ctx->pci_device_id; + props->pci_domain_id = ctx->pci_domain_id; + props->library = GGML_VK_NAME; + props->numeric_id = ctx->id.empty() ? nullptr : ctx->id.c_str(); +} + +static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { + UNUSED(params); + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ggml_backend_vk_init(ctx->device); +} + +static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: + return ggml_is_contiguous(op->src[0]) && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (op->src[0]->type == op->type); + default: + return false; + } + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + return ggml_is_contiguous(op->src[0]) && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (op->src[0]->type == op->type); + default: + return false; + } + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + { + ggml_type src0_type = op->src[0]->type; + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + if (op->op == GGML_OP_MUL_MAT_ID) { + if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { + // If there's not enough shared memory for row_ids and the result tile, fallback to CPU + return false; + } + } + switch (src0_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + break; + default: + return false; + } + struct ggml_tensor * a; + struct ggml_tensor * b; + if (op->op == GGML_OP_MUL_MAT) { + a = op->src[0]; + b = op->src[1]; + } else { + a = op->src[2]; + b = op->src[1]; + } + if (a->ne[3] != b->ne[3]) { + return false; + } + if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) || + !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) { + return false; + } + if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) { + // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader. + // So don't support this combination for now. + return false; + } + + return true; + } + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + bool coopmat2 = device->coopmat2; + uint32_t HSK = op->src[1]->ne[0]; + uint32_t HSV = op->src[2]->ne[0]; + if ((HSK % 8) != 0 || (HSV % 8) != 0) { + return false; + } + if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) { + return false; + } + if (op->src[0]->type != GGML_TYPE_F32) { + return false; + } + if (op->type != GGML_TYPE_F32) { + return false; + } + if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { + return false; + } + // It's straightforward to support different K/V dequant, but would + // significantly increase the number of pipelines + if (op->src[1]->type != op->src[2]->type) { + return false; + } + switch (op->src[1]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + // supported in scalar and coopmat2 paths + break; + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently + //case GGML_TYPE_Q2_K: + //case GGML_TYPE_Q3_K: + //case GGML_TYPE_Q4_K: + //case GGML_TYPE_Q5_K: + //case GGML_TYPE_Q6_K: + //case GGML_TYPE_IQ1_S: + //case GGML_TYPE_IQ1_M: + //case GGML_TYPE_IQ2_XXS: + //case GGML_TYPE_IQ2_XS: + //case GGML_TYPE_IQ2_S: + //case GGML_TYPE_IQ3_XXS: + //case GGML_TYPE_IQ3_S: + //case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + // currently supported only in coopmat2 path + if (!coopmat2) { + return false; + } + break; + default: + return false; + } + if (!coopmat2 && !device->subgroup_shuffle) { + // scalar FA uses subgroupShuffle + return false; + } + return true; + } + case GGML_OP_GET_ROWS: + { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + return true; + default: + return false; + } + } + case GGML_OP_SET_ROWS: + { + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + return false; + } + } + case GGML_OP_CONT: + case GGML_OP_CPY: + case GGML_OP_DUP: + { + ggml_type src0_type = op->src[0]->type; + ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type; + + if (src0_type == GGML_TYPE_F32) { + switch (src1_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + break; + } + } + if (src1_type == GGML_TYPE_F32) { + switch (src0_type) { + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + break; + } + } + + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return true; + } + + if ( + (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) || + (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) + ) { + return true; + } + + // We can handle copying from a type to the same type if it's + // contiguous (memcpy). We use f16 or f32 shaders to do the copy, + // so the type/block size must be a multiple of 4. + if (src0_type == src1_type && + ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) && + (ggml_type_size(src0_type) % 2) == 0) { + return true; + } + return false; + } + case GGML_OP_REPEAT: + return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); + case GGML_OP_REPEAT_BACK: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_RMS_NORM: + return true; + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_L2_NORM: + return ggml_is_contiguous(op->src[0]); + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); + case GGML_OP_ADD_ID: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 && + op->type == GGML_TYPE_F32; + case GGML_OP_SILU_BACK: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_LEAKY_RELU: + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ARGSORT: + return op->ne[0] <= max_argsort_cols; + case GGML_OP_UPSCALE: + case GGML_OP_ACC: + case GGML_OP_CONCAT: + case GGML_OP_SCALE: + case GGML_OP_PAD: + case GGML_OP_ROLL: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + return true; + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_2D_DW: + case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + return true; + case GGML_OP_CONV_TRANSPOSE_1D: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: + { + // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + if (op->op == GGML_OP_CONV_TRANSPOSE_2D && + device->properties.limits.maxPushConstantsSize < sizeof(vk_op_conv_transpose_2d_push_constants)) { + return false; + } + // Channel-contiguous format is not supported yet. + return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + ggml_is_contiguous(op)); + } + default: + return false; + } + + UNUSED(dev); +} + +static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) { + return false; + } + + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context; + + return buft_ctx->device->idx == ctx->device; +} + +static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + const int min_batch_size = 32; + + return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || + (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); + + UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_vk_device_i = { + /* .get_name = */ ggml_backend_vk_device_get_name, + /* .get_description = */ ggml_backend_vk_device_get_description, + /* .get_memory = */ ggml_backend_vk_device_get_memory, + /* .get_type = */ ggml_backend_vk_device_get_type, + /* .get_props = */ ggml_backend_vk_device_get_props, + /* .init_backend = */ ggml_backend_vk_device_init, + /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_vk_device_supports_op, + /* .supports_buft = */ ggml_backend_vk_device_supports_buft, + /* .offload_op = */ ggml_backend_vk_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) { + UNUSED(reg); + return GGML_VK_NAME; +} + +static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) { + UNUSED(reg); + return ggml_backend_vk_get_device_count(); +} + +static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) { + static std::vector devices; + + static bool initialized = false; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + std::vector vk_devices = vk_instance.instance.enumeratePhysicalDevices(); + + for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { + ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; + char desc[256]; + ggml_backend_vk_get_device_description(i, desc, sizeof(desc)); + ctx->device = i; + ctx->name = GGML_VK_NAME + std::to_string(i); + ctx->description = desc; + ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; + ctx->pci_id = ggml_backend_vk_get_device_pci_id(i); + ctx->id = ggml_backend_vk_get_device_id(i); + devices.push_back(new ggml_backend_device { + /* .iface = */ ggml_backend_vk_device_i, + /* .reg = */ reg, + /* .context = */ ctx, + }); + + // Gather additional information about the device + int dev_idx = vk_instance.device_indices[i]; + vk::PhysicalDeviceProperties props1; + vk_devices[dev_idx].getProperties(&props1); + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceIDProperties device_id_props; + vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_props; + vk::PhysicalDeviceDriverProperties driver_props; + props2.pNext = &device_id_props; + device_id_props.pNext = &pci_bus_props; + pci_bus_props.pNext = &driver_props; + vk_devices[dev_idx].getProperties2(&props2); + std::ostringstream oss; + oss << std::hex << std::setfill('0'); + oss << "GPU-"; + int byteIdx = 0; + for (int i = 0; i < 16; ++i, ++byteIdx) { + oss << std::setw(2) << static_cast(device_id_props.deviceUUID[i]); + if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) { + oss << '-'; + } + } + ctx->uuid = oss.str(); + ctx->pci_bus_id = pci_bus_props.pciBus; + ctx->pci_device_id = pci_bus_props.pciDevice; + ctx->pci_domain_id = pci_bus_props.pciDomain; + ctx->id = std::to_string(i); + ctx->major = 0; + ctx->minor = 0; + // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string + ctx->driver_major = 0; + ctx->driver_minor = 0; + } + initialized = true; + } + } + + GGML_ASSERT(device < devices.size()); + return devices[device]; +} + +static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = { + /* .get_name = */ ggml_backend_vk_reg_get_name, + /* .get_device_count = */ ggml_backend_vk_reg_get_device_count, + /* .get_device = */ ggml_backend_vk_reg_get_device, + /* .get_proc_address = */ NULL, +}; + +ggml_backend_reg_t ggml_backend_vk_reg() { + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_vk_reg_i, + /* .context = */ nullptr, + }; + try { + ggml_vk_instance_init(); + return ® + } catch (const vk::SystemError& e) { + VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what()); + return nullptr; + } catch (const std::exception &e) { + VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: " << e.what()); + return nullptr; + } catch (...) { + VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: unknown exception during Vulkan init"); + return nullptr; + } +} + +// Extension availability +static bool ggml_vk_instance_validation_ext_available() { +#ifdef GGML_VULKAN_VALIDATE + // Check if validation layer provides the extension + const std::string layer_name = "VK_LAYER_KHRONOS_validation"; + for (const auto& layer : vk::enumerateInstanceLayerProperties()) { + if (layer_name == layer.layerName.data()) { + for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) { + if (strcmp("VK_EXT_validation_features", ext.extensionName.data()) == 0) { + return true; + } + } + } + } + + std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_validation_features not found." << std::endl; +#endif + return false; +} +static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions) { +#ifdef __APPLE__ + // Check for portability enumeration extension for MoltenVK support + for (const auto& properties : instance_extensions) { + if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { + return true; + } + } + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; +#endif + return false; + + UNUSED(instance_extensions); +} + +// Extension availability +static bool ggml_vk_instance_debug_utils_ext_available( + const std::vector & instance_extensions) { + // Check for portability enumeration extension for MoltenVK support + for (const auto & properties : instance_extensions) { + if (strcmp("VK_EXT_debug_utils", properties.extensionName) == 0) { + return true; + } + } + + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl; + return false; + + UNUSED(instance_extensions); +} + +static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) { + VkPhysicalDeviceFeatures2 device_features2; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + + VkPhysicalDeviceVulkan11Features vk11_features; + vk11_features.pNext = nullptr; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + device_features2.pNext = &vk11_features; + + vkGetPhysicalDeviceFeatures2(vkdev, &device_features2); + + return vk11_features.storageBuffer16BitAccess; +} + +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) { + switch (props.vendorID) { + case VK_VENDOR_ID_INTEL: + // Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost, + // while some older hardware (ex. Arc A770) has performance regressions + return arch == vk_device_architecture::INTEL_XE2; + case VK_VENDOR_ID_AMD: + if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { + // Workaround for AMD proprietary driver reporting support on all GPUs + return arch == vk_device_architecture::AMD_RDNA3; + } + return true; + default: + return true; + } +} + +// checks + +#ifdef GGML_VULKAN_CHECK_RESULTS +static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector& done, int level = 0) { + if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) { + return; + } + for (int j = 0; j < level; j++) { + std::cerr << " "; + } + std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl; + + done.push_back(tensor); + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] != nullptr) { + ggml_vk_print_graph_origin(tensor->src[i], done, level + 1); + } + } +} + +static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) { + if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + i3 = std::max(i3, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { + float val; + if (tensor->type == GGML_TYPE_F32) { + val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_I32) { + val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) { + void * tensor_data = tensor->data; + + const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer); + + if (is_gpu) { + const size_t tensor_size = ggml_nbytes(tensor); + tensor_data = malloc(tensor_size); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_buffer buffer_gpu = buf_ctx->dev_buffer; + ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size); + } + + std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl; + std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl; + if (tensor->src[0] != nullptr) { + std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl; + } + if (tensor->src[1] != nullptr) { + std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl; + } + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + + if (is_gpu) { + free(tensor_data); + } +} + +void * comp_result; +size_t comp_size; +size_t comp_nb[GGML_MAX_DIMS]; +size_t check_counter = 0; +static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { + ggml_tensor * tensor = cgraph->nodes[tensor_idx]; + if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { + return; + } + + bool fused_rms_norm_mul = false; + int rms_norm_idx = -1; + if (ctx->num_additional_fused_ops == 1 && + tensor->op == GGML_OP_RMS_NORM && + cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { + fused_rms_norm_mul = true; + tensor = cgraph->nodes[tensor_idx + 1]; + } + + check_counter++; + if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { + return; + } + + VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")"); + + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + + struct ggml_init_params iparams = { + /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + + struct ggml_context * ggml_ctx = ggml_init(iparams); + + std::array src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + std::array src_size = {0, 0, 0, 0, 0, 0}; + std::array src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"}; + + struct ggml_tensor * tensor_clone = nullptr; + + for (int i = 0; i < 6; i++) { + ggml_tensor * srci = tensor->src[i]; + if (fused_rms_norm_mul) { + rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1; + ggml_tensor *rms_norm = tensor->src[rms_norm_idx]; + switch (i) { + case 0: srci = rms_norm->src[0]; break; + case 1: srci = tensor->src[1 - rms_norm_idx]; break; + default: continue; + } + } + if (srci == nullptr) { + continue; + } + ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci); + size_t srci_size = ggml_nbytes(srci); + + src_clone[i] = srci_clone; + src_size[i] = ggml_nbytes(srci); + src_buffer[i] = malloc(srci_size); + + srci_clone->data = src_buffer[i]; + if (ggml_backend_buffer_is_host(srci->buffer)) { + memcpy(srci_clone->data, srci->data, srci_size); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(srci->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(srci) + srci->view_offs; + if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) { + for (int i3 = 0; i3 < srci->ne[3]; i3++) { + for (int i2 = 0; i2 < srci->ne[2]; i2++) { + const int idx = i3*srci->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]); + } + } + + srci_clone->nb[0] = srci->nb[0]; + srci_clone->nb[1] = srci->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1]; + } + } else { + if (offset + srci_size >= buffer_gpu->size) { + srci_size = buffer_gpu->size - offset; + } + ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); + } + } else { + GGML_ABORT("fatal error"); + } + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(srci, srci_name[i]); + } + } + + if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); + if (src_clone[4]) { + ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]); + } + } else if (tensor->op == GGML_OP_MUL_MAT) { + tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_MUL_MAT_ID) { + tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); + } else if (tensor->op == GGML_OP_SUB) { + tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_MUL) { + if (fused_rms_norm_mul) { + tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params); + tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]); + } else { + tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); + } + } else if (tensor->op == GGML_OP_DIV) { + tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_CONCAT) { + tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_UPSCALE) { + tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]); + } else if (tensor->op == GGML_OP_SCALE) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]); + } else if (tensor->op == GGML_OP_SQR) { + tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_SQRT) { + tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_SIN) { + tensor_clone = ggml_sin(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_COS) { + tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_CLAMP) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); + } else if (tensor->op == GGML_OP_PAD) { + tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3], + tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]); + } else if (tensor->op == GGML_OP_REPEAT) { + tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor); + } else if (tensor->op == GGML_OP_REPEAT_BACK) { + tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor); + } else if (tensor->op == GGML_OP_ADD) { + tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_ACC) { + tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); + } else if (tensor->op == GGML_OP_NORM) { + tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_GROUP_NORM) { + const float * float_params = (const float *)tensor->op_params; + tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]); + } else if (tensor->op == GGML_OP_RMS_NORM) { + tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_RMS_NORM_BACK) { + const float eps = ((float *) tensor->op_params)[0]; + tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps); + } else if (tensor->op == GGML_OP_SILU_BACK) { + tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_L2_NORM) { + const float eps = ((float *) tensor->op_params)[0]; + tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps); + } else if (tensor->op == GGML_OP_SOFT_MAX) { + if (src1 != nullptr) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]); + } else { + tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); + } + } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) { + tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { + tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]); + } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) { + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4]; + const float freq_base = ((float *) tensor->op_params)[5]; + const float freq_scale = ((float *) tensor->op_params)[6]; + const float ext_factor = ((float *) tensor->op_params)[7]; + const float attn_factor = ((float *) tensor->op_params)[8]; + const float beta_fast = ((float *) tensor->op_params)[9]; + const float beta_slow = ((float *) tensor->op_params)[10]; + if (mode & GGML_ROPE_TYPE_MROPE) { + int32_t *sections = ((int32_t *) tensor->op_params) + 11; + if (tensor->op == GGML_OP_ROPE) { + tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else { + tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } + } else { + if (tensor->op == GGML_OP_ROPE) { + tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else { + tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } + } + } else if (tensor->op == GGML_OP_UNARY) { + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_EXP: + tensor_clone = ggml_exp(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_SILU: + tensor_clone = ggml_silu(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_GELU: + tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_GELU_ERF: + tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_GELU_QUICK: + tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_RELU: + tensor_clone = ggml_relu(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_TANH: + tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_SIGMOID: + tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_HARDSIGMOID: + tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_HARDSWISH: + tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]); + break; + default: + std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; + GGML_ABORT("fatal error"); + } + } else if (tensor->op == GGML_OP_GLU) { + if (src_clone[1] == nullptr) { + tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]); + } else { + tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]); + } + ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2)); + ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3)); + } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { + if (src1 == nullptr) { + tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); + tensor_clone->type = tensor->type; + } else { + tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]); + } + } else if (tensor->op == GGML_OP_CONT) { + tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } else if (tensor->op == GGML_OP_RESHAPE) { + tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } else if (tensor->op == GGML_OP_VIEW) { + tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); + } else if (tensor->op == GGML_OP_PERMUTE) { + int32_t * params = (int32_t *)tensor->op_params; + tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]); + } else if (tensor->op == GGML_OP_TRANSPOSE) { + tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_GET_ROWS) { + tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_ARGSORT) { + tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_SUM) { + tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_SUM_ROWS) { + tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_MEAN) { + tensor_clone = ggml_mean(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_ARGMAX) { + tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_COUNT_EQUAL) { + tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_IM2COL) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + + const bool is_2D = tensor->op_params[6] == 1; + tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); + } else if (tensor->op == GGML_OP_IM2COL_3D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t s2 = tensor->op_params[2]; + const int32_t p0 = tensor->op_params[3]; + const int32_t p1 = tensor->op_params[4]; + const int32_t p2 = tensor->op_params[5]; + const int32_t d0 = tensor->op_params[6]; + const int32_t d1 = tensor->op_params[7]; + const int32_t d2 = tensor->op_params[8]; + const int32_t IC = tensor->op_params[9]; + + tensor_clone = ggml_im2col_3d(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type); + } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { + const int32_t dim = tensor->op_params[0]; + const int32_t max_period = tensor->op_params[1]; + tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period); + } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){ + const int32_t s0 = tensor->op_params[0]; + const int32_t p0 = tensor->op_params[1]; + const int32_t d0 = tensor->op_params[2]; + tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0); + } else if (tensor->op == GGML_OP_POOL_2D) { + enum ggml_op_pool op = static_cast(tensor->op_params[0]); + const int32_t k0 = tensor->op_params[1]; + const int32_t k1 = tensor->op_params[2]; + const int32_t s0 = tensor->op_params[3]; + const int32_t s1 = tensor->op_params[4]; + const int32_t p0 = tensor->op_params[5]; + const int32_t p1 = tensor->op_params[6]; + + tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1); + } else if (tensor->op == GGML_OP_CONV_2D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); + } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) { + const int32_t s = tensor->op_params[0]; + tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s); + } else if (tensor->op == GGML_OP_LEAKY_RELU) { + const float * op_params = (const float *)tensor->op_params; + tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); + } else if (tensor->op == GGML_OP_RWKV_WKV6) { + tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4], src_clone[5]); + } else if (tensor->op == GGML_OP_RWKV_WKV7) { + tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], + src_clone[4], src_clone[5], src_clone[6]); + } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { + src_clone[0]->flags = src0->flags; + tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4]); + } else if (tensor->op == GGML_OP_OPT_STEP_SGD) { + src_clone[0]->flags = src0->flags; + tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2]); + } else if (tensor->op == GGML_OP_ADD_ID) { + tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); + } + else { + std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; + GGML_ABORT("fatal error"); + } + + ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph_cpu, tensor_clone); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8); + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(tensor_clone, "tensor_clone"); + } + + comp_size = ggml_nbytes(tensor_clone); + + comp_result = malloc(comp_size); + memcpy(comp_result, tensor_clone->data, comp_size); + memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); + + for (int i = 0; i < 6; i++) { + if (src_buffer[i] != nullptr) { + free(src_buffer[i]); + } + } + + ggml_free(ggml_ctx); + + VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")"); +} + +static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { + ggml_tensor * tensor = cgraph->nodes[tensor_idx]; + if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { + return; + } + if (ctx->num_additional_fused_ops == 1 && + tensor->op == GGML_OP_RMS_NORM && + cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { + tensor = cgraph->nodes[tensor_idx + 1]; + } + + if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { + return; + } + + VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")"); + + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + ggml_tensor * src2 = tensor->src[2]; + ggml_tensor * src3 = tensor->src[3]; + + void * tensor_data = tensor->data; + + if (ggml_backend_buffer_is_vk(tensor->buffer)) { + size_t tensor_size = ggml_nbytes(tensor); + tensor_data = malloc(tensor_size); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs; + if (offset + tensor_size >= buffer_gpu->size) { + tensor_size = buffer_gpu->size - offset; + } + + ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size); + } + + float first_error_result = -1.0f; + float first_error_correct = -1.0f; + std::array first_error = { -1, -1, -1, -1 }; + double avg_err = 0.0; + size_t counter = 0; + + for (int i3 = 0; i3 < tensor->ne[3]; i3++) { + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size; + float correct = 0.0f; + float result = 0.0f; + + if (buffer_size_fit) { + if (tensor->type == GGML_TYPE_F32) { + correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); + result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_BF16) { + correct = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); + result = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_I32) { + correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_I64) { + correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else { + std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl; + } + } else { + std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl; + GGML_ABORT("fatal error"); + } + + if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) { + std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + GGML_ABORT("fatal error"); + } + const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f; + if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) { + first_error[0] = i0; + first_error[1] = i1; + first_error[2] = i2; + first_error[3] = i3; + first_error_result = result; + first_error_correct = correct; + } + + // Special case, value is infinite, avoid NaN result in avg_err + // NaN also appears in results, if both are nan error is 0 + if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) { + avg_err += std::fabs(correct - result) / denom; + } + counter++; + } + } + } + } + + avg_err /= counter; + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + } + + if (avg_err > 0.5 || std::isnan(avg_err)) { + std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + GGML_ABORT("fatal error"); + } else { + std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl; + } + + free(comp_result); + comp_result = nullptr; + comp_size = 0; + + if (ggml_backend_buffer_is_vk(tensor->buffer)) { + free(tensor_data); + } + + VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")"); +} +#endif + +GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt new file mode 100644 index 00000000..e1f613fb --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -0,0 +1,31 @@ +cmake_minimum_required(VERSION 3.19) +project("vulkan-shaders-gen" C CXX) + +find_package (Threads REQUIRED) + +if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat glslc support") +endif() +if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat2 glslc support") +endif() +if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + message(STATUS "Enabling dot glslc support") +endif() +if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + message(STATUS "Enabling bfloat16 glslc support") +endif() +if (GGML_VULKAN_SHADER_DEBUG_INFO) + add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) + message(STATUS "Enabling shader debug info") +endif() + +set(TARGET vulkan-shaders-gen) +add_executable(${TARGET} vulkan-shaders-gen.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_compile_features(${TARGET} PRIVATE cxx_std_17) +target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp new file mode 100644 index 00000000..5084a70e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + const uint offset = p.param3; + const uint src1_i = idx - offset; + const uint oz = src1_i / p.nb02; + const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; + const uint ox = src1_i % p.nb01; + + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); + } else { + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)])); + } +} + diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp new file mode 100644 index 00000000..3bcfe690 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp @@ -0,0 +1,69 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#if ADD_RMS +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + +#include "types.glsl" +#include "generic_binary_head.glsl" + +const uint num_threads = 256; + +layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];}; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +#if ADD_RMS +// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant +shared FLOAT_TYPE sumsh[num_threads]; +#endif + +void main() { + uint idx = get_idx(); + uint orig_idx = idx; + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + FLOAT_TYPE sum_sq = 0; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]); + sum_sq += sum*sum; + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum); + + idx += num_threads; + } + +#if ADD_RMS + if (p.param3 != 0) { + // reduce the sum within each subgroup, then across subgroups + const uint NumSubgroups = num_threads / gl_SubgroupSize; + sum_sq = subgroupAdd(sum_sq); + if (gl_SubgroupInvocationID == 0) { + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) { + if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) { + sum_sq += sumsh[gl_SubgroupID + s]; + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + } + + if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) { + partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq; + } + } +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp new file mode 100644 index 00000000..495249d5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp @@ -0,0 +1,42 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + uint ne0; + uint ne1; + uint s01; + uint s02; + uint s11; + uint s21; +} p; + +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) readonly buffer Z {int32_t data_c[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i1 = gl_WorkGroupID.x; + const uint i2 = gl_WorkGroupID.y; + + const uint i11 = data_c[i1 + i2 * p.s21]; + + const uint s1 = p.ne0; + const uint s2 = p.ne0 * p.ne1; + + const uint d0 = i1 * s1 + i2 * s2; + const uint a0 = i1 * p.s01 + i2 * p.s02; + const uint b0 = i11 * p.s11; + + for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) { + data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp new file mode 100644 index 00000000..7c128776 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp @@ -0,0 +1,60 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +#define FLT_MAX 3.402823466e+38F + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmpmax[BLOCK_SIZE]; +shared uint tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + if (row >= p.KY) { + return; + } + + A_TYPE amax = -FLT_MAX; + uint acol = col; + + if (col < p.KX) { + amax = data_a[row*p.KX + col]; + } + + for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) { + A_TYPE val = data_a[row*p.KX + i]; + if (val > amax) { + amax = val; + acol = i; + } + } + + tmp[col] = acol; + tmpmax[col] = amax; + + barrier(); + [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { + if (col < s && col + s < p.KX) { + if (tmpmax[col] < tmpmax[col + s]) { + tmpmax[col] = tmpmax[col + s]; + tmp[col] = tmp[col + s]; + } + } + barrier(); + } + + if (col == 0) { + data_d[row] = D_TYPE(tmp[0]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp new file mode 100644 index 00000000..c81b8445 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -0,0 +1,79 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10; +#define ASC 0 + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) buffer D {int data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint order; +} p; + +shared int dst_row[BLOCK_SIZE]; +shared A_TYPE a_sh[BLOCK_SIZE]; + +void swap(uint idx0, uint idx1) { + int tmp = dst_row[idx0]; + dst_row[idx0] = dst_row[idx1]; + dst_row[idx1] = tmp; +} + +void argsort(bool needs_bounds_check) { + // bitonic sort + const int col = int(gl_LocalInvocationID.x); + const uint row = gl_WorkGroupID.y; + + const uint row_offset = row * p.ncols; + + // initialize indices + dst_row[col] = col; + a_sh[col] = data_a[row_offset + col]; + barrier(); + + uint num_outer_loop_iters = BLOCK_SIZE_LOG2; + [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) { + uint num_inner_loop_iters = outer_idx + 1; + [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { + const int ixj = int(col ^ j); + + int idx_0 = (col & k) == 0 ? col : ixj; + int idx_1 = (col & k) == 0 ? ixj : col; + + int sh_idx_0 = dst_row[idx_0]; + int sh_idx_1 = dst_row[idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) { + swap(idx_0, idx_1); + } + + barrier(); + } + } + + if (col < p.ncols) { + if (p.order == ASC) { + data_d[row_offset + col] = dst_row[col]; + } else { + data_d[row_offset + p.ncols - col - 1] = dst_row[col]; + } + } +} + +void main() { + if (p.ncols == BLOCK_SIZE) { + argsort(false); + } else { + argsort(true); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp new file mode 100644 index 00000000..65343189 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp new file mode 100644 index 00000000..e4046983 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + const int dim = p.param3; + + if (idx >= p.ne) { + return; + } + + const uint i3 = idx / (p.ne22*p.ne21*p.ne20); + const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20; + const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20); + const uint i2_offset = i2*p.ne21*p.ne20; + const uint i1 = (idx - i3_offset - i2_offset) / p.ne20; + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20; + + uint o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03)); + + const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; + const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10; + const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20; + + const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; + +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]); +#else + if (is_src0) { + data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx]; + } else { + data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx]; + } +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp new file mode 100644 index 00000000..ca1a3ac2 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +#extension GL_EXT_control_flow_attributes : require + +const uint num_threads = 128; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 4; + + // fast path for when all four iterations are in-bounds + if (idx + (num_iter-1)*num_threads < p.ne) { + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + idx]); + data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) + data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); +#else + data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; +#endif + idx += num_threads; + } + } else { + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + idx]); + data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) + data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); +#else + data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; +#endif + idx += num_threads; + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp new file mode 100644 index 00000000..70a30148 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp @@ -0,0 +1,105 @@ +#version 450 + +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + uint ne; + uint batches; + uint channels; + uint dst_w; + uint dst_h; + uint src_w; + uint src_h; + uint knl_w; + uint knl_h; + int stride_x; + int stride_y; + int pad_x; + int pad_y; + int dilation_x; + int dilation_y; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; +layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];}; + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE conv_2d_dw_whcn(uint idx) { + uint i0 = idx / p.dst_w; + uint dst_x = idx - i0 * p.dst_w; + uint i1 = i0 / p.dst_h; + uint dst_y = i0 - i1 * p.dst_h; + uint n = i1 / p.channels; + uint c = i1 - n * p.channels; + + uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w; + uint knl_i = c * p.knl_h * p.knl_w; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]); + sum = fma(v, k, sum); + } + } + return sum; +} + +FLOAT_TYPE conv_2d_dw_cwhn(uint idx) { + uint i0 = idx / p.channels; + uint c = idx - i0 * p.channels; + uint i1 = i0 / p.dst_w; + uint dst_x = i0 - i1 * p.dst_w; + uint n = i1 / p.dst_h; + uint dst_y = i1 - n * p.dst_h; + + uint src_i = n * p.channels * p.src_h * p.src_w; + uint src_row = p.src_w * p.channels; + uint knl_row = p.knl_w * p.channels; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]); + sum = fma(v, k, sum); + } + } + return sum; +} + +void main() { + uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + FLOAT_TYPE result = +#ifdef WHCN + conv_2d_dw_whcn(idx); +#else + conv_2d_dw_cwhn(idx); +#endif + dst_data[idx] = D_TYPE(result); +} + diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp new file mode 100644 index 00000000..0367e80b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -0,0 +1,349 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#ifdef COOPMAT2 +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + +#ifdef USE_COLLECTIVES +# extension GL_KHR_shader_subgroup_shuffle : enable +#endif + +#include "types.glsl" + +// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j +layout(binding = 0) readonly buffer A { + A_TYPE knl_data[]; +}; // src0 - kernel: [KW, KH, Cin, Cout] for conv_2d, [KW, KH, Cout, Cin] for conv_transposed_2d + +layout(binding = 1) readonly buffer B { + B_TYPE src_data[]; +}; // src1 - input: [W, H, Cin, N] -- channel_first format + +layout(binding = 2) writeonly buffer D { + D_TYPE dst_data[]; +}; // dst - result: [OW, OH, Cout, N] + +layout(push_constant) uniform parameter { + // I/O channels, batch size + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + // Tensor spatial sizes: kernel, input, output + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + // Parameters: stride, padding, dilation - 0=y, 1=x + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + // Strides in elements + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + // fastdiv helper values + uint32_t KWmp; uint32_t KWL; + uint32_t KWKHmp; uint32_t KWKHL; + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; +#ifdef TRANSPOSE + uint32_t s0mp; uint32_t s0L; + uint32_t s1mp; uint32_t s1L; +#endif +} + +p; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +// Blocktile sizes +layout(constant_id = 1) const uint BS_K = 128; +layout(constant_id = 2) const uint BS_CRS = 16; +layout(constant_id = 3) const uint BS_NPQ = 128; +// Thread-tile sizes +layout(constant_id = 4) const uint TS_K = 8; +layout(constant_id = 5) const uint use_collectives = 1; +layout(constant_id = 6) const uint SHMEM_PAD = 4; + +uint32_t tid = gl_LocalInvocationID.x; +const uint32_t WG_SIZE = gl_WorkGroupSize.x; + +uint splitWork(uint work_size, uint block_size) { + return (block_size + work_size - 1) / block_size; +} + +uint32_t K = p.Cout; +uint32_t CRS = p.Cin * p.KH * p.KW; +uint32_t NPQ = p.N * p.OH * p.OW; + +uint32_t n_elems_out = K * NPQ; + +// Number of blocktiles per input +uint32_t NB_CRS = splitWork(CRS, BS_CRS); + +#ifdef COOPMAT2 +#define SHMEM_TYPE float16_t +#else +#define SHMEM_TYPE float +#endif + +const uint32_t Ash_stride = BS_CRS + SHMEM_PAD; +const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD; + +const uint32_t Ash_numel = BS_K * BS_CRS; +const uint32_t Bsh_numel = BS_CRS * BS_NPQ; + +const uint32_t Ash_len = BS_K * Ash_stride; +const uint32_t Bsh_len = BS_CRS * Bsh_stride; + +shared SHMEM_TYPE Ash[Ash_len]; // K x CRS +shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ + +// Threadtile sizes +const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; + +// Number of threadtiles per blocktile +const uint32_t NT_K = BS_K / TS_K; +const uint32_t NT_NPQ = BS_NPQ / TS_NPQ; + +/* +Compute +KxCRS @ CRSxNPQ = K x NPQ +K=Cout +C=Cin +R,S=KH,KW +P,Q=OH,OW +*/ + +uint32_t B_idx_K = gl_WorkGroupID.x; +uint32_t B_idx_NPQ = gl_WorkGroupID.y; + +uint32_t T_y = tid / NT_NPQ; +uint32_t T_x = tid % NT_NPQ; + +uint32_t Ar = tid / BS_CRS; +uint32_t Ac = tid % BS_CRS; +const uint32_t ArpWg = WG_SIZE / BS_CRS; + +uint32_t Br = tid / BS_NPQ; +uint32_t Bc = tid % BS_NPQ; +const uint32_t BrpWg = WG_SIZE / BS_NPQ; + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + +#ifdef COOPMAT2 +#define ACC_TYPE float16_t + +ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem) +{ + uint32_t K_idx = B_idx_K * BS_K + r; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c; + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (K_idx < K && NPQ_idx < NPQ) { + dst_data[dst_idx] = D_TYPE(elem); + } + return elem; +} +#endif + +void main() { +#ifdef COOPMAT2 + coopmat matC; + matC = coopmat(0.0); +#else + float regC[TS_K][TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = 0.0; + } + } +#endif + /* Advance block in CRS dim */ + for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { + uint32_t CRS_idx_a; + uint32_t Cin_idx_a; + uint32_t KH_idx_a; + uint32_t KW_idx_a; + +#ifdef USE_COLLECTIVES + uint32_t cached_CRS_idx; + uint32_t cached_Cin_idx; + uint32_t cached_KH_idx; + uint32_t cached_KW_idx; + if (use_collectives == 1) { + cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID; + cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH); + cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW; + + CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac); + Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac); + KH_idx_a = subgroupShuffle(cached_KH_idx, Ac); + KW_idx_a = subgroupShuffle(cached_KW_idx, Ac); + } else { + CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) + Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; + KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_a = CRS_remainder - KH_idx_a * p.KW; + } +#else + CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) + Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH); + CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; + KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_a = CRS_remainder - KH_idx_a * p.KW; +#endif + + /* Load kernel to A_block: (BS_K x BS_CRS)*/ + for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { + uint32_t B_ly = r_offset + Ar; + uint32_t B_lx = Ac; + uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ +#ifdef TRANSPOSE + uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1); +#else + uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); +#endif + float val = knl_data[knl_idx]; + if (K_idx >= K || CRS_idx_a >= CRS) { + val = 0.0; + } + Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); + } + /* Load input to B_block: (BS_CRS x BS_NPQ) */ + UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { + uint32_t B_ly = r_offset + Br; /* Row index of B block */ + uint32_t B_lx = Bc; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */ + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW; + + uint32_t CRS_idx_b; + uint32_t Cin_idx_b; + uint32_t KH_idx_b; + uint32_t KW_idx_b; +#ifdef USE_COLLECTIVES + if (use_collectives == 1) { + CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br); + Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br); + KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br); + KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br); + } else { + CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ + Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; + KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_b = CRS_remainder - KH_idx_b * p.KW; + } +#else + CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ + Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; + KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_b = CRS_remainder - KH_idx_b * p.KW; +#endif + +#ifdef TRANSPOSE + uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1; + uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0; + uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L); + uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L); +#else + uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1; + uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0; +#endif + uint32_t src_idx = + min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); + float val = src_data[src_idx]; + if (CRS_idx_b >= CRS || NPQ_idx >= NPQ + || H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case) +#ifdef TRANSPOSE + || (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0) +#endif + ) { + val = 0.0; + } + Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); + } + barrier(); +#ifdef COOPMAT2 + coopmat matA; + coopmat matB; + + coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + matC = coopMatMulAdd(matA, matB, matC); +#else + if (T_y * TS_K < K) { + UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { + float regA[TS_K]; + float regB[TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx]; + } + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx]; + } + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]); + } + } + } + } +#endif + barrier(); + } + /* Save C* */ +#ifdef COOPMAT2 + coopMatPerElementNV(matC, matC, perElemOpStore); +#else + if (T_y * TS_K < K) { + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (K_idx < K && NPQ_idx < NPQ) { + dst_data[dst_idx] = regC[T_ly][T_lx]; + } + } + } + } +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp new file mode 100644 index 00000000..5217e18b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp @@ -0,0 +1,98 @@ +#version 450 + +#include "types.glsl" + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin] +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin] +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout] + +layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in; + +layout (push_constant) uniform parameter { + uint32_t Cout; + uint32_t Cin; + uint32_t K; + uint32_t L; + uint32_t KL; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb11; + uint32_t nb1; + + int32_t s0; +} p; + + +uint32_t Cout_idx = gl_WorkGroupID.x; +const uint32_t bs = gl_WorkGroupSize.x; +uint32_t tid = gl_LocalInvocationID.x; +// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K. +uint32_t tmp_len = bs*p.s0+p.K; +shared D_TYPE tmp[4096]; + +uint splitWork(uint workSize){ + return (bs + workSize -1) / bs; +} + +void main(){ + for(uint32_t i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + if(idx < tmp_len){ + tmp[idx] = 0.0; + } + } + + uint32_t L_blocks = splitWork(p.L); + for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){ + if(L_block_id > 0){ + barrier(); + // Shift values in tmp to the current processing window + for(int i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + if(idx >= bs*p.s0 && idx < tmp_len){ + tmp[idx-bs*p.s0] = tmp[idx]; + tmp[idx] = 0.0; + }else if(idx >= p.K && idx < bs*p.s0){ + tmp[idx] = 0.0; + } + } + } + barrier(); + + // Save contributions of the block to tmp + uint32_t L_idx = L_block_id*bs + tid; + for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){ + D_TYPE dp = 0.0; + for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){ + A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02]; + if(L_idx < p.L){ + B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11]; + dp = fma(elemKrn, elemInp, dp); + } + } + tmp[tid*p.s0 + K_idx] += dp; + barrier(); + } + + // Save the computed values except the last block that can have different size + uint32_t KLb_idx = L_block_id*bs*p.s0; + if(L_block_id < L_blocks-1){ + for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){ + uint32_t sh_idx = p.s0*tid+s0_idx; + uint32_t KL_idx = KLb_idx+sh_idx; + if(KL_idx < p.KL){ + data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx]; + } + } + } + } + + for(uint32_t i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx; + if(KL_idx < p.KL){ + data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx]; + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp new file mode 100644 index 00000000..9f8bfd3c --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -0,0 +1,23 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]); +#else + data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)]; +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp new file mode 100644 index 00000000..06df5095 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" +#include "dequant_funcs.glsl" + +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) +// 16 invocations needed for init_iq_shmem +layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#else +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +#endif + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); + if (gl_LocalInvocationIndex.x != 0) { + return; + } +#endif + + const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint dst_idx = get_doffset() + dst_idx(idx); + uint src_idx = src0_idx_quant(idx, QUANT_K); + + const uint a_offset = 0; + const uint ib = src_idx; + const vec2 dm = get_dm(ib, a_offset); + + [[unroll]] for (int j = 0; j < QUANT_K; j += 4) { + vec4 v = dequantize4(ib, j / QUANT_R, a_offset); + v = v * dm.x + vec4(dm.y); + +#if QUANT_R == 2 + data_d[dst_idx + j/2 + 0] = v[0]; + data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1]; + data_d[dst_idx + j/2 + 1] = v[2]; + data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3]; +#else + data_d[dst_idx + j + 0] = v[0]; + data_d[dst_idx + j + 1] = v[1]; + data_d[dst_idx + j + 2] = v[2]; + data_d[dst_idx + j + 3] = v[3]; +#endif + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp new file mode 100644 index 00000000..b8c40eec --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -0,0 +1,296 @@ +#version 450 + +#include "rte.glsl" +#include "types.glsl" + +#if defined(SET_ROWS) && QUANT_K == 1 +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; +const uint BLOCK_SIZE = 512; +#else +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; +const uint BLOCK_SIZE = 32; +#endif + +layout (binding = 0) readonly buffer S {float data_s[];}; + +#if defined(SET_ROWS) +#include "generic_binary_head.glsl" +layout (binding = 1) readonly buffer C {B_TYPE data_i[];}; +layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];}; + +#if B_SIZE == 64 +#define DATA_I_SWIZZLE .x +#else +#define DATA_I_SWIZZLE +#endif + +#else +#include "generic_unary_head.glsl" +layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];}; +#endif + +#if defined(DATA_A_Q4_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + const float d = vmax / -8; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id; + + const uint xi0 = min(15, int(x0 + 8.5)); + const uint xi1 = min(15, int(x1 + 8.5)); + + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + } +} +#endif + +#if defined(DATA_A_Q4_1) +void quantize(uint dst_idx, uint src_idx) +{ + float vmin = 1.0/0.0; + float vmax = -vmin; + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) { + const float v = data_s[src_idx + j]; + + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + data_q[dst_idx].m = float16_t(vmin); + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) { + const float x0 = (data_s[src_idx + 0 + j] - vmin)*id; + const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id; + + const uint xi0 = min(15, int(x0 + 0.5)); + const uint xi1 = min(15, int(x1 + 0.5)); + + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + } +} +#endif + +#if defined(DATA_A_Q5_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + const float d = vmax / -16; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + uint32_t qh = 0; + [[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id; + + const uint xi0 = min(31, int(x0 + 16.5)); + const uint xi1 = min(31, int(x1 + 16.5)); + + data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4)); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2); + } + data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF); + data_q[dst_idx].qh[1] = uint16_t(qh >> 16); +} +#endif + +#if defined(DATA_A_Q5_1) +void quantize(uint dst_idx, uint src_idx) +{ + float min = data_s[src_idx + 0]; + float max = min; + + [[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) { + const float v = data_s[src_idx + j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = (d != 0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + data_q[dst_idx].m = float16_t(min); + + uint32_t qh = 0; + [[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) { + const float x0 = (data_s[src_idx + 0 + j] - min)*id; + const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id; + + const uint xi0 = uint(x0 + 0.5); + const uint xi1 = uint(x1 + 0.5); + + data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4)); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2); + } + data_q[dst_idx].qh = qh; +} +#endif + +#if defined(DATA_A_Q8_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; // absolute max + + [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) { + const float v = data_s[src_idx + j]; + amax = max(amax, abs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) { + const float x0 = data_s[src_idx + j]*id; + + data_q[dst_idx].qs[j] = int8_t(round(x0)); + } +} +#endif + +#if defined(DATA_A_IQ4_NL) +uint best_index(float x) { + if (x <= kvalues_iq4nl[0]) return 0; + if (x >= kvalues_iq4nl[15]) return 15; + int ml = 0, mu = 15; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav; + } + return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu; +} + +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + float d = vmax / kvalues_iq4nl[0]; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + float sumqx = 0, sumq2 = 0; + [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id; + const uint xi0 = best_index(x0); + const uint xi1 = best_index(x1); + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + const float v0 = kvalues_iq4nl[xi0]; + const float v1 = kvalues_iq4nl[xi1]; + const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j]; + const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j]; + sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + } + + data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d); + +} +#endif + +#if defined(DATA_A_F32) || defined(DATA_A_F16) +void quantize(uint dst_idx, uint src_idx) +{ + data_q[dst_idx] = A_TYPE(data_s[src_idx]); +} +#endif + +#if defined(DATA_A_BF16) +void quantize(uint dst_idx, uint src_idx) +{ + data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx])); +} +#endif + +#if defined(SET_ROWS) + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + uint i12 = fastmod(i03, p.ne12); + uint i11 = fastmod(i02, p.ne11); + uint i10 = i01; + + uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()] DATA_I_SWIZZLE; + + uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset(); + uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset(); + + quantize(dst_idx, src0_idx); +} + +#else + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint dst_idx = dst_idx_quant(idx, QUANT_K); + uint src_idx = get_aoffset() + src0_idx(idx); + + quantize(dst_idx, src_idx); +} + +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp new file mode 100644 index 00000000..db6865db --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp new file mode 100644 index 00000000..e75df667 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp @@ -0,0 +1,31 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "types.glsl" +#include "generic_head.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +const uint CHUNK_SIZE = 512; + +void main() { + const uint base = gl_WorkGroupID.x * CHUNK_SIZE; + const uint col = gl_LocalInvocationID.x; + + uint count = 0; + [[unroll]] + for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) { + const uint idx = base + i + col; + if (idx >= p.KX) { + break; + } + count += uint(data_a[idx] == data_b[idx]); + } + + atomicAdd(data_d[0], D_TYPE(count)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp new file mode 100644 index 00000000..765afffa --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_GlobalInvocationID.x * 16; + + if (i >= p.nel) { + return; + } + + [[unroll]] for (uint l = 0; l < 16; l++) { + data_b[i + l] = D_TYPE(data_a[i + l]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl new file mode 100644 index 00000000..0d98f5a9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -0,0 +1,616 @@ +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#endif + +#include "types.glsl" + +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + +#if defined(DATA_A_F32) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); +} +#endif + +#if defined(DATA_A_F16) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); +} +#endif + +#if defined(DATA_A_BF16) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1])); +} +#endif + +#if defined(DATA_A_Q4_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2(vui & 0xF, vui >> 4) - 8.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f); +} +#endif + +#if defined(DATA_A_Q4_1) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(vui & 0xF, vui >> 4); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12); +} +#endif + +#if defined(DATA_A_Q5_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0]; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0]; + const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f); +} +#endif + +#if defined(DATA_A_Q5_1) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = data_a[a_offset + ib].qh; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = data_a_packed16[a_offset + ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y); +} +#endif + +#if defined(DATA_A_Q8_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy; + return vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#if defined(DATA_A_IQ1_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + const int i8 = int(iqs % 8); + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const float dl = float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint idxhi = bitfieldExtract(qh, 3 * int(ib8 & 3), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]); + // Signed bitfield extract. + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + return dl * (vec2(gvec) + delta); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + const int i8 = int(iqs % 8); + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const float dl = 2 * bitfieldExtract(qh, 12, 3) + 1; + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); + // Signed bitfield extract. + const ivec4 gvec = ivec4( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2), + bitfieldExtract(grid, 2 * (i8 + 2), 2), + bitfieldExtract(grid, 2 * (i8 + 3), 2) + ); + return dl * (vec4(gvec) + delta); +} +#endif + +#if defined(DATA_A_IQ1_M) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib8 = iqs / 8; + const uint ib16 = iqs / 16; + const int i8 = int(iqs % 8); + const uint sc = data_a[a_offset + ib].scales[iqs / 64]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + // Signed bitfield extract. + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + return dl * (vec2(gvec) + delta); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib8 = iqs / 8; + const uint ib16 = iqs / 16; + const int i8 = int(iqs % 8); + const uint sc = data_a[a_offset + ib].scales[iqs / 64]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + // Signed bitfield extract. + const ivec4 gvec = ivec4( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2), + bitfieldExtract(grid, 2 * (i8 + 2), 2), + bitfieldExtract(grid, 2 * (i8 + 3), 2) + ); + return dl * (vec4(gvec) + delta); +} +#endif + +#if defined(DATA_A_IQ2_XXS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = (iqs / 8) % 4; + const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2], + data_a_packed16[a_offset + ib].qs[4 * ib32 + 3])); + const float db = 0.25 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = (iqs / 8) % 4; + const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2], + data_a_packed16[a_offset + ib].qs[4 * ib32 + 3])); + const float db = 0.25 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ2_XS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[iqs / 8]; + const float db = 0.25 * (0.5 + scale); + const uint sign7 = qs >> 9; + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[iqs / 8]; + const float db = 0.25 * (0.5 + scale); + const uint sign7 = qs >> 9; + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ2_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + + const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8); + + const float db = 0.25 * (0.5 + scale); + const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid[iqs % 4] * (sign0 ? -1.0 : 1.0), + grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + + const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8); + + const float db = 0.25 * (0.5 + scale); + const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ3_XXS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint is = QUANT_K / 4 + 4 * ib32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2], + data_a_packed16[a_offset + ib].qs[is / 2 + 1])); + const float db = 0.5 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint is = QUANT_K / 4 + 4 * ib32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2], + data_a_packed16[a_offset + ib].qs[is / 2 + 1])); + const float db = 0.5 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq3xxs_grid[qs]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ3_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint qs = data_a[a_offset + ib].qs[iqs / 4]; + const uint qh = data_a[a_offset + ib].qh[iqs / 32]; + const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8); + const uint scale = data_a[a_offset + ib].scales[iqs / 64]; + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4)); + return db * vec2( + int(grid & 0xFF) * (sign0 ? -1.0 : 1.0), + int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8); + const uint scale = data_a[a_offset + ib].scales[ib32 / 2]; + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4)); + return db * vec4( + int(grid & 0xFF) * (sign0 ? -1.0 : 1.0), + int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0), + int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0), + int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ4_XS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint iq = 16 * ib32 + (iqs % 16); + + const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3; + const uint qshift = (iqs & 16) >> 2; + u8vec2 qs = u8vec2(data_a[a_offset + ib].qs[iq], data_a[a_offset + ib].qs[iq + 1]); + qs = (qs >> qshift) & uint8_t(0xF); + + const float dl = float(int(sl | (sh << 4)) - 32); + return dl * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint iq = 16 * ib32 + (iqs % 16); + + const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3; + const uint qshift = (iqs & 16) >> 2; + u8vec4 qs = u8vec4( + data_a[a_offset + ib].qs[iq + 0], + data_a[a_offset + ib].qs[iq + 1], + data_a[a_offset + ib].qs[iq + 2], + data_a[a_offset + ib].qs[iq + 3] + ); + qs = (qs >> qshift) & uint8_t(0xF); + + const float dl = float(int(sl | (sh << 4)) - 32); + return dl * vec4( + kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], + kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]); +} +#endif + +#if defined(DATA_A_IQ4_NL) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]); +} +#endif + +#if defined(DATA_A_MXFP4) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + vec2 v0 = dequantize(ib, iqs, a_offset); + vec2 v1 = dequantize(ib, iqs + 1, a_offset); + return vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(0, 0); +} +#endif + +#if defined(DATA_A_IQ1_M) +vec2 get_dm(uint ib, uint a_offset) { + const uint16_t[4] scales = data_a[a_offset + ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + return vec2(d, 0); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(float(data_a[a_offset + ib].d), 0); +} +#endif + +#if defined(DATA_A_MXFP4) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); +} +#endif + +#if defined(DATA_A_Q2_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint scalesi = iqs / 8; // 0..15 + const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + + const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]); + const uint scales = data_a[a_offset + ib].scales[scalesi]; + const vec2 d = vec2(data_a[a_offset + ib].d); + + return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q3_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 64; // 0,1 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + const uint hmi = (iqs % 16) * 2; // 0,2,4..30 + const uint j = (iqs % 64) / 4; // 0..3 + const uint is = iqs / 8; // 0..15 + const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + const int8_t us = int8_t(((data_a[a_offset + ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) + | (((data_a[a_offset + ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); + const float dl = float(data_a[a_offset + ib].d) * float(us - 32); + + return vec2(dl * float(int8_t((data_a[a_offset + ib].qs[qsi ] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi ] & m) != 0) ? 0 : 4)), + dl * float(int8_t((data_a[a_offset + ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q4_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + + const vec2 loadd = vec2(data_a[a_offset + ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF), m), + fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q5_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + const uint qhi = (iqs % 16) * 2; // 0,2,4..30 + + const uint8_t hm = uint8_t(1 << (iqs / 16)); + + const vec2 loadd = vec2(data_a[a_offset + ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t(((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi ] & hm) != 0 ? 16 : 0), m), + fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q6_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 64; // 0,1 + const uint b = (iqs % 64) / 32; // 0,1 + const uint is_b = (iqs % 16) / 8; // 0,1 + const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + const uint is = 8 * n + qhshift + is_b; // 0..15 + const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + const float dscale = float(data_a[a_offset + ib].d) * float(data_a[a_offset + ib].scales[is]); + + return vec2(dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32), + dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl new file mode 100644 index 00000000..6a5bb457 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -0,0 +1,720 @@ + +#include "types.glsl" + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { + block_q4_0_packed16 block; +}; + +float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]); + qs >>= shift; + qs &= 0x0F0F; + qs = unpack8(qs)[idx & 1]; + float16_t ret = (float16_t(qs) - float16_t(8)) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { + block_q4_1 block; +}; + +float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(qs) * d + m; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 { + block_q5_0 block; +}; + +float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + + const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0]; + const uint qh = ((uint_qh >> idx) << 4) & 0x10; + + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + + float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 { + block_q5_1 block; +}; + +float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + + const uint uint_qh = bl.block.qh; + const uint qh = ((uint_qh >> idx) << 4) & 0x10; + + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + + float16_t ret = float16_t(qs | qh) * d + m; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { + block_q8_0_packed16 block; +}; + +float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx; + + // Load 16b and select the byte for this element + int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1]; + float16_t ret = float16_t(qs) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K { + block_q2_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 { + block_q2_K_packed16 block; +}; + +float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); + const f16vec2 d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint scalesi = (idx & 0xF0) >> 4; // 0..15 + const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6 + + uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); + qs = (qs >> qsshift) & 0x0303; + qs = unpack8(qs)[idx & 1]; + + const uint scales = bl.block.scales[scalesi]; + float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4); + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K { + block_q3_K block; +}; + +float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + const uint iqs = idx; + + const uint n = iqs / 128; // 0,1 + const uint qsi = n * 32 + (iqs % 32); // 0..63 + const uint hmi = (iqs % 32); // 0..31 + const uint j = (iqs % 128) / 8; // 0..15 + const uint is = iqs / 16; // 0..15 + const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + uint32_t scaleidx0 = (is < 8) ? is : (is-8); + uint32_t scaleidx0shift = (is < 8) ? 0 : 4; + uint32_t scaleidx1 = is + 8 - (is/4)*4; + uint32_t scaleidx1shift = (is/4)*2; + + const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); + + const float16_t dl = bl.block.d * float16_t(us - 32); + + float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4)); + + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { + block_q4_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 { + block_q4_K_packed16 block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 { + block_q4_K_packed128 block; +}; + +#if defined(IS_MUL_MM2) + +// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales +// into shared memory and then process the whole tile using those scales. +// There is a fetch function that loads into private variables and then a store +// function that stores into shared memory. +// Q4_K and Q5_K have the same encoding of scales, so everything is shared except +// the part that fetches from the structure (which has a different block layout). +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +const uint shAscales_stride = (BM + 2); +// 1 scale per 32 elements -> 8 scales per block, per row +shared vec2 shAscales[8 * shAscales_stride]; +uvec4 row_v; +#endif + +#if defined(DATA_A_Q4_K) +layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];}; + +void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds) +{ + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + uint row = ir_BM + tid_row; + uint block_index = pos_a + row * stride_a + (block_k / QUANT_K); + if (in_bounds || row < p.M) { + row_v = data_a_q4_k_packed128[block_index].q4k[0]; + } +} +#endif +#if defined(DATA_A_Q5_K) +layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];}; + +void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds) +{ + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + uint row = ir_BM + tid_row; + uint block_index = pos_a + row * stride_a + (block_k / QUANT_K); + if (in_bounds || row < p.M) { + row_v = data_a_q5_k_packed128[block_index].q5k[0]; + } +} +#endif + +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +void store_scalesQ4_K(uint tid) +{ + barrier(); + + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) { + uint is = idx + is_start; + uvec4 v = row_v; + const vec2 loadd = vec2(unpackFloat2x16(v.x)); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); + shAscales[is * shAscales_stride + tid_row] = vec2(d,m); + } + + barrier(); +} +#endif + +#endif + +float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); + decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x20) >> 5; // 0,1 + const uint is = (idx & 0xE0) >> 5; // 0..7 + +#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else + uvec4 v = bl128.block.q4k[0]; + const vec2 loadd = vec2(unpackFloat2x16(v.x)); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); +#endif + + uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); + qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF; + + float ret = d * float(qs) - m; + + return float16_t(ret); +} + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { + block_q5_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 { + block_q5_K_packed16 block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 { + block_q5_K_packed128 block; +}; + +float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); + decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x20) >> 5; // 0,1 + const uint is = (idx & 0xE0) >> 5; // 0..7 + +#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else + uvec4 v = bl128.block.q5k[0]; + + const f16vec2 loadd = unpackFloat2x16(v.x); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float16_t d = loadd.x * float16_t(sc); + const float16_t m = loadd.y * float16_t(mbyte); +#endif + + uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); + qh = ((qh >> is) & 0x101) << 4; + + uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); + qs = (qs >> (b * 4)) & 0x0F0F; + qs = unpack8(qs | qh)[idx & 1]; + + float ret = d * float(qs) - m; + + return float16_t(ret); +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { + block_q6_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 { + block_q6_K_packed16 block; +}; + +float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x40) >> 6; // 0,1 + const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6 + const uint is = (idx & 0xF0) >> 4; // 0..15 + + const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); + + uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]); + ql = (ql >> (b * 4)) & 0x0F0F; + + uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); + qh = ((qh >> qhshift) & 0x0303) << 4; + + int q = unpack8(ql | qh)[idx & 1]; + + float16_t ret = dscale * float16_t(q - 32); + + return ret; +} + +#if defined(DATA_A_IQ1_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S { + block_iq1_s block; +}; + +float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; + const uint ib8 = (idx & 0xF8) >> 3; + + const uint qh = bl.block.qh[ib32]; + const uint qs = bl.block.qs[ib8]; + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]; + + float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta)); + return ret; +} +#endif + +#if defined(DATA_A_IQ1_M) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M { + block_iq1_m block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 { + block_iq1_m_packed64 block; +}; + +float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl); + const uint idx = coordInBlock[1]; + + uvec2 scales = unpack32(bl64.block.scales); + const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16))); + + const uint ib8 = (idx & 0xF8) >> 3; + const uint ib16 = (idx & 0xF0) >> 4; + const int i8 = int(idx % 8); + const uint sc = bl.block.scales[ib8 / 8]; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | ((qh & 7) << 8)]; + + float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta)); + return ret; +} +#endif + +#if defined(DATA_A_IQ2_XXS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS { + block_iq2_xxs block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 { + block_iq2_xxs_packed16 block; +}; + +float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + const uint ib8 = (idx & 0x18) >> 3; // 0..3 + const uint iqs = 8 * ib32 + ib8; + + const uint qs = bl.block.qs[iqs]; + const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); + + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28)); + uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7); + sign |= bitCount(sign) << 7; + + uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 g = vec2(unpack8(g2)); + + vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return float16_t(ret[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ2_XS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS { + block_iq2_xs block; +}; + +float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint is = (idx & 0xE0) >> 5; // 0..8 + const uint sshift = (idx & 0x10) >> 2; // 0,4 + const uint iqs = (idx & 0xF8) >> 3; // 0..63 + + const uint16_t qs = bl.block.qs[iqs]; + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF)); + + uint sign = uint(qs >> 9); + sign |= bitCount(sign) << 7; + uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 g = vec2(unpack8(g2)); + + vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return float16_t(ret[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ2_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S { + block_iq2_s block; +}; + +float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + const uint ib8 = (idx & 0xF8) >> 3; // 0..31 + const uint qhshift = 2 * (ib8 % 4); + + const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib32]; + const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6); + + const float d = float(bl.block.d); + const float db = d * 0.25 * (0.5 + scale); + const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign)); + uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 v = db * vec2(sign01) * vec2(unpack8(g2)); + return float16_t(v[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ3_XXS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS { + block_iq3_xxs block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 { + block_iq3_xxs_packed16 block; +}; + +float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl); + uint idx = coordInBlock[1]; + + const uint iqs = (idx & 0xFC) >> 2; // 0..63 + const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint signs = pack32(u16vec2( + bl16.block.qs[is/2+0], + bl16.block.qs[is/2+1] + )); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6); + const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign))); + const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + return float16_t(v[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ3_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S { + block_iq3_s block; +}; + +float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + uint idx = coordInBlock[1]; + + const uint iqs = (idx & 0xFC) >> 2; // 0..63 + const uint iqh = (idx & 0xE0) >> 5; + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint qh = bl.block.qh[iqh]; + const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6)); + const uint scale = bl.block.scales[iqs / 16]; + const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign))); + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + return float16_t(v[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ4_XS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS { + block_iq4_xs block; +}; + +float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + + const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3; + const uint qshift = (idx & 16) >> 2; + const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF; + + float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]); + return ret; +} +#endif + +#if defined(DATA_A_IQ4_NL) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { + block_iq4_nl block; +}; + +float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; + return ret; +} +#endif + +#if defined(DATA_A_MXFP4) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 { + block_mxfp4 block; +}; + +float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float d = e8m0_to_fp32(bl.block.e); + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(kvalues_mxfp4[qs] * d); + return ret; +} +#endif + +#if defined(DATA_A_Q4_0) +#define dequantFuncA dequantFuncQ4_0 +#elif defined(DATA_A_Q4_1) +#define dequantFuncA dequantFuncQ4_1 +#elif defined(DATA_A_Q5_0) +#define dequantFuncA dequantFuncQ5_0 +#elif defined(DATA_A_Q5_1) +#define dequantFuncA dequantFuncQ5_1 +#elif defined(DATA_A_Q8_0) +#define dequantFuncA dequantFuncQ8_0 +#elif defined(DATA_A_Q2_K) +#define dequantFuncA dequantFuncQ2_K +#elif defined(DATA_A_Q3_K) +#define dequantFuncA dequantFuncQ3_K +#elif defined(DATA_A_Q4_K) +#define dequantFuncA dequantFuncQ4_K +#define fetch_scales fetch_scalesQ4_K +#define store_scales store_scalesQ4_K +#elif defined(DATA_A_Q5_K) +#define dequantFuncA dequantFuncQ5_K +#define fetch_scales fetch_scalesQ5_K +#define store_scales store_scalesQ4_K +#elif defined(DATA_A_Q6_K) +#define dequantFuncA dequantFuncQ6_K +#elif defined(DATA_A_IQ1_S) +#define dequantFuncA dequantFuncIQ1_S +#elif defined(DATA_A_IQ1_M) +#define dequantFuncA dequantFuncIQ1_M +#elif defined(DATA_A_IQ2_XXS) +#define dequantFuncA dequantFuncIQ2_XXS +#elif defined(DATA_A_IQ2_XS) +#define dequantFuncA dequantFuncIQ2_XS +#elif defined(DATA_A_IQ2_S) +#define dequantFuncA dequantFuncIQ2_S +#elif defined(DATA_A_IQ3_XXS) +#define dequantFuncA dequantFuncIQ3_XXS +#elif defined(DATA_A_IQ3_S) +#define dequantFuncA dequantFuncIQ3_S +#elif defined(DATA_A_IQ4_XS) +#define dequantFuncA dequantFuncIQ4_XS +#elif defined(DATA_A_IQ4_NL) +#define dequantFuncA dequantFuncIQ4_NL +#elif defined(DATA_A_MXFP4) +#define dequantFuncA dequantFuncMXFP4 +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl new file mode 100644 index 00000000..addceafa --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl @@ -0,0 +1,13 @@ +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint M; + uint K; + uint stride_a; + uint stride_b; + uint nel; +} p; + +#include "types.glsl" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp new file mode 100644 index 00000000..637c95fa --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp @@ -0,0 +1,42 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq1_m data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint ib64 = ib32 / 2; + const uint b_idx = 256 * ib + 32 * ib32; + + const uint16_t[4] scales = data_a[ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint sc = data_a[ib].scales[ib64]; + [[unroll]] for (int l = 0; l < 4; ++l) { + const uint ib16 = 2 * ib32 + l / 2; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const uint qh = data_a[ib].qh[ib16] >> (4 * (l & 1)); + const uint qs = data_a[ib].qs[4 * ib32 + l]; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + [[unroll]] for (int j = 0; j < 8; ++j) { + data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta)); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp new file mode 100644 index 00000000..d1cbc5e9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq1_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + uint qh = data_a[ib].qh[ib32]; + const float d = float(data_a[ib].d); + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qs = data_a[ib].qs[4 * ib32 + l]; + const uint hi = bitfieldExtract(qh, 3 * int(l), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (hi << 8)]); + [[unroll]] for (int j = 0; j < 8; ++j) { + data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta)); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp new file mode 100644 index 00000000..78490162 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp @@ -0,0 +1,44 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + const float d = float(data_a[ib].d); + const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4); + const vec2 db = d * (0.5 + scale) * 0.25; + + uint qh = data_a[ib].qh[ib32]; + [[unroll]] for (uint l = 0; l < 4; ++l) { + uint qs = data_a[ib].qs[4 * ib32 + l]; + const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l]; + qs |= (qh << (8 - 2 * l)) & 0x300; + const uvec2 grid = iq2s_grid[qs]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp new file mode 100644 index 00000000..9b8ce0a7 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp @@ -0,0 +1,43 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_xs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + const float d = float(data_a[ib].d); + const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4); + const vec2 db = d * (0.5 + scale) * 0.25; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + uint16_t qs = data_a[ib].qs[4 * ib32 + l]; + const uint sign7 = qs >> 9; + const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit + const uvec2 grid = iq2xs_grid[qs & 511]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp new file mode 100644 index 00000000..aacf07d0 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_xxs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale block (32 values) + // Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + + const float d = float(data_a[ib].d); + uint signscale = pack32(u8vec4( + data_a[ib].qs[8*is + 4], + data_a[ib].qs[8*is + 5], + data_a[ib].qs[8*is + 6], + data_a[ib].qs[8*is + 7] + )); + const float db = d * (0.5 + (signscale >> 28)) * 0.25; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); + const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit + const uint qs = data_a[ib].qs[8 * is + l]; + const uvec2 grid = iq2xxs_grid[qs]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp new file mode 100644 index 00000000..f2c20b1d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp @@ -0,0 +1,40 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq3_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale nibble. + // Each block contains 4 scale bytes (8 scales) for 256 output values. + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + + const float d = float(data_a[ib].d); + const float db = d * (1 + 2 * ((data_a[ib].scales[is / 2] >> (4 * (is % 2))) & 0xf)); + + // We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes. + uint qh = data_a[ib].qh[is]; + [[unroll]] for (uint l = 0; l < 8; ++l) { + const uint iqs = 8 * is + l; + const uint qs = data_a[ib].qs[iqs]; + const uint gidx = qs | ((qh << (8 - l)) & 256); + const uint8_t signs = data_a[ib].signs[iqs / 2] >> (4 * (l & 1)); + const u8vec4 grid = unpack8(iq3s_grid[gidx]); + data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp new file mode 100644 index 00000000..671c1f4a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale block (32 values) + // 8 threads handle 1 superblock + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + const uint s_idx = QUANT_K / 4 + 4 * is; + + const float d = float(data_a[ib].d); + uint signscale = pack32(u8vec4( + data_a[ib].qs[s_idx + 0], + data_a[ib].qs[s_idx + 1], + data_a[ib].qs[s_idx + 2], + data_a[ib].qs[s_idx + 3] + )); + const float db = d * (0.5 + (signscale >> 28)) * 0.5; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); + // Restore parity bit. + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint qs0 = data_a[ib].qs[8 * is + 2 * l]; + const uint qs1 = data_a[ib].qs[8 * is + 2 * l + 1]; + const u8vec4 grid0 = unpack8(iq3xxs_grid[qs0]); + const u8vec4 grid1 = unpack8(iq3xxs_grid[qs1]); + data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp new file mode 100644 index 00000000..8f7833ea --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq4_nl data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq_shmem(gl_WorkGroupSize); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint q_idx = 8*il; + const uint b_idx = 1024*i + 32*ir + q_idx; + + const float d = float(data_a[ib].d); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]); + data_b[b_idx + l + 16] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp new file mode 100644 index 00000000..a3136997 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq4_xs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (1 scale and 32 quantized values) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + + const float d = float(data_a[ib].d); + // Scales are 6 bits + const uint scale = ((data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF) + | (((data_a[ib].scales_h >> (2 * ib32)) & 3) << 4); + const float dl = d * (int(scale) - 32); + + const uint b_idx = 256 * ib + 32 * ib32; + const uint q_idx = 16 * ib32; + [[unroll]] for (uint l = 0; l < 16; ++l) { + data_b[b_idx + l + 0] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]); + data_b[b_idx + l + 16] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp new file mode 100644 index 00000000..ffba5a77 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq_shmem(gl_WorkGroupSize); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint q_idx = 8*il; + const uint b_idx = 1024*i + 32*ir + q_idx; + + const float d = e8m0_to_fp32(data_a[ib].e); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]); + data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp new file mode 100644 index 00000000..58dc2e5d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = gl_WorkGroupID.x * 256 + wgy; + if (i >= p.nel / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint ip = tid / 32; + const uint il = tid - 32 * ip; + const uint is = 8 * ip + il / 16; + + const uint y_idx = i * QUANT_K + 128 * ip + il; + + const uint ql_idx = 32 * ip + il; + const uint8_t qs = data_a[i].qs[32 * ip + il]; + + FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); + FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); + data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4)); + data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4)); + data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4)); + data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp new file mode 100644 index 00000000..0c90be8b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = uint(gl_WorkGroupID.x * 256 + wgy); + if (i >= p.nel / QUANT_K) { + return; + } + + const uint r = gl_LocalInvocationID.x / 4; + const uint tid = r / 2; + const uint is0 = r % 2; + const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4); + const uint n = tid / 4; + const uint j = tid - 4*n; + + const uint8_t m = uint8_t(1 << (4*n + j)); + const uint is = 8*n + 2*j + is0; + const uint shift = 2*j; + + const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) : + (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4)); + const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d); + const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32); + + const uint y_idx = i * QUANT_K + 128 * n + 32 * j; + const uint qs_idx = 32*n; + + for (uint l = l0; l < l0 + 4; ++l) { + data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4))); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp new file mode 100644 index 00000000..b92b2921 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp @@ -0,0 +1,30 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q4_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint q_idx = 8*il; + const uint b_idx = 1024*i + 32*ir + q_idx; + + const float d = float(data_a[ib].d); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] & 0xF) - 8.0f)); + data_b[b_idx + l + 16] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] >> 4) - 8.0f)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp new file mode 100644 index 00000000..6b63cbe5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q4_1 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m); + data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp new file mode 100644 index 00000000..8b7be557 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -0,0 +1,68 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint ib = gl_WorkGroupID.x * 256 + wgy; + if (ib >= p.nel / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint il = tid / 8; + const uint ir = tid % 8; + const uint is = 2 * il; + const uint n = 4; + + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + + const uint y_idx = ib * QUANT_K + 64 * il + n * ir; + const uint qs_idx = 32*il + n * ir; + + uint scidx0 = (is < 4) ? is : (is + 4); + uint scidx1 = (is < 4) ? is : (is - 4); + uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint scidxshift1 = (is < 4) ? 0 : 2; + uint mbidx0 = is + 4; + uint mbidx1 = (is < 4) ? is + 4 : is; + uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint mbidxshift0 = (is < 4) ? 0 : 4; + uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint mbidxshift1 = (is < 4) ? 0 : 2; + + uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d1 = dall * sc; + const FLOAT_TYPE m1 = dmin * mbyte; + + scidx0 = (is < 4) ? is + 1 : (is + 5); + scidx1 = (is < 4) ? is + 1 : (is - 3); + scidxmask1 = (is < 4) ? 0x30 : 0xC0; + scidxshift1 = (is < 4) ? 0 : 2; + mbidx0 = is + 5; + mbidx1 = (is < 4) ? is + 5 : is + 1; + mbidxmask0 = (is < 4) ? 0xF : 0xF0; + mbidxshift0 = (is < 4) ? 0 : 4; + mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d2 = dall * sc; + const FLOAT_TYPE m2 = dmin * mbyte; + + [[unroll]] for (uint l = 0; l < n; ++l) { + data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] & 0xF) - m1); + data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] >> 4) - m2); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp new file mode 100644 index 00000000..f1b0bac8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q5_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + const uint iqs = q_idx + l; + const uint vui = uint(data_a[ib].qs[iqs]); + data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f)); + data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp new file mode 100644 index 00000000..c495b31f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q5_1 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + const uint qh = data_a[ib].qh; + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + const uint iqs = q_idx + l; + const uint vui = uint(data_a[ib].qs[iqs]); + data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m); + data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp new file mode 100644 index 00000000..6bc04670 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -0,0 +1,70 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint ib = gl_WorkGroupID.x * 256 + wgy; + if (ib >= p.nel / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint il = tid / 16; + const uint ir = tid % 16; + const uint is = 2 * il; + + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + + const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir; + const uint qs_idx = 32*il + 2 * ir; + const uint qh_idx = 2 * ir; + + uint scidx0 = (is < 4) ? is : (is + 4); + uint scidx1 = (is < 4) ? is : (is - 4); + uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint scidxshift1 = (is < 4) ? 0 : 2; + uint mbidx0 = is + 4; + uint mbidx1 = (is < 4) ? is + 4 : is; + uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint mbidxshift0 = (is < 4) ? 0 : 4; + uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint mbidxshift1 = (is < 4) ? 0 : 2; + + uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d1 = dall * sc; + const FLOAT_TYPE m1 = dmin * mbyte; + + scidx0 = (is < 4) ? is + 1 : (is + 5); + scidx1 = (is < 4) ? is + 1 : (is - 3); + scidxmask1 = (is < 4) ? 0x30 : 0xC0; + scidxshift1 = (is < 4) ? 0 : 2; + mbidx0 = is + 5; + mbidx1 = (is < 4) ? is + 5 : is + 1; + mbidxmask0 = (is < 4) ? 0xF : 0xF0; + mbidxshift0 = (is < 4) ? 0 : 4; + mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d2 = dall * sc; + const FLOAT_TYPE m2 = dmin * mbyte; + + const uint8_t hm1 = uint8_t(1 << (2 * il )); + const uint8_t hm2 = uint8_t(1 << (2 * il + 1)); + data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] & 0xF) + (((data_a[ib].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1); + data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] & 0xF) + (((data_a[ib].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1); + data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] >> 4) + (((data_a[ib].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2); + data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] >> 4) + (((data_a[ib].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp new file mode 100644 index 00000000..c8d6fcb4 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp @@ -0,0 +1,33 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = gl_WorkGroupID.x * 256 + wgy; + if (i >= p.nel / QUANT_K) { + return; + } + const uint tid = gl_LocalInvocationID.x; + const uint ip = tid / 32; + const uint il = tid - 32 * ip; + const uint is = 8 * ip + il / 16; + + const uint y_idx = i * QUANT_K + 128 * ip + il; + + const uint ql_idx = 64 * ip + il; + const uint8_t qh = data_a[i].qh[32 * ip + il]; + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d); + + data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); + data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); + data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); + data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp new file mode 100644 index 00000000..10844ddf --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp @@ -0,0 +1,31 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q8_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 16*il; + + const float d = float(data_a[ib].d); + + const uint q_idx = 16*il; + + [[unroll]] for (uint l = 0; l < 16; l += 2) { + data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]); + data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp new file mode 100644 index 00000000..9cef8a8e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp @@ -0,0 +1,34 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint ncols; + uint rows_per_channel; + uint n_past; +} p; + +#include "types.glsl" + +layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint col = gl_GlobalInvocationID.y; + const uint row = gl_GlobalInvocationID.x; + + if (col >= p.ncols) { + return; + } + + const uint i = row*p.ncols + col; + if (col > p.n_past + row % p.rows_per_channel) { + data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000)); + } else { + data_d[i] = D_TYPE(data_a[i]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp new file mode 100644 index 00000000..572472f8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp new file mode 100644 index 00000000..b69d4ddb --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp @@ -0,0 +1,21 @@ +#version 450 + +#include "rte.glsl" +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + data_d[i] = D_TYPE(exp(float(data_a[i]))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp new file mode 100644 index 00000000..62acbf10 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -0,0 +1,383 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.glsl" +#include "flash_attn_base.glsl" + +const uint32_t HSK_per_thread = HSK / D_split; +const uint32_t HSV_per_thread = HSV / D_split; + +const uint32_t cols_per_iter = WorkGroupSize / D_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * HSV + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +shared FLOAT_TYPE tmpsh[WorkGroupSize]; +shared vec4 tmpshv4[WorkGroupSize]; + +shared float masksh[Bc][Br]; +shared vec4 Qf[Br][HSK / 4]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + init_indices(); + + const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = gl_LocalInvocationIndex / D_split; + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t r = (idx + tid) / (HSK / 4); + if (r < Br && d < HSK / 4 && + i * Br + r < N) { + Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; + } + } + barrier(); + + vec4 Of[Br][HSV_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = vec4(0.0); + } + } + + float Lf[Br], Mf[Br]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + float slope[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = 1.0; + } + + // ALiBi + if (p.max_bias > 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + } + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + float Sf[Br][cols_per_thread]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = 0.0; + } + } + + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); +#else + vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf); + } + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + // Compute sum across the D_split + [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += subgroupShuffleXor(Sf[r][c], s); + } + } + } + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); + } + } + } + + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br) { + if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { + masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + } else { + masksh[c][r] = float(0); + } + } + } + barrier(); + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float mvf = masksh[c * cols_per_iter + col_tid][r]; + + Sf[r][c] += slope[r]*mvf; + } + } + barrier(); + } + + float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + rowmaxf[r] = NEG_FLT_MAX_OVER_2; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); + } + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Pf[r][c] = exp(Sf[r][c] - Mf[r]); + } + eMf[r] = exp(Moldf[r] - Mf[r]); + + // Compute sum across row of P + rowsumf[r] = 0.0; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + rowsumf[r] += Pf[r][c]; + } + + Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = eMf[r] * Of[r][d]; + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] += Pf[r][c] * Vf; + } + } + } + + barrier(); + } + + // reduce across threads + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float rowmaxf, eMf; + + tmpsh[tid] = Mf[r]; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]); + } + barrier(); + } + rowmaxf = tmpsh[d_tid]; + barrier(); + + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf = exp(Moldf - Mf[r]); + + Lf[r] = eMf*Lf[r]; + + tmpsh[tid] = Lf[r]; + + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s]; + } + barrier(); + } + Lf[r] = tmpsh[d_tid]; + barrier(); + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + + Of[r][d] = eMf * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + Of[r][d] += tmpshv4[tid + s]; + tmpshv4[tid] = Of[r][d]; + } + barrier(); + } + Of[r][d] = tmpshv4[d_tid]; + barrier(); + } + } + + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + + float ms = 1.0f; + float vs = 1.0f; + + if (sink > Mf[r]) { + ms = exp(Mf[r] - sink); + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + Of[r][d] *= ms; + } + } else { + vs = exp(sink - Mf[r]); + } + + Lf[r] = Lf[r]*ms + vs; + } + } + + float Lfrcp[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] *= Lfrcp[r]; +#if defined(ACC_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX)); +#endif + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (i * Br + r < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl new file mode 100644 index 00000000..9b1f153b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -0,0 +1,202 @@ + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 0) const uint32_t WorkGroupSize = 128; +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t HSK = 32; +layout (constant_id = 4) const uint32_t HSV = 32; +layout (constant_id = 5) const uint32_t Clamp = 0; +layout (constant_id = 6) const uint32_t D_split = 16; + +// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths +const uint32_t HSK_pad = (HSK + 15) & ~15; +const uint32_t HSV_pad = (HSV + 15) & ~15; + +const bool KV_bounds_check = Clamp != 0; + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + uint32_t nem2; + uint32_t nem3; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask_n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +} p; + +#define SINK_ENABLE_BIT (1<<24) +#define MASK_ENABLE_BIT (1<<16) +#define N_LOG2_MASK 0xFFFF + +layout (binding = 4) readonly buffer S {float data_s[];}; + +layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; + +#if defined(A_TYPE_PACKED16) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; +layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; +#endif + +#if defined(DATA_A_Q4_0) +#define BLOCK_BYTE_SIZE 18 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + } else { + uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + } +} +#endif + +#if defined(DATA_A_Q8_0) +#define BLOCK_BYTE_SIZE 34 +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + } else { + const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + } +} +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK; + + const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + +// Load the sink value, indexed by Q's dimension 2. +ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + return ACC_TYPE(data_s[h]); +} + +uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, + iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, + q_stride, k_stride, v_stride, m_stride; + +void init_indices() +{ + N = p.N; + KV = p.KV; + + i = gl_WorkGroupID.x; + split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + + Tr = CEIL_DIV(N, Br); + + start_j = split_k_index * p.split_kv / Bc; + end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); + + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + iq2 = gl_WorkGroupID.y * p.gqa_ratio; + iq3 = gl_WorkGroupID.z; + + // broadcast factors + rk2 = p.neq2/p.nek2; + rk3 = p.neq3/p.nek3; + + rv2 = p.neq2/p.nev2; + rv3 = p.neq3/p.nev3; + + // k indices + ik3 = iq3 / rk3; + ik2 = iq2 / rk2; + + // v indices + iv3 = iq3 / rv3; + iv2 = iq2 / rv2; + + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; + k_stride = p.nb11; + v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp new file mode 100644 index 00000000..2066a05b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -0,0 +1,418 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable + +#include "types.glsl" +#include "flash_attn_base.glsl" + +const uint32_t HSK_per_thread = HSK / D_split; +const uint32_t HSV_per_thread = HSV / D_split; + +const uint32_t row_split = 4; +const uint32_t rows_per_thread = Br / row_split; +const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * HSV + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd +const uint32_t MatBr = 16; +const uint32_t MatBc = 16; + +shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; +shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; + +const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 +shared f16vec4 Qf[Br * qstride]; + +// Avoid padding for hsk==256 to make it fit in 48KB shmem. +const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br; +shared ACC_TYPE sfsh[Bc * sfshstride]; + +const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4 +shared f16vec4 ksh[Bc * kshstride]; + +shared float slope[Br]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + init_indices(); + + const uint32_t tid = gl_LocalInvocationIndex; + + const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + +#define tile_row(r) (row_tid * rows_per_thread + (r)) + + // Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK). + if ((HSK % 16) != 0) { + [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) { + if (i + tid < Br * qstride) { + Qf[i + tid] = f16vec4(0); + } + } + [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) { + if (i + tid < Bc * kshstride) { + ksh[i + tid] = f16vec4(0); + } + } + barrier(); + } + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t r = (idx + tid) / (HSK / 4); + if (r < Br && d < HSK / 4 && + i * Br + r < N) { + Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + } + } + barrier(); + + ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = ACC_TYPEV4(0.0); + } + } + + float Lf[rows_per_thread], Mf[rows_per_thread]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + // ALiBi + if (p.max_bias > 0.0f) { + if (tid < Br) { + uint r = tid; + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + barrier(); + } else { + if (tid < Br) { + uint r = tid; + slope[r] = 1.0; + } + barrier(); + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + } + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t c = (idx + tid) / (HSK / 4); + if (c < Bc && d < HSK / 4) { + f16vec4 K_Tf = f16vec4(0); + if (!KV_bounds_check || j * Bc + c < KV) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); +#else + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); +#endif + } + + ksh[c * kshstride + d] = K_Tf; + } + } + barrier(); + + // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br + // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 + // This is written transposed in order to allow for N being 8 if implementations need it + coopmat SfMat = coopmat(0); + coopmat KMat; + coopmat QMat; + + for (uint32_t d = 0; d < HSK_pad / 16; ++d) { + coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); + + uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; + coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + + SfMat = coopMatMulAdd(KMat, QMat, SfMat); + } + + uint coord = gl_SubgroupID * MatBc * sfshstride; + coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor); + barrier(); + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / Br; + uint32_t r = (idx + tid) % Br; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); + } + } + barrier(); + } + + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); + } + } + } + barrier(); + } + + float eMf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float rowmaxf = NEG_FLT_MAX_OVER_2; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); + } + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf[r] = exp(Moldf - Mf[r]); + } + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; + } + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + float Pf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); + Lf[r] += Pf[r]; + } + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf); + } + } + } + + barrier(); + } + + // reduce across threads + + float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE M = Mf[r]; + tmpsh[tid] = M; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + M = max(M, tmpsh[tid ^ s]); + barrier(); + tmpsh[tid] = M; + barrier(); + } + rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + eMf[r] = exp(Moldf[r] - Mf[r]); + + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE L = Lf[r]; + tmpsh[tid] = L; + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + L += tmpsh[tid ^ s]; + barrier(); + tmpsh[tid] = L; + barrier(); + } + Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + + Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + Of[r][d] += tmpshv4[tid ^ s]; + barrier(); + tmpshv4[tid] = Of[r][d]; + barrier(); + } + Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + } + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); + + float ms = 1.0f; + float vs = 1.0f; + + if (sink > Mf[r]) { + ms = exp(Mf[r] - sink); + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + Of[r][d] *= ACC_TYPE(ms); + } + } else { + vs = exp(sink - Mf[r]); + } + + Lf[r] = Lf[r]*ms + vs; + } + } + + float Lfrcp[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] *= ACC_TYPE(Lfrcp[r]); +#if defined(ACC_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX); +#endif + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (i * Br + tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp new file mode 100644 index 00000000..910da1ab --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -0,0 +1,320 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_buffer_reference : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_vote : enable +#extension GL_EXT_null_initializer : enable + +#include "types.glsl" +#include "dequant_funcs_cm2.glsl" +#include "flash_attn_base.glsl" + +layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; +layout (binding = 1) readonly buffer K {uint8_t data_k[];}; +layout (binding = 2) readonly buffer V {uint8_t data_v[];}; +layout (binding = 3) readonly buffer M {uint8_t data_m[];}; + +ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { + return max(x, y); +} + +ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) { + return x; +} + +// Replace matrix elements >= numRows or numCols with 'replace' +ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) { + if (row >= numRows || col >= numCols) { + return replace; + } + return elem; +} + +ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem) +{ + return exp(elem); +} + +ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1) +{ + return max(elem0, elem1); +} + +#if defined(BLOCK_SIZE) +#define DECODEFUNC , DEQUANTFUNC +#else +#define DECODEFUNC +#endif + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c < HSV) { + uint32_t offset = (iq2 + r) * HSV + c; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + init_indices(); + + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); + tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp); + + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); + +#if defined(BLOCK_SIZE) + tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); + tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); +#endif + + tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK); + tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK); + tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV); + + // hint to the compiler that strides are aligned for the aligned variant of the shader + if (Clamp != gl_CooperativeMatrixClampModeConstantNV) + { + q_stride &= ~7; +#if !defined(BLOCK_SIZE) + k_stride &= ~7; + v_stride &= ~7; +#endif + m_stride &= ~7; + } + tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); + tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); + tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); + + coopmat Q; + coopmat Qf16; + + uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; + coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad)); + + Qf16 = coopmat(Q); + Qf16 *= float16_t(p.scale); + + coopmat O = coopmat(0); + + coopmat L, M; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + L = coopmat(0); + M = coopmat(NEG_FLT_MAX_OVER_2); + + coopmat slopeMat = coopmat(1.0); + + // ALiBi + if (p.max_bias > 0.0f) { + coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); + } + + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + } + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + coopmat S = coopmat(0); + + coopmat K_T; + + uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC); + S = coopMatMulAdd(Qf16, K_T, S); + + if (p.logit_softcap != 0.0f) { + [[unroll]] + for (int k = 0; k < S.length(); ++k) { + S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); + } + } + + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + if (nem1_bounds_check) { + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + + coopmat mv; + + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + + S += slopeMat*coopmat(mv); + } else { + tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); + // Don't clamp against nem1 when GQA is enabled + uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1; + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + + coopmat mv; + + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + + S += slopeMat*coopmat(mv); + } + } + + // Clear padding elements to -inf, so they don't contribute to rowmax + if (Clamp != 0 && + ((j + 1) * Bc > KV || + (i + 1) * Br > N)) { + + uint R = ((i + 1) * Br > N) ? (N % Br) : Br; + uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; + + coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C); + } + + coopmat rowmax, P, rowsum, eM; + + coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); + + coopmat Mold = M; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + coopMatPerElementNV(M, rowmax, Max, Mold); + coopMatPerElementNV(P, S - M, Exp); + coopMatPerElementNV(eM, Mold - M, Exp); + + // Clear padding elements to 0, so they don't contribute to rowsum + if (Clamp != 0 && + ((j + 1) * Bc > KV || + (i + 1) * Br > N)) { + + uint R = ((i + 1) * Br > N) ? (N % Br) : Br; + uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; + + coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); + } + + coopmat P_A = coopmat(P); + + // compute rowsum by multiplying by matrix of all ones. + coopmat One = coopmat(1.0); + + rowsum = coopmat(0.0); + rowsum = coopMatMulAdd(P_A, One, rowsum); + + coopmat V; + uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC); + + L = eM*L + rowsum; + + // This is the "diagonal" matrix in the paper, but since we do componentwise + // multiply rather than matrix multiply it has the diagonal element smeared + // across the row + coopmat eMdiag; + + // resize eM by using smear/reduce + coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); + + // multiply with fp16 accumulation, then add to O. + coopmat PV = coopmat(0); + PV = coopMatMulAdd(P_A, V, PV); + + O = eMdiag * O + coopmat(PV); + } + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + coopmat O_D = coopmat(O); + + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); + coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + return; + } + + coopmat Ldiag; + + // resize L by using smear/reduce + coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); + + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + coopmat S; + coopMatPerElementNV(S, S, perElemOpGetSink, iq2); + + coopmat Mr; + + // resize M by using smear/reduce + coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce); + + // O, Ldiag, Mr all have the same type so all element locations match + [[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) { + ACC_TYPE sink = S[i]; + + ACC_TYPE ms = ACC_TYPE(1.0f); + ACC_TYPE vs = ACC_TYPE(1.0f); + + if (sink > Mr[i]) { + ms = exp(Mr[i] - sink); + + O[i] *= ms; + } else { + vs = exp(sink - Mr[i]); + } + + Ldiag[i] = Ldiag[i]*ms + vs; + } + } + + [[unroll]] + for (int k = 0; k < Ldiag.length(); ++k) { + Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; + } + + O = Ldiag*O; + +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + + coopmat O_D = coopmat(O); + if (p.gqa_ratio > 1) { + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + } else { + tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV); + + // permute dimensions + tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); + + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp new file mode 100644 index 00000000..06e83822 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -0,0 +1,120 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 1) readonly buffer B {float data_s[];}; +layout (binding = 2) writeonly buffer D {float data_d[];}; + +layout (push_constant) uniform parameter { + uint D; + uint N; + uint ne3; + uint k_num; + uint sinks; +} p; + +shared float tmpsh[BLOCK_SIZE]; + +void main() { + // Each workgroup handles a row + const uint n = gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + const uint iq3 = gl_WorkGroupID.z; + + uint D = p.D; + uint N = p.N; + uint k_num = p.k_num; + + uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n; + uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n; + uint lm_stride = N * 2; + + // Compute the max m value for the row + float m_max = -1.0/0.0; + for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) { + float m = data_a[m_offset + (k + tid) * lm_stride]; + m_max = max(m_max, m); + } + + // reduce across the workgroup + tmpsh[tid] = m_max; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + m_max = max(m_max, tmpsh[tid + s]); + tmpsh[tid] = m_max; + } + barrier(); + } + m_max = tmpsh[0]; + + barrier(); + + // Compute L based on m_max + float L = 0; + for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) { + float l = data_a[l_offset + (k + tid) * lm_stride]; + float m = data_a[m_offset + (k + tid) * lm_stride]; + L += exp(m - m_max) * l; + } + + // reduce across the workgroup + tmpsh[tid] = L; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + L += tmpsh[tid + s]; + tmpsh[tid] = L; + } + barrier(); + } + L = tmpsh[0]; + + float sink; + if (p.sinks != 0) { + sink = data_s[n]; + + float ms = 1.0f; + float vs = 1.0f; + + if (sink > m_max) { + ms = exp(m_max - sink); + } else { + vs = exp(sink - m_max); + } + + L = L*ms + vs; + } + + L = 1.0 / L; + + // D dimension is split across workgroups in the y dimension + uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE; + // Scale and sum the O contributions based on m_max and store the result to memory + if (d < D) { + float O = 0.0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; + float m = data_a[m_offset + k * lm_stride]; + O += exp(m - m_max) * data_a[o_offset]; + } + if (p.sinks != 0) { + if (sink > m_max) { + float ms = 1.0f; + ms = exp(m_max - sink); + O *= ms; + } + } + O *= L; + + const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF); + O = clamp(O, -FLT_MAX, FLT_MAX); + + data_d[iq3 * D * N + D * n + d] = O; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp new file mode 100644 index 00000000..e017b503 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp @@ -0,0 +1,13 @@ +#version 450 + +#include "glu_head.glsl" + +const float GELU_COEF_A = 0.044715f; +const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +float op(float a, float b) { + const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a); + return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b; +} + +#include "glu_main.glsl" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp new file mode 100644 index 00000000..759a1848 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "glu_head.glsl" + +// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation +// ref: https://www.johndcook.com/blog/python_erf/ +const float p_erf = 0.3275911f; +const float a1_erf = 0.254829592f; +const float a2_erf = -0.284496736f; +const float a3_erf = 1.421413741f; +const float a4_erf = -1.453152027f; +const float a5_erf = 1.061405429f; + +const float SQRT_2_INV = 0.70710678118654752440084436210484f; + +float op(float a, float b) { + const float a_div_sqr2 = a * SQRT_2_INV; + const float sign_x = sign(a_div_sqr2); + const float x = abs(a_div_sqr2); + const float t = 1.0f / (1.0f + p_erf * x); + const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + const float erf_approx = sign_x * y; + + return 0.5f * a * (1.0f + erf_approx) * b; +} + +#include "glu_main.glsl" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp new file mode 100644 index 00000000..c4032ab2 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp @@ -0,0 +1,11 @@ +#version 450 + +#include "glu_head.glsl" + +const float GELU_QUICK_COEF = -1.702f; + +float op(float a, float b) { + return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b; +} + +#include "glu_main.glsl" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp new file mode 100644 index 00000000..a95c2525 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp @@ -0,0 +1,25 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float xi = float(data_a[i]); + const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi); + data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp new file mode 100644 index 00000000..58375aba --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp @@ -0,0 +1,39 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation + // ref: https://www.johndcook.com/blog/python_erf/ + const float p_erf = 0.3275911f; + const float a1_erf = 0.254829592f; + const float a2_erf = -0.284496736f; + const float a3_erf = 1.421413741f; + const float a4_erf = -1.453152027f; + const float a5_erf = 1.061405429f; + + const float SQRT_2_INV = 0.70710678118654752440084436210484f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float a = float(data_a[i]); + const float a_div_sqr2 = a * SQRT_2_INV; + const float sign_x = sign(a_div_sqr2); + const float x = abs(a_div_sqr2); + const float t = 1.0f / (1.0f + p_erf * x); + const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + const float erf_approx = sign_x * y; + + data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp new file mode 100644 index 00000000..bfdfe218 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp @@ -0,0 +1,23 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const float GELU_QUICK_COEF = -1.702f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x)))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl new file mode 100644 index 00000000..99595fc6 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -0,0 +1,51 @@ +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +#include "rte.glsl" +#include "utils.glsl" + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; + uint misalign_offsets; + float param1; float param2; int param3; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +// true if src0/src1 are the same shape and the indices can be reused without additional modulus +layout(constant_id = 0) const bool norepeat = false; + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; } +uint get_doffset() { return p.misalign_offsets & 0xFF; } + + +void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) { + get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03); +} + +uint src0_idx(uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; +} + +uint src1_idx(uint i00, uint i01, uint i02, uint i03) { + if (norepeat) { + return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10; + } else { + return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10; + } +} + +uint dst_idx(uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl new file mode 100644 index 00000000..66e46ae6 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl @@ -0,0 +1,9 @@ +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + float param1; + float param2; +} p; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl new file mode 100644 index 00000000..8dc9d360 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl @@ -0,0 +1,76 @@ +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint misalign_offsets; + float param1; float param2; + + uint ne0_012mp; uint ne0_012L; + uint ne0_01mp; uint ne0_01L; + uint ne0_0mp; uint ne0_0L; + uint ne1_012mp; uint ne1_012L; + uint ne1_01mp; uint ne1_01L; + uint ne1_0mp; uint ne1_0L; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + +uint src0_idx(uint idx) { + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; +} + +uint dst_idx(uint idx) { + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; +} + +uint src0_idx_quant(uint idx, uint qk) { + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + (i00/qk)*p.nb00; +} + +uint dst_idx_quant(uint idx, uint qk) { + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + (i10/qk)*p.nb10; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp new file mode 100644 index 00000000..76d83041 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint i00 = gl_GlobalInvocationID.x; + + if (i00 >= p.ne00) { + return; + } + + uint gid_z = gl_GlobalInvocationID.z; + while (gid_z < p.ne11 * p.ne12) { + uint gid_y = gl_GlobalInvocationID.y; + while (gid_y < p.ne10) { + const uint i10 = gid_y; + const uint i11 = gid_z / p.ne12; + const uint i12 = gid_z % p.ne12; + + const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + + const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + +#if defined(DATA_A_BF16) + FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00])); +#else + FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]); +#endif +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[d_offset + i00] = D_TYPE(v); +#else + data_d[d_offset + i00] = D_TYPE(v); +#endif + gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp new file mode 100644 index 00000000..9dba437e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -0,0 +1,51 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "types.glsl" +#include "generic_binary_head.glsl" +#include "dequant_funcs.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint i00 = (gl_GlobalInvocationID.x)*2; + +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + if (i00 >= p.ne00) { + return; + } + + uint gid_z = gl_GlobalInvocationID.z; + while (gid_z < p.ne11 * p.ne12) { + uint gid_y = gl_GlobalInvocationID.y; + while (gid_y < p.ne10) { + const uint i10 = gid_y; + const uint i11 = gid_z / p.ne12; + const uint i12 = gid_z % p.ne12; + + const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + + const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + + const uint ib = a_offset + i00/QUANT_K; // block index + const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index + const uint iybs = i00 - i00%QUANT_K; // dst block start index + const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + + vec2 v = dequantize(ib, iqs, 0); + const vec2 dm = get_dm(ib, 0); + v = v * dm.x + dm.y; + + data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); + data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); + + gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl new file mode 100644 index 00000000..21689893 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl @@ -0,0 +1,19 @@ +#extension GL_EXT_shader_16bit_storage : require + +#include "rte.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {A_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +layout (push_constant) uniform parameter +{ + uint N; + uint ne00; + uint ne20; + uint mode; + float alpha; + float limit; +} p; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl new file mode 100644 index 00000000..85cf65a9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl @@ -0,0 +1,29 @@ +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.N) { + return; + } + + const uint row = i / p.ne20; + const uint col = i - row * p.ne20; + + if (p.mode == 0) { + // Default + const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; + + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + } else if (p.mode == 1) { + // Swapped + const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; + + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + } else { + // Split + const uint idx = row * p.ne00 + col; + + data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp new file mode 100644 index 00000000..bdf97dbb --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp @@ -0,0 +1,66 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared float tmp[BLOCK_SIZE]; + +void main() { + const uint group_size = p.KX; + const float eps = p.param1; + + const uint tid = gl_LocalInvocationID.x; + const uint start = gl_WorkGroupID.x * group_size + tid; + const uint end = (gl_WorkGroupID.x + 1) * group_size; + + tmp[tid] = 0.0f; + + // Calculate mean + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + tmp[tid] += float(data_a[col]); + } + + // tmp up partial tmps and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + const float mean = tmp[0] / group_size; + barrier(); + tmp[tid] = 0.0f; + + // Calculate variance + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + const float xi = float(data_a[col]) - mean; + data_d[col] = D_TYPE(xi); + tmp[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + const float variance = tmp[0] / group_size; + const float scale = inversesqrt(variance + eps); + + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + data_d[col] *= D_TYPE(scale); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp new file mode 100644 index 00000000..b4dbdf31 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(min(1.0f, max(0.0f, (x + 3.0f) / 6.0f))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp new file mode 100644 index 00000000..1ec31591 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp new file mode 100644 index 00000000..1827d647 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -0,0 +1,103 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +#include "rte.glsl" +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + BDA_STORAGE_T dst_addr; + uint batch_offset; uint offset_delta; + uint IC; + uint IW; uint IH; + uint OW; uint OH; + uint KW; uint KH; + uint pelements; + uint CHW; + int s0; int s1; + int p0; int p1; + int d0; int d1; +} p; + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +const uint NUM_ITER = 512 / BLOCK_SIZE; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +#if BDA +layout (buffer_reference) buffer D_ptr {D_TYPE d;}; +#endif + +void main() { + const uint gidx = gl_GlobalInvocationID.x; + + const uint oh = gl_GlobalInvocationID.y; + const uint batch = gl_GlobalInvocationID.z / p.IC; + const uint ic = gl_GlobalInvocationID.z % p.IC; + + const uint src_base = ic * p.offset_delta + batch * p.batch_offset; + const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); + const int oh_s1 = int(oh) * p.s1; + const uint ksize = p.OW * p.KH; + + const uint base_linear_idx = gidx * NUM_ITER; + + uint current_kx = base_linear_idx / ksize; + const uint rem = base_linear_idx - (current_kx * ksize); + uint current_ky = rem / p.OW; + uint current_ix = rem % p.OW; + + A_TYPE values[NUM_ITER]; + BDA_OFFSET_T offset_dst[NUM_ITER]; + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + values[idx] = A_TYPE(0); + } + + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + + const uint linear_idx = base_linear_idx + idx; + + if (linear_idx >= p.pelements) { + continue; + } + + const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; + const uint iih = oh_s1 + current_ky * p.d1 - p.p1; + + offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx; + + if ((iih < p.IH) && (iiw < p.IW)) { + values[idx] = data_a[src_base + iih * p.IW + iiw]; + } + + if (++current_ix == p.OW) { + current_ix = 0; + if (++current_ky == p.KH) { + current_ky = 0; + current_kx++; + } + } + } + + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + + const uint linear_idx = base_linear_idx + idx; + + if (linear_idx >= p.pelements) { + continue; + } + +#if BDA + D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]); + dst_addr.d = D_TYPE(values[idx]); +#else + data_d[offset_dst[idx]] = D_TYPE(values[idx]); +#endif + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp new file mode 100644 index 00000000..4bf8b4ca --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp @@ -0,0 +1,125 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "rte.glsl" +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + BDA_STORAGE_T dst_addr; + uint32_t nb10; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t s0; + uint32_t s1; + uint32_t s2; + uint32_t p0; + uint32_t p1; + uint32_t p2; + uint32_t d0; + uint32_t d1; + uint32_t d2; + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t IC; + uint32_t KW; + uint32_t OH; + uint32_t KD_KH_KW; + uint32_t KH_KW; + uint32_t IC_KD_KH_KW; + uint32_t N_OD_OH; + uint32_t OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW; + uint32_t misalign_offsets; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +#if BDA +layout (buffer_reference) buffer D_ptr {D_TYPE d;}; +#endif + +void main() { + const uint32_t i = gl_GlobalInvocationID.x; + + uint32_t nb10 = p.nb10; + uint32_t nb11 = p.nb11; + uint32_t nb12 = p.nb12; + uint32_t nb13 = p.nb13; + uint32_t s0 = p.s0; + uint32_t s1 = p.s1; + uint32_t s2 = p.s2; + uint32_t p0 = p.p0; + uint32_t p1 = p.p1; + uint32_t p2 = p.p2; + uint32_t d0 = p.d0; + uint32_t d1 = p.d1; + uint32_t d2 = p.d2; + uint32_t IW = p.IW; + uint32_t IH = p.IH; + uint32_t ID = p.ID; + uint32_t IC = p.IC; + uint32_t KW = p.KW; + uint32_t OH = p.OH; + uint32_t KD_KH_KW = p.KD_KH_KW; + uint32_t KH_KW = p.KH_KW; + uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW; + uint32_t N_OD_OH = p.N_OD_OH; + uint32_t OD_OH = p.OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW; + + if (i >= IC_KD_KH_KW) { + return; + } + + const uint32_t iic = i / KD_KH_KW; + const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW; + const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; + const uint32_t ikw = i % KW; + + const uint32_t iow = gl_GlobalInvocationID.y; + for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) { + const uint32_t in_ = iz / OD_OH; + const uint32_t iod = (iz - in_*OD_OH) / OH; + const uint32_t ioh = iz % OH; + + const uint32_t iiw = iow * s0 + ikw * d0 - p0; + const uint32_t iih = ioh * s1 + ikh * d1 - p1; + const uint32_t iid = iod * s2 + ikd * d2 - p2; + + const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + + const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10; +#if BDA + D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst); + if (iih >= IH || iiw >= IW || iid >= ID) { + dst_addr.d = D_TYPE(0.0f); + } else { + dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]); + } +#else + if (iih >= IH || iiw >= IW || iid >= ID) { + data_d[offset_dst + get_doffset()] = D_TYPE(0.0f); + } else { + data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]); + } +#endif + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp new file mode 100644 index 00000000..83ef2f87 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + sum[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1))); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp new file mode 100644 index 00000000..b281e855 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float val = float(data_a[i]); + data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp new file mode 100644 index 00000000..02ef1eac --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp new file mode 100644 index 00000000..4c64fd47 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp @@ -0,0 +1,48 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 0) readonly buffer A4 {vec4 data_a4[];}; +layout (binding = 1) writeonly buffer D {float data_d[];}; +layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];}; + +layout (push_constant) uniform parameter { + uint ne; + uint k_num; +} p; + +void main() { + // Each invocation handles four consecutive components + const uint idx = gl_GlobalInvocationID.x * 4; + + if (idx >= p.ne) { + return; + } + + // Check if all four components are in bounds and aligned, + // then use vector loads + if (idx + 3 < p.ne && (p.ne % 4) == 0) { + vec4 result = vec4(0.0f); + + [[unroll]] for (uint i = 0; i < p.k_num; i++) { + result += data_a4[(i * p.ne + idx) / 4]; + } + + data_d4[idx / 4] = result; + } else { + [[unroll]] for (uint j = 0; j < 4; ++j) { + if (idx + j < p.ne) { + float result = 0.0f; + + [[unroll]] for (uint i = 0; i < p.k_num; i++) { + result += data_a[i * p.ne + idx + j]; + } + + data_d[idx + j] = result; + } + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp new file mode 100644 index 00000000..9a03925c --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -0,0 +1,169 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16) +#define K_PER_ITER 8 +#else +#define K_PER_ITER 2 +#endif + + +uint a_offset, b_offset, d_offset, y_offset; + +void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) +{ + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; + const uint iqs = (col%QUANT_K)/QUANT_R; // quant index + const uint iybs = col - col%QUANT_K; // y block start index + +#if K_PER_ITER == 8 +#if QUANT_R == 2 + const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); + const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]); + const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); + const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w); +#else + const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); + const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); +#endif +#else + // Check if the second of the pair of elements is OOB, and don't fetch B or + // accumulate it. We still fetch a pair of elements for A, which is fine for + // quantized formats since they'll be within the same block. We should + // probably skip fetching the second element for F16/F32, but as of now we + // still do. + const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); + + FLOAT_TYPE b0 = 0, b1 = 0; + b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]); + if (!OOB) { + b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); + } +#endif + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib = (ibi + col)/QUANT_K; // block index + ibi += p.ncols; + +#if K_PER_ITER == 8 + vec4 v = dequantize4(ib, iqs, a_offset); + vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset); + + const vec2 dm = get_dm(ib, a_offset); + if (dm.y != 0) { // quant has min component + v = v * dm.x + dm.y; + v2 = v2 * dm.x + dm.y; + } + + // matrix multiplication + FLOAT_TYPE rowtmp = dot(bv0, v); + rowtmp += dot(bv1, v2); + + if (dm.y == 0) + rowtmp *= dm.x; + + temp[j][n] += rowtmp; +#else + const vec2 v = dequantize(ib, iqs, a_offset); + + // matrix multiplication + temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]); + if (!OOB) { + temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]); + } +#endif + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + const uint tid = gl_LocalInvocationID.x; + + get_offsets(a_offset, b_offset, d_offset); + a_offset /= QUANT_K; + + y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); + if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) { + num_iters++; + } + int unroll_count = 4; + uint unrolled_iters = num_iters & ~(unroll_count - 1); + +#if K_PER_ITER == 2 + // If the K dimension is odd, we need lastiter==true on the last iteration + // so OOB is computed correctly. Skip some unrolling to make that happen. + if ((p.ncols & 1) != 0 && + unrolled_iters == num_iters && + unrolled_iters > 0) { + unrolled_iters -= unroll_count; + } +#endif + + uint i = 0; + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); + i++; + } + } + + unroll_count = 2; + unrolled_iters = num_iters & ~(unroll_count - 1); + +#if K_PER_ITER == 2 + if ((p.ncols & 1) != 0 && + unrolled_iters == num_iters && + unrolled_iters > 0) { + unrolled_iters -= unroll_count; + } +#endif + + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); + i++; + } + } + while (i < num_iters) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true); + i++; + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl new file mode 100644 index 00000000..450dee04 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -0,0 +1,182 @@ +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_8bit_storage : require + +#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#endif + +#ifdef MUL_MAT_ID +#define EXPERT_COUNT 8 +#endif + +#include "types.glsl" + +#ifndef MMQ +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#else +layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#endif + +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +#ifdef B_TYPE_VEC2 +layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; +#endif +#ifdef B_TYPE_VEC4 +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; +#endif + +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +#include "dequant_funcs.glsl" + +layout (push_constant) uniform parameter +{ + uint ncols; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint ne11; +#else + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.y; +#else + const uint batch_idx = gl_GlobalInvocationID.y; +#endif + +#ifndef MUL_MAT_ID + uint batch_idx_a = 0; + if (batch_idx != 0) { + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + batch_idx_a = i03 * p.ne02 + i02; + } +#else + const uint expert_id = data_ids[expert_idx]; +#endif + + a_offset = +#ifdef MUL_MAT_ID + expert_id * p.batch_stride_a; +#else + batch_idx_a * p.batch_stride_a; +#endif + b_offset = +#ifdef MUL_MAT_ID + (expert_idx % p.ne11) * p.stride_b; +#else + batch_idx * p.batch_stride_b; +#endif + d_offset = +#ifdef MUL_MAT_ID + expert_idx * p.stride_d; +#else + batch_idx * p.batch_stride_d; +#endif +} + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; +layout (constant_id = 1) const uint NUM_ROWS = 1; +layout (constant_id = 2) const uint NUM_COLS = 1; + +#ifdef USE_SUBGROUP_ADD_NO_SHMEM +void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = subgroupAdd(temp[j][n]); + } + } + + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]); + } + } + } +} +#else +shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE]; + +void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { + // subgroupAdd is probably faster on devices that support it, + // particularly when the workgroup has more than one subgroup +#if USE_SUBGROUP_ADD + // sum up partial sums within a subgroup + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = subgroupAdd(temp[j][n]); + } + } + + // Go through shared memory to sum partials across subgroups + if (gl_SubgroupInvocationID == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][gl_SubgroupID] = temp[j][n]; + } + } + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = FLOAT_TYPE(0); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + temp[j][n] += tmpsh[j][n][s]; + } + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]); + } + } + } +#else + // sum up partial sums and write back result + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][tid] = temp[j][n]; + } + } + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][tid] += tmpsh[j][n][tid + s]; + } + } + } + barrier(); + } + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]); + } + } + } +#endif +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp new file mode 100644 index 00000000..4cb29238 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp @@ -0,0 +1,82 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint16_t[4] scales = data_a[ibi].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1)); + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1)); + const uint qs = data_a[ibi].qs[4 * ib32 + l]; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1); + + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int k = 0; k < 4; ++k) { + sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, + fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); + } + temp[j][n] = fma(dl, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp new file mode 100644 index 00000000..0b74b332 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp @@ -0,0 +1,79 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint qh = data_a[ibi].qh[ib32]; + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qs = data_a[ibi].qs[4 * ib32 + l]; + const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int k = 0; k < 4; ++k) { + sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, + fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); + } + temp[j][n] = fma(dl, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp new file mode 100644 index 00000000..e424af12 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp @@ -0,0 +1,90 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint nibble_shift = 4 * (itid & 1); + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; + const float db = d * (0.5 + scale) * 0.25; + + const uint qh = data_a[ibi].qh[ib32]; + const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147 + const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy; + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint8_t sign = sign16[l]; + const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300); + const uvec2 grid = iq2s_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp new file mode 100644 index 00000000..0cd906db --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp @@ -0,0 +1,87 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint nibble_shift = 4 * (itid & 1); + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; + const float db = d * (0.5 + scale) * 0.25; + + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs = data_a[ibi].qs[2 * itid + l]; + const uint sign = qs >> 9; + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x)); + const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp new file mode 100644 index 00000000..71bd72d1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp @@ -0,0 +1,87 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint signscale = pack32(u16vec2( + data_a_packed16[ibi].qs[4 * ib32 + 2], + data_a_packed16[ibi].qs[4 * ib32 + 3])); + const float db = d * 0.25 * (0.5 + (signscale >> 28)); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs = data_a[ibi].qs[8 * ib32 + 2 * (itid & 1) + l]; + const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7); + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq2xxs_grid[qs].x)); + const vec4 grid1 = vec4(unpack8(iq2xxs_grid[qs].y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp new file mode 100644 index 00000000..a4b9ab1f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp @@ -0,0 +1,90 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const float dscale = d * (1 + 2 * scale); + const uint qh = data_a[ibi].qh[ib32]; + FLOAT_TYPE sum[NUM_COLS]; + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + sum[j] = 0.0; + } + [[unroll]] for (uint l = 0; l < 4; ++l) { + const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147 + const uint sign = data_a[ibi].signs[4 * ib32 + l]; + const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)])); + const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)])); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + sum[j] = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w), + sum[j])))))))); + } + } + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + temp[j][n] = fma(dscale, sum[j], temp[j][n]); + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp new file mode 100644 index 00000000..40849c69 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp @@ -0,0 +1,88 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint signscale = pack32(u16vec2( + data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32], + data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32 + 1])); + const float db = d * 0.5 * (0.5 + (signscale >> 28)); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs0 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l]; + const uint qs1 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l + 1]; + const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7); + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq3xxs_grid[qs0])); + const vec4 grid1 = vec4(unpack8(iq3xxs_grid[qs1])); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp new file mode 100644 index 00000000..638878d9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp @@ -0,0 +1,122 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#define BLOCK_SIZE 32 +#define FLOAT_TYPE float + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; + +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +layout (push_constant) uniform parameter +{ + uint ncols_x; + uint nrows_x; + uint row_stride_x; + uint channel_stride_x; + uint channel_stride_y; + uint channel_x_divisor; + uint ne12; + uint b_offset; + uint d_offset; + uint nb03; + uint nb13; + uint nb23; +} p; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint row_x = gl_GlobalInvocationID.y; + const uint channel = gl_GlobalInvocationID.z; + const uint i3 = gl_WorkGroupID.x; + const uint channel_x = channel / p.channel_x_divisor; + const uint channel_y = channel % p.ne12; + + const uint nrows_y = p.ncols_x; + const uint nrows_dst = p.nrows_x; + const uint row_dst = row_x; + + const uint idst = i3*p.nb23 + channel*nrows_dst + row_dst; + + FLOAT_TYPE temp = 0.0f; + + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0; + + for (uint col_x0 = 0; col_x0 < p.ncols_x;) { + + // Unroll 2x and do vec4 loads if aligned + const uint unroll_count = 2; + if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) { + [[unroll]] for (uint i = 0; i < unroll_count; ++i) { + const uint col_x = col_x0 + 4*tid; + + const uint row_y = col_x; + + const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y; + + const vec4 av4 = vec4(data_a_v4[ix / 4]); + const vec4 bv4 = vec4(data_b_v4[iy / 4]); + + temp += dot(av4, bv4); + + col_x0 += 4*BLOCK_SIZE; + } + // do vec4 loads if aligned + } else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { + const uint col_x = col_x0 + 4*tid; + + const uint row_y = col_x; + + const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y; + + const vec4 av4 = vec4(data_a_v4[ix / 4]); + const vec4 bv4 = vec4(data_b_v4[iy / 4]); + + temp += dot(av4, bv4); + + col_x0 += 4*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + if (col_x >= p.ncols_x) { + break; + } + + const uint row_y = col_x; + + const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y; + + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp); + col_x0 += BLOCK_SIZE; + } + } + + tmp[tid] = temp; + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + if (tid == 0) { + dst[idst] = tmp[0]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp new file mode 100644 index 00000000..7aa070ee --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp @@ -0,0 +1,154 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif + +#define FLOAT_TYPE float + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; + +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +layout(constant_id = 0) const int BLOCK_SIZE = 32; +// gqa_ratio is in the range [1,8] +layout(constant_id = 1) const uint gqa_ratio = 1; + +layout (push_constant) uniform parameter +{ + uint ncols_x; + uint nrows_x; + uint nchannels_x; + uint nchannels_y; + uint b_offset; + uint d_offset; +} p; + +#if !USE_SUBGROUP_ADD +shared FLOAT_TYPE tmp[8][BLOCK_SIZE]; +#endif + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint row_x = gl_GlobalInvocationID.y; + + uint channel, channel_x; + + // When gqa_ratio > 1, each invocation does multiple rows. + // The row in the A matrix is starting from channel / gqa_ratio and the + // rows in the B matrix are [channel, channel+gqa_ratio). + // When gpa_ratio is 1, each invocation does one row. + if (gqa_ratio > 1) { + channel_x = gl_GlobalInvocationID.z; + channel = channel_x * gqa_ratio; + } else { + channel = gl_GlobalInvocationID.z; + channel_x = channel / (p.nchannels_y / p.nchannels_x);; + } + + const uint nrows_y = p.ncols_x; + const uint nrows_dst = p.nrows_x; + const uint row_dst = row_x; + + FLOAT_TYPE temp[8]; + [[unroll]] for (uint i = 0; i < 8; ++i) { + temp[i] = FLOAT_TYPE(0.0f); + } + + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0; + + for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { + + // Use vec4 loads if aligned + if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { + + uint col_x = col_x0 + 4*tid; + const uint row_y = col_x; + + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const vec4 av4 = vec4(data_a_v4[ix / 4]); + + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; + + vec4 bv4 = data_b_v4[iy / 4]; + temp[c] += dot(av4, bv4); + } + + col_x0 += 3*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + + if (col_x >= p.ncols_x) { + break; + } + + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + const uint row_y = col_x; + + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; + + temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]); + } + } + } + +#if USE_SUBGROUP_ADD + // reduce vec4 at a time + vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]); + t = subgroupAdd(t); + temp[0] = t[0]; + temp[1] = t[1]; + temp[2] = t[2]; + temp[3] = t[3]; + if (gqa_ratio > 4) { + t = vec4(temp[4], temp[5], temp[6], temp[7]); + t = subgroupAdd(t); + temp[4] = t[0]; + temp[5] = t[1]; + temp[6] = t[2]; + temp[7] = t[3]; + } +#else + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + tmp[c][tid] = temp[c]; + } + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] += tmp[c][tid + s]; + tmp[c][tid] = temp[c]; + } + } + barrier(); + } + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] = tmp[c][tid]; + } +#endif + + if (tid == 0) { + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // dst is not transposed and not permuted + const uint idst = (channel + c)*nrows_dst + row_dst; + dst[idst] = temp[c]; + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp new file mode 100644 index 00000000..03ed25d3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -0,0 +1,130 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16]; +shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; + + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) { + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + } + barrier(); + + if (i >= num_blocks_per_row) + continue; + } else { + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + barrier(); + } + + const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); + FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[csel][ix][ 8*v_im] * qs_u32_0[l ], + fma(FLOAT_TYPE(b16[l]), sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2], + fma(FLOAT_TYPE(b32[l]), sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l ], + fma(FLOAT_TYPE(b48[l]), sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2], + fma(FLOAT_TYPE(b64[l]), sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l ], + fma(FLOAT_TYPE(b80[l]), sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2], + fma(FLOAT_TYPE(b96[l]), sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l ], + fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[csel][ix][ 8*v_im], + fma(FLOAT_TYPE(b16[l]), sccache2[csel][ix][1 + 8*v_im], + fma(FLOAT_TYPE(b32[l]), sccache2[csel][ix][2 + 8*v_im], + fma(FLOAT_TYPE(b48[l]), sccache2[csel][ix][3 + 8*v_im], + fma(FLOAT_TYPE(b64[l]), sccache2[csel][ix][4 + 8*v_im], + fma(FLOAT_TYPE(b80[l]), sccache2[csel][ix][5 + 8*v_im], + fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im], + fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2)))))))); + } + temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...7 + + const uint l0 = 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 128*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp new file mode 100644 index 00000000..528f224d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -0,0 +1,132 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; + + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) + sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16)); + const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2)); + const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2)); + const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2)); + const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2)); + + // 0, 1, 16, 17 + uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8); + qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16; + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + if (all_threads) { + sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum = fma(FLOAT_TYPE( b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], + fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], + fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], + fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], + fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], + fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], + fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], + fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); + } + temp[j][n] = fma(d, sum, temp[j][n]); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + const uint itid8 = itid%8; + + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_im4 = v_im*4; + const uint v_in = itid - 8*v_im; // 0...7 + + const uint32_t m = 0x01010101 << (4 * v_im); + uint32_t hm_m[4]; + [[unroll]] for (uint j = 0; j < 4; ++j) + hm_m[j] = m << j; + + const uint l0 = 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 128*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint s_shift = v_im4 + 2*(itid8/4); + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp new file mode 100644 index 00000000..21d07d2e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -0,0 +1,136 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; + const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; + + const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; + const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; + const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; + + const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4)); + const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4)); + const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4)); + const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_lo4.x; + const FLOAT_TYPE q4_1 = qs0_lo4.y; + const FLOAT_TYPE q4_2 = qs0_lo4.z; + const FLOAT_TYPE q4_3 = qs0_lo4.w; + const FLOAT_TYPE q4_4 = qs0_hi4.x; + const FLOAT_TYPE q4_5 = qs0_hi4.y; + const FLOAT_TYPE q4_6 = qs0_hi4.z; + const FLOAT_TYPE q4_7 = qs0_hi4.w; + const FLOAT_TYPE q4_8 = qs64_lo4.x; + const FLOAT_TYPE q4_9 = qs64_lo4.y; + const FLOAT_TYPE q4_10 = qs64_lo4.z; + const FLOAT_TYPE q4_11 = qs64_lo4.w; + const FLOAT_TYPE q4_12 = qs64_hi4.x; + const FLOAT_TYPE q4_13 = qs64_hi4.y; + const FLOAT_TYPE q4_14 = qs64_hi4.z; + const FLOAT_TYPE q4_15 = qs64_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]); + vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]); + vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]); + vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]); + + const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); + const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); + const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); + const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, + fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, + fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, + fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...3 + const uint n = 4; + + const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const uint v_in = il % 2; + + const uint l0 = n * (2 * ir + v_in); // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 64*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp new file mode 100644 index 00000000..9e46c89a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -0,0 +1,167 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); + + uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; + uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; + uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; + uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); + + const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; + const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; + const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); + const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; + + qs0_16_u32_lo4 += qs0_16_lo4_offset16; + qs0_16_u32_hi4 += qs0_16_hi4_offset16; + qs64_80_u32_lo4 += qs64_80_lo4_offset16; + qs64_80_u32_hi4 += qs64_80_hi4_offset16; + + const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4)); + const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4)); + const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4)); + const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_16_lo4.x; + const FLOAT_TYPE q4_1 = qs0_16_lo4.y; + const FLOAT_TYPE q4_2 = qs0_16_lo4.z; + const FLOAT_TYPE q4_3 = qs0_16_lo4.w; + const FLOAT_TYPE q4_4 = qs0_16_hi4.x; + const FLOAT_TYPE q4_5 = qs0_16_hi4.y; + const FLOAT_TYPE q4_6 = qs0_16_hi4.z; + const FLOAT_TYPE q4_7 = qs0_16_hi4.w; + const FLOAT_TYPE q4_8 = qs64_80_lo4.x; + const FLOAT_TYPE q4_9 = qs64_80_lo4.y; + const FLOAT_TYPE q4_10 = qs64_80_lo4.z; + const FLOAT_TYPE q4_11 = qs64_80_lo4.w; + const FLOAT_TYPE q4_12 = qs64_80_hi4.x; + const FLOAT_TYPE q4_13 = qs64_80_hi4.y; + const FLOAT_TYPE q4_14 = qs64_80_hi4.z; + const FLOAT_TYPE q4_15 = qs64_80_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]); + vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]); + vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]); + vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]); + vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]); + vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]); + vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]); + vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]); + + const FLOAT_TYPE sx = + fma(FLOAT_TYPE(by10.x), q4_0, + fma(FLOAT_TYPE(by10.y), q4_1, + fma(FLOAT_TYPE(by116.x), q4_2, + FLOAT_TYPE(by116.y) * q4_3))); + const FLOAT_TYPE sy = + fma(FLOAT_TYPE(by132.x), q4_4, + fma(FLOAT_TYPE(by132.y), q4_5, + fma(FLOAT_TYPE(by148.x), q4_6, + FLOAT_TYPE(by148.y) * q4_7))); + const FLOAT_TYPE sz = + fma(FLOAT_TYPE(by20.x), q4_8, + fma(FLOAT_TYPE(by20.y), q4_9, + fma(FLOAT_TYPE(by216.x), q4_10, + FLOAT_TYPE(by216.y) * q4_11))); + const FLOAT_TYPE sw = + fma(FLOAT_TYPE(by232.x), q4_12, + fma(FLOAT_TYPE(by232.y), q4_13, + fma(FLOAT_TYPE(by248.x), q4_14, + FLOAT_TYPE(by248.y) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, + fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, + fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, + (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...3 + + const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const uint v_in = il % 2; + + const uint l0 = 4*ir + 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 64*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp new file mode 100644 index 00000000..d7a7f642 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -0,0 +1,130 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; + + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) + sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); + const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); + + const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; + const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; + const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); + const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; + const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; + const uint32_t qh4_u32 = (qh_u32 & 0x30303030); + const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; + + const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; + const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; + const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; + const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; + + const vec4 q0 = vec4(unpack8(q0_u32)) - 32; + const vec4 q1 = vec4(unpack8(q1_u32)) - 32; + const vec4 q2 = vec4(unpack8(q2_u32)) - 32; + const vec4 q3 = vec4(unpack8(q3_u32)) - 32; + + if (all_threads) { + sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]); + vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]); + vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]); + vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]); + + FLOAT_TYPE sum[4] = {0, 0, 0, 0}; + [[unroll]] for (uint l = 0; l < 4; ++l) { + sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]); + sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]); + sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]); + sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]); + } + temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]); + } + } +} + +void compute_outputs(const uint first_row, const uint num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...7 + + const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 + const uint is = v_in / 4; + + const uint ql_offset = 64*v_im + l0; + const uint qh_offset = 32*v_im + l0; + const uint s_offset = 8*v_im + is; + const uint y_offset = 128*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp new file mode 100644 index 00000000..64293f6e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -0,0 +1,140 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_integer_dot_product : require + +#define MMQ +#define B_TYPE block_q8_1_x4 + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#define K_PER_ITER 8 + +#include "mul_mmq_funcs.glsl" + +uint a_offset, b_offset, d_offset; + +int32_t cache_b_qs[2]; +vec2 cache_b_ds; + +void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + tid*K_PER_ITER; + + // Preload data_b block + const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset; + const uint b_qs_idx = tid % 4; + const uint b_block_idx_outer = b_block_idx / 4; + const uint b_block_idx_inner = b_block_idx % 4; + cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]); + +#if QUANT_R == 2 + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4]; +#else + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1]; +#endif + + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint a_block_idx = (ibi + col)/QUANT_K + a_offset; + ibi += p.ncols; + + int32_t q_sum = 0; +#if QUANT_R == 2 + const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx); + q_sum += dotPacked4x8EXT(data_a_qs.x, + cache_b_qs[0]); + q_sum += dotPacked4x8EXT(data_a_qs.y, + cache_b_qs[1]); +#else + int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[0]); + data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[1]); +#endif + +#if QUANT_AUXF == 1 + temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4); +#else + temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4); +#endif + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + const uint tid = gl_LocalInvocationID.x; + + get_offsets(a_offset, b_offset, d_offset); + a_offset /= QUANT_K; + b_offset /= QUANT_K_Q8_1; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = FLOAT_TYPE(0.0f); + } + } + + uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); + if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) { + num_iters++; + } + int unroll_count = 4; + uint unrolled_iters = num_iters & ~(unroll_count - 1); + + uint i = 0; + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + + unroll_count = 2; + unrolled_iters = num_iters & ~(unroll_count - 1); + +#if K_PER_ITER == 2 + if ((p.ncols & 1) != 0 && + unrolled_iters == num_iters && + unrolled_iters > 0) { + unrolled_iters -= unroll_count; + } +#endif + + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + while (i < num_iters) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp new file mode 100644 index 00000000..85400ac5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -0,0 +1,481 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif +#if defined(DATA_A_IQ1_M) +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#if defined(DATA_A_BF16) && defined(COOPMAT) +#extension GL_EXT_bfloat16 : enable +#endif + +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_memory_scope_semantics : enable +#endif + +#if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS) +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#endif + +#ifdef MUL_MAT_ID +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#include "types.glsl" + +#ifndef LOAD_VEC_A +#define LOAD_VEC_A 1 +#endif +#ifndef LOAD_VEC_B +#define LOAD_VEC_B 1 +#endif + +// Load 2 values at once without affecting index calculations through LOAD_VEC +#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED) +#define LOAD_VEC_BATCH_A 2 +#else +#define LOAD_VEC_BATCH_A 1 +#endif +#if !defined(ALIGNED) +#define LOAD_VEC_BATCH_B 2 +#else +#define LOAD_VEC_BATCH_B 1 +#endif + +#if !defined(TO_FLOAT_TYPE) +#define TO_FLOAT_TYPE FLOAT_TYPE +#endif + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +layout (constant_id = 0) const uint BLOCK_SIZE = 64; +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant +layout (constant_id = 4) const uint WM = 32; +layout (constant_id = 5) const uint WN = 32; +layout (constant_id = 6) const uint WMITER = 2; +layout (constant_id = 7) const uint TM = 4; +layout (constant_id = 8) const uint TN = 2; +layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat +layout (constant_id = 10) const uint WARP = 32; + +#ifdef COOPMAT +#define SHMEM_STRIDE (BK / 2 + 4) +#else +#define SHMEM_STRIDE (BK / 2 + 1) +#endif + +shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE]; +shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE]; + +#define NUM_WARPS (BLOCK_SIZE / WARP) + +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[BN]; +uint _ne1; + +#ifdef MUL_MAT_ID_USE_SUBGROUPS +shared uvec4 ballots_sh[NUM_WARPS]; + +void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + uint nei0shift = findLSB(p.nei0); + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_LocalInvocationIndex; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + + ballots_sh[gl_SubgroupID] = ballot; + barrier(); + + uint subgroup_base = 0; + uint total = 0; + for (uint k = 0; k < gl_NumSubgroups; ++k) { + if (k == gl_SubgroupID) { + subgroup_base = total; + } + total += subgroupBallotBitCount(ballots_sh[k]); + } + barrier(); + + uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { + row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1); + } + _ne1 += total; + iter &= 15; + if (_ne1 >= (ic + 1) * BN) { + break; + } + } + barrier(); +} +#endif // MUL_MAT_ID_USE_SUBGROUPS +#endif // MUL_MAT_ID + +#ifdef COOPMAT +shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; +#endif + +#include "mul_mm_funcs.glsl" + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + + const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const uint WSUBM = WM / WMITER; + const uint WSUBN = WN / WNITER; + +#ifdef COOPMAT + const uint warp_i = gl_SubgroupID; + + const uint tiw = gl_SubgroupInvocationID; + + const uint cms_per_row = WM / TM; + const uint cms_per_col = WN / TN; + + const uint storestride = WARP / TM; + const uint store_r = tiw % TM; + const uint store_c = tiw / TM; +#else + const uint warp_i = gl_LocalInvocationID.x / WARP; + + const uint tiw = gl_LocalInvocationID.x % WARP; + + const uint tiwr = tiw % (WSUBM / TM); + const uint tiwc = tiw / (WSUBM / TM); +#endif + + const uint warp_r = warp_i % (BM / WM); + const uint warp_c = warp_i / (BM / WM); + + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); + + const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK; + const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK; + +#ifdef MUL_MAT_ID +#ifdef MUL_MAT_ID_USE_SUBGROUPS + if (bitCount(p.nei0) == 1) { + load_row_ids(expert_idx, true, ic); + } else { + load_row_ids(expert_idx, false, ic); + } +#else + _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) { + if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { + if (_ne1 >= ic * BN) { + row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1); + } + _ne1++; + } + } + } + + barrier(); +#endif + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + const uint start_k = 0; + const uint end_k = p.K; +#else + const uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + + uint pos_a = ( +#ifdef MUL_MAT_ID + expert_idx * p.batch_stride_a + +#else + batch_idx_a * p.batch_stride_a + +#endif + ir * BM * p.stride_a + start_k) / LOAD_VEC_A; +#ifdef MUL_MAT_ID + uint pos_b = 0; +#else + uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; +#endif + +#ifdef COOPMAT + coopmat cache_a; + coopmat cache_b; + coopmat sums[cms_per_row * cms_per_col]; + + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0f); + } +#else + ACC_TYPE sums[WMITER * TM * WNITER * TN]; + FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; + FLOAT_TYPE_VEC2 cache_b[TN]; + + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = ACC_TYPE(0.0f); + } +#endif + + for (uint block = start_k; block < end_k; block += BK) { + [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { + load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k); + } + [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { +#if !defined(MUL_MAT_ID) + load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k); +#else + load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block, end_k); +#endif + } + + barrier(); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + +#ifdef COOPMAT + [[unroll]] for (uint i = 0; i < BK; i += TK) { + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + // Load from shared into cache + coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + + sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]); + } + } + } +#else + [[unroll]] for (uint i = 0; i < BK / 2; i++) { + // Load from shared into cache + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint j = 0; j < TM; j++) { + cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; + } + } + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint j = 0; j < TN; j++) { + cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; + } + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx])); + } + } + } + } + } +#endif + + barrier(); + } + +#if defined(ACC_TYPE_MAX) +#ifdef COOPMAT + [[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) { + [[unroll]] for (uint i = 0; i < sums[j].length(); ++i) { + sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX); + } + } +#else + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); + } +#endif +#endif + + const uint dr = ir * BM + warp_r * WM; + const uint dc = ic * BN + warp_c * WN; + +#ifndef MUL_MAT_ID + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; +#endif + +#ifdef COOPMAT +#ifdef MUL_MAT_ID + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + const uint row_i = dc + cm_col * TN + col + store_c; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i - ic * BN]; + + if (dr + cm_row * TM + store_r < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } + } +#else + const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float + + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; + + if (is_aligned && is_in_bounds) { + // Full coopMat is within bounds and stride_d is aligned with 16B + coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); + coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); + } else if (is_in_bounds) { + // Full coopMat is within bounds, but stride_d is not aligned + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { + // Partial coopMat is within bounds + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } + } + } +#endif // MUL_MAT_ID +#else + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + + const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; + const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; + [[unroll]] for (uint cc = 0; cc < TN; cc++) { +#ifdef MUL_MAT_ID + const uint row_i = dc_warp + cc; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i - ic * BN]; +#endif // MUL_MAT_ID + [[unroll]] for (uint cr = 0; cr < TM; cr++) { +#ifdef MUL_MAT_ID + if (dr_warp + cr < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } +#else + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } +#endif // MUL_MAT_ID + } + } + } + } +#endif // COOPMAT +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp new file mode 100644 index 00000000..2e04baa4 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -0,0 +1,609 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_buffer_reference : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_vote : enable +#ifdef DATA_A_BF16 +#extension GL_EXT_bfloat16 : enable +#endif + +#include "types.glsl" +#include "utils.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#define IS_MUL_MM2 1 + +layout (constant_id = 0) const uint BLOCK_SIZE = 256; +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant + +layout (constant_id = 4) const bool enable_smaller_matrices = false; +const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN; +const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN; + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif + // N dimension for the B matrix can be >= p.N + uint padded_N; +} p; + + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#if QUANT_K > 1 +#define DECODEFUNCA , dequantFuncA + +#include "dequant_funcs_cm2.glsl" + +#else +#define DECODEFUNCA +#endif + +#if !defined(fetch_scales) +#define fetch_scales(a, b, c, d, e, f) +#endif +#if !defined(store_scales) +#define store_scales(a) +#endif + +#if defined(DATA_A_BF16) +#define MAT_TYPE bfloat16_t +#else +#define MAT_TYPE FLOAT_TYPE +#endif + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; + +shared u16vec4 row_ids[BN]; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { + B_TYPE b[]; +}; + +uint _ne1; +layout (constant_id = 5) const uint subgroup_size = 32; +shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size]; + +B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint row_i = blockCoords[0]; + + if (row_i >= _ne1) { + return B_TYPE(0.0); + } + + const u16vec4 row_idx = row_ids[row_i & (BN - 1)]; + B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; + + return ret; +} + +D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic) +{ + uint dr = ir * BM + r; + uint dc = ic * BN + c; + + if (dr < p.M && dc < _ne1) { + uint row_i = c; + const u16vec4 row_idx = row_ids[row_i]; + data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem; + } + return elem; +} + +void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + uint nei0shift = findLSB(p.nei0); + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_LocalInvocationIndex; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + + ballots_sh[gl_SubgroupID] = ballot; + barrier(); + + uint subgroup_base = 0; + uint total = 0; + for (uint k = 0; k < gl_NumSubgroups; ++k) { + if (k == gl_SubgroupID) { + subgroup_base = total; + } + total += subgroupBallotBitCount(ballots_sh[k]); + } + barrier(); + + uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { + row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0); + } + _ne1 += total; + iter &= 15; + if (_ne1 >= (ic + 1) * BN) { + break; + } + } + barrier(); +} +#endif + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint tid = gl_LocalInvocationIndex; + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + +#ifdef MUL_MAT_ID + if (bitCount(p.nei0) == 1) { + load_row_ids(expert_idx, true, ic); + } else { + load_row_ids(expert_idx, false, ic); + } + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + uint start_k = 0; + const uint end_k = p.K; +#else + uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + +#ifdef MUL_MAT_ID + uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; + uint pos_b = 0; +#else + uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; + uint pos_b = batch_idx * p.batch_stride_b; + uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; +#endif + + uint stride_a = p.stride_a / QUANT_K; + uint stride_b = p.stride_b; + + // Hint to the compiler that values are aligned (want 16B alignment). + // Quants are always block-aligned, no alignment needed. +#if ALIGNED +#if QUANT_K == 1 + stride_a &= ~7; +#endif + stride_b &= ~7; +#endif + + // Create layouts for both clamped and unclamped accesses + tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + +#if QUANT_K > 1 + tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); + tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K); +#endif + + // Use end_k rather than p.K as the dimension because that's what + // we need to bound check against when using split_k. + // Bounds check B against padded_N, but bounds check D against N. + tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k); + tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M); + tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); + tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k); + + tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); + + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); + +#if !defined(MUL_MAT_ID) + + const uint START_ALIGN_K = 256; + // For Qi_K (block size 256), unroll whole 256 element tiles. + // For legacy quants (block size 32), unroll 8x. + const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8); + const uint unroll_count = UNROLL_K / BK; + + // Detect a fast path where all loads are entirely in bounds and no clamping is required + if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 && +#if QUANT_K == 1 + (stride_a % 8) == 0 && +#endif + (stride_b % 8) == 0) { + // Hint to the compiler that values are aligned (want 16B alignment) + start_k &= ~(START_ALIGN_K-1); + stride_b &= ~7; +#if QUANT_K == 1 + stride_a &= ~7; +#endif + + tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); + tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); + + uint k_iters = (end_k - start_k) / UNROLL_K; + uint block_k = start_k; + + // fetch scale values for a tile of quants. These will be copied into shared memory. + // The fetches and stores are pipelined to hide the latency. + fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true); + + if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) { + coopmat sum = coopmat(0.0); + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose); + return; + } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) { + coopmat sum = coopmat(0.0); + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose); + return; + } else { + coopmat sum = coopmat(0.0); + + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); + return; + } + } else +#endif // !defined(MUL_MAT_ID) + { + tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); + + tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1); + + tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); + + tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); + + uint k_iters = (end_k - start_k + BK - 1) / BK; + + fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false); + store_scales(tid); + +#ifdef MUL_MAT_ID + if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) { + coopmat sum; + sum = coopmat(0.0); + + [[dont_unroll]] + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + if ((block_k % QUANT_K) == 0) { + store_scales(tid); + } + if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); + } + + if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); + + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); + return; + } + if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) { + coopmat sum; + sum = coopmat(0.0); + + [[dont_unroll]] + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + if ((block_k % QUANT_K) == 0) { + store_scales(tid); + } + if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); + } + + if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); + + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); + return; + } +#endif + coopmat sum; + sum = coopmat(0.0); + + [[dont_unroll]] + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + if ((block_k % QUANT_K) == 0) { + store_scales(tid); + } + if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); + } + + if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); +#endif + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); +#endif + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); + +#ifdef MUL_MAT_ID + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); +#else + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); +#endif + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl new file mode 100644 index 00000000..0ebfbd64 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -0,0 +1,556 @@ +void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) { +#if defined(DATA_A_F32) || defined(DATA_A_F16) +#if LOAD_VEC_A == 8 + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]); + buf_a[buf_idx ] = aa[0].xy; + buf_a[buf_idx + 1] = aa[0].zw; + buf_a[buf_idx + 2] = aa[1].xy; + buf_a[buf_idx + 3] = aa[1].zw; +#elif LOAD_VEC_A == 4 + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]); + buf_a[buf_idx ] = aa.xy; + buf_a[buf_idx + 1] = aa.zw; +#else // LOAD_VEC_BATCH_A == 2 + const uint idx = pos_a + col * p.stride_a + row * 2; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (idx_m < p.M && block + row * 2 + 1 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], + data_a[idx + 1]); + } else if (idx_m < p.M && block + row * 2 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f); + } else { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +#elif defined(DATA_A_BF16) +#if LOAD_VEC_A == 4 + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx])); + buf_a[buf_idx ] = aa.xy; + buf_a[buf_idx + 1] = aa.zw; +#else // LOAD_VEC_BATCH_A == 2 + const uint idx = pos_a + col * p.stride_a + row * 2; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (idx_m < p.M && block + row * 2 + 1 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), + TO_FLOAT_TYPE(data_a[idx + 1])); + } else if (idx_m < p.M && block + row * 2 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); + } else { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +#elif defined(DATA_A_Q4_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + + const uint ib = idx / 4; + const uint iqs = idx & 0x03; + + const float d = float(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; + const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy); + buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); +#elif defined(DATA_A_Q4_1) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + + const uint ib = idx / 4; + const uint iqs = idx & 0x03; + + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; + const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); + buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw); + buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy); + buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw); +#elif defined(DATA_A_Q5_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]); + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); +#elif defined(DATA_A_Q5_1) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint uint_qh = data_a_packed16[ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); +#elif defined(DATA_A_Q8_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; + const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); +#elif defined(DATA_A_Q2_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint scalesi = iqs / 8; // 0..15 + const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + + const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); + const uint scales = data_a[ib].scales[scalesi]; + const vec2 d = vec2(data_a[ib].d); + + const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); +#elif defined(DATA_A_Q3_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + const uint hmi = (iqs % 16) * 2; // 0,2,4..30 + const uint j = (iqs % 64) / 4; // 0..3 + const uint is = iqs / 8; // 0..15 + const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) + | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); + const float dl = float(data_a[ib].d) * float(us - 32); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)), + dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); +#elif defined(DATA_A_Q4_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + + const vec2 loadd = vec2(data_a[ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m), + fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); +#elif defined(DATA_A_Q5_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + const uint qhi = (iqs % 16) * 2; // 0,2,4..30 + + const uint8_t hm = uint8_t(1 << (iqs / 16)); + + const vec2 loadd = vec2(data_a[ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m), + fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); +#elif defined(DATA_A_Q6_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint b = (iqs % 64) / 32; // 0,1 + const uint is_b = (iqs % 16) / 8; // 0,1 + const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + const uint is = 8 * n + qhshift + is_b; // 0..15 + const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32), + dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +#elif defined(DATA_A_IQ1_S) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 32; + + const float d = float(data_a[ib].d); + const uint qh = data_a[ib].qh[ib32]; + const uint qs = data_a[ib].qs[ib8]; + const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); + + [[unroll]] for (int k = 0; k < 4; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + } +#elif defined(DATA_A_IQ1_M) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; + const uint ib16 = ib8 / 2; + + const uint16_t[4] scales = data_a[ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + const uint sc = scales[ib8 / 8]; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + + [[unroll]] for (int k = 0; k < 4; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + } +#elif defined(DATA_A_IQ2_XXS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[8 * ib32 + ib8]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[8*ib32 + 4], + data_a[ib].qs[8*ib32 + 5], + data_a[ib].qs[8*ib32 + 6], + data_a[ib].qs[8*ib32 + 7] + )); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28))); + const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xxs_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); +#elif defined(DATA_A_IQ2_XS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; // 0..3 + + const float d = float(data_a[ib].d); + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); + const uint qs = data_a[ib].qs[4 * ib32 + ib8]; + const uint sign7 = qs >> 9; + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xs_grid[qs & 511]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); +#elif defined(DATA_A_IQ2_S) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 + + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8]; + + const float d = float(data_a[ib].d); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); + const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); +#elif defined(DATA_A_IQ3_XXS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 + const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[is+0], + data_a[ib].qs[is+1], + data_a[ib].qs[is+2], + data_a[ib].qs[is+3] + )); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2)); + const uint grid = iq3xxs_grid[qs]; + const vec4 v = db * vec4(unpack8(grid)); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); +#elif defined(DATA_A_IQ3_S) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 + const uint iqh = iqs / 8; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint qh = data_a[ib].qh[iqh]; + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2))); + const uint scale = data_a[ib].scales[iqs / 16]; + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const vec4 v = db * vec4(unpack8(grid)); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); +#elif defined(DATA_A_IQ4_XS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint iq = 16 * ib32 + 2 * (idx % 8); + + const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; + const uint qshift = (idx & 8) >> 1; + u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]); + qs = (qs >> qshift) & uint8_t(0xF); + + const float d = float(data_a[ib].d); + const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); +#elif defined(DATA_A_IQ4_NL) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + + buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF], + kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]); + buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)], + kvalues_iq4nl[vui >> 12]); +#elif defined(DATA_A_MXFP4) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = (idx & 0x07) * 2; + + const float d = e8m0_to_fp32(data_a[ib].e); + const uint vui = uint(data_a[ib].qs[iqs]); + const uint vui2 = uint(data_a[ib].qs[iqs+1]); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d, + kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d, + kvalues_mxfp4[vui2 >> 4] * d); +#endif +} + +#if !defined(MUL_MAT_ID) +void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) { +#if LOAD_VEC_B == 8 + // Not supported for b_type bf16 because bf16mat2x4 does not exist + const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + buf_b[buf_idx + 0] = bb[0].xy; + buf_b[buf_idx + 1] = bb[0].zw; + buf_b[buf_idx + 2] = bb[1].xy; + buf_b[buf_idx + 3] = bb[1].zw; +#elif LOAD_VEC_B == 4 + const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; +#if defined(DATA_B_BF16) + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); +#else + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); +#endif + buf_b[buf_idx + 0] = bb.xy; + buf_b[buf_idx + 1] = bb.zw; +#else // LOAD_VEC_BATCH_B == 2 + const uint idx = pos_b + col * p.stride_b + row * 2; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (idx_n < p.N && block + row * 2 + 1 < end_k) { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); + } else if (idx_n < p.N && block + row * 2 < end_k) { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + } else { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +} +#else +void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) { +#if LOAD_VEC_B == 8 + // Not supported for b_type bf16 because bf16mat2x4 does not exist + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + buf_b[buf_idx + 0] = bb[0].xy; + buf_b[buf_idx + 1] = bb[0].zw; + buf_b[buf_idx + 2] = bb[1].xy; + buf_b[buf_idx + 3] = bb[1].zw; +#elif LOAD_VEC_B == 4 + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; +#if defined(DATA_B_BF16) + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); +#else + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); +#endif + buf_b[buf_idx + 0] = bb.xy; + buf_b[buf_idx + 1] = bb.zw; +#else // LOAD_VEC_BATCH_B == 2 + const uint row_i = ic * BN + col; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (row_i < _ne1 && block + row * 2 + 1 < end_k) { + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); + } else if (row_i < _ne1 && block + row * 2 < end_k) { + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + } else { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp new file mode 100644 index 00000000..b5d761c0 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -0,0 +1,449 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#extension GL_EXT_integer_dot_product : require + +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif + +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + +#ifdef MUL_MAT_ID +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#include "types.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif +layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +layout (constant_id = 0) const uint BLOCK_SIZE = 64; +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +// layout (constant_id = 3) const uint BK = 32; +layout (constant_id = 4) const uint WM = 32; +layout (constant_id = 5) const uint WN = 32; +layout (constant_id = 6) const uint WMITER = 2; +layout (constant_id = 7) const uint TM = 4; +layout (constant_id = 8) const uint TN = 2; +layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat +layout (constant_id = 10) const uint WARP = 32; + +#define BK 32 + +#ifdef COOPMAT +#define SHMEM_STRIDE (BK / 4 + 4) +#else +#define SHMEM_STRIDE (BK / 4 + 1) +#endif + +shared int32_t buf_a_qs[BM * SHMEM_STRIDE]; + +#ifndef COOPMAT +#if QUANT_AUXF == 1 +shared FLOAT_TYPE buf_a_dm[BM]; +#else +shared FLOAT_TYPE_VEC2 buf_a_dm[BM]; +#endif +#endif + +shared int32_t buf_b_qs[BN * SHMEM_STRIDE]; +#ifndef COOPMAT +shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; +#endif + +#define LOAD_VEC_A (4 * QUANT_R) +#define LOAD_VEC_B 16 + +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[4096]; +#endif // MUL_MAT_ID + +#define NUM_WARPS (BLOCK_SIZE / WARP) + +#ifdef COOPMAT +shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; +#endif + +#include "mul_mmq_funcs.glsl" + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + + const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const uint WSUBM = WM / WMITER; + const uint WSUBN = WN / WNITER; + +#ifdef COOPMAT + const uint warp_i = gl_SubgroupID; + + const uint tiw = gl_SubgroupInvocationID; + + const uint cms_per_row = WM / TM; + const uint cms_per_col = WN / TN; + + const uint storestride = WARP / TM; + const uint store_r = tiw % TM; + const uint store_c = tiw / TM; +#else + const uint warp_i = gl_LocalInvocationID.x / WARP; + + const uint tiw = gl_LocalInvocationID.x % WARP; + + const uint tiwr = tiw % (WSUBM / TM); + const uint tiwc = tiw / (WSUBM / TM); +#endif + + const uint warp_r = warp_i % (BM / WM); + const uint warp_c = warp_i / (BM / WM); + + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); + + const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK; + const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK; + +#ifdef MUL_MAT_ID + uint _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0; ii0++) { + if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { + row_ids[_ne1] = u16vec2(ii0, ii1); + _ne1++; + } + } + } + + barrier(); + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + const uint start_k = 0; + const uint end_k = p.K; +#else + const uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + + uint pos_a_ib = ( +#ifdef MUL_MAT_ID + expert_idx * p.batch_stride_a + +#else + batch_idx_a * p.batch_stride_a + +#endif + ir * BM * p.stride_a + start_k) / BK; +#ifdef MUL_MAT_ID + uint pos_b_ib = 0; +#else + uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK; +#endif + +#ifdef COOPMAT + coopmat cache_a; + coopmat cache_b; + coopmat cm_result; + + coopmat factors[cms_per_row * cms_per_col]; + + coopmat sums[cms_per_row * cms_per_col]; + + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0f); + } +#else + int32_t cache_a_qs[WMITER * TM * BK / 4]; + + int32_t cache_b_qs[TN * BK / 4]; + + ACC_TYPE sums[WMITER * TM * WNITER * TN]; + + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = ACC_TYPE(0.0f); + } +#endif + +#if QUANT_AUXF == 1 + FLOAT_TYPE cache_a_dm[WMITER * TM]; +#else + FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM]; +#endif + + FLOAT_TYPE_VEC2 cache_b_ds[TN]; + + for (uint block = start_k; block < end_k; block += BK) { + [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { + const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK; + const uint iqs = loadr_a; + const uint buf_ib = loadc_a + l; + + if (iqs == 0) { +#if QUANT_AUXF == 1 + buf_a_dm[buf_ib] = get_d(ib); +#else + buf_a_dm[buf_ib] = get_dm(ib); +#endif + } +#if QUANT_R == 1 + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs); +#else + const i32vec2 vals = repack(ib, iqs); + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x; + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y; +#endif + } + [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; + const uint ib = idx / 8; + const uint iqs = idx & 0x7; +#else + const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; + const uint ib_outer = ib / 4; + const uint ib_inner = ib % 4; + + const uint iqs = loadr_b; +#endif + + const uint buf_ib = loadc_b + l; + + if (iqs == 0) { + buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + } + const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w; + } + + barrier(); + + pos_a_ib += 1; + pos_b_ib += 1; + +#ifdef COOPMAT + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + const uint ib_a = warp_r * WM + cm_row * TM; + // Load from shared into cache + coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + // TODO: only cache values that are actually needed + [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) { + cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx]; + } + + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const uint ib_b = warp_c * WN + cm_col * TN; + coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + + // TODO: only cache values that are actually needed + [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) { + cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx]; + } + + cm_result = coopmat(0); + cm_result = coopMatMulAdd(cache_a, cache_b, cm_result); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col])); + } + + coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + sums[cm_col * cms_per_row + cm_row] += factors * coopmat(cm_result); + } + } +#else + // Load from shared into cache + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; + cache_a_dm[wsir * TM + cr] = buf_a_dm[ib]; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k]; + } + } + } + + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; + cache_b_ds[cc] = buf_b_ds[ib]; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k]; + } + } + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint cache_a_idx = wsir * TM + cr; + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + int32_t q_sum = 0; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k], + cache_b_qs[cc * (BK / 4) + idx_k]); + } + + sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1); + } + } + } + } +#endif + + barrier(); + } + + const uint dr = ir * BM + warp_r * WM; + const uint dc = ic * BN + warp_c * WN; + +#ifndef MUL_MAT_ID + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; +#endif + +#ifdef COOPMAT +#ifdef MUL_MAT_ID + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < BN; col += storestride) { + const uint row_i = dc + cm_col * TN + col + store_c; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; + + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } +#else + const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float + + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; + + if (is_aligned && is_in_bounds) { + // Full coopMat is within bounds and stride_d is aligned with 16B + coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); + coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); + } else if (is_in_bounds) { + // Full coopMat is within bounds, but stride_d is not aligned + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { + // Partial coopMat is within bounds + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } + } + } +#endif // MUL_MAT_ID +#else + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + + const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; + const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; + [[unroll]] for (uint cc = 0; cc < TN; cc++) { +#ifdef MUL_MAT_ID + const uint row_i = dc_warp + cc; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; +#endif // MUL_MAT_ID + [[unroll]] for (uint cr = 0; cr < TM; cr++) { +#ifdef MUL_MAT_ID + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); +#else + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } +#endif // MUL_MAT_ID + } + } + } + } +#endif // COOPMAT +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl new file mode 100644 index 00000000..fe71eb13 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -0,0 +1,105 @@ +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#include "types.glsl" + +// Each iqs value maps to a 32-bit integer + +#if defined(DATA_A_Q4_0) +i32vec2 repack(uint ib, uint iqs) { + // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4 + const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); +} +#endif + +#if defined(DATA_A_Q4_1) +i32vec2 repack(uint ib, uint iqs) { + // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +} +#endif + +#if defined(DATA_A_Q5_0) +i32vec2 repack(uint ib, uint iqs) { + // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4 + const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); +} +#endif + +#if defined(DATA_A_Q5_1) +i32vec2 repack(uint ib, uint iqs) { + // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +} +#endif + +#if defined(DATA_A_Q8_0) +int32_t repack(uint ib, uint iqs) { + // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4 + return pack32(i16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1])); +} + +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(float(q_sum) * da * dsb.x); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +FLOAT_TYPE get_d(uint ib) { + return FLOAT_TYPE(data_a[ib].d); +} +#endif + +#if defined(DATA_A_MXFP4) +FLOAT_TYPE get_d(uint ib) { + return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp new file mode 100644 index 00000000..1e8f694a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp @@ -0,0 +1,111 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_nonuniform_qualifier : enable +#extension GL_EXT_control_flow_attributes : require +#if ADD_RMS +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + +#include "rte.glsl" +#include "types.glsl" +#include "utils.glsl" + +layout (push_constant) uniform parameter2 +{ + // shape for dst + uint ne20; uint ne21; uint ne22; uint ne23; + + // strides for srcs+dst + uint nb[12][4]; + + uint rms_partials; +} p; + +// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498 +// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[]; +// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[]; +layout (binding = 0) buffer A {A_TYPE data_a[];} a[]; +layout (binding = 0) buffer D {D_TYPE data_d[];} d[]; + +layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[]; + +layout(constant_id = 0) const uint num_srcs = 2; + +uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0]; +} + +uint dst_idx(uint i00, uint i01, uint i02, uint i03) { + uint nb20 = p.nb[num_srcs][0]; + uint nb21 = p.nb[num_srcs][1]; + uint nb22 = p.nb[num_srcs][2]; + uint nb23 = p.nb[num_srcs][3]; + return i03*nb23 + i02*nb22 + i01*nb21 + i00*nb20; +} + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +#if ADD_RMS +// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant +shared FLOAT_TYPE sumsh[num_threads]; +#endif + +void main() { + uint idx = get_idx(); + uint orig_idx = idx; + + uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23; + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + FLOAT_TYPE sum_sq = 0; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03, p.ne20, p.ne21, p.ne22, p.ne23); + + FLOAT_TYPE sum = FLOAT_TYPE(0); + [[unroll]] for (uint s = 0; s < num_srcs; ++s) { + sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]); + } + sum_sq += sum*sum; + d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum); + + idx += num_threads; + } + +#if ADD_RMS + if (p.rms_partials != 0) { + // reduce the sum within each subgroup, then across subgroups + const uint NumSubgroups = num_threads / gl_SubgroupSize; + sum_sq = subgroupAdd(sum_sq); + if (gl_SubgroupInvocationID == 0) { + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) { + if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) { + sum_sq += sumsh[gl_SubgroupID + s]; + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + } + + if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) { + partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq; + } + } +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp new file mode 100644 index 00000000..cc3ea0b7 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp @@ -0,0 +1,44 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared vec2 sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = vec2(0.0f, 0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const float xi = float(data_a[row*p.KX + col]); + sum[tid].x += xi; + sum[tid].y += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const float mean = sum[0].x / p.KX; + const float var = sum[0].y / p.KX - mean * mean; + const float inv_std = inversesqrt(var + p.param1); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp new file mode 100644 index 00000000..1f05f922 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) buffer X {A_TYPE x[];}; +layout (binding = 1) readonly buffer G {A_TYPE grad[];}; +layout (binding = 2) buffer GM {A_TYPE gradm[];}; +layout (binding = 3) buffer GV {A_TYPE gradv[];}; +layout (binding = 4) readonly buffer P {float params[7];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float alpha = params[0]; + const float beta1 = params[1]; + const float beta2 = params[2]; + const float eps = params[3]; + const float wd = params[4]; + const float beta1h = params[5]; + const float beta2h = params[6]; + + const float gi = grad[i]; + const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1); + const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2); + + gradm[i] = gmi; + gradv[i] = gvi; + + const float mh = gmi*beta1h; + const float vh = sqrt(gvi*beta2h) + eps; + + x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp new file mode 100644 index 00000000..1251f9cc --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) buffer X {A_TYPE data_x[];}; +layout (binding = 1) readonly buffer G {A_TYPE data_grad[];}; +layout (binding = 2) readonly buffer P {float data_params[2];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float alpha = data_params[0]; + const float keep = 1.f - alpha * data_params[1]; + + data_x[i] = data_x[i] * keep - alpha * data_grad[i]; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp new file mode 100644 index 00000000..f3c81768 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint misalign_offsets; + + uint lp0; uint rp0; + uint lp1; uint rp1; + uint lp2; uint rp2; + uint lp3; uint rp3; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (idx >= p.ne) { + return; + } + + const uint i3 = idx / (p.ne12*p.ne11*p.ne10); + const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10; + const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10); + const uint i2_offset = i2*p.ne11*p.ne10; + const uint i1 = (idx - i3_offset - i2_offset) / p.ne10; + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; + + const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00; + const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; + + const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 && + i1 >= p.lp1 && i1 < p.ne11 - p.rp1 && + i2 >= p.lp2 && i2 < p.ne12 - p.rp2 && + i3 >= p.lp3 && i3 < p.ne13 - p.rp3; + + data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp new file mode 100644 index 00000000..d9d7166e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp @@ -0,0 +1,74 @@ +#version 450 + +#include "types.glsl" + +#extension GL_EXT_shader_16bit_storage : require + +layout(push_constant) uniform parameter { + uint IW; uint IH; + uint OW; uint OH; + uint OC; + uint pelements; + uint op; + int k0; int k1; + int s0; int s1; + int p0; int p1; +} p; + +#define BLOCK_SIZE 512 +#define FLT_MAX 3.402823466e+38F +#define OP_POOL_MAX 0u +#define OP_POOL_AVG 1u + +layout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout(binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint idx = gl_GlobalInvocationID.x; + if (idx >= p.pelements) { + return; + } + + const uint O_HW = p.OW * p.OH; + + const uint nc = idx / O_HW; + const uint cur_oh = (idx % O_HW) / p.OW; + const uint cur_ow = (idx % O_HW) % p.OW; + + const int start_h = int(cur_oh) * p.s0 - p.p0; + const uint bh = max(start_h, 0); + const uint eh = min(start_h + p.k0, p.IH); + + const int start_w = int(cur_ow) * p.s1 - p.p1; + const uint bw = max(start_w, 0); + const uint ew = min(start_w + p.k1, p.IW); + + const float scale = 1.0 / float(p.k0 * p.k1); + float res; + + if (p.op == OP_POOL_AVG) { + res = 0.0; + } else if (p.op == OP_POOL_MAX) { + res = -FLT_MAX; + } else { + return; + } + + #pragma unroll + for (uint i = bh; i < eh; i++) { + #pragma unroll + for (uint j = bw; j < ew; j++) { + const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]); + + if (p.op == OP_POOL_AVG) { + res += cur * scale; + } else if (p.op == OP_POOL_MAX) { + res = max(res, cur); + } + } + } + + data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp new file mode 100644 index 00000000..0f3c6ca8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp @@ -0,0 +1,127 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_16bit_storage : require + +#ifdef USE_SUBGROUPS +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_clustered : require + +#define INVOCATION_ID gl_SubgroupInvocationID.x +#else +#define INVOCATION_ID gl_LocalInvocationID.x +#endif + +layout (push_constant) uniform parameter +{ + uint ne; +} p; + +#include "types.glsl" + +layout(constant_id = 0) const uint GROUP_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {vec4 data_a[];}; +#ifndef QBLOCK_X4 +layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];}; +#else +layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];}; +#endif + +#ifndef USE_SUBGROUPS +shared float shmem[GROUP_SIZE]; +#endif + +void quantize() { + const uint wgid = gl_WorkGroupID.x; + const uint tid = INVOCATION_ID; + + // Each thread handles a vec4, so 8 threads handle a block + const uint blocks_per_group = GROUP_SIZE / 8; + + const uint block_in_wg = tid / 8; + + const uint ib = wgid * blocks_per_group + block_in_wg; + const uint iqs = tid % 8; + +#ifndef QBLOCK_X4 + if (ib >= gl_NumWorkGroups.x * blocks_per_group) { + return; + } +#else + const uint ibx4_outer = ib / 4; + const uint ibx4_inner = ib % 4; + + const uint required_x4_blocks = (p.ne + 127) / 128; + if (ibx4_outer >= required_x4_blocks) { + return; + } +#endif + + const uint a_idx = ib * 8 + iqs; + + vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f); + const vec4 abs_vals = abs(vals); + + // Find absolute max for each block + const float thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); +#ifndef USE_SUBGROUPS + shmem[tid] = thread_max; + barrier(); + [[unroll]] for (uint s = 4; s > 0; s >>= 1) { + if (iqs < s) { + shmem[tid] = max(shmem[tid], shmem[tid + s]); + } + barrier(); + } + + const float amax = shmem[block_in_wg * 8]; +#else + const float amax = subgroupClusteredMax(thread_max, 8); +#endif + + const float d = amax / 127.0; + const float d_inv = d != 0.0 ? 1.0 / d : 0.0; + vals = round(vals * d_inv); + +#ifndef QBLOCK_X4 + data_b[ib].qs[iqs] = pack32(i8vec4(round(vals))); +#else + data_b[ibx4_outer].qs[ibx4_inner * 8 + iqs] = pack32(i8vec4(round(vals))); +#endif + +#ifndef USE_SUBGROUPS + barrier(); +#endif + + // Calculate the sum for each block + const float thread_sum = vals.x + vals.y + vals.z + vals.w; +#ifndef USE_SUBGROUPS + shmem[tid] = thread_sum; + barrier(); + [[unroll]] for (uint s = 4; s > 0; s >>= 1) { + if (iqs < s) { + shmem[tid] += shmem[tid + s]; + } + barrier(); + } +#else + const float sum = subgroupClusteredAdd(thread_sum, 8); +#endif + if (iqs == 0) { +#ifndef USE_SUBGROUPS + const float sum = shmem[tid]; +#endif + +#ifndef QBLOCK_X4 + data_b[ib].ds = f16vec2(vec2(d, sum * d)); +#else + data_b[ibx4_outer].ds[ibx4_inner] = f16vec2(vec2(d, sum * d)); +#endif + } +} + +void main() { + quantize(); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp new file mode 100644 index 00000000..86be2669 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp @@ -0,0 +1,9 @@ +#version 450 + +#include "glu_head.glsl" + +float op(float a, float b) { + return max(a, 0.0f) * b; +} + +#include "glu_main.glsl" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp new file mode 100644 index 00000000..5725cef2 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp @@ -0,0 +1,21 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + data_d[i] = D_TYPE(max(float(data_a[i]), 0)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp new file mode 100644 index 00000000..8f4b9a86 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp @@ -0,0 +1,26 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +uint src0_idx_mod(uint idx) { + const uint i13 = idx / (p.ne12*p.ne11*p.ne10); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = (idx - i13_offset - i12_offset) / p.ne10; + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00; +} + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx_mod(idx)]); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp new file mode 100644 index 00000000..87df7829 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp @@ -0,0 +1,37 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + // Destination multi-index (inlined dst_idx) + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; + + // Accumulate from sources + A_TYPE acc = A_TYPE(0); + for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) { + for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) { + for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) { + for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) { + acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00]; + } + } + } + } + + data_d[get_doffset() + d_idx] = D_TYPE(acc); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp new file mode 100644 index 00000000..d5b211ff --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -0,0 +1,105 @@ +#version 450 + +#include "generic_binary_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout (constant_id = 1) const bool do_multiply = false; + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; + +void rms_norm(uint num_iters) { + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = gl_WorkGroupID.x; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + const uint tid = gl_LocalInvocationID.x; + + const uint stride_row = p.nb01; + const uint stride_channel = p.nb02; + const uint stride_sample = p.nb03; + + uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + FLOAT_TYPE xi = FLOAT_TYPE(0); + if (col < ncols) { + xi = FLOAT_TYPE(data_a[a_offset + col]); + } + sum += xi * xi; + } + + sumsh[tid] = sum; + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum += sumsh[tid + s]; + sumsh[tid] = sum; + } + barrier(); + } + sum = sumsh[0]; + + const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols); + const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); + + if (do_multiply) { + if (ncols > p.ne10) { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); + } + } else { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } + } + } else { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } + } +} + +void main() { + // instantiate the rms_norm function for several different + // dimensions, to allow loop unrolling + uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE; + if (num_blocks > 32) { + rms_norm(num_blocks); + } else if (num_blocks > 16) { + rms_norm(32); + } else if (num_blocks > 8) { + rms_norm(16); + } else if (num_blocks > 4) { + rms_norm(8); + } else if (num_blocks == 4) { + rms_norm(4); + } else if (num_blocks == 3) { + rms_norm(3); + } else if (num_blocks == 2) { + rms_norm(2); + } else if (num_blocks == 1) { + rms_norm(1); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp new file mode 100644 index 00000000..87707fc1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp @@ -0,0 +1,55 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer G {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer X {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum_xx[BLOCK_SIZE]; +shared FLOAT_TYPE sum_xg[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + // Compute derivative of x[i]/norm(x) = g[i]/norm(x) - x[i] dot(x,g)/KX / norm(x)^1.5 + + // partial sums for thread in warp + sum_xx[tid] = FLOAT_TYPE(0.0f); + sum_xg[tid] = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE gi = FLOAT_TYPE(data_a[row*p.KX + col]); + const FLOAT_TYPE xi = FLOAT_TYPE(data_b[row*p.KX + col]); + sum_xx[tid] += xi * xi; + sum_xg[tid] += xi * gi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum_xx[tid] += sum_xx[tid + s]; + sum_xg[tid] += sum_xg[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE eps = FLOAT_TYPE(p.param1); + const FLOAT_TYPE mean = sum_xx[0] / FLOAT_TYPE(p.KX); + const FLOAT_TYPE scale_g = inversesqrt(mean + eps); + const FLOAT_TYPE scale_x = -scale_g * sum_xg[0] / (sum_xx[0] + FLOAT_TYPE(p.KX) * eps); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE( + scale_g * FLOAT_TYPE(data_a[row*p.KX + col]) + + scale_x * FLOAT_TYPE(data_b[row*p.KX + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp new file mode 100644 index 00000000..4618b2c7 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp @@ -0,0 +1,65 @@ +#version 450 + +#include "generic_binary_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +#define BLOCK_SIZE 128 + +layout (constant_id = 1) const bool do_multiply = false; + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];}; + +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; + +void main() { + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = 0; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + // The work is split across multiple workgroups in the x dimension. Each invocation + // processes one element + const uint tid = gl_GlobalInvocationID.x; + + const uint stride_row = p.nb01; + const uint stride_channel = p.nb02; + const uint stride_sample = p.nb03; + + uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + uint32_t num_partials = p.param3; + for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) { + sum += partial_sums[i]; + } + sum = subgroupAdd(sum); + + uint col = tid; + if (col >= ncols) { + return; + } + + const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols); + const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); + + if (do_multiply) { + if (ncols > p.ne10) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); + } else { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } + } else { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp new file mode 100644 index 00000000..68fbd0c7 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp @@ -0,0 +1,46 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +uint wrap_idx(int i, uint ne) { + if (i < 0) { + return i + ne; + } else if (i >= ne) { + return i - ne; + } + return i; +} + +void main() { + const uint idx = get_idx(); + if (idx >= p.ne) { + return; + } + + const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10; + const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L); + const uint i2_offset = i2*p.ne11*p.ne10; + const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L); + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; + + const uint p1 = floatBitsToUint(p.param1); + const uint p2 = floatBitsToUint(p.param2); + const int s0 = int(p1 >> 16) - 0x8000; + const int s1 = int(p1 & 0xFFFF) - 0x8000; + const int s2 = int(p2 >> 16) - 0x8000; + const int s3 = int(p2 & 0xFFFF) - 0x8000; + + const uint i00 = wrap_idx(int(i0) - s0, p.ne10); + const uint i01 = wrap_idx(int(i1) - s1, p.ne11); + const uint i02 = wrap_idx(int(i2) - s2, p.ne12); + const uint i03 = wrap_idx(int(i3) - s3, p.ne13); + + const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; + const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10; + + data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl new file mode 100644 index 00000000..50fc1f1e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl @@ -0,0 +1,55 @@ +#include "types.glsl" + +#extension GL_EXT_shader_16bit_storage : require + +#include "rte.glsl" + +layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {int data_pos[];}; +layout (binding = 2) readonly buffer Z {float data_ff[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint n_dims; + float freq_scale; + uint p_delta_rows; + float freq_base; + float ext_factor; + float attn_factor; + float corr_dims[2]; + float theta_scale; + uint has_ff; + uint ne02; + uint s1; + uint s2; + int sections[4]; + uint is_back; +} p; + +float rope_yarn_ramp(const float low, const float high, const uint i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) { + float mscale = p.attn_factor; + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = p.freq_scale * theta_extrap; + float theta = theta_interp; + if (p.ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); + } + // Backprogagation uses inverted rotation + if (p.is_back != 0) { + theta = -theta; + } + cos_theta = cos(theta) * mscale; + sin_theta = sin(theta) * mscale; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp new file mode 100644 index 00000000..111286b4 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -0,0 +1,58 @@ +#version 450 + +#include "rope_head.glsl" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + if (i0 >= p.n_dims) { + data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; + data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; + + return; + } + + const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= p.sections[0] && sector < sec_w) { + theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims/2]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp new file mode 100644 index 00000000..06e095be --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "rope_head.glsl" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + if (i0 >= p.n_dims) { + data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; + data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; + + return; + } + + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims/2]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp new file mode 100644 index 00000000..6ba95754 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "rope_head.glsl" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; + + if (i0 >= p.n_dims) { + data_d[idst + 0] = data_a[ix + 0]; + data_d[idst + 1] = data_a[ix + 1]; + + return; + } + + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + 1]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp new file mode 100644 index 00000000..d37d1c10 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -0,0 +1,47 @@ +#version 450 + +#include "rope_head.glsl" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + const int sect_dims = p.sections[0] + p.sections[1]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + const uint p0 = sector; + theta_base = data_pos[channel_x]*pow(p.theta_scale, p0); + } + else if (sector >= p.sections[0] && sector < sec_w) { + const uint p0 = sector - p.sections[0]; + theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0); + } + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl new file mode 100644 index 00000000..ad51c1e8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl @@ -0,0 +1,5 @@ + +#if RTE16 +#extension GL_EXT_spirv_intrinsics : enable +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif // RTE16 diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp new file mode 100644 index 00000000..35ec726a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp @@ -0,0 +1,24 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +const uint num_threads = 128; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 4; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2)); + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp new file mode 100644 index 00000000..32298d43 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + data_d[i] = D_TYPE(1. / (1 + exp(-1. * float(data_a[i])))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp new file mode 100644 index 00000000..7d1cc6f4 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float xi = float(data_a[i]); + data_d[i] = D_TYPE(xi / (1.0f + exp(-xi))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp new file mode 100644 index 00000000..e5d949ff --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp @@ -0,0 +1,26 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer G {A_TYPE data_g[];}; +layout (binding = 1) readonly buffer X {B_TYPE data_x[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + // Compute derivative of SiLU(x): 1/(1+exp(-x)) - x*exp(-x)/(1+exp(-x))^2 + + const float xi = float(data_x[i]); + const float s = 1.0f / (1.0f + exp(-xi)); + data_d[i] = D_TYPE(data_g[i] * (s + xi * s * (1 - s))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp new file mode 100644 index 00000000..61f17b2f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp new file mode 100644 index 00000000..dca0d896 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -0,0 +1,195 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + uint ne00; + uint ne01; + uint ne02; + uint ne12; + uint ne13; + uint nb11; + uint nb12; + uint nb13; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + uint nrows_x; + uint has_sinks; +} p; + +#include "types.glsl" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) readonly buffer Z {float data_c[];}; +layout (binding = 3) buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE vals[BLOCK_SIZE]; + +// num_iters is the number of BLOCK_SIZE loop iterations we need to iterate +// over all the columns. The main function tries to pass a constant here, +// as if it were a template function, to allow unrolling. +void soft_max(uint num_iters) { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + float slope = 1.0f; + + // ALiBi + if (p.max_bias > 0.0f) { + const uint h = (rowx / p.ne01) % p.ne02; // head index + + const float base = h < p.n_head_log2 ? p.m0 : p.m1; + const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; + + slope = pow(base, exp); + } + + // Find max + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + + // Cache values while we compute the max, so we don't need to read them + // again when we're ready to compute exp(x-max). + const uint DATA_CACHE_SIZE = 16; + FLOAT_TYPE data_cache[DATA_CACHE_SIZE]; + + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + FLOAT_TYPE a = FLOAT_TYPE(0); + if (col < p.KX) { + a = data_a[rowx * p.KX + col]; + } + + FLOAT_TYPE b = FLOAT_TYPE(0); + if (p.KY > 0 && col < p.KX) { + b = data_b[rowy_start + col]; + } + + FLOAT_TYPE v = a * p.scale + slope * b; + + if (col < p.KX) { + max_val = max(max_val, v); + } + + if (idx < DATA_CACHE_SIZE) { + data_cache[idx] = v; + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(vals[tid], vals[tid + s]); + } + barrier(); + } + + max_val = vals[0]; + barrier(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + // Compute sum{exp(x - max)} + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + + // compute exp(a*scale+b*slope), add it to sum, and cache the new value + // in data_cache if possible. + const uint i = rowx * p.KX + col; + FLOAT_TYPE val; + if (idx < DATA_CACHE_SIZE) { + val = exp(data_cache[idx] - max_val); + } else { + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val); + } + sum += val; + if (idx < DATA_CACHE_SIZE) { + data_cache[idx] = val; + } else { + data_d[i] = D_TYPE(val); + } + } + + // reduce across the workgroup + vals[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + sum = vals[0]; + + if (p.has_sinks != 0) { + sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val)); + } + + FLOAT_TYPE rcpdivisor = 1.0/sum; + + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + continue; + } + + if (idx < DATA_CACHE_SIZE) { + data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor); + } else { + data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); + } + } +} + +void main() { + // instantiate the soft_max function for several different + // dimensions, to allow loop unrolling + uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE; + if (num_blocks > 32) { + soft_max(num_blocks); + } else if (num_blocks > 16) { + soft_max(32); + } else if (num_blocks > 8) { + soft_max(16); + } else if (num_blocks > 4) { + soft_max(8); + } else if (num_blocks == 4) { + soft_max(4); + } else if (num_blocks == 3) { + soft_max(3); + } else if (num_blocks == 2) { + soft_max(2); + } else if (num_blocks == 1) { + soft_max(1); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp new file mode 100644 index 00000000..d873332e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp @@ -0,0 +1,54 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "generic_head.glsl" +#include "types.glsl" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +// In this shader Y = softmax(X) and X is not provided as input. + +layout (binding = 0) readonly buffer G {A_TYPE data_g[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_y[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum_yg[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + if (row >= p.KY) { + return; + } + + FLOAT_TYPE scale = p.param1; + + // partial sums for thread in warp + sum_yg[tid] = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE gi = FLOAT_TYPE(data_g[row*p.KX + col]); + const FLOAT_TYPE yi = FLOAT_TYPE(data_y[row*p.KX + col]); + sum_yg[tid] += yi * gi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum_yg[tid] += sum_yg[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE dot_yg = sum_yg[0]; + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale + * (FLOAT_TYPE(data_g[row*p.KX + col]) - dot_yg) + * FLOAT_TYPE(data_y[row*p.KX + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp new file mode 100644 index 00000000..70daad6c --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp new file mode 100644 index 00000000..4eb56afc --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp new file mode 100644 index 00000000..bc924b52 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp @@ -0,0 +1,29 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.glsl" +#include "generic_binary_head.glsl" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp new file mode 100644 index 00000000..bc22aa7b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp @@ -0,0 +1,70 @@ +#version 450 + +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +layout (push_constant) uniform parameter +{ + uint n_cols; + uint ne01, ne02; + uint nb01, nb02, nb03; + uint nb11, nb12, nb13; + float weight; + uint misalign_offsets; + uint ne0_12mp, ne0_12L; + uint ne0_1mp, ne0_1L; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + const float weight = p.weight; + + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; + + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; + + tmp[col] = FLOAT_TYPE(0.0); + + for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) { + tmp[col] += FLOAT_TYPE(data_a[src_idx + i]); + } + + barrier(); + [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { + if (col < s) { + tmp[col] += tmp[col + s]; + } + barrier(); + } + + if (col == 0) { + data_d[dst_idx] = D_TYPE(tmp[0] * weight); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp new file mode 100644 index 00000000..4fee433a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp @@ -0,0 +1,9 @@ +#version 450 + +#include "glu_head.glsl" + +float op(float a, float b) { + return a / (1.0f + exp(-a)) * b; +} + +#include "glu_main.glsl" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp new file mode 100644 index 00000000..bda9dea2 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp @@ -0,0 +1,14 @@ +#version 450 + +#include "glu_head.glsl" + +float op(float a, float b) { + float xi = min(a, p.limit); + float gi = max(min(b, p.limit), -p.limit); + + float out_glu = xi / (1.0f + exp(-xi * p.alpha)); + out_glu = out_glu * (1.0f + gi); + return out_glu; +} + +#include "glu_main.glsl" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp new file mode 100644 index 00000000..7b5eb413 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + data_d[i] = D_TYPE(1. - 2. / (exp(2.*float(data_a[i])) + 1.)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp new file mode 100644 index 00000000..16055654 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp @@ -0,0 +1,42 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint nb1; + uint dim; + uint max_period; +} p; + +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 256 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_WorkGroupID.y; + const uint j = gl_GlobalInvocationID.x; + const uint d_offset = i * p.nb1; + + const uint half_dim = p.dim / 2; + + if (p.dim % 2 != 0 && j == half_dim) { + data_d[d_offset + 2 * half_dim] = 0.f; + } + + if (j >= half_dim) { + return; + } + + const float timestep = float(data_a[i]); + const float freq = float(exp(-log(p.max_period) * j / half_dim)); + const float arg = timestep * freq; + data_d[d_offset + j] = D_TYPE(cos(arg)); + data_d[d_offset + j + half_dim] = D_TYPE(sin(arg)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl new file mode 100644 index 00000000..2fa54ce5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -0,0 +1,1465 @@ +#if !defined(GGML_TYPES_COMP) +#define GGML_TYPES_COMP + +#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_16bit_storage : require + +#if defined(DATA_A_F32) +#define QUANT_K 1 +#define QUANT_R 1 + +#if LOAD_VEC_A == 4 +#define A_TYPE vec4 +#elif LOAD_VEC_A == 8 +#define A_TYPE mat2x4 +#else +#define A_TYPE float +#endif +#endif + +#if defined(DATA_A_F16) +#define QUANT_K 1 +#define QUANT_R 1 + +#if LOAD_VEC_A == 4 +#define A_TYPE f16vec4 +#elif LOAD_VEC_A == 8 +#define A_TYPE f16mat2x4 +#else +#define A_TYPE float16_t +#endif +#endif + +#if defined(DATA_A_BF16) +#define QUANT_K 1 +#define QUANT_R 1 + +#if LOAD_VEC_A == 4 +#define A_TYPE u16vec4 +#elif LOAD_VEC_A == 8 +#error unsupported +#else +#define A_TYPE uint16_t +#endif +#endif + +#define QUANT_K_Q4_0 32 +#define QUANT_R_Q4_0 2 + +struct block_q4_0 +{ + float16_t d; + uint8_t qs[16]; +}; +struct block_q4_0_packed16 +{ + float16_t d; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q4_0) +#define QUANT_K QUANT_K_Q4_0 +#define QUANT_R QUANT_R_Q4_0 +#define QUANT_AUXF 1 +#define A_TYPE block_q4_0 +#define A_TYPE_PACKED16 block_q4_0_packed16 +#endif + +#define QUANT_K_Q4_1 32 +#define QUANT_R_Q4_1 2 + +struct block_q4_1 +{ + float16_t d; + float16_t m; + uint8_t qs[16]; +}; + +struct block_q4_1_packed16 +{ + float16_t d; + float16_t m; + uint16_t qs[16/2]; +}; + +struct block_q4_1_packed32 +{ + f16vec2 dm; + uint32_t qs[16/4]; +}; + +#if defined(DATA_A_Q4_1) +#define QUANT_K QUANT_K_Q4_1 +#define QUANT_R QUANT_R_Q4_1 +#define QUANT_AUXF 2 +#define A_TYPE block_q4_1 +#define A_TYPE_PACKED16 block_q4_1_packed16 +#define A_TYPE_PACKED32 block_q4_1_packed32 +#endif + +#define QUANT_K_Q5_0 32 +#define QUANT_R_Q5_0 2 + +struct block_q5_0 +{ + float16_t d; + uint16_t qh[2]; + uint8_t qs[16]; +}; + +struct block_q5_0_packed16 +{ + float16_t d; + uint16_t qh[2]; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q5_0) +#define QUANT_K QUANT_K_Q5_0 +#define QUANT_R QUANT_R_Q5_0 +#define QUANT_AUXF 1 +#define A_TYPE block_q5_0 +#define A_TYPE_PACKED16 block_q5_0_packed16 +#endif + +#define QUANT_K_Q5_1 32 +#define QUANT_R_Q5_1 2 + +struct block_q5_1 +{ + float16_t d; + float16_t m; + uint qh; + uint8_t qs[16]; +}; + +struct block_q5_1_packed16 +{ + float16_t d; + float16_t m; + uint qh; + uint16_t qs[16/2]; +}; + +struct block_q5_1_packed32 +{ + f16vec2 dm; + uint qh; + uint32_t qs[16/4]; +}; + +#if defined(DATA_A_Q5_1) +#define QUANT_K QUANT_K_Q5_1 +#define QUANT_R QUANT_R_Q5_1 +#define QUANT_AUXF 2 +#define A_TYPE block_q5_1 +#define A_TYPE_PACKED16 block_q5_1_packed16 +#define A_TYPE_PACKED32 block_q5_1_packed32 +#endif + +#define QUANT_K_Q8_0 32 +#define QUANT_R_Q8_0 1 + +struct block_q8_0 +{ + float16_t d; + int8_t qs[32]; +}; +struct block_q8_0_packed16 +{ + float16_t d; + int16_t qs[32/2]; +}; +struct block_q8_0_packed32 +{ + float16_t d; + int32_t qs[32/4]; +}; + +#if defined(DATA_A_Q8_0) +#define QUANT_K QUANT_K_Q8_0 +#define QUANT_R QUANT_R_Q8_0 +#define QUANT_AUXF 1 +#define A_TYPE block_q8_0 +#define A_TYPE_PACKED16 block_q8_0_packed16 +#define A_TYPE_PACKED32 block_q8_0_packed32 +#endif + +#define QUANT_K_Q8_1 32 +#define QUANT_R_Q8_1 1 + +struct block_q8_1 +{ + f16vec2 ds; + int8_t qs[32]; +}; +struct block_q8_1_packed16 +{ + f16vec2 ds; + int16_t qs[16]; +}; +struct block_q8_1_packed32 +{ + f16vec2 ds; + int32_t qs[8]; +}; + +// 4 blocks in one to allow 16-byte/128-bit alignment and loads +struct block_q8_1_x4 +{ + f16vec2 ds[4]; + int32_t qs[32]; +}; +struct block_q8_1_x4_packed128 +{ + f16vec2 ds[4]; + ivec4 qs[8]; +}; + +// K-quants +#define QUANT_K_Q2_K 256 + +struct block_q2_K +{ + uint8_t scales[QUANT_K_Q2_K/16]; + uint8_t qs[QUANT_K_Q2_K/4]; + f16vec2 d; +}; + +struct block_q2_K_packed16 +{ + uint16_t scales[QUANT_K_Q2_K/16/2]; + uint16_t qs[QUANT_K_Q2_K/4/2]; + f16vec2 d; +}; + +struct block_q2_K_packed32 +{ + uint32_t scales[QUANT_K_Q2_K/16/4]; + uint32_t qs[QUANT_K_Q2_K/4/4]; + f16vec2 d; +}; + +#if defined(DATA_A_Q2_K) +#define QUANT_K QUANT_K_Q2_K +#define QUANT_R 1 +#define A_TYPE block_q2_K +#define A_TYPE_PACKED16 block_q2_K_packed16 +#define A_TYPE_PACKED32 block_q2_K_packed32 +#endif + +#define QUANT_K_Q3_K 256 + +struct block_q3_K +{ + uint8_t hmask[QUANT_K_Q3_K/8]; + uint8_t qs[QUANT_K_Q3_K/4]; + uint8_t scales[12]; + float16_t d; +}; + +struct block_q3_K_packed16 +{ + uint16_t hmask[QUANT_K_Q3_K/8/2]; + uint16_t qs[QUANT_K_Q3_K/4/2]; + uint16_t scales[12/2]; + float16_t d; +}; + +#if defined(DATA_A_Q3_K) +#define QUANT_K QUANT_K_Q3_K +#define QUANT_R 1 +#define A_TYPE block_q3_K +#define A_TYPE_PACKED16 block_q3_K_packed16 +#endif + +#define QUANT_K_Q4_K 256 + +struct block_q4_K +{ + f16vec2 d; + uint8_t scales[3*QUANT_K_Q4_K/64]; + uint8_t qs[QUANT_K_Q4_K/2]; +}; + +struct block_q4_K_packed16 +{ + f16vec2 d; + uint16_t scales[3*QUANT_K_Q4_K/64/2]; + uint16_t qs[QUANT_K_Q4_K/2/2]; +}; + +struct block_q4_K_packed32 +{ + f16vec2 d; + uint32_t scales[3*QUANT_K_Q4_K/64/4]; + uint32_t qs[QUANT_K_Q4_K/2/4]; +}; + +struct block_q4_K_packed128 +{ + uvec4 q4k[9]; +}; + +#if defined(DATA_A_Q4_K) +#define QUANT_K QUANT_K_Q4_K +#define QUANT_R 1 +#define A_TYPE block_q4_K +#define A_TYPE_PACKED16 block_q4_K_packed16 +#define A_TYPE_PACKED32 block_q4_K_packed32 +#endif + +#define QUANT_K_Q5_K 256 + +struct block_q5_K +{ + f16vec2 d; + uint8_t scales[12]; + uint8_t qh[QUANT_K_Q5_K/8]; + uint8_t qs[QUANT_K_Q5_K/2]; +}; + +struct block_q5_K_packed16 +{ + f16vec2 d; + uint16_t scales[12/2]; + uint16_t qh[QUANT_K_Q5_K/8/2]; + uint16_t qs[QUANT_K_Q5_K/2/2]; +}; + +struct block_q5_K_packed128 +{ + uvec4 q5k[11]; +}; + +#if defined(DATA_A_Q5_K) +#define QUANT_K QUANT_K_Q5_K +#define QUANT_R 1 +#define A_TYPE block_q5_K +#define A_TYPE_PACKED16 block_q5_K_packed16 +#endif + +#define QUANT_K_Q6_K 256 + +struct block_q6_K +{ + uint8_t ql[QUANT_K_Q6_K/2]; + uint8_t qh[QUANT_K_Q6_K/4]; + int8_t scales[QUANT_K_Q6_K/16]; + float16_t d; +}; + +struct block_q6_K_packed16 +{ + uint16_t ql[QUANT_K_Q6_K/2/2]; + uint16_t qh[QUANT_K_Q6_K/4/2]; + int8_t scales[QUANT_K_Q6_K/16]; + float16_t d; +}; + +#if defined(DATA_A_Q6_K) +#define QUANT_K QUANT_K_Q6_K +#define QUANT_R 1 +#define A_TYPE block_q6_K +#define A_TYPE_PACKED16 block_q6_K_packed16 +#endif + +// IQuants + +#define QUANT_K_IQ1_S 256 +#define QUANT_R_IQ1_S 1 + +struct block_iq1_s { + float16_t d; + uint8_t qs[QUANT_K_IQ1_S/8]; + uint16_t qh[QUANT_K_IQ1_S/32]; +}; + +#define QUANT_K_IQ1_M 256 +#define QUANT_R_IQ1_M 1 + +struct block_iq1_m { + uint8_t qs[QUANT_K_IQ1_M/8]; + uint8_t qh[QUANT_K_IQ1_M/16]; + uint16_t scales[QUANT_K_IQ1_M/64]; +}; + +struct block_iq1_m_packed64 { + uint64_t qs[QUANT_K_IQ1_M/8/8]; + uint64_t qh[QUANT_K_IQ1_M/16/8]; + uint64_t scales; +}; + +#if defined(DATA_A_IQ1_S) +#define QUANT_K QUANT_K_IQ1_S +#define QUANT_R QUANT_R_IQ1_S +#define A_TYPE block_iq1_s +#endif + +#if defined(DATA_A_IQ1_M) +#define QUANT_K QUANT_K_IQ1_M +#define QUANT_R QUANT_R_IQ1_M +#define A_TYPE block_iq1_m +#endif + +#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) +#define IQ1S_DELTA 0.125f +#define IQ1M_DELTA 0.125f + +// Packed IQ1S grid where every 2 vec8 are encoded on 32 bits (2 bits per coordinate). +const uint[1024] iq1s_grid_const = { + 0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01, + 0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4, + 0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41, + 0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f, + 0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334, + 0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f, + 0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040, + 0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f, + 0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5, + 0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3, + 0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff, + 0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570, + 0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f, + 0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf, + 0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f, + 0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07, + 0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc, + 0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374, + 0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0, + 0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001, + 0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043, + 0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc, + 0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117, + 0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f, + 0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5, + 0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474, + 0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d, + 0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd, + 0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50, + 0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10, + 0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30, + 0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1, + 0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c, + 0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074, + 0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134, + 0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7, + 0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3, + 0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450, + 0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577, + 0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c, + 0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5, + 0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c, + 0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00, + 0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300, + 0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc, + 0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034, + 0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077, + 0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5, + 0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117, + 0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f, + 0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5, + 0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404, + 0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1, + 0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd, + 0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71, + 0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7, + 0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00, + 0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44, + 0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00, + 0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0, + 0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303, + 0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343, + 0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd, + 0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031, + 0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011, + 0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c, + 0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4, + 0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c, + 0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174, + 0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7, + 0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d, + 0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4, + 0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c, + 0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7, + 0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510, + 0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33, + 0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4, + 0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73, + 0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f, + 0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337, + 0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343, + 0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030, + 0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075, + 0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4, + 0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170, + 0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705, + 0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c, + 0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c, + 0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514, + 0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c, + 0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3, + 0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70, + 0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03, + 0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c, + 0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c, + 0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074, + 0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104, + 0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7, + 0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757, + 0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c, + 0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c, + 0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4, + 0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc, + 0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03, + 0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc, + 0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54, + 0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f, + 0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf, + 0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c, + 0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c, + 0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4, + 0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174, + 0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700, + 0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7, + 0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d, + 0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531, + 0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf, + 0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57, + 0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13, + 0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01, + 0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f, + 0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7, + 0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074, + 0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107, + 0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd, + 0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0, + 0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7, + 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 +}; + +shared uint16_t iq1s_grid[2048]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq1s_grid_const.length(); i += wgsize.x) { + uint idx = i + gl_LocalInvocationIndex.x; + if (iq1s_grid_const.length() % wgsize.x == 0 || idx < iq1s_grid_const.length()) { + u16vec2 g = unpack16(iq1s_grid_const[idx]); + iq1s_grid[2*idx+0] = g.x; + iq1s_grid[2*idx+1] = g.y; + } + } + barrier(); +} +#endif + +#define QUANT_K_IQ2_XXS 256 +#define QUANT_R_IQ2_XXS 1 + +struct block_iq2_xxs +{ + float16_t d; + uint8_t qs[QUANT_K_IQ2_XXS/4]; +}; + +struct block_iq2_xxs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XXS/8]; +}; + +#if defined(DATA_A_IQ2_XXS) + +const uvec2[256] iq2xxs_grid_const = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x082b0808, 0x08080808), + uvec2(0x082b082b, 0x08080808), uvec2(0x082b2b08, 0x08080808), uvec2(0x082b2b2b, 0x08080808), uvec2(0x19080819, 0x08080808), + uvec2(0x19081908, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), + uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b082b2b, 0x08080808), + uvec2(0x2b2b082b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), uvec2(0x08190808, 0x08080819), + uvec2(0x08191919, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x2b081908, 0x08080819), uvec2(0x2b192b08, 0x08080819), + uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x082b082b, 0x0808082b), uvec2(0x2b08082b, 0x0808082b), + uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x082b0819, 0x08081908), + uvec2(0x082b1908, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19082b08, 0x08081908), + uvec2(0x192b0808, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), + uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), uvec2(0x08082b08, 0x08081919), + uvec2(0x082b0808, 0x08081919), uvec2(0x1908192b, 0x08081919), uvec2(0x192b2b19, 0x08081919), uvec2(0x2b080808, 0x08081919), + uvec2(0x2b190819, 0x08081919), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x19080808, 0x0808192b), + uvec2(0x2b081908, 0x0808192b), uvec2(0x2b2b1908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x08081919, 0x08082b08), + uvec2(0x08082b08, 0x08082b08), uvec2(0x08191908, 0x08082b08), uvec2(0x082b2b08, 0x08082b08), uvec2(0x19080819, 0x08082b08), + uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x2b082b08, 0x08082b08), + uvec2(0x08081908, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x0808082b, 0x08082b2b), uvec2(0x08191908, 0x08082b2b), + uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x082b0819, 0x08190808), + uvec2(0x19080808, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), + uvec2(0x2b191919, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x082b0808, 0x08190819), + uvec2(0x19190808, 0x08190819), uvec2(0x19192b2b, 0x08190819), uvec2(0x2b080808, 0x08190819), uvec2(0x082b1908, 0x0819082b), + uvec2(0x19081919, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x08082b08, 0x08191908), uvec2(0x082b0808, 0x08191908), + uvec2(0x082b1919, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08192b08, 0x08191919), + uvec2(0x192b082b, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x0819192b, 0x0819192b), uvec2(0x08080819, 0x08192b08), + uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x2b2b0808, 0x08192b19), uvec2(0x19190819, 0x08192b2b), + uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x19081908, 0x082b0808), + uvec2(0x192b0819, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b08082b, 0x082b0808), uvec2(0x082b2b19, 0x082b0819), + uvec2(0x19082b08, 0x082b0819), uvec2(0x08080808, 0x082b082b), uvec2(0x0808082b, 0x082b082b), uvec2(0x08080819, 0x082b1908), + uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x19080808, 0x082b1908), uvec2(0x1919192b, 0x082b1908), + uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x192b1908, 0x082b1919), uvec2(0x2b190808, 0x082b192b), + uvec2(0x08082b08, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), uvec2(0x2b191908, 0x082b2b08), uvec2(0x19081908, 0x082b2b2b), + uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x08192b08, 0x19080808), + uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x19080808, 0x19080808), uvec2(0x19082b08, 0x19080808), + uvec2(0x1919192b, 0x19080808), uvec2(0x192b0808, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), + uvec2(0x2b190808, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x192b0819, 0x19080819), + uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08190808, 0x1908082b), + uvec2(0x19082b08, 0x1908082b), uvec2(0x1919192b, 0x1908082b), uvec2(0x192b2b08, 0x1908082b), uvec2(0x08080808, 0x19081908), + uvec2(0x08082b08, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b192b19, 0x19081908), + uvec2(0x0819082b, 0x19081919), uvec2(0x082b1908, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08080819, 0x19082b08), + uvec2(0x08081908, 0x19082b08), uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), + uvec2(0x08080808, 0x19082b19), uvec2(0x19192b08, 0x19082b19), uvec2(0x192b0819, 0x19082b19), uvec2(0x2b08082b, 0x19082b19), + uvec2(0x19081919, 0x19082b2b), uvec2(0x2b190808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x08082b08, 0x19190808), + uvec2(0x08190819, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x2b080808, 0x19190808), + uvec2(0x2b082b08, 0x19190808), uvec2(0x08081908, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x2b2b1908, 0x19190819), + uvec2(0x2b190819, 0x1919082b), uvec2(0x2b190808, 0x19191908), uvec2(0x2b19082b, 0x19191908), uvec2(0x08082b2b, 0x19191919), + uvec2(0x08080819, 0x1919192b), uvec2(0x19191908, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x08190819, 0x19192b08), + uvec2(0x08192b19, 0x19192b08), uvec2(0x192b1908, 0x19192b08), uvec2(0x19080808, 0x19192b19), uvec2(0x08082b08, 0x19192b2b), + uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x192b2b08, 0x192b0808), + uvec2(0x08080808, 0x192b0819), uvec2(0x19191919, 0x192b0819), uvec2(0x08192b08, 0x192b082b), uvec2(0x192b0808, 0x192b082b), + uvec2(0x08080808, 0x192b1908), uvec2(0x08081919, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x0819082b, 0x192b1919), + uvec2(0x2b081908, 0x192b1919), uvec2(0x1908082b, 0x192b2b08), uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), + uvec2(0x08082b2b, 0x2b080808), uvec2(0x19080819, 0x2b080808), uvec2(0x2b08082b, 0x2b080808), uvec2(0x08081908, 0x2b080819), + uvec2(0x08192b08, 0x2b080819), uvec2(0x19080808, 0x2b080819), uvec2(0x08190819, 0x2b08082b), uvec2(0x08080819, 0x2b081908), + uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), + uvec2(0x192b0808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x1908192b, 0x2b081919), uvec2(0x2b191908, 0x2b081919), + uvec2(0x08082b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x192b0808, 0x2b08192b), uvec2(0x0808082b, 0x2b082b08), + uvec2(0x08081908, 0x2b082b19), uvec2(0x08190819, 0x2b082b2b), uvec2(0x08081908, 0x2b190808), uvec2(0x08190808, 0x2b190808), + uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x2b2b0819, 0x2b190808), uvec2(0x0819192b, 0x2b190819), + uvec2(0x2b080808, 0x2b190819), uvec2(0x19081919, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x082b082b, 0x2b191908), + uvec2(0x19081908, 0x2b191908), uvec2(0x19190819, 0x2b191919), uvec2(0x2b080819, 0x2b192b08), uvec2(0x082b0808, 0x2b192b19), + uvec2(0x0808082b, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b081919, 0x2b2b0808), uvec2(0x08082b19, 0x2b2b0819), + uvec2(0x08080808, 0x2b2b082b), uvec2(0x08192b08, 0x2b2b1908), uvec2(0x19190808, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19) +}; + +shared uvec2 iq2xxs_grid[256]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq2xxs_grid.length(); i += wgsize.x) { + if (iq2xxs_grid_const.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xxs_grid_const.length()) { + iq2xxs_grid[i + gl_LocalInvocationIndex.x] = iq2xxs_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_XXS +#define QUANT_R QUANT_R_IQ2_XXS +#define A_TYPE block_iq2_xxs +#define A_TYPE_PACKED16 block_iq2_xxs_packed16 +#endif + +#define QUANT_K_IQ2_XS 256 +#define QUANT_R_IQ2_XS 1 + +struct block_iq2_xs +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XS/8]; + uint8_t scales[QUANT_K_IQ2_XS/32]; +}; + +struct block_iq2_xs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XS/8]; + uint16_t scales[QUANT_K_IQ2_XS/64]; +}; + +#if defined(DATA_A_IQ2_XS) + +const uvec2 iq2xs_grid_const[512] = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808), + uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808), + uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808), + uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808), + uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), + uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), + uvec2(0x2b191908, 0x08080808), uvec2(0x2b192b19, 0x08080808), uvec2(0x2b2b0808, 0x08080808), uvec2(0x08080819, 0x08080819), + uvec2(0x08081908, 0x08080819), uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), + uvec2(0x0819082b, 0x08080819), uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x08192b2b, 0x08080819), + uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), + uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), + uvec2(0x192b0808, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), uvec2(0x2b081908, 0x08080819), + uvec2(0x2b190808, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x08081919, 0x0808082b), + uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), uvec2(0x082b0808, 0x0808082b), + uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b), + uvec2(0x2b080808, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), + uvec2(0x0808192b, 0x08081908), uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), + uvec2(0x08191919, 0x08081908), uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), + uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), uvec2(0x19082b08, 0x08081908), + uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), uvec2(0x1919192b, 0x08081908), uvec2(0x192b0808, 0x08081908), + uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x08080808, 0x08081919), + uvec2(0x0808082b, 0x08081919), uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08190819, 0x08081919), + uvec2(0x08191908, 0x08081919), uvec2(0x082b0808, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919), + uvec2(0x19190808, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x2b080808, 0x08081919), uvec2(0x08080819, 0x0808192b), + uvec2(0x08081908, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x082b192b, 0x0808192b), uvec2(0x19080808, 0x0808192b), + uvec2(0x1908082b, 0x0808192b), uvec2(0x2b081908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08), + uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08082b2b, 0x08082b08), uvec2(0x08190819, 0x08082b08), + uvec2(0x08191908, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), uvec2(0x19080819, 0x08082b08), + uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x19192b08, 0x08082b08), uvec2(0x2b080808, 0x08082b08), + uvec2(0x2b2b0808, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), uvec2(0x08081908, 0x08082b19), + uvec2(0x08190808, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x2b080819, 0x08082b19), uvec2(0x2b082b19, 0x08082b19), + uvec2(0x08080808, 0x08082b2b), uvec2(0x082b0808, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x2b19192b, 0x08082b2b), + uvec2(0x2b2b0808, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x0808192b, 0x08190808), + uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), uvec2(0x08191919, 0x08190808), + uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), uvec2(0x19080808, 0x08190808), + uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), + uvec2(0x19191908, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b2b2b, 0x08190808), uvec2(0x2b080819, 0x08190808), + uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), + uvec2(0x08081919, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819), + uvec2(0x082b0808, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), uvec2(0x19190808, 0x08190819), + uvec2(0x2b080808, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x2b19192b, 0x08190819), uvec2(0x08080819, 0x0819082b), + uvec2(0x08081908, 0x0819082b), uvec2(0x0808192b, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x19080808, 0x0819082b), + uvec2(0x192b0808, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), + uvec2(0x08082b08, 0x08191908), uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x082b0808, 0x08191908), + uvec2(0x19080819, 0x08191908), uvec2(0x19081908, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908), + uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), + uvec2(0x08190808, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x08191908, 0x0819192b), + uvec2(0x19082b19, 0x0819192b), uvec2(0x08080819, 0x08192b08), uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), + uvec2(0x0819082b, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x19191908, 0x08192b08), uvec2(0x2b08192b, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x192b192b, 0x08192b19), uvec2(0x19190819, 0x08192b2b), + uvec2(0x2b2b2b19, 0x08192b2b), uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), + uvec2(0x08082b08, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), + uvec2(0x082b0808, 0x082b0808), uvec2(0x19080819, 0x082b0808), uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), + uvec2(0x2b080808, 0x082b0808), uvec2(0x2b2b0808, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), + uvec2(0x08190808, 0x082b0819), uvec2(0x19080808, 0x082b0819), uvec2(0x19082b08, 0x082b0819), uvec2(0x192b1919, 0x082b0819), + uvec2(0x08080808, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x2b080808, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), + uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x082b2b19, 0x082b1908), + uvec2(0x19080808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x1919082b, 0x082b1919), + uvec2(0x2b192b19, 0x082b1919), uvec2(0x08080819, 0x082b192b), uvec2(0x08192b2b, 0x082b192b), uvec2(0x2b2b192b, 0x082b192b), + uvec2(0x08080808, 0x082b2b08), uvec2(0x08082b08, 0x082b2b08), uvec2(0x08082b2b, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), + uvec2(0x19191919, 0x082b2b08), uvec2(0x2b082b08, 0x082b2b08), uvec2(0x2b2b082b, 0x082b2b08), uvec2(0x192b2b08, 0x082b2b19), + uvec2(0x2b190808, 0x082b2b19), uvec2(0x08082b08, 0x082b2b2b), uvec2(0x082b0808, 0x082b2b2b), uvec2(0x2b08082b, 0x082b2b2b), + uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), + uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x0819082b, 0x19080808), + uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), + uvec2(0x19080808, 0x19080808), uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), + uvec2(0x19082b2b, 0x19080808), uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x192b0808, 0x19080808), + uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), + uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), + uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x19080819, 0x19080819), + uvec2(0x19081908, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), + uvec2(0x2b2b082b, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), uvec2(0x08190808, 0x1908082b), + uvec2(0x0819082b, 0x1908082b), uvec2(0x082b2b19, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x08080808, 0x19081908), + uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), uvec2(0x08082b08, 0x19081908), uvec2(0x08190819, 0x19081908), + uvec2(0x08191908, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x19080819, 0x19081908), + uvec2(0x19081908, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b191908, 0x19081908), + uvec2(0x08080819, 0x19081919), uvec2(0x08081908, 0x19081919), uvec2(0x08190808, 0x19081919), uvec2(0x082b1908, 0x19081919), + uvec2(0x19080808, 0x19081919), uvec2(0x2b192b2b, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08082b2b, 0x1908192b), + uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), + uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), uvec2(0x19191908, 0x19082b08), + uvec2(0x192b082b, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x19081908, 0x19082b19), + uvec2(0x19190808, 0x19082b19), uvec2(0x192b2b19, 0x19082b19), uvec2(0x08081908, 0x19082b2b), uvec2(0x08080808, 0x19190808), + uvec2(0x0808082b, 0x19190808), uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), + uvec2(0x08191908, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), + uvec2(0x19081908, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x2b080808, 0x19190808), uvec2(0x08080819, 0x19190819), + uvec2(0x08081908, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x08191919, 0x19190819), uvec2(0x19080808, 0x19190819), + uvec2(0x1908082b, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x2b2b2b2b, 0x1919082b), + uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x082b0819, 0x19191908), + uvec2(0x19080808, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b2b0819, 0x19191908), + uvec2(0x08080808, 0x19191919), uvec2(0x08082b08, 0x19191919), uvec2(0x2b080808, 0x19191919), uvec2(0x2b082b08, 0x19191919), + uvec2(0x082b0819, 0x1919192b), uvec2(0x192b2b08, 0x1919192b), uvec2(0x2b2b0819, 0x1919192b), uvec2(0x08080808, 0x19192b08), + uvec2(0x08191908, 0x19192b08), uvec2(0x19080819, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x2b192b19, 0x19192b08), + uvec2(0x08192b2b, 0x19192b19), uvec2(0x19080808, 0x19192b19), uvec2(0x1908082b, 0x19192b19), uvec2(0x2b081919, 0x19192b2b), + uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), + uvec2(0x19191908, 0x192b0808), uvec2(0x192b082b, 0x192b0808), uvec2(0x2b08192b, 0x192b0808), uvec2(0x2b2b2b19, 0x192b0808), + uvec2(0x08080808, 0x192b0819), uvec2(0x082b1908, 0x192b082b), uvec2(0x19082b2b, 0x192b082b), uvec2(0x2b19082b, 0x192b082b), + uvec2(0x08080808, 0x192b1908), uvec2(0x0819192b, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x19080808, 0x192b1919), + uvec2(0x19081919, 0x192b1919), uvec2(0x2b2b1908, 0x192b1919), uvec2(0x08080819, 0x192b2b08), uvec2(0x192b2b2b, 0x192b2b08), + uvec2(0x082b1919, 0x192b2b19), uvec2(0x0808192b, 0x192b2b2b), uvec2(0x19191908, 0x192b2b2b), uvec2(0x192b082b, 0x192b2b2b), + uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), + uvec2(0x08190819, 0x2b080808), uvec2(0x08191908, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b2b2b, 0x2b080808), + uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x2b080808, 0x2b080808), + uvec2(0x2b08082b, 0x2b080808), uvec2(0x2b2b2b08, 0x2b080808), uvec2(0x2b2b2b2b, 0x2b080808), uvec2(0x08080819, 0x2b080819), + uvec2(0x08081908, 0x2b080819), uvec2(0x0808192b, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x19080808, 0x2b080819), + uvec2(0x19190819, 0x2b080819), uvec2(0x19192b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x082b0808, 0x2b08082b), + uvec2(0x2b080808, 0x2b08082b), uvec2(0x2b08082b, 0x2b08082b), uvec2(0x2b2b0808, 0x2b08082b), uvec2(0x2b2b2b08, 0x2b08082b), + uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908), + uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b082b19, 0x2b081908), + uvec2(0x08080808, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x2b2b1919, 0x2b081919), uvec2(0x08192b08, 0x2b08192b), + uvec2(0x192b2b2b, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08082b08, 0x2b082b08), uvec2(0x082b1919, 0x2b082b08), + uvec2(0x19192b2b, 0x2b082b08), uvec2(0x2b080808, 0x2b082b08), uvec2(0x2b08082b, 0x2b082b08), uvec2(0x2b2b2b08, 0x2b082b08), + uvec2(0x0808192b, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x2b080808, 0x2b082b2b), uvec2(0x2b082b08, 0x2b082b2b), + uvec2(0x2b19192b, 0x2b082b2b), uvec2(0x2b2b2b08, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), uvec2(0x08081908, 0x2b190808), + uvec2(0x08190808, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x1919192b, 0x2b190808), uvec2(0x2b081908, 0x2b190808), + uvec2(0x08080808, 0x2b190819), uvec2(0x082b082b, 0x2b190819), uvec2(0x192b1908, 0x2b190819), uvec2(0x1919192b, 0x2b19082b), + uvec2(0x2b082b19, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x08081919, 0x2b191908), uvec2(0x19081908, 0x2b191908), + uvec2(0x19190808, 0x2b191908), uvec2(0x19192b08, 0x2b191908), uvec2(0x082b2b19, 0x2b191919), uvec2(0x2b190808, 0x2b191919), + uvec2(0x2b19082b, 0x2b191919), uvec2(0x19080819, 0x2b19192b), uvec2(0x19190819, 0x2b192b08), uvec2(0x2b2b192b, 0x2b192b08), + uvec2(0x19082b19, 0x2b192b19), uvec2(0x08191919, 0x2b192b2b), uvec2(0x192b0808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), + uvec2(0x0808082b, 0x2b2b0808), uvec2(0x08082b08, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), uvec2(0x082b0808, 0x2b2b0808), + uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x2b2b0808, 0x2b2b0808), uvec2(0x19190819, 0x2b2b0819), uvec2(0x19192b19, 0x2b2b0819), + uvec2(0x2b2b192b, 0x2b2b0819), uvec2(0x08080808, 0x2b2b082b), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b08, 0x2b2b082b), + uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b080808, 0x2b2b082b), uvec2(0x2b2b0808, 0x2b2b082b), uvec2(0x19080808, 0x2b2b1908), + uvec2(0x2b191919, 0x2b2b1908), uvec2(0x192b1919, 0x2b2b192b), uvec2(0x2b192b08, 0x2b2b192b), uvec2(0x08082b2b, 0x2b2b2b08), + uvec2(0x082b0808, 0x2b2b2b08), uvec2(0x082b082b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b0808, 0x2b2b2b08), + uvec2(0x2b2b2b08, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19), uvec2(0x2b081908, 0x2b2b2b19), uvec2(0x2b08192b, 0x2b2b2b19), + uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x082b2b2b, 0x2b2b2b2b), uvec2(0x2b190819, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b), +}; + +shared uvec2 iq2xs_grid[512]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq2xs_grid.length(); i += wgsize.x) { + if (iq2xs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xs_grid_const.length()) { + iq2xs_grid[i + gl_LocalInvocationIndex.x] = iq2xs_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_XS +#define QUANT_R QUANT_R_IQ2_XS +#define A_TYPE block_iq2_xs +#define A_TYPE_PACKED16 block_iq2_xs_packed16 +#endif + +#define QUANT_K_IQ2_S 256 +#define QUANT_R_IQ2_S 1 + +struct block_iq2_s +{ + float16_t d; + uint8_t qs[QUANT_K_IQ2_S/4]; + uint8_t qh[QUANT_K_IQ2_S/32]; + uint8_t scales[QUANT_K_IQ2_S/32]; +}; + +struct block_iq2_s_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_S/8]; + uint16_t qh[QUANT_K_IQ2_S/64]; + uint16_t scales[QUANT_K_IQ2_S/64]; +}; + +#if defined(DATA_A_IQ2_S) + +const uvec2 iq2s_grid_const[1024] = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808), + uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808), + uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808), + uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808), + uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x192b192b, 0x08080808), + uvec2(0x192b2b19, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), + uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), uvec2(0x2b191908, 0x08080808), uvec2(0x2b2b0808, 0x08080808), + uvec2(0x2b2b1919, 0x08080808), uvec2(0x2b2b2b2b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), + uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), uvec2(0x0819082b, 0x08080819), + uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), + uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), + uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), uvec2(0x1919192b, 0x08080819), uvec2(0x19192b19, 0x08080819), + uvec2(0x192b0808, 0x08080819), uvec2(0x192b1919, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), + uvec2(0x2b081908, 0x08080819), uvec2(0x2b190808, 0x08080819), uvec2(0x2b19082b, 0x08080819), uvec2(0x2b191919, 0x08080819), + uvec2(0x2b2b0819, 0x08080819), uvec2(0x2b2b1908, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), + uvec2(0x08081919, 0x0808082b), uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), + uvec2(0x082b0808, 0x0808082b), uvec2(0x082b2b2b, 0x0808082b), uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), + uvec2(0x1908192b, 0x0808082b), uvec2(0x19082b19, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b), + uvec2(0x2b080808, 0x0808082b), uvec2(0x2b081919, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x2b191908, 0x0808082b), + uvec2(0x2b2b082b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x0808192b, 0x08081908), + uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), uvec2(0x08191919, 0x08081908), + uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), uvec2(0x082b192b, 0x08081908), + uvec2(0x082b2b19, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), + uvec2(0x19082b08, 0x08081908), uvec2(0x19082b2b, 0x08081908), uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), + uvec2(0x1919192b, 0x08081908), uvec2(0x19192b19, 0x08081908), uvec2(0x192b0808, 0x08081908), uvec2(0x192b082b, 0x08081908), + uvec2(0x192b1919, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b08192b, 0x08081908), + uvec2(0x2b082b19, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x2b191919, 0x08081908), uvec2(0x2b192b08, 0x08081908), + uvec2(0x2b2b0819, 0x08081908), uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), + uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08082b2b, 0x08081919), uvec2(0x08190819, 0x08081919), + uvec2(0x08191908, 0x08081919), uvec2(0x0819192b, 0x08081919), uvec2(0x08192b19, 0x08081919), uvec2(0x082b0808, 0x08081919), + uvec2(0x082b1919, 0x08081919), uvec2(0x082b2b08, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919), + uvec2(0x1908192b, 0x08081919), uvec2(0x19082b19, 0x08081919), uvec2(0x19190808, 0x08081919), uvec2(0x1919082b, 0x08081919), + uvec2(0x19191919, 0x08081919), uvec2(0x19192b08, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x192b1908, 0x08081919), + uvec2(0x2b080808, 0x08081919), uvec2(0x2b08082b, 0x08081919), uvec2(0x2b081919, 0x08081919), uvec2(0x2b082b08, 0x08081919), + uvec2(0x2b190819, 0x08081919), uvec2(0x2b191908, 0x08081919), uvec2(0x2b2b0808, 0x08081919), uvec2(0x08080819, 0x0808192b), + uvec2(0x08081908, 0x0808192b), uvec2(0x0808192b, 0x0808192b), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), + uvec2(0x08191919, 0x0808192b), uvec2(0x19080808, 0x0808192b), uvec2(0x19081919, 0x0808192b), uvec2(0x19082b08, 0x0808192b), + uvec2(0x19190819, 0x0808192b), uvec2(0x19191908, 0x0808192b), uvec2(0x192b0808, 0x0808192b), uvec2(0x2b080819, 0x0808192b), + uvec2(0x2b081908, 0x0808192b), uvec2(0x2b190808, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08), + uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08190819, 0x08082b08), uvec2(0x08191908, 0x08082b08), + uvec2(0x0819192b, 0x08082b08), uvec2(0x08192b19, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), + uvec2(0x082b2b2b, 0x08082b08), uvec2(0x19080819, 0x08082b08), uvec2(0x19081908, 0x08082b08), uvec2(0x1908192b, 0x08082b08), + uvec2(0x19082b19, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x19191919, 0x08082b08), + uvec2(0x19192b08, 0x08082b08), uvec2(0x192b0819, 0x08082b08), uvec2(0x192b1908, 0x08082b08), uvec2(0x2b080808, 0x08082b08), + uvec2(0x2b081919, 0x08082b08), uvec2(0x2b191908, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), + uvec2(0x08081908, 0x08082b19), uvec2(0x08190808, 0x08082b19), uvec2(0x0819082b, 0x08082b19), uvec2(0x08191919, 0x08082b19), + uvec2(0x08192b08, 0x08082b19), uvec2(0x082b0819, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x19081919, 0x08082b19), + uvec2(0x19082b08, 0x08082b19), uvec2(0x19190819, 0x08082b19), uvec2(0x19191908, 0x08082b19), uvec2(0x192b0808, 0x08082b19), + uvec2(0x2b080819, 0x08082b19), uvec2(0x2b190808, 0x08082b19), uvec2(0x08080808, 0x08082b2b), uvec2(0x08190819, 0x08082b2b), + uvec2(0x08191908, 0x08082b2b), uvec2(0x082b082b, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x082b2b2b, 0x08082b2b), + uvec2(0x19190808, 0x08082b2b), uvec2(0x2b192b19, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), + uvec2(0x0808192b, 0x08190808), uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), + uvec2(0x08191919, 0x08190808), uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), + uvec2(0x082b192b, 0x08190808), uvec2(0x19080808, 0x08190808), uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), + uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), uvec2(0x19191908, 0x08190808), uvec2(0x1919192b, 0x08190808), + uvec2(0x19192b19, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b082b, 0x08190808), uvec2(0x192b1919, 0x08190808), + uvec2(0x192b2b08, 0x08190808), uvec2(0x2b080819, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b08192b, 0x08190808), + uvec2(0x2b190808, 0x08190808), uvec2(0x2b191919, 0x08190808), uvec2(0x2b192b08, 0x08190808), uvec2(0x2b2b0819, 0x08190808), + uvec2(0x2b2b1908, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), uvec2(0x08081919, 0x08190819), + uvec2(0x08082b08, 0x08190819), uvec2(0x08082b2b, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819), + uvec2(0x0819192b, 0x08190819), uvec2(0x08192b19, 0x08190819), uvec2(0x082b0808, 0x08190819), uvec2(0x082b082b, 0x08190819), + uvec2(0x082b1919, 0x08190819), uvec2(0x082b2b08, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), + uvec2(0x1908192b, 0x08190819), uvec2(0x19082b19, 0x08190819), uvec2(0x19190808, 0x08190819), uvec2(0x1919082b, 0x08190819), + uvec2(0x19191919, 0x08190819), uvec2(0x19192b08, 0x08190819), uvec2(0x192b0819, 0x08190819), uvec2(0x192b1908, 0x08190819), + uvec2(0x2b080808, 0x08190819), uvec2(0x2b08082b, 0x08190819), uvec2(0x2b081919, 0x08190819), uvec2(0x2b082b08, 0x08190819), + uvec2(0x2b190819, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x08080819, 0x0819082b), uvec2(0x08081908, 0x0819082b), + uvec2(0x08082b19, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x08191919, 0x0819082b), uvec2(0x082b0819, 0x0819082b), + uvec2(0x082b1908, 0x0819082b), uvec2(0x19080808, 0x0819082b), uvec2(0x19081919, 0x0819082b), uvec2(0x19190819, 0x0819082b), + uvec2(0x19191908, 0x0819082b), uvec2(0x2b080819, 0x0819082b), uvec2(0x2b081908, 0x0819082b), uvec2(0x2b190808, 0x0819082b), + uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), uvec2(0x08082b08, 0x08191908), + uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x0819192b, 0x08191908), uvec2(0x08192b19, 0x08191908), + uvec2(0x082b0808, 0x08191908), uvec2(0x082b1919, 0x08191908), uvec2(0x082b2b08, 0x08191908), uvec2(0x19080819, 0x08191908), + uvec2(0x19081908, 0x08191908), uvec2(0x1908192b, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908), + uvec2(0x1919082b, 0x08191908), uvec2(0x19191919, 0x08191908), uvec2(0x19192b08, 0x08191908), uvec2(0x192b0819, 0x08191908), + uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x2b08082b, 0x08191908), uvec2(0x2b081919, 0x08191908), + uvec2(0x2b082b08, 0x08191908), uvec2(0x2b190819, 0x08191908), uvec2(0x2b191908, 0x08191908), uvec2(0x2b2b0808, 0x08191908), + uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), uvec2(0x0808192b, 0x08191919), uvec2(0x08082b19, 0x08191919), + uvec2(0x08190808, 0x08191919), uvec2(0x0819082b, 0x08191919), uvec2(0x08191919, 0x08191919), uvec2(0x08192b08, 0x08191919), + uvec2(0x082b0819, 0x08191919), uvec2(0x082b1908, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x1908082b, 0x08191919), + uvec2(0x19081919, 0x08191919), uvec2(0x19082b08, 0x08191919), uvec2(0x19190819, 0x08191919), uvec2(0x19191908, 0x08191919), + uvec2(0x192b0808, 0x08191919), uvec2(0x2b080819, 0x08191919), uvec2(0x2b081908, 0x08191919), uvec2(0x2b190808, 0x08191919), + uvec2(0x08080808, 0x0819192b), uvec2(0x08081919, 0x0819192b), uvec2(0x08082b08, 0x0819192b), uvec2(0x08190819, 0x0819192b), + uvec2(0x08191908, 0x0819192b), uvec2(0x082b0808, 0x0819192b), uvec2(0x19080819, 0x0819192b), uvec2(0x19081908, 0x0819192b), + uvec2(0x19190808, 0x0819192b), uvec2(0x2b080808, 0x0819192b), uvec2(0x2b2b2b2b, 0x0819192b), uvec2(0x08080819, 0x08192b08), + uvec2(0x08081908, 0x08192b08), uvec2(0x0808192b, 0x08192b08), uvec2(0x08082b19, 0x08192b08), uvec2(0x08190808, 0x08192b08), + uvec2(0x08191919, 0x08192b08), uvec2(0x08192b08, 0x08192b08), uvec2(0x082b0819, 0x08192b08), uvec2(0x19080808, 0x08192b08), + uvec2(0x1908082b, 0x08192b08), uvec2(0x19081919, 0x08192b08), uvec2(0x19082b08, 0x08192b08), uvec2(0x19190819, 0x08192b08), + uvec2(0x19191908, 0x08192b08), uvec2(0x192b0808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), uvec2(0x2b081908, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x0808082b, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x08082b08, 0x08192b19), + uvec2(0x08190819, 0x08192b19), uvec2(0x08191908, 0x08192b19), uvec2(0x082b0808, 0x08192b19), uvec2(0x19080819, 0x08192b19), + uvec2(0x19081908, 0x08192b19), uvec2(0x19190808, 0x08192b19), uvec2(0x192b2b19, 0x08192b19), uvec2(0x2b2b082b, 0x08192b19), + uvec2(0x08081908, 0x08192b2b), uvec2(0x08190808, 0x08192b2b), uvec2(0x19080808, 0x08192b2b), uvec2(0x1919192b, 0x08192b2b), + uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), uvec2(0x08082b08, 0x082b0808), + uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), uvec2(0x0819192b, 0x082b0808), uvec2(0x08192b19, 0x082b0808), + uvec2(0x082b0808, 0x082b0808), uvec2(0x082b1919, 0x082b0808), uvec2(0x082b2b2b, 0x082b0808), uvec2(0x19080819, 0x082b0808), + uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), uvec2(0x1919082b, 0x082b0808), uvec2(0x19191919, 0x082b0808), + uvec2(0x192b1908, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b082b2b, 0x082b0808), uvec2(0x2b191908, 0x082b0808), + uvec2(0x2b2b2b2b, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), uvec2(0x08190808, 0x082b0819), + uvec2(0x0819082b, 0x082b0819), uvec2(0x08191919, 0x082b0819), uvec2(0x082b0819, 0x082b0819), uvec2(0x19080808, 0x082b0819), + uvec2(0x1908082b, 0x082b0819), uvec2(0x19081919, 0x082b0819), uvec2(0x19190819, 0x082b0819), uvec2(0x19191908, 0x082b0819), + uvec2(0x192b0808, 0x082b0819), uvec2(0x2b080819, 0x082b0819), uvec2(0x2b081908, 0x082b0819), uvec2(0x2b190808, 0x082b0819), + uvec2(0x08080808, 0x082b082b), uvec2(0x08082b2b, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x082b2b08, 0x082b082b), + uvec2(0x082b2b2b, 0x082b082b), uvec2(0x19081908, 0x082b082b), uvec2(0x19190808, 0x082b082b), uvec2(0x2b082b08, 0x082b082b), + uvec2(0x2b082b2b, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), + uvec2(0x0808192b, 0x082b1908), uvec2(0x08082b19, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x08191919, 0x082b1908), + uvec2(0x08192b08, 0x082b1908), uvec2(0x082b0819, 0x082b1908), uvec2(0x082b1908, 0x082b1908), uvec2(0x19080808, 0x082b1908), + uvec2(0x1908082b, 0x082b1908), uvec2(0x19081919, 0x082b1908), uvec2(0x19082b08, 0x082b1908), uvec2(0x19190819, 0x082b1908), + uvec2(0x19191908, 0x082b1908), uvec2(0x192b0808, 0x082b1908), uvec2(0x2b080819, 0x082b1908), uvec2(0x2b081908, 0x082b1908), + uvec2(0x2b190808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x08081919, 0x082b1919), uvec2(0x08082b08, 0x082b1919), + uvec2(0x08190819, 0x082b1919), uvec2(0x08191908, 0x082b1919), uvec2(0x082b0808, 0x082b1919), uvec2(0x19080819, 0x082b1919), + uvec2(0x19081908, 0x082b1919), uvec2(0x19190808, 0x082b1919), uvec2(0x192b192b, 0x082b1919), uvec2(0x2b080808, 0x082b1919), + uvec2(0x08080819, 0x082b192b), uvec2(0x08081908, 0x082b192b), uvec2(0x08190808, 0x082b192b), uvec2(0x19080808, 0x082b192b), + uvec2(0x19192b19, 0x082b192b), uvec2(0x08080808, 0x082b2b08), uvec2(0x08081919, 0x082b2b08), uvec2(0x08190819, 0x082b2b08), + uvec2(0x08191908, 0x082b2b08), uvec2(0x19080819, 0x082b2b08), uvec2(0x19081908, 0x082b2b08), uvec2(0x19190808, 0x082b2b08), + uvec2(0x2b082b2b, 0x082b2b08), uvec2(0x2b2b2b2b, 0x082b2b08), uvec2(0x08080819, 0x082b2b19), uvec2(0x08081908, 0x082b2b19), + uvec2(0x08190808, 0x082b2b19), uvec2(0x2b191919, 0x082b2b19), uvec2(0x08082b2b, 0x082b2b2b), uvec2(0x082b082b, 0x082b2b2b), + uvec2(0x192b1908, 0x082b2b2b), uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), + uvec2(0x08081908, 0x19080808), uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), + uvec2(0x0819082b, 0x19080808), uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x08192b2b, 0x19080808), + uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x082b192b, 0x19080808), uvec2(0x19080808, 0x19080808), + uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), uvec2(0x19082b2b, 0x19080808), + uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x1919192b, 0x19080808), uvec2(0x19192b19, 0x19080808), + uvec2(0x192b0808, 0x19080808), uvec2(0x192b082b, 0x19080808), uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), + uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), uvec2(0x2b191919, 0x19080808), uvec2(0x2b192b08, 0x19080808), + uvec2(0x2b2b0819, 0x19080808), uvec2(0x2b2b1908, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), + uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), + uvec2(0x0819192b, 0x19080819), uvec2(0x08192b19, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x082b082b, 0x19080819), + uvec2(0x082b1919, 0x19080819), uvec2(0x19080819, 0x19080819), uvec2(0x19081908, 0x19080819), uvec2(0x1908192b, 0x19080819), + uvec2(0x19082b19, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x1919082b, 0x19080819), uvec2(0x19191919, 0x19080819), + uvec2(0x19192b08, 0x19080819), uvec2(0x192b0819, 0x19080819), uvec2(0x192b1908, 0x19080819), uvec2(0x2b080808, 0x19080819), + uvec2(0x2b08082b, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x2b082b08, 0x19080819), uvec2(0x2b190819, 0x19080819), + uvec2(0x2b191908, 0x19080819), uvec2(0x2b2b0808, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), + uvec2(0x08190808, 0x1908082b), uvec2(0x0819082b, 0x1908082b), uvec2(0x08191919, 0x1908082b), uvec2(0x08192b08, 0x1908082b), + uvec2(0x082b1908, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x19081919, 0x1908082b), uvec2(0x19082b08, 0x1908082b), + uvec2(0x19190819, 0x1908082b), uvec2(0x19191908, 0x1908082b), uvec2(0x192b0808, 0x1908082b), uvec2(0x2b080819, 0x1908082b), + uvec2(0x2b081908, 0x1908082b), uvec2(0x08080808, 0x19081908), uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), + uvec2(0x08082b08, 0x19081908), uvec2(0x08082b2b, 0x19081908), uvec2(0x08190819, 0x19081908), uvec2(0x08191908, 0x19081908), + uvec2(0x0819192b, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x082b082b, 0x19081908), + uvec2(0x082b1919, 0x19081908), uvec2(0x082b2b08, 0x19081908), uvec2(0x19080819, 0x19081908), uvec2(0x19081908, 0x19081908), + uvec2(0x1908192b, 0x19081908), uvec2(0x19082b19, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x1919082b, 0x19081908), + uvec2(0x19191919, 0x19081908), uvec2(0x19192b08, 0x19081908), uvec2(0x192b0819, 0x19081908), uvec2(0x192b1908, 0x19081908), + uvec2(0x2b080808, 0x19081908), uvec2(0x2b08082b, 0x19081908), uvec2(0x2b081919, 0x19081908), uvec2(0x2b082b08, 0x19081908), + uvec2(0x2b190819, 0x19081908), uvec2(0x2b191908, 0x19081908), uvec2(0x2b2b0808, 0x19081908), uvec2(0x08080819, 0x19081919), + uvec2(0x08081908, 0x19081919), uvec2(0x0808192b, 0x19081919), uvec2(0x08082b19, 0x19081919), uvec2(0x08190808, 0x19081919), + uvec2(0x0819082b, 0x19081919), uvec2(0x08191919, 0x19081919), uvec2(0x08192b08, 0x19081919), uvec2(0x082b0819, 0x19081919), + uvec2(0x082b1908, 0x19081919), uvec2(0x19080808, 0x19081919), uvec2(0x1908082b, 0x19081919), uvec2(0x19081919, 0x19081919), + uvec2(0x19082b08, 0x19081919), uvec2(0x19190819, 0x19081919), uvec2(0x19191908, 0x19081919), uvec2(0x192b0808, 0x19081919), + uvec2(0x192b2b2b, 0x19081919), uvec2(0x2b080819, 0x19081919), uvec2(0x2b081908, 0x19081919), uvec2(0x2b190808, 0x19081919), + uvec2(0x08080808, 0x1908192b), uvec2(0x0808082b, 0x1908192b), uvec2(0x08081919, 0x1908192b), uvec2(0x08082b08, 0x1908192b), + uvec2(0x08190819, 0x1908192b), uvec2(0x08191908, 0x1908192b), uvec2(0x082b0808, 0x1908192b), uvec2(0x19080819, 0x1908192b), + uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x2b080808, 0x1908192b), uvec2(0x2b2b1919, 0x1908192b), + uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), uvec2(0x08082b19, 0x19082b08), uvec2(0x08190808, 0x19082b08), + uvec2(0x0819082b, 0x19082b08), uvec2(0x08191919, 0x19082b08), uvec2(0x08192b08, 0x19082b08), uvec2(0x082b0819, 0x19082b08), + uvec2(0x082b1908, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x1908082b, 0x19082b08), uvec2(0x19081919, 0x19082b08), + uvec2(0x19082b08, 0x19082b08), uvec2(0x19190819, 0x19082b08), uvec2(0x19191908, 0x19082b08), uvec2(0x192b0808, 0x19082b08), + uvec2(0x2b081908, 0x19082b08), uvec2(0x2b190808, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x0808082b, 0x19082b19), + uvec2(0x08081919, 0x19082b19), uvec2(0x08082b08, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x08191908, 0x19082b19), + uvec2(0x082b0808, 0x19082b19), uvec2(0x19080819, 0x19082b19), uvec2(0x19081908, 0x19082b19), uvec2(0x19190808, 0x19082b19), + uvec2(0x2b080808, 0x19082b19), uvec2(0x2b19192b, 0x19082b19), uvec2(0x08080819, 0x19082b2b), uvec2(0x08081908, 0x19082b2b), + uvec2(0x08190808, 0x19082b2b), uvec2(0x19080808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x0808082b, 0x19190808), + uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), uvec2(0x08191908, 0x19190808), + uvec2(0x0819192b, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b082b, 0x19190808), + uvec2(0x082b1919, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), uvec2(0x19081908, 0x19190808), + uvec2(0x1908192b, 0x19190808), uvec2(0x19082b19, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x1919082b, 0x19190808), + uvec2(0x19191919, 0x19190808), uvec2(0x19192b08, 0x19190808), uvec2(0x192b0819, 0x19190808), uvec2(0x192b1908, 0x19190808), + uvec2(0x2b080808, 0x19190808), uvec2(0x2b08082b, 0x19190808), uvec2(0x2b081919, 0x19190808), uvec2(0x2b082b08, 0x19190808), + uvec2(0x2b190819, 0x19190808), uvec2(0x2b191908, 0x19190808), uvec2(0x08080819, 0x19190819), uvec2(0x08081908, 0x19190819), + uvec2(0x0808192b, 0x19190819), uvec2(0x08082b19, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x0819082b, 0x19190819), + uvec2(0x08191919, 0x19190819), uvec2(0x08192b08, 0x19190819), uvec2(0x082b0819, 0x19190819), uvec2(0x082b1908, 0x19190819), + uvec2(0x19080808, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x19081919, 0x19190819), uvec2(0x19082b08, 0x19190819), + uvec2(0x19190819, 0x19190819), uvec2(0x19191908, 0x19190819), uvec2(0x192b0808, 0x19190819), uvec2(0x2b080819, 0x19190819), + uvec2(0x2b081908, 0x19190819), uvec2(0x2b190808, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x08081919, 0x1919082b), + uvec2(0x08082b08, 0x1919082b), uvec2(0x08190819, 0x1919082b), uvec2(0x08191908, 0x1919082b), uvec2(0x082b0808, 0x1919082b), + uvec2(0x19080819, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x19190808, 0x1919082b), uvec2(0x192b2b19, 0x1919082b), + uvec2(0x2b080808, 0x1919082b), uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x0808192b, 0x19191908), + uvec2(0x08082b19, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x0819082b, 0x19191908), uvec2(0x08191919, 0x19191908), + uvec2(0x08192b08, 0x19191908), uvec2(0x082b0819, 0x19191908), uvec2(0x082b1908, 0x19191908), uvec2(0x19080808, 0x19191908), + uvec2(0x1908082b, 0x19191908), uvec2(0x19081919, 0x19191908), uvec2(0x19082b08, 0x19191908), uvec2(0x19190819, 0x19191908), + uvec2(0x19191908, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b081908, 0x19191908), + uvec2(0x2b190808, 0x19191908), uvec2(0x08080808, 0x19191919), uvec2(0x0808082b, 0x19191919), uvec2(0x08081919, 0x19191919), + uvec2(0x08082b08, 0x19191919), uvec2(0x08190819, 0x19191919), uvec2(0x08191908, 0x19191919), uvec2(0x082b0808, 0x19191919), + uvec2(0x19080819, 0x19191919), uvec2(0x19081908, 0x19191919), uvec2(0x19190808, 0x19191919), uvec2(0x2b080808, 0x19191919), + uvec2(0x08080819, 0x1919192b), uvec2(0x08081908, 0x1919192b), uvec2(0x08190808, 0x1919192b), uvec2(0x082b192b, 0x1919192b), + uvec2(0x19080808, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x0808082b, 0x19192b08), uvec2(0x08081919, 0x19192b08), + uvec2(0x08082b08, 0x19192b08), uvec2(0x08190819, 0x19192b08), uvec2(0x08191908, 0x19192b08), uvec2(0x082b0808, 0x19192b08), + uvec2(0x19080819, 0x19192b08), uvec2(0x19081908, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x19192b2b, 0x19192b08), + uvec2(0x2b080808, 0x19192b08), uvec2(0x08080819, 0x19192b19), uvec2(0x08081908, 0x19192b19), uvec2(0x08190808, 0x19192b19), + uvec2(0x19080808, 0x19192b19), uvec2(0x08080808, 0x19192b2b), uvec2(0x08192b19, 0x19192b2b), uvec2(0x2b081919, 0x19192b2b), + uvec2(0x2b2b2b08, 0x19192b2b), uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x0808192b, 0x192b0808), + uvec2(0x08190808, 0x192b0808), uvec2(0x0819082b, 0x192b0808), uvec2(0x08191919, 0x192b0808), uvec2(0x08192b08, 0x192b0808), + uvec2(0x082b0819, 0x192b0808), uvec2(0x082b1908, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x19081919, 0x192b0808), + uvec2(0x19082b08, 0x192b0808), uvec2(0x19190819, 0x192b0808), uvec2(0x19191908, 0x192b0808), uvec2(0x192b0808, 0x192b0808), + uvec2(0x2b081908, 0x192b0808), uvec2(0x2b190808, 0x192b0808), uvec2(0x08080808, 0x192b0819), uvec2(0x0808082b, 0x192b0819), + uvec2(0x08081919, 0x192b0819), uvec2(0x08082b08, 0x192b0819), uvec2(0x08190819, 0x192b0819), uvec2(0x08191908, 0x192b0819), + uvec2(0x082b0808, 0x192b0819), uvec2(0x19080819, 0x192b0819), uvec2(0x19081908, 0x192b0819), uvec2(0x19190808, 0x192b0819), + uvec2(0x2b080808, 0x192b0819), uvec2(0x2b192b19, 0x192b0819), uvec2(0x08081908, 0x192b082b), uvec2(0x08190808, 0x192b082b), + uvec2(0x19080808, 0x192b082b), uvec2(0x1919192b, 0x192b082b), uvec2(0x2b2b0819, 0x192b082b), uvec2(0x08080808, 0x192b1908), + uvec2(0x08081919, 0x192b1908), uvec2(0x08082b08, 0x192b1908), uvec2(0x08190819, 0x192b1908), uvec2(0x08191908, 0x192b1908), + uvec2(0x082b0808, 0x192b1908), uvec2(0x19080819, 0x192b1908), uvec2(0x19081908, 0x192b1908), uvec2(0x19190808, 0x192b1908), + uvec2(0x2b080808, 0x192b1908), uvec2(0x08080819, 0x192b1919), uvec2(0x08081908, 0x192b1919), uvec2(0x08190808, 0x192b1919), + uvec2(0x19080808, 0x192b1919), uvec2(0x19082b2b, 0x192b1919), uvec2(0x192b2b08, 0x192b1919), uvec2(0x2b19082b, 0x192b1919), + uvec2(0x08080808, 0x192b192b), uvec2(0x2b191908, 0x192b192b), uvec2(0x08080819, 0x192b2b08), uvec2(0x08081908, 0x192b2b08), + uvec2(0x08190808, 0x192b2b08), uvec2(0x192b1919, 0x192b2b08), uvec2(0x2b192b08, 0x192b2b08), uvec2(0x08080808, 0x192b2b19), + uvec2(0x082b2b2b, 0x192b2b19), uvec2(0x1908082b, 0x192b2b2b), uvec2(0x2b2b0819, 0x192b2b2b), uvec2(0x08080808, 0x2b080808), + uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), uvec2(0x08190819, 0x2b080808), + uvec2(0x08191908, 0x2b080808), uvec2(0x08192b19, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b1919, 0x2b080808), + uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x1919082b, 0x2b080808), + uvec2(0x19191919, 0x2b080808), uvec2(0x19192b08, 0x2b080808), uvec2(0x192b0819, 0x2b080808), uvec2(0x2b080808, 0x2b080808), + uvec2(0x2b081919, 0x2b080808), uvec2(0x2b190819, 0x2b080808), uvec2(0x2b191908, 0x2b080808), uvec2(0x08080819, 0x2b080819), + uvec2(0x08081908, 0x2b080819), uvec2(0x08082b19, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x0819082b, 0x2b080819), + uvec2(0x08191919, 0x2b080819), uvec2(0x08192b08, 0x2b080819), uvec2(0x082b0819, 0x2b080819), uvec2(0x082b1908, 0x2b080819), + uvec2(0x19080808, 0x2b080819), uvec2(0x1908082b, 0x2b080819), uvec2(0x19081919, 0x2b080819), uvec2(0x19082b08, 0x2b080819), + uvec2(0x19190819, 0x2b080819), uvec2(0x19191908, 0x2b080819), uvec2(0x2b080819, 0x2b080819), uvec2(0x2b081908, 0x2b080819), + uvec2(0x2b190808, 0x2b080819), uvec2(0x2b2b2b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x08081919, 0x2b08082b), + uvec2(0x08082b2b, 0x2b08082b), uvec2(0x08190819, 0x2b08082b), uvec2(0x08191908, 0x2b08082b), uvec2(0x19080819, 0x2b08082b), + uvec2(0x19081908, 0x2b08082b), uvec2(0x19190808, 0x2b08082b), uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), + uvec2(0x0808192b, 0x2b081908), uvec2(0x08082b19, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908), + uvec2(0x08191919, 0x2b081908), uvec2(0x08192b08, 0x2b081908), uvec2(0x082b0819, 0x2b081908), uvec2(0x19080808, 0x2b081908), + uvec2(0x1908082b, 0x2b081908), uvec2(0x19081919, 0x2b081908), uvec2(0x19082b08, 0x2b081908), uvec2(0x19190819, 0x2b081908), + uvec2(0x19191908, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b080819, 0x2b081908), uvec2(0x2b081908, 0x2b081908), + uvec2(0x2b190808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x0808082b, 0x2b081919), uvec2(0x08081919, 0x2b081919), + uvec2(0x08082b08, 0x2b081919), uvec2(0x08190819, 0x2b081919), uvec2(0x08191908, 0x2b081919), uvec2(0x082b0808, 0x2b081919), + uvec2(0x19080819, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x19190808, 0x2b081919), uvec2(0x2b080808, 0x2b081919), + uvec2(0x2b082b2b, 0x2b081919), uvec2(0x08080819, 0x2b08192b), uvec2(0x08081908, 0x2b08192b), uvec2(0x08190808, 0x2b08192b), + uvec2(0x082b2b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08081919, 0x2b082b08), + uvec2(0x08190819, 0x2b082b08), uvec2(0x08191908, 0x2b082b08), uvec2(0x19080819, 0x2b082b08), uvec2(0x19081908, 0x2b082b08), + uvec2(0x19190808, 0x2b082b08), uvec2(0x2b2b082b, 0x2b082b08), uvec2(0x08080819, 0x2b082b19), uvec2(0x08081908, 0x2b082b19), + uvec2(0x19080808, 0x2b082b19), uvec2(0x192b1919, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x19192b08, 0x2b082b2b), + uvec2(0x19192b2b, 0x2b082b2b), uvec2(0x2b08082b, 0x2b082b2b), uvec2(0x2b2b082b, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), + uvec2(0x08081908, 0x2b190808), uvec2(0x08082b19, 0x2b190808), uvec2(0x08190808, 0x2b190808), uvec2(0x0819082b, 0x2b190808), + uvec2(0x08191919, 0x2b190808), uvec2(0x08192b08, 0x2b190808), uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), + uvec2(0x1908082b, 0x2b190808), uvec2(0x19081919, 0x2b190808), uvec2(0x19082b08, 0x2b190808), uvec2(0x19190819, 0x2b190808), + uvec2(0x19191908, 0x2b190808), uvec2(0x192b0808, 0x2b190808), uvec2(0x2b080819, 0x2b190808), uvec2(0x2b081908, 0x2b190808), + uvec2(0x2b190808, 0x2b190808), uvec2(0x08080808, 0x2b190819), uvec2(0x08081919, 0x2b190819), uvec2(0x08190819, 0x2b190819), + uvec2(0x08191908, 0x2b190819), uvec2(0x19080819, 0x2b190819), uvec2(0x19081908, 0x2b190819), uvec2(0x19190808, 0x2b190819), + uvec2(0x19192b2b, 0x2b190819), uvec2(0x08080819, 0x2b19082b), uvec2(0x08081908, 0x2b19082b), uvec2(0x08190808, 0x2b19082b), + uvec2(0x19080808, 0x2b19082b), uvec2(0x2b2b192b, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x0808082b, 0x2b191908), + uvec2(0x08081919, 0x2b191908), uvec2(0x08082b08, 0x2b191908), uvec2(0x08190819, 0x2b191908), uvec2(0x08191908, 0x2b191908), + uvec2(0x082b0808, 0x2b191908), uvec2(0x19080819, 0x2b191908), uvec2(0x19081908, 0x2b191908), uvec2(0x19190808, 0x2b191908), + uvec2(0x2b080808, 0x2b191908), uvec2(0x2b19192b, 0x2b191908), uvec2(0x08080819, 0x2b191919), uvec2(0x08081908, 0x2b191919), + uvec2(0x08190808, 0x2b191919), uvec2(0x19080808, 0x2b191919), uvec2(0x2b192b08, 0x2b191919), uvec2(0x2b2b0819, 0x2b191919), + uvec2(0x08080808, 0x2b19192b), uvec2(0x1908192b, 0x2b19192b), uvec2(0x192b1908, 0x2b19192b), uvec2(0x08080819, 0x2b192b08), + uvec2(0x08081908, 0x2b192b08), uvec2(0x08190808, 0x2b192b08), uvec2(0x082b192b, 0x2b192b08), uvec2(0x19080808, 0x2b192b08), + uvec2(0x2b2b2b19, 0x2b192b08), uvec2(0x08080808, 0x2b192b19), uvec2(0x19082b19, 0x2b192b19), uvec2(0x1919082b, 0x2b192b19), + uvec2(0x2b190808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), uvec2(0x08081919, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), + uvec2(0x08191908, 0x2b2b0808), uvec2(0x082b082b, 0x2b2b0808), uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x19080819, 0x2b2b0808), + uvec2(0x19081908, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b2b082b, 0x2b2b0808), uvec2(0x2b2b2b2b, 0x2b2b0808), + uvec2(0x19080808, 0x2b2b0819), uvec2(0x192b1919, 0x2b2b0819), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b2b, 0x2b2b082b), + uvec2(0x082b082b, 0x2b2b082b), uvec2(0x082b2b08, 0x2b2b082b), uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b08082b, 0x2b2b082b), + uvec2(0x2b082b08, 0x2b2b082b), uvec2(0x2b082b2b, 0x2b2b082b), uvec2(0x2b2b2b08, 0x2b2b082b), uvec2(0x08080819, 0x2b2b1908), + uvec2(0x08081908, 0x2b2b1908), uvec2(0x08190808, 0x2b2b1908), uvec2(0x19080808, 0x2b2b1908), uvec2(0x2b082b19, 0x2b2b1908), + uvec2(0x2b2b1908, 0x2b2b1908), uvec2(0x08080808, 0x2b2b1919), uvec2(0x08192b19, 0x2b2b1919), uvec2(0x19190819, 0x2b2b192b), + uvec2(0x08082b2b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b082b, 0x2b2b2b08), uvec2(0x19191908, 0x2b2b2b19), + uvec2(0x2b08192b, 0x2b2b2b19), uvec2(0x08082b08, 0x2b2b2b2b), uvec2(0x08082b2b, 0x2b2b2b2b), uvec2(0x082b0808, 0x2b2b2b2b), + uvec2(0x082b082b, 0x2b2b2b2b), uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x2b082b08, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b) +}; + +shared uvec2 iq2s_grid[1024]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq2s_grid.length(); i += wgsize.x) { + if (iq2s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2s_grid_const.length()) { + iq2s_grid[i + gl_LocalInvocationIndex.x] = iq2s_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_S +#define QUANT_R QUANT_R_IQ2_S +#define A_TYPE block_iq2_s +#define A_TYPE_PACKED16 block_iq2_s_packed16 +#endif + +#define QUANT_K_IQ3_XXS 256 +#define QUANT_R_IQ3_XXS 1 + +struct block_iq3_xxs +{ + float16_t d; + uint8_t qs[QUANT_K_IQ3_XXS/4 + QUANT_K_IQ3_XXS/8]; +}; + +struct block_iq3_xxs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ3_XXS/8 + QUANT_K_IQ3_XXS/16]; +}; + +#if defined(DATA_A_IQ3_XXS) + +const uint32_t iq3xxs_grid_const[256] = { + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, +}; + +shared uint32_t iq3xxs_grid[256]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq3xxs_grid.length(); i += wgsize.x) { + if (iq3xxs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3xxs_grid.length()) { + iq3xxs_grid[i + gl_LocalInvocationIndex.x] = iq3xxs_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ3_XXS +#define QUANT_R QUANT_R_IQ3_XXS +#define A_TYPE block_iq3_xxs +#define A_TYPE_PACKED16 block_iq3_xxs_packed16 +#endif + +#define QUANT_K_IQ3_S 256 +#define QUANT_R_IQ3_S 1 + +struct block_iq3_s +{ + float16_t d; + uint8_t qs[QUANT_K_IQ3_S/4]; + uint8_t qh[QUANT_K_IQ3_S/32]; + uint8_t signs[QUANT_K_IQ3_S/8]; + uint8_t scales[QUANT_K_IQ3_S/64]; +}; + +struct block_iq3_s_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ3_S/4/2]; + uint16_t qh[QUANT_K_IQ3_S/32/2]; + uint16_t signs[QUANT_K_IQ3_S/8/2]; + uint16_t scales[QUANT_K_IQ3_S/64/2]; +}; + +#if defined(DATA_A_IQ3_S) + +const uint32_t iq3s_grid_const[512] = { + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, +}; + +shared uint32_t iq3s_grid[512]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq3s_grid.length(); i += wgsize.x) { + if (iq3s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3s_grid.length()) { + iq3s_grid[i + gl_LocalInvocationIndex.x] = iq3s_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ3_S +#define QUANT_R QUANT_R_IQ3_S +#define A_TYPE block_iq3_s +#define A_TYPE_PACKED16 block_iq3_s_packed16 +#endif + +#define QUANT_K_IQ4_XS 256 +#define QUANT_R_IQ4_XS 1 + +struct block_iq4_xs +{ + float16_t d; + uint16_t scales_h; + uint8_t scales_l[QUANT_K_IQ4_XS/64]; + uint8_t qs[QUANT_K_IQ4_XS/2]; +}; + +#if defined(DATA_A_IQ4_XS) +#define QUANT_K QUANT_K_IQ4_XS +#define QUANT_R QUANT_R_IQ4_XS +#define A_TYPE block_iq4_xs +#endif + +#define QUANT_K_IQ4_NL 32 +#define QUANT_R_IQ4_NL 2 + +struct block_iq4_nl +{ + float16_t d; + uint8_t qs[QUANT_K_IQ4_NL/2]; +}; + +struct block_iq4_nl_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ4_NL/2/2]; +}; + +#if defined(DATA_A_IQ4_NL) +#define QUANT_K QUANT_K_IQ4_NL +#define QUANT_R QUANT_R_IQ4_NL +#define A_TYPE block_iq4_nl +#define A_TYPE_PACKED16 block_iq4_nl_packed16 +#endif + +#define QUANT_K_MXFP4 32 +#define QUANT_R_MXFP4 2 + +struct block_mxfp4 +{ + uint8_t e; + uint8_t qs[QUANT_K_MXFP4/2]; +}; + +//struct block_mxfp4_packed16 +//{ +// uint8_t e; +// uint16_t qs[QUANT_K_MXFP4/2/2]; +//}; + +#if defined(DATA_A_MXFP4) +#define QUANT_K QUANT_K_MXFP4 +#define QUANT_R QUANT_R_MXFP4 +#define QUANT_AUXF 1 +#define A_TYPE block_mxfp4 +//#define A_TYPE_PACKED16 block_mxfp4_packed16 +#endif + +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) +const int8_t kvalues_iq4nl_const[16] = { + int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), + int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113) +}; + +shared FLOAT_TYPE kvalues_iq4nl[16]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < kvalues_iq4nl.length(); i += wgsize.x) { + kvalues_iq4nl[i] = FLOAT_TYPE(kvalues_iq4nl_const[i]); + } + barrier(); +} +#endif + +#if defined(DATA_A_MXFP4) +const FLOAT_TYPE kvalues_mxfp4_const[16] = { + FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f), + FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f) +}; + +shared FLOAT_TYPE kvalues_mxfp4[16]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) { + kvalues_mxfp4[i] = kvalues_mxfp4_const[i]; + } + barrier(); +} +#endif + +// returns the bfloat value in the low 16b. +// See ggml_compute_fp32_to_bf16 +uint32_t fp32_to_bf16(float f) +{ + uint32_t u = floatBitsToUint(f); + u = (u + (0x7fff + ((u >> 16) & 1))) >> 16; + return u; +} + +float bf16_to_fp32(uint32_t u) +{ + return uintBitsToFloat(u << 16); +} + +vec4 bf16_to_fp32(uvec4 u) +{ + return vec4(bf16_to_fp32(u.x), bf16_to_fp32(u.y), bf16_to_fp32(u.z), bf16_to_fp32(u.w)); +} + +float e8m0_to_fp32(uint8_t x) { + uint32_t bits; + + if (x == 0) { + bits = 0x00400000; + } else { + bits = x; + bits = bits << 23; + } + + return uintBitsToFloat(bits); +} + +#if BDA + +#extension GL_EXT_buffer_reference : enable +#extension GL_EXT_shader_explicit_arithmetic_types_int64 : enable + +#define BDA_STORAGE_T uint64_t +#define BDA_OFFSET_T uint64_t + +#else + +#define BDA_STORAGE_T uvec2 +#define BDA_OFFSET_T uint + +#endif + +#endif // !defined(GGML_TYPES_COMP) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp new file mode 100644 index 00000000..154a2172 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -0,0 +1,100 @@ +#version 450 + +layout (push_constant) uniform parameter +{ + uint ne; uint a_offset; uint d_offset; + uint ne00; uint ne01; + uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; + float sf0; float sf1; float sf2; float sf3; +} p; + +#include "types.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag +#define NEAREST 0 +#define BILINEAR 1 +#define ALIGN_CORNERS (1 << 8) + +layout (constant_id = 0) const uint scale_mode = 0; + +float fetch_nearest(uint i10, uint i11, uint i12, uint i13) { + const uint i00 = uint(i10 / p.sf0); + const uint i01 = uint(i11 / p.sf1); + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); + + return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]; +} + +float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) { + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); + const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02; + + const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00]; + const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00]; + const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00]; + const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00]; + + return + v00 * (1.0-d.x) * (1.0-d.y) + + v01 * d.x * (1.0-d.y) + + v10 * (1.0-d.x) * d.y + + v11 * d.x * d.y; +} + +float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) { + const ivec2 ne0 = ivec2(p.ne00, p.ne01); + + const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5; + const vec2 c0f = floor(c); + const vec2 d = c - c0f; + const ivec2 c0 = max(ivec2(c0f), 0); + const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1); + + return fetch_bilinear(c0, c1, d, i12, i13); +} + +float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) { + const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1); + const vec2 c0f = floor(c); + const vec2 d = c - c0f; + const ivec2 c0 = ivec2(c0f); + const ivec2 c1 = c0 + 1; + + return fetch_bilinear(c0, c1, d, i12, i13); +} + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (idx >= p.ne) { + return; + } + + const uint i10 = idx % p.ne10; + const uint i11 = (idx / p.ne10) % p.ne11; + const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12; + const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13; + + float result; + switch (scale_mode) { + case NEAREST: + result = fetch_nearest(i10, i11, i12, i13); + break; + case BILINEAR: + result = interpolate_bilinear(i10, i11, i12, i13); + break; + case BILINEAR | ALIGN_CORNERS: + result = interpolate_bilinear_align_corners(i10, i11, i12, i13); + break; + } + + data_d[p.d_offset + idx] = D_TYPE(result); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl new file mode 100644 index 00000000..dc4a1e6d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl @@ -0,0 +1,25 @@ +#ifndef UTILS_COMP +#define UTILS_COMP + +// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1 +uint fastmod(uint a, uint b) { + if ((b & (b-1)) == 0) { + return a & (b-1); + } + return a % b; +} + +uint fastdiv(uint a, uint b) { + return (a < b) ? 0 : (a / b); +} + +void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03, uint ne00, uint ne01, uint ne02, uint ne03) { + i03 = fastdiv(idx, (ne02*ne01*ne00)); + const uint i03_offset = i03 * ne02*ne01*ne00; + i02 = fastdiv((idx - i03_offset), (ne01*ne00)); + const uint i02_offset = i02*ne01*ne00; + i01 = (idx - i03_offset - i02_offset) / ne00; + i00 = idx - i03_offset - i02_offset - i01*ne00; +} + +#endif // UTILS_COMP diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp new file mode 100644 index 00000000..f0cc24ff --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -0,0 +1,1097 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 + #define NOMINMAX + #include + #include // For _mkdir on Windows +#else + #include + #include + #include +#endif + +#define ASYNCIO_CONCURRENCY 64 + +std::mutex lock; +std::vector> shader_fnames; +std::locale c_locale("C"); + +std::string GLSLC = "glslc"; +std::string input_filepath = ""; +std::string output_dir = "/tmp"; +std::string target_hpp = ""; +std::string target_cpp = ""; + +const std::vector type_names = { + "f32", + "f16", + "q4_0", + "q4_1", + "q5_0", + "q5_1", + "q8_0", + "q2_k", + "q3_k", + "q4_k", + "q5_k", + "q6_k", + "iq1_s", + "iq1_m", + "iq2_xxs", + "iq2_xs", + "iq2_s", + "iq3_xxs", + "iq3_s", + "iq4_xs", + "iq4_nl", + "mxfp4", + "bf16", +}; + +enum MatMulIdType { + NONE, + DEFAULT, + SUBGROUP, +}; + +namespace { + +void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { +#ifdef _WIN32 + HANDLE stdout_read, stdout_write; + HANDLE stderr_read, stderr_write; + SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE }; + + if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) || + !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) { + throw std::runtime_error("Failed to create stdout pipe"); + } + + if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) || + !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) { + throw std::runtime_error("Failed to create stderr pipe"); + } + + PROCESS_INFORMATION pi; + STARTUPINFOA si = {}; + si.cb = sizeof(STARTUPINFOA); + si.dwFlags = STARTF_USESTDHANDLES; + si.hStdOutput = stdout_write; + si.hStdError = stderr_write; + + std::vector cmd(command.begin(), command.end()); + cmd.push_back('\0'); + + if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) { + throw std::runtime_error("Failed to create process"); + } + + CloseHandle(stdout_write); + CloseHandle(stderr_write); + + std::array buffer; + DWORD bytes_read; + + while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { + stdout_str.append(buffer.data(), bytes_read); + } + + while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { + stderr_str.append(buffer.data(), bytes_read); + } + + CloseHandle(stdout_read); + CloseHandle(stderr_read); + WaitForSingleObject(pi.hProcess, INFINITE); + CloseHandle(pi.hProcess); + CloseHandle(pi.hThread); +#else + int stdout_pipe[2]; + int stderr_pipe[2]; + + if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) { + throw std::runtime_error("Failed to create pipes"); + } + + pid_t pid = fork(); + if (pid < 0) { + throw std::runtime_error("Failed to fork process"); + } + + if (pid == 0) { + close(stdout_pipe[0]); + close(stderr_pipe[0]); + dup2(stdout_pipe[1], STDOUT_FILENO); + dup2(stderr_pipe[1], STDERR_FILENO); + close(stdout_pipe[1]); + close(stderr_pipe[1]); + execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr); + _exit(EXIT_FAILURE); + } else { + close(stdout_pipe[1]); + close(stderr_pipe[1]); + + std::array buffer; + ssize_t bytes_read; + + while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) { + stdout_str.append(buffer.data(), bytes_read); + } + + while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) { + stderr_str.append(buffer.data(), bytes_read); + } + + close(stdout_pipe[0]); + close(stderr_pipe[0]); + waitpid(pid, nullptr, 0); + } +#endif +} + +bool directory_exists(const std::string& path) { + struct stat info; + if (stat(path.c_str(), &info) != 0) { + return false; // Path doesn't exist or can't be accessed + } + return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory +} + +bool create_directory(const std::string& path) { +#ifdef _WIN32 + return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists +#else + return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions +#endif +} + +std::string to_uppercase(const std::string& input) { + std::string result = input; + for (char& c : result) { + c = std::toupper(c); + } + return result; +} + +bool string_starts_with(const std::string& str, const std::string& prefix) { + if (prefix.size() > str.size()) { + return false; + } + return std::equal(prefix.begin(), prefix.end(), str.begin()); +} + +bool string_ends_with(const std::string& str, const std::string& suffix) { + if (suffix.size() > str.size()) { + return false; + } + return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); +} + +bool is_quantized_type(const std::string& type_name) { + return type_name != "f32" && type_name != "f16" && type_name != "bf16"; +} + +bool is_legacy_quant(const std::string& type_name) { + return type_name == "q4_0" || type_name == "q4_1" || type_name == "q5_0" || type_name == "q5_1" || type_name == "q8_0"; +} + +bool is_k_quant(const std::string& type_name) { + return string_ends_with(type_name, "_k"); +} + +bool is_iq_quant(const std::string& type_name) { + return string_starts_with(type_name, "iq"); +} + +static const char path_separator = '/'; + +std::string join_paths(const std::string& path1, const std::string& path2) { + return path1 + path_separator + path2; +} + +std::string basename(const std::string &path) { + return path.substr(path.find_last_of("/\\") + 1); +} + +std::stringstream make_generic_stringstream() { + std::stringstream ss; + ss.imbue(c_locale); + return ss; +} + +std::string read_binary_file(const std::string& path, bool may_not_exist = false) { + FILE* f = fopen(path.c_str(), "rb"); + if (!f) { + if (!may_not_exist) { + std::cerr << "Error opening file: " << path << " (" << strerror(errno) << ")\n"; + } + return {}; + } + + fseek(f, 0, SEEK_END); + size_t size = ftell(f); + fseek(f, 0, SEEK_SET); + + std::string data(size, '\0'); + size_t read_size = fread(data.data(), 1, size, f); + fclose(f); + if (read_size != size) { + std::cerr << "Error reading file: " << path << " (" << strerror(errno) << ")\n"; + return {}; + } + + return data; +} + +void write_binary_file(const std::string& path, const std::string& content) { + FILE* f = fopen(path.c_str(), "wb"); + if (!f) { + std::cerr << "Error opening file for writing: " << path << " (" << strerror(errno) << ")\n"; + return; + } + + size_t write_size = fwrite(content.data(), 1, content.size(), f); + fclose(f); + if (write_size != content.size()) { + std::cerr << "Error writing file: " << path << " (" << strerror(errno) << ")\n"; + return; + } +} + +void write_file_if_changed(const std::string& path, const std::string& content) { + std::string existing = read_binary_file(path, true); + if (existing != content) { + write_binary_file(path, content); + } +} + + +// variables to track number of compiles in progress +static uint32_t compile_count = 0; +static std::mutex compile_count_mutex; +static std::condition_variable compile_count_cond; +static bool generate_dep_file = true; + +void decrement_compile_count(uint32_t * count) { + if (count) { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + compile_count_cond.notify_all(); + } +} + +using compile_count_guard = std::unique_ptr; + +compile_count_guard acquire_compile_slot() { + // wait until fewer than N compiles are in progress. + // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. + uint32_t N = std::max(1u, std::min(16u, std::thread::hardware_concurrency())); + std::unique_lock guard(compile_count_mutex); + compile_count_cond.wait(guard, [N] { return compile_count < N; }); + compile_count++; + return compile_count_guard(&compile_count, &decrement_compile_count); +} + +void string_to_spv_func(std::string name, std::string in_path, std::string out_path, std::map defines, bool coopmat, bool dep_file, compile_count_guard slot) { + std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; + + // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 + // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344 + std::string opt_level = (coopmat || name.find("bf16") != std::string::npos) ? "" : "-O"; + + #ifdef _WIN32 + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""}; + #else + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_path}; + #endif + + if (dep_file) { + cmd.push_back("-MD"); + cmd.push_back("-MF"); + cmd.push_back("\"" + target_cpp + ".d\""); + } + + #ifdef GGML_VULKAN_SHADER_DEBUG_INFO + cmd.push_back("-g"); + #endif + + for (const auto& define : defines) { + cmd.push_back("-D" + define.first + "=" + define.second); + } + + std::string command; + for (const auto& part : cmd) { + command += part + " "; + } + + std::string stdout_str, stderr_str; + try { + // std::cout << "Executing command: "; + // for (const auto& part : cmd) { + // std::cout << part << " "; + // } + // std::cout << std::endl; + + execute_command(command, stdout_str, stderr_str); + if (!stderr_str.empty()) { + std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl; + return; + } + + if (dep_file) { + // replace .spv output path with the embed .cpp path which is used as output in CMakeLists.txt + std::string dep = read_binary_file(target_cpp + ".d", true); + if (!dep.empty()) { + size_t pos = dep.find(out_path); + if (pos != std::string::npos) { + dep.replace(pos, out_path.length(), target_cpp); + } + write_binary_file(target_cpp + ".d", dep); + } + } + + std::lock_guard guard(lock); + shader_fnames.push_back(std::make_pair(name, out_path)); + } catch (const std::exception& e) { + std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl; + } +} + +std::map merge_maps(const std::map& a, const std::map& b) { + std::map result = a; + result.insert(b.begin(), b.end()); + return result; +} + +static std::vector> compiles; +void string_to_spv(std::string name, const std::string& source, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { + name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); + std::string out_path = join_paths(output_dir, name + ".spv"); + + if (input_filepath == "") { + // No input source to compile, only generate header for all shaders + shader_fnames.push_back(std::pair(name, out_path)); + return; + } else if (basename(input_filepath) != source) { + // Only compile shader variants matching the input filename + return; + } + + compile_count_guard slot = acquire_compile_slot(); + compiles.push_back(std::async( + string_to_spv_func, name, input_filepath, out_path, defines, coopmat, generate_dep_file, std::move(slot))); + // Don't write the same dep file from multiple processes + generate_dep_file = false; +} + +void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) { + std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; + std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; + std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; + + std::map base_dict; + std::string shader_name = "matmul"; + + if (matmul_id_type == MatMulIdType::DEFAULT) { + base_dict["MUL_MAT_ID"] = "1"; + shader_name = "matmul_id"; + } else if (matmul_id_type == MatMulIdType::SUBGROUP) { + base_dict["MUL_MAT_ID"] = "1"; + base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1"; + shader_name = "matmul_id_subgroup"; + } + + if (fp16) { + base_dict["FLOAT16"] = "1"; + } + + base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float"; + base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2"; + if (f16acc) { + base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\""; + } + + if (coopmat) { + base_dict["COOPMAT"] = "1"; + } + + const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; + + auto const &FLOAT_TYPE = [&](int vec, const std::string &t) -> std::string { + switch (vec) { + case 1: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "float"; + } + return "bfloat16_t"; + } + if (coopmat2 || fp16) { + return "float16_t"; + } + return "float"; + case 2: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "vec2"; + } + return "bf16vec2"; + } + if (coopmat2 || fp16) { + return "f16vec2"; + } + return "vec2"; + case 4: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "vec4"; + } + return "bf16vec4"; + } + if (coopmat2 || fp16) { + return "f16vec4"; + } + return "vec4"; + case 8: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "mat2x4"; + } + throw std::runtime_error("bf16 vec8 not supported"); + } + if (coopmat2 || fp16) { + return "f16mat2x4"; + } + return "mat2x4"; + default: + throw std::runtime_error("invalid vector size"); + } + }; + + const std::map float_type_dict_f16 = { + {"FLOAT_TYPE", FLOAT_TYPE(1, "f16")}, + {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")}, + {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")}, + {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")}, + }; + + // Shaders with f16 B_TYPE + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + + string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + + // bf16 + { + // For aligned matmul loads + std::string load_vec_a = coopmat2 ? "1" : "4"; + + // scalar path promotes to float + std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32"; + + const std::map float_type_dict_bf16 = { + {"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")}, + {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")}, + {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")}, + }; + + // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader +#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!(coopmat || coopmat2)) +#endif + { + string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + } + + for (const auto& tname : type_names) { + std::string load_vec_quant = "2"; + if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) + load_vec_quant = "8"; + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) + load_vec_quant = "4"; + + if (tname == "bf16") { + continue; + } + + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + // For unaligned, load one at a time for f32/f16, or two at a time for quants + std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant; + // For aligned matmul loads + std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; + + const std::map float_type_dict = { + {"FLOAT_TYPE", FLOAT_TYPE(1, tname)}, + {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)}, + {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)}, + {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)}, + }; + + // don't generate f32 variants for coopmat2 + if (!coopmat2) { + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + + if (tname != "f16" && tname != "f32") { + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) { + string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); + } +#endif + } +} + +void process_shaders() { + std::map base_dict = {{"FLOAT_TYPE", "float"}}; + + // matmul + for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) { + // No coopmats + // fp32 + matmul_shaders(false, matmul_id_type, false, false, false); + + // fp16, fp32acc and fp16acc + matmul_shaders(true, matmul_id_type, false, false, false); + matmul_shaders(true, matmul_id_type, false, false, true); + + if (matmul_id_type != MatMulIdType::DEFAULT) { +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + // Coopmat, fp32acc and fp16acc + matmul_shaders(true, matmul_id_type, true, false, false); + matmul_shaders(true, matmul_id_type, true, false, true); +#endif + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + // Coopmat2, fp32acc and fp16acc + matmul_shaders(true, matmul_id_type, false, true, false); + matmul_shaders(true, matmul_id_type, false, true, true); +#endif + } + } + + // flash attention + for (const auto& f16acc : {false, true}) { + std::map fa_base_dict = base_dict; + fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4"; + if (f16acc) { + fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\""; + } + + for (const auto& tname : type_names) { + if (tname == "f32") { + continue; + } + if (tname == "bf16") continue; + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc); + } else { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); + } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } +#endif + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); + } + } + } + + for (const auto& tname : type_names) { + // mul mat vec + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + + string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); + + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + + string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + + // mul mat vec with integer dot product +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (is_legacy_quant(tname)) { + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + } +#endif + + // Dequant shaders + if (tname != "f16" && tname != "bf16") { + string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); + } + + shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp"; + + if (tname == "f16") { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); + } else { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); + } + string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); + } + + string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); + string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); + + // Norms + string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); + string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); + string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); + string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); + string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); + string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); + + for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + } + + for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + } + + auto get_type_str = [](bool f16) { + return f16 ? "float16_t" : "float"; + }; + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; + s += std::string(src0_f16 ? "_f16" : "_f32"); + s += std::string(src1_f16 ? "_f16" : "_f32"); + s += std::string(dst_f16 ? "_f16" : "_f32"); + return s; + }; + for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) { + for (auto src0_f16 : {false, true}) { + for (auto src1_f16 : {false, true}) { + for (auto dst_f16 : {false, true}) { + for (auto rte : {false, true}) { + auto source = op == "add_rms" ? std::string("add") : op; + auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : ""); + auto add_rms = op == "add_rms" ? "1" : "0"; + string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}}); + } + } + } + } + } + + string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); + string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {}); + + string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); + string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}}); + + string_to_spv("quantize_q8_1_x4", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}}); + string_to_spv("quantize_q8_1_x4_subgroup", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}, {"USE_SUBGROUPS", "1"}}); + + string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}); + + string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + for (auto rte : {false, true}) { + std::string suffix = rte ? "_rte" : ""; + string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}}); + } + string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_erf_f32", "gelu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("hardsigmoid_f16","hardsigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + for (auto rte : {false, true}) { + std::string suffix = rte ? "_rte" : ""; + string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + } + + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + + string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); + string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); + + for (std::string dim_str : {"", "_3d"}) { + for (bool bda : {false, true}) { + std::string bda_str = bda ? "_bda" : ""; + std::string bda_def = bda ? "1" : "0"; + string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}})); + string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}})); + string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}})); + } + } + + string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + + string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + + for (auto transpose : {false, true}) { + for (auto unroll : {false, true}) { + for (auto a_f16 : {false, true}) { + std::map defines = { + {"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, + {"USE_COLLECTIVES", "1"}, {"UNROLL", unroll ? "[[unroll]]" : ""}, + }; + if (transpose) defines["TRANSPOSE"] = "1"; + std::string name = std::string(transpose ? "conv_transpose_2d": "conv2d") + + (a_f16 ? "_f16" : "") + "_f32"; + string_to_spv(name + (unroll ? "_unroll" : ""), "conv2d_mm.comp", defines); +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (unroll) { + defines["COOPMAT2"] = "1"; + string_to_spv(name, "conv2d_mm.comp", defines, true, false, true); + } +#endif + } + } + } + + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); + string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); + string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); + string_to_spv("conv2d_dw_cwhn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); + + string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); + string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); + + for (auto &c : compiles) { + c.wait(); + } +} + +void write_output_files() { + std::stringstream hdr = make_generic_stringstream(); + std::stringstream src = make_generic_stringstream(); + + hdr << "#include \n\n"; + src << "#include \"" << basename(target_hpp) << "\"\n\n"; + + std::sort(shader_fnames.begin(), shader_fnames.end()); + for (const auto& pair : shader_fnames) { + const std::string& name = pair.first; + #ifdef _WIN32 + std::string path = pair.second; + std::replace(path.begin(), path.end(), '/', '\\' ); + #else + const std::string& path = pair.second; + #endif + + hdr << "extern const uint64_t " << name << "_len;\n"; + hdr << "extern const unsigned char " << name << "_data[];\n\n"; + + if (input_filepath != "") { + std::string data = read_binary_file(path); + if (data.empty()) { + continue; + } + + src << "const uint64_t " << name << "_len = " << data.size() << ";\n"; + src << "const unsigned char " << name << "_data[" << data.size() << "] = {\n" << std::hex; + auto bytes = reinterpret_cast(data.data()); + for (size_t i = 0; i < data.size(); ++i) { + src << "0x" << static_cast(bytes[i]) << ","; + if ((i + 1) % 12 == 0) src << "\n"; + } + src << std::dec << "\n};\n\n"; + } + } + + std::string suffixes[2] = {"_f32", "_f16"}; + for (auto op : {"add", "sub", "mul", "div", "add_rms"}) { + hdr << "extern const void * " << op << "_data[2][2][2][2];\n"; + hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n"; + + std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp"; + if (basename(input_filepath) != op_file) { + continue; + } + std::stringstream data = make_generic_stringstream(); + std::stringstream len = make_generic_stringstream(); + data << "const void * " << op << "_data[2][2][2][2] = "; + len << "const uint64_t " << op << "_len[2][2][2][2] = "; + for (uint32_t t0 = 0; t0 < 2; ++t0) { + if (t0 == 0) { + data << "{"; + len << "{"; + } + for (uint32_t t1 = 0; t1 < 2; ++t1) { + if (t1 == 0) { + data << "{"; + len << "{"; + } + for (uint32_t t2 = 0; t2 < 2; ++t2) { + if (t2 == 0) { + data << "{"; + len << "{"; + } + for (uint32_t rte = 0; rte < 2; ++rte) { + if (rte == 0) { + data << "{"; + len << "{"; + } + data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); + len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); + data << "_data,"; + len << "_len,"; + if (rte == 1) { + data << "}, "; + len << "}, "; + } + } + if (t2 == 1) { + data << "}, "; + len << "}, "; + } + } + if (t1 == 1) { + data << "}, "; + len << "}, "; + } + } + if (t0 == 1) { + data << "};\n"; + len << "};\n"; + } + } + src << data.str(); + src << len.str(); + } + + std::vector btypes = {"f16", "f32"}; + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + btypes.push_back("q8_1"); +#endif + + for (const std::string& btype : btypes) { + for (const auto& tname : type_names) { + if (btype == "q8_1" && !is_legacy_quant(tname)) { + continue; + } + hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; + hdr << "extern const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3];\n"; + if (basename(input_filepath) == "mul_mat_vec.comp") { + src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n"; + src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; + } + } + } + + if (input_filepath == "") { + write_file_if_changed(target_hpp, hdr.str()); + } + if (target_cpp != "") { + write_binary_file(target_cpp, src.str()); + } +} + +} // namespace + +int main(int argc, char** argv) { + std::map args; + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg.rfind("--", 0) == 0) { + if (i + 1 < argc && argv[i + 1][0] != '-') { + args[arg] = argv[i + 1]; + ++i; + } else { + args[arg] = ""; + } + } + } + + if (args.find("--glslc") != args.end()) { + GLSLC = args["--glslc"]; // Path to glslc + } + if (args.find("--source") != args.end()) { + input_filepath = args["--source"]; // The shader source file to compile + } + if (args.find("--output-dir") != args.end()) { + output_dir = args["--output-dir"]; // Directory for containing SPIR-V output + } + if (args.find("--target-hpp") != args.end()) { + target_hpp = args["--target-hpp"]; // Path to generated header file + } + if (args.find("--target-cpp") != args.end()) { + target_cpp = args["--target-cpp"]; // Path to generated cpp file + } + + if (!directory_exists(output_dir)) { + if (!create_directory(output_dir)) { + std::cerr << "Error creating output directory: " << output_dir << "\n"; + return EXIT_FAILURE; + } + } + + process_shaders(); + + write_output_files(); + + return EXIT_SUCCESS; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp new file mode 100644 index 00000000..35cc6c45 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp @@ -0,0 +1,87 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; + uint T; + uint C; + uint H; +}; + +layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; }; +layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; }; +layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 6) buffer DstBuf { A_TYPE dst[]; }; + +shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE]; + +void main() { + const uint head_size = BLOCK_SIZE; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (batch_id >= B || head_id >= H) { + return; + } + + A_TYPE state[BLOCK_SIZE]; + [[unroll]] for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid]; + } + + barrier(); + _tf[tid] = tf[head_id * head_size + tid]; + barrier(); + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + barrier(); + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + vec4 kv = k_vec * v_val; + + vec4 temp = tf_vec * kv + s_vec; + y += dot(r_vec, temp); + + s_vec = s_vec * td_vec + kv; + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + [[unroll]] for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid] = state[i]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp new file mode 100644 index 00000000..88c1c02b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp @@ -0,0 +1,91 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; + uint T; + uint C; + uint H; +}; + +layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; }; +layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; }; +layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; }; +layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 7) buffer DstBuf { A_TYPE dst[]; }; + +shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE]; + +void main() { + const uint head_size = BLOCK_SIZE; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (batch_id >= B || head_id >= H) { + return; + } + + A_TYPE state[BLOCK_SIZE]; + [[unroll]] for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i]; + } + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + barrier(); + + A_TYPE sa = 0.0; + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]); + sa += dot(s_vec, a_vec); + } + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]); + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + vec4 kv = k_vec * v_val; + s_vec = s_vec * w_vec + kv + sa * b_vec; + y += dot(r_vec, s_vec); + + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + [[unroll]] for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i] = state[i]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml.c b/ml/backend/ggml/ggml/src/ggml.c index 55a76f82..2bce1375 100644 --- a/ml/backend/ggml/ggml/src/ggml.c +++ b/ml/backend/ggml/ggml/src/ggml.c @@ -974,7 +974,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CONV_TRANSPOSE_1D", "IM2COL", "IM2COL_BACK", + "IM2COL_3D", "CONV_2D", + "CONV_3D", "CONV_2D_DW", "CONV_TRANSPOSE_2D", "POOL_1D", @@ -1012,11 +1014,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", "OPT_STEP_ADAMW", + "OPT_STEP_SGD", "GLU", }; -static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1075,7 +1078,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "conv_transpose_1d(x)", "im2col(x)", "im2col_back(x)", + "im2col_3d(x)", "conv_2d(x)", + "conv_3d(x)", "conv_2d_dw(x)", "conv_transpose_2d(x)", "pool_1d(x)", @@ -1113,15 +1118,15 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", "adamw(x)", + "sgd(x)", "glu(x)", }; -static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); - static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "ABS", "SGN", @@ -1138,10 +1143,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSIGMOID", "EXP", "GELU_ERF", + "XIELU", }; -static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); - +static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16"); static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", @@ -2647,6 +2652,29 @@ struct ggml_tensor * ggml_silu_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU); } +// ggml_xielu + +struct ggml_tensor * ggml_xielu( + struct ggml_context * ctx, + struct ggml_tensor * a, + float alpha_n, + float alpha_p, + float beta, + float eps) { + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU); + ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n)); + ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p)); + ggml_set_op_params_f32(result, 3, beta); + ggml_set_op_params_f32(result, 4, eps); + + result->op = GGML_OP_UNARY; + result->src[0] = a; + + return result; +} + // ggml_silu_back struct ggml_tensor * ggml_silu_back( @@ -3618,6 +3646,7 @@ struct ggml_tensor * ggml_get_rows( struct ggml_tensor * a, struct ggml_tensor * b) { GGML_ASSERT(a->ne[2] == b->ne[1]); + GGML_ASSERT(a->ne[3] == b->ne[2]); GGML_ASSERT(b->ne[3] == 1); GGML_ASSERT(b->type == GGML_TYPE_I32); @@ -3671,7 +3700,7 @@ struct ggml_tensor * ggml_set_rows( GGML_ASSERT(b->ne[3] % c->ne[2] == 0); GGML_ASSERT(c->ne[3] == 1); GGML_ASSERT(b->type == GGML_TYPE_F32); - GGML_ASSERT(c->type == GGML_TYPE_I64); + GGML_ASSERT(c->type == GGML_TYPE_I64 || c->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_contiguous_rows(a)); GGML_ASSERT(ggml_is_contiguous_rows(b)); @@ -3681,6 +3710,7 @@ struct ggml_tensor * ggml_set_rows( result->op = GGML_OP_SET_ROWS; result->src[0] = b; result->src[1] = c; + result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931) return result; } @@ -3822,6 +3852,15 @@ struct ggml_tensor * ggml_soft_max_ext( return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } +struct ggml_tensor * ggml_soft_max_ext_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias) { + return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true); +} + void ggml_soft_max_add_sinks( struct ggml_tensor * a, struct ggml_tensor * sinks) { @@ -3885,6 +3924,7 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * b, struct ggml_tensor * c, int n_dims, + int sections[GGML_MROPE_SECTIONS], int mode, int n_ctx_orig, float freq_base, @@ -3898,15 +3938,19 @@ static struct ggml_tensor * ggml_rope_impl( GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(b->type == GGML_TYPE_I32); - GGML_ASSERT(a->ne[2] == b->ne[0]); + + bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; + if (mrope_used) { + GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token + } else { + GGML_ASSERT(a->ne[2] == b->ne[0]); + } if (c) { GGML_ASSERT(c->type == GGML_TYPE_F32); GGML_ASSERT(c->ne[0] >= n_dims / 2); } - int sections[4] = {0, 0, 0, 0}; - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; @@ -3916,7 +3960,11 @@ static struct ggml_tensor * ggml_rope_impl( memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float)); - memcpy(params + 11, §ions, sizeof(int)*4); + if (mrope_used && sections) { + memcpy(params + 11, sections, sizeof(int32_t) * GGML_MROPE_SECTIONS); + } else { + memset(params + 11, 0, sizeof(int32_t) * GGML_MROPE_SECTIONS); + } ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE; @@ -3934,7 +3982,7 @@ struct ggml_tensor * ggml_rope( int n_dims, int mode) { return ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false + ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false ); } @@ -3944,7 +3992,7 @@ struct ggml_tensor * ggml_rope_multi( struct ggml_tensor * b, struct ggml_tensor * c, int n_dims, - int sections[4], + int sections[GGML_MROPE_SECTIONS], int mode, int n_ctx_orig, float freq_base, @@ -3953,36 +4001,31 @@ struct ggml_tensor * ggml_rope_multi( float attn_factor, float beta_fast, float beta_slow) { - // Multimodal Rotary Position Embedding - GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); + return ggml_rope_impl( + ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, false + ); +} - GGML_ASSERT(ggml_is_vector(b)); - GGML_ASSERT(b->type == GGML_TYPE_I32); - GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token - - if (c) { - GGML_ASSERT(c->type == GGML_TYPE_F32); - GGML_ASSERT(c->ne[0] >= n_dims / 2); - } - - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - - int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; - memcpy(params + 5, &freq_base, sizeof(float)); - memcpy(params + 6, &freq_scale, sizeof(float)); - memcpy(params + 7, &ext_factor, sizeof(float)); - memcpy(params + 8, &attn_factor, sizeof(float)); - memcpy(params + 9, &beta_fast, sizeof(float)); - memcpy(params + 10, &beta_slow, sizeof(float)); - memcpy(¶ms[11], sections, sizeof(int)*4); - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_ROPE; - result->src[0] = a; - result->src[1] = b; - result->src[2] = c; - - return result; +struct ggml_tensor * ggml_rope_multi_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[GGML_MROPE_SECTIONS], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, true + ); } struct ggml_tensor * ggml_rope_inplace( @@ -3992,7 +4035,7 @@ struct ggml_tensor * ggml_rope_inplace( int n_dims, int mode) { return ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true + ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true ); } @@ -4011,7 +4054,7 @@ struct ggml_tensor * ggml_rope_ext( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, false ); } @@ -4031,7 +4074,7 @@ struct ggml_tensor * ggml_rope_ext_inplace( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, true ); } @@ -4050,7 +4093,7 @@ struct ggml_tensor * ggml_rope_custom( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, false ); } @@ -4069,7 +4112,7 @@ struct ggml_tensor * ggml_rope_custom_inplace( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, true ); } @@ -4267,14 +4310,13 @@ struct ggml_tensor * ggml_conv_1d_dw( int s0, int p0, int d0) { - struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]); struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]); - struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); + struct ggml_tensor * im2col = ggml_im2col(ctx, a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a); - result = ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1); + result = ggml_reshape_3d(ctx, result, result->ne[0], result->ne[2], 1); return result; } @@ -4355,6 +4397,91 @@ struct ggml_tensor * ggml_conv_2d( return result; } +// a: [OC*IC, KD, KH, KW] +// b: [N*IC, ID, IH, IW] +// result: [N*OD, OH, OW, IC * KD * KH * KW] +struct ggml_tensor * ggml_im2col_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t IC, + int s0, // stride width + int s1, // stride height + int s2, // stride depth + int p0, // padding width + int p1, // padding height + int p2, // padding depth + int d0, // dilation width + int d1, // dilation height + int d2, // dilation depth + enum ggml_type dst_type) { + const int64_t N = b->ne[3] / IC; + const int64_t ID = b->ne[2]; + const int64_t IH = b->ne[1]; + const int64_t IW = b->ne[0]; + + const int64_t OC = a->ne[3] / IC; + UNUSED(OC); + const int64_t KD = a->ne[2]; + const int64_t KH = a->ne[1]; + const int64_t KW = a->ne[0]; + const int64_t OD = ggml_calc_conv_output_size(ID, KD, s2, p2, d2); + const int64_t OH = ggml_calc_conv_output_size(IH, KH, s1, p1, d1); + const int64_t OW = ggml_calc_conv_output_size(IW, KW, s0, p0, d0); + + GGML_ASSERT((OD > 0) && "b too small compared to a"); + GGML_ASSERT((OH > 0) && "b too small compared to a"); + GGML_ASSERT((OW > 0) && "b too small compared to a"); + + + const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N}; + + struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne); + int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC}; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_IM2COL_3D; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// a: [OC*IC, KD, KH, KW] +// b: [N*IC, ID, IH, IW] +// result: [N*OC, OD, OH, OW] +struct ggml_tensor * ggml_conv_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t IC, + int s0, // stride width + int s1, // stride height + int s2, // stride depth + int p0, // padding width + int p1, // padding height + int p2, // padding depth + int d0, // dilation width + int d1, // dilation height + int d2 // dilation depth + ) { + struct ggml_tensor * im2col = ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW] + + int64_t OC = a->ne[3] / IC; + int64_t N = b->ne[3] / IC; + struct ggml_tensor * result = + ggml_mul_mat(ctx, + ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW] + ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC)); // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW] + + int64_t OD = im2col->ne[3] / N; + result = ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW] + result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW] + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW] + + return result; +} + // ggml_conv_2d_sk_p0 struct ggml_tensor * ggml_conv_2d_sk_p0( @@ -4476,6 +4603,56 @@ struct ggml_tensor * ggml_conv_2d_direct( return result; } +// ggml_conv_3d_direct + +struct ggml_tensor * ggml_conv_3d_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int s2, + int p0, + int p1, + int p2, + int d0, + int d1, + int d2, + int c, + int n, + int oc) { + + GGML_ASSERT(a->ne[3] == (int64_t) c * oc); + GGML_ASSERT(b->ne[3] == (int64_t) c * n); + + int64_t ne[4]; + ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); + ne[2] = ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2); + ne[3] = (int64_t) oc * n; + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_set_op_params_i32(result, 0, s0); + ggml_set_op_params_i32(result, 1, s1); + ggml_set_op_params_i32(result, 2, s2); + ggml_set_op_params_i32(result, 3, p0); + ggml_set_op_params_i32(result, 4, p1); + ggml_set_op_params_i32(result, 5, p2); + ggml_set_op_params_i32(result, 6, d0); + ggml_set_op_params_i32(result, 7, d1); + ggml_set_op_params_i32(result, 8, d2); + ggml_set_op_params_i32(result, 9, c); + ggml_set_op_params_i32(result, 10, n); + ggml_set_op_params_i32(result, 11, oc); + + result->op = GGML_OP_CONV_3D; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_conv_transpose_2d_p0 static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { @@ -4654,11 +4831,36 @@ struct ggml_tensor * ggml_pad( int p1, int p2, int p3) { + return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3); +} + +struct ggml_tensor * ggml_pad_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + int lp0, + int rp0, + int lp1, + int rp1, + int lp2, + int rp2, + int lp3, + int rp3 + ) { struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, - a->ne[0] + p0, - a->ne[1] + p1, - a->ne[2] + p2, - a->ne[3] + p3); + a->ne[0] + lp0 + rp0, + a->ne[1] + lp1 + rp1, + a->ne[2] + lp2 + rp2, + a->ne[3] + lp3 + rp3); + + ggml_set_op_params_i32(result, 0, lp0); + ggml_set_op_params_i32(result, 1, rp0); + ggml_set_op_params_i32(result, 2, lp1); + ggml_set_op_params_i32(result, 3, rp1); + ggml_set_op_params_i32(result, 4, lp2); + ggml_set_op_params_i32(result, 5, rp2); + ggml_set_op_params_i32(result, 6, lp3); + ggml_set_op_params_i32(result, 7, rp3); + result->op = GGML_OP_PAD; result->src[0] = a; @@ -4754,12 +4956,8 @@ struct ggml_tensor * ggml_timestep_embedding( struct ggml_tensor * timesteps, int dim, int max_period) { - int actual_dim = dim; - if (dim % 2 != 0) { - actual_dim = dim + 1; - } - struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]); + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, timesteps->ne[0]); ggml_set_op_params_i32(result, 0, dim); ggml_set_op_params_i32(result, 1, max_period); @@ -5602,6 +5800,28 @@ struct ggml_tensor * ggml_opt_step_adamw( return result; } +// opt_step_sgd + +struct ggml_tensor * ggml_opt_step_sgd( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * grad, + struct ggml_tensor * params) { + GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM); + GGML_ASSERT(ggml_are_same_shape(a, grad)); + GGML_ASSERT(params->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_nelements(params) == 2); + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->op = GGML_OP_OPT_STEP_SGD; + result->src[0] = a; + result->src[1] = grad; + result->src[2] = params; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) { diff --git a/ml/backend/ggml/ggml/src/ggml.go b/ml/backend/ggml/ggml/src/ggml.go index 37347807..7e215916 100644 --- a/ml/backend/ggml/ggml/src/ggml.go +++ b/ml/backend/ggml/ggml/src/ggml.go @@ -75,9 +75,9 @@ var OnceLoad = sync.OnceFunc(func() { paths = value } - split := filepath.SplitList(paths) - visited := make(map[string]struct{}, len(split)) - for _, path := range split { + libPaths = filepath.SplitList(paths) + visited := make(map[string]struct{}, len(libPaths)) + for _, path := range libPaths { abspath, err := filepath.Abs(path) if err != nil { slog.Error("failed to get absolute path", "error", err) @@ -104,6 +104,12 @@ var OnceLoad = sync.OnceFunc(func() { slog.Info("system", "", system{}) }) +var libPaths []string + +func LibPaths() []string { + return libPaths +} + type system struct{} func (system) LogValue() slog.Value { diff --git a/ml/backend/ggml/ggml/src/gguf.cpp b/ml/backend/ggml/ggml/src/gguf.cpp index 0f71d5f3..d950dbdf 100644 --- a/ml/backend/ggml/ggml/src/gguf.cpp +++ b/ml/backend/ggml/ggml/src/gguf.cpp @@ -273,7 +273,7 @@ struct gguf_reader { } bool read(std::string & dst) const { - uint64_t size = -1; + uint64_t size = 0; if (!read(size)) { return false; } @@ -523,7 +523,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // tensor shape { - uint32_t n_dims = -1; + uint32_t n_dims = 0; ok = ok && gr.read(n_dims); if (n_dims > GGML_MAX_DIMS) { GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", @@ -1169,50 +1169,51 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const } -struct gguf_writer { - std::vector & buf; +struct gguf_writer_base { + size_t written_bytes {0u}; - gguf_writer(std::vector & buf) : buf(buf) {} + ~gguf_writer_base(void) {} + + // we bet on devirtualization + virtual void write(int8_t val) = 0; + virtual void write(const std::vector & val) = 0; + virtual void write_tensor_data(const struct gguf_tensor_info & info, size_t offset_data, size_t alignment) = 0; template - void write(const T & val) const { + void write(const T & val) { for (size_t i = 0; i < sizeof(val); ++i) { - buf.push_back(reinterpret_cast(&val)[i]); + write(reinterpret_cast(&val)[i]); } } - void write(const std::vector & val) const { - buf.insert(buf.end(), val.begin(), val.end()); - } - - void write(const bool & val) const { + void write(const bool & val) { const int8_t val8 = val ? 1 : 0; write(val8); } - void write(const std::string & val) const { + void write(const std::string & val) { { const uint64_t n = val.length(); write(n); } for (size_t i = 0; i < val.length(); ++i) { - buf.push_back(reinterpret_cast(val.data())[i]); + write((val.data())[i]); } } - void write(const char * val) const { + void write(const char * val) { write(std::string(val)); } - void write(const enum ggml_type & val) const { + void write(const enum ggml_type & val) { write(int32_t(val)); } - void write(const enum gguf_type & val) const { + void write(const enum gguf_type & val) { write(int32_t(val)); } - void write(const struct gguf_kv & kv) const { + void write(const struct gguf_kv & kv) { const uint64_t ne = kv.get_ne(); write(kv.get_key()); @@ -1253,7 +1254,7 @@ struct gguf_writer { } } - void write_tensor_meta(const struct gguf_tensor_info & info) const { + void write_tensor_meta(const struct gguf_tensor_info & info) { write(info.t.name); const uint32_t n_dims = ggml_n_dims(&info.t); @@ -1266,14 +1267,33 @@ struct gguf_writer { write(info.offset); } - void pad(const size_t alignment) const { - while (buf.size() % alignment != 0) { + void pad(const size_t alignment) { + while (written_bytes % alignment != 0) { const int8_t zero = 0; write(zero); } } +}; - void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const { +// vector buffer based writer +struct gguf_writer_buf final : public gguf_writer_base { + std::vector & buf; + + gguf_writer_buf(std::vector & buf) : buf(buf) {} + + using gguf_writer_base::write; + + void write(const int8_t val) override { + buf.push_back(val); + written_bytes++; + } + + void write(const std::vector & val) override { + buf.insert(buf.end(), val.begin(), val.end()); + written_bytes += val.size(); + } + + void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override { GGML_ASSERT(buf.size() - offset_data == info.offset); GGML_ASSERT(ggml_is_contiguous(&info.t)); @@ -1287,14 +1307,58 @@ struct gguf_writer { GGML_ASSERT(info.t.data); memcpy(buf.data() + offset, info.t.data, nbytes); } + written_bytes += nbytes; pad(alignment); } }; -void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta) { - const struct gguf_writer gw(buf); +// file based writer +struct gguf_writer_file final : public gguf_writer_base { + FILE * file; + gguf_writer_file(FILE* file) : file(file) {} + + using gguf_writer_base::write; + + void write(const int8_t val) override { + const auto real_val = static_cast(val); + const auto ret = fputc(real_val, file); + written_bytes++; + if (ret != real_val) { + throw std::runtime_error("unexpected fputc result '" + std::to_string(ret) + "' instead of '" + std::to_string((int)real_val) + "'"); + } + } + + void write(const std::vector & val) override { + const auto ret = fwrite(val.data(), 1, val.size(), file); + written_bytes += val.size(); + if (ret != val.size()) { + throw std::runtime_error("unexpected fwrite number of bytes written, '" + std::to_string(ret) + "' instead of '" + std::to_string(val.size()) + "'"); + } + } + + void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override { + GGML_ASSERT(written_bytes - offset_data == info.offset); + + GGML_ASSERT(ggml_is_contiguous(&info.t)); + const size_t nbytes = ggml_nbytes(&info.t); + + std::vector buf(nbytes); + if (info.t.buffer) { + ggml_backend_tensor_get(&info.t, buf.data(), 0, nbytes); + } else { + GGML_ASSERT(info.t.data); + memcpy(buf.data(), info.t.data, nbytes); + } + write(buf); + + pad(alignment); + } +}; + +template +static void gguf_write_out(const struct gguf_context * ctx, writer_t & gw, bool only_meta) { const int64_t n_kv = gguf_get_n_kv(ctx); const int64_t n_tensors = gguf_get_n_tensors(ctx); @@ -1324,7 +1388,7 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & bu return; } - const size_t offset_data = gw.buf.size(); + const size_t offset_data = gw.written_bytes; // write tensor data for (int64_t i = 0; i < n_tensors; ++i) { @@ -1332,6 +1396,11 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & bu } } +void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta) { + gguf_writer_buf gw(buf); + gguf_write_out(ctx, gw, only_meta); +} + bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { FILE * file = ggml_fopen(fname, "wb"); @@ -1340,11 +1409,17 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo return false; } - std::vector buf; - gguf_write_to_buf(ctx, buf, only_meta); - const bool ok = fwrite(buf.data(), 1, buf.size(), file) == buf.size(); + try { + gguf_writer_file gw(file); + gguf_write_out(ctx, gw, only_meta); + } catch (const std::runtime_error& ex) { + GGML_LOG_ERROR("%s: failed to write GGUF data into '%s': %s\n", __func__, fname, ex.what()); + fclose(file); + return false; + } + fclose(file); - return ok; + return true; } size_t gguf_get_meta_size(const struct gguf_context * ctx) { diff --git a/ml/backend/ggml/ggml/src/mem_hip.cpp b/ml/backend/ggml/ggml/src/mem_hip.cpp new file mode 100644 index 00000000..8ef19b8c --- /dev/null +++ b/ml/backend/ggml/ggml/src/mem_hip.cpp @@ -0,0 +1,449 @@ +#include "ggml.h" + +#ifdef _WIN32 +// AMD Device Library eXtra (ADLX) +// +// https://github.com/GPUOpen-LibrariesAndSDKs/ADLX +// +// This Windows-only library provides accurate VRAM reporting for AMD GPUs. +// The runtime DLL is installed with every AMD Driver on Windows, however +// the SDK isn't a part of the HIP SDK packaging. As such, we avoid including +// the headers from the SDK to simplify building from source. +// +// ADLX relies heavily on function pointer tables. +// Only the minimal set of types are defined below to facilitate +// finding the target AMD GPU(s) and querying their current VRAM usage +// Unused function parameters are commented out to avoid unnecessary type +// definitions. + +#include "ggml-impl.h" +#include +#include + +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include + +namespace fs = std::filesystem; + +#include +#include + +// Begin minimal ADLX definitions - derived from tag v1.0 (Dec 2022) +typedef uint64_t adlx_uint64; +typedef uint32_t adlx_uint32; +typedef int32_t adlx_int32; +typedef adlx_int32 adlx_int; +typedef adlx_uint32 adlx_uint; +typedef long adlx_long; +typedef uint8_t adlx_uint8; +typedef enum +{ + ADLX_OK = 0, /**< @ENG_START_DOX This result indicates success. @ENG_END_DOX */ + ADLX_ALREADY_ENABLED, /**< @ENG_START_DOX This result indicates that the asked action is already enabled. @ENG_END_DOX */ + ADLX_ALREADY_INITIALIZED, /**< @ENG_START_DOX This result indicates that ADLX has a unspecified type of initialization. @ENG_END_DOX */ + ADLX_FAIL, /**< @ENG_START_DOX This result indicates an unspecified failure. @ENG_END_DOX */ + ADLX_INVALID_ARGS, /**< @ENG_START_DOX This result indicates that the arguments are invalid. @ENG_END_DOX */ + ADLX_BAD_VER, /**< @ENG_START_DOX This result indicates that the asked version is incompatible with the current version. @ENG_END_DOX */ + ADLX_UNKNOWN_INTERFACE, /**< @ENG_START_DOX This result indicates that an unknown interface was asked. @ENG_END_DOX */ + ADLX_TERMINATED, /**< @ENG_START_DOX This result indicates that the calls were made in an interface after ADLX was terminated. @ENG_END_DOX */ + ADLX_ADL_INIT_ERROR, /**< @ENG_START_DOX This result indicates that the ADL initialization failed. @ENG_END_DOX */ + ADLX_NOT_FOUND, /**< @ENG_START_DOX This result indicates that the item is not found. @ENG_END_DOX */ + ADLX_INVALID_OBJECT, /**< @ENG_START_DOX This result indicates that the method was called into an invalid object. @ENG_END_DOX */ + ADLX_ORPHAN_OBJECTS, /**< @ENG_START_DOX This result indicates that ADLX was terminated with outstanding ADLX objects. Any interface obtained from ADLX points to invalid memory and calls in their methods will result in unexpected behavior. @ENG_END_DOX */ + ADLX_NOT_SUPPORTED, /**< @ENG_START_DOX This result indicates that the asked feature is not supported. @ENG_END_DOX */ + ADLX_PENDING_OPERATION, /**< @ENG_START_DOX This result indicates a failure due to an operation currently in progress. @ENG_END_DOX */ + ADLX_GPU_INACTIVE /**< @ENG_START_DOX This result indicates that the GPU is inactive. @ENG_END_DOX */ +} ADLX_RESULT; +#define ADLX_SUCCEEDED(x) (ADLX_OK == (x) || ADLX_ALREADY_ENABLED == (x) || ADLX_ALREADY_INITIALIZED == (x)) +#define ADLX_FAILED(x) (ADLX_OK != (x) && ADLX_ALREADY_ENABLED != (x) && ADLX_ALREADY_INITIALIZED != (x)) +#define ADLX_VER_MAJOR 1 +#define ADLX_VER_MINOR 0 +#define ADLX_VER_RELEASE 5 +#define ADLX_VER_BUILD_NUM 30 +#define ADLX_MAKE_FULL_VER(VERSION_MAJOR, VERSION_MINOR, VERSION_RELEASE, VERSION_BUILD_NUM) ( ((adlx_uint64)(VERSION_MAJOR) << 48ull) | ((adlx_uint64)(VERSION_MINOR) << 32ull) | ((adlx_uint64)(VERSION_RELEASE) << 16ull) | (adlx_uint64)(VERSION_BUILD_NUM)) +#define ADLX_FULL_VERSION ADLX_MAKE_FULL_VER(ADLX_VER_MAJOR, ADLX_VER_MINOR, ADLX_VER_RELEASE, ADLX_VER_BUILD_NUM) +#define ADLX_CORE_LINK __declspec(dllexport) +#define ADLX_STD_CALL __stdcall +#define ADLX_CDECL_CALL __cdecl +#define ADLX_FAST_CALL __fastcall +#define ADLX_INLINE __inline +#define ADLX_FORCEINLINE __forceinline +#define ADLX_NO_VTABLE __declspec(novtable) + +#if defined(__cplusplus) +typedef bool adlx_bool; +#else +typedef adlx_uint8 adlx_bool; +#define true 1 +#define false 0 +#endif + +typedef struct IADLXSystem IADLXSystem; +typedef struct IADLXGPUList IADLXGPUList; +typedef struct IADLXGPU IADLXGPU; +typedef struct IADLXInterface IADLXInterface; +typedef struct IADLXPerformanceMonitoringServices IADLXPerformanceMonitoringServices; +typedef struct IADLXGPUMetrics IADLXGPUMetrics; +typedef struct IADLXGPUMetricsSupport IADLXGPUMetricsSupport; + +typedef struct IADLXSystemVtbl +{ + // IADLXSystem interface + ADLX_RESULT (ADLX_STD_CALL *GetHybridGraphicsType)(/* IADLXSystem* pThis, ADLX_HG_TYPE* hgType */); + ADLX_RESULT (ADLX_STD_CALL *GetGPUs)(IADLXSystem* pThis, IADLXGPUList** ppGPUs); // Used + ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXSystem* pThis, const wchar_t* interfaceId, void** ppInterface */); + ADLX_RESULT (ADLX_STD_CALL *GetDisplaysServices)(/* IADLXSystem* pThis, IADLXDisplayServices** ppDispServices */); + ADLX_RESULT (ADLX_STD_CALL *GetDesktopsServices)(/* IADLXSystem* pThis, IADLXDesktopServices** ppDeskServices */); + ADLX_RESULT (ADLX_STD_CALL *GetGPUsChangedHandling)(/* IADLXSystem* pThis, IADLXGPUsChangedHandling** ppGPUsChangedHandling */); + ADLX_RESULT (ADLX_STD_CALL *EnableLog)(/* IADLXSystem* pThis, ADLX_LOG_DESTINATION mode, ADLX_LOG_SEVERITY severity, IADLXLog* pLogger, const wchar_t* fileName */); + ADLX_RESULT (ADLX_STD_CALL *Get3DSettingsServices)(/* IADLXSystem* pThis, IADLX3DSettingsServices** pp3DSettingsServices */); + ADLX_RESULT (ADLX_STD_CALL *GetGPUTuningServices)(/* IADLXSystem* pThis, IADLXGPUTuningServices** ppGPUTuningServices */); + ADLX_RESULT (ADLX_STD_CALL *GetPerformanceMonitoringServices)(IADLXSystem* pThis, IADLXPerformanceMonitoringServices** ppPerformanceMonitoringServices); // Used + ADLX_RESULT (ADLX_STD_CALL *TotalSystemRAM)(/* IADLXSystem* pThis, adlx_uint* ramMB */); + ADLX_RESULT (ADLX_STD_CALL *GetI2C)(/* IADLXSystem* pThis, IADLXGPU* pGPU, IADLXI2C** ppI2C */); +} IADLXSystemVtbl; +struct IADLXSystem { const IADLXSystemVtbl *pVtbl; }; + +typedef struct IADLXGPUVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXGPU* pThis */); + adlx_long (ADLX_STD_CALL *Release)(IADLXGPU* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXGPU* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXGPU + ADLX_RESULT (ADLX_STD_CALL *VendorId)(/* IADLXGPU* pThis, const char** vendorId */); + ADLX_RESULT (ADLX_STD_CALL *ASICFamilyType)(/* IADLXGPU* pThis, ADLX_ASIC_FAMILY_TYPE* asicFamilyType */); + ADLX_RESULT (ADLX_STD_CALL *Type)(/* IADLXGPU* pThis, ADLX_GPU_TYPE* gpuType */); + ADLX_RESULT (ADLX_STD_CALL *IsExternal)(/* IADLXGPU* pThis, adlx_bool* isExternal */); + ADLX_RESULT (ADLX_STD_CALL *Name)(/* IADLXGPU* pThis, const char** gpuName */); + ADLX_RESULT (ADLX_STD_CALL *DriverPath)(/* IADLXGPU* pThis, const char** driverPath */); + ADLX_RESULT (ADLX_STD_CALL *PNPString)(/* IADLXGPU* pThis, const char** pnpString */); + ADLX_RESULT (ADLX_STD_CALL *HasDesktops)(/* IADLXGPU* pThis, adlx_bool* hasDesktops */); + ADLX_RESULT (ADLX_STD_CALL *TotalVRAM)(IADLXGPU* pThis, adlx_uint* vramMB); // Used + ADLX_RESULT (ADLX_STD_CALL *VRAMType)(/* IADLXGPU* pThis, const char** type */); + ADLX_RESULT (ADLX_STD_CALL *BIOSInfo)(/* IADLXGPU* pThis, const char** partNumber, const char** version, const char** date */); + ADLX_RESULT (ADLX_STD_CALL *DeviceId)(/* IADLXGPU* pThis, const char** deviceId */); + ADLX_RESULT (ADLX_STD_CALL *RevisionId)(/* IADLXGPU* pThis, const char** revisionId */); + ADLX_RESULT (ADLX_STD_CALL *SubSystemId)(/* IADLXGPU* pThis, const char** subSystemId */); + ADLX_RESULT (ADLX_STD_CALL *SubSystemVendorId)(/* IADLXGPU* pThis, const char** subSystemVendorId */); + ADLX_RESULT (ADLX_STD_CALL *UniqueId)(IADLXGPU* pThis, adlx_int* uniqueId); // Used +} IADLXGPUVtbl; +struct IADLXGPU { const IADLXGPUVtbl *pVtbl; }; + +typedef struct IADLXGPUListVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXGPUList* pThis */); + adlx_long (ADLX_STD_CALL *Release)(IADLXGPUList* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXGPUList* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXList + adlx_uint (ADLX_STD_CALL *Size)(/* IADLXGPUList* pThis */); + adlx_uint8 (ADLX_STD_CALL *Empty)(/* IADLXGPUList* pThis */); + adlx_uint (ADLX_STD_CALL *Begin)(IADLXGPUList* pThis); // Used + adlx_uint (ADLX_STD_CALL *End)(IADLXGPUList* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL *At)(/* IADLXGPUList* pThis, const adlx_uint location, IADLXInterface** ppItem */); + ADLX_RESULT (ADLX_STD_CALL *Clear)(/* IADLXGPUList* pThis */); + ADLX_RESULT (ADLX_STD_CALL *Remove_Back)(/* IADLXGPUList* pThis */); + ADLX_RESULT (ADLX_STD_CALL *Add_Back)(/* IADLXGPUList* pThis, IADLXInterface* pItem */); + + //IADLXGPUList + ADLX_RESULT (ADLX_STD_CALL *At_GPUList)(IADLXGPUList* pThis, const adlx_uint location, IADLXGPU** ppItem); // Used + ADLX_RESULT (ADLX_STD_CALL *Add_Back_GPUList)(/* IADLXGPUList* pThis, IADLXGPU* pItem */); + +} IADLXGPUListVtbl; +struct IADLXGPUList { const IADLXGPUListVtbl *pVtbl; }; + +typedef struct IADLXPerformanceMonitoringServicesVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL *Acquire)(/* IADLXPerformanceMonitoringServices* pThis */); + adlx_long (ADLX_STD_CALL *Release)(IADLXPerformanceMonitoringServices* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL *QueryInterface)(/* IADLXPerformanceMonitoringServices* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXPerformanceMonitoringServices + ADLX_RESULT (ADLX_STD_CALL *GetSamplingIntervalRange)(/* IADLXPerformanceMonitoringServices* pThis, ADLX_IntRange* range */); + ADLX_RESULT (ADLX_STD_CALL *SetSamplingInterval)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int intervalMs */); + ADLX_RESULT (ADLX_STD_CALL *GetSamplingInterval)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* intervalMs */); + ADLX_RESULT (ADLX_STD_CALL *GetMaxPerformanceMetricsHistorySizeRange)(/* IADLXPerformanceMonitoringServices* pThis, ADLX_IntRange* range */); + ADLX_RESULT (ADLX_STD_CALL *SetMaxPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int sizeSec */); + ADLX_RESULT (ADLX_STD_CALL *GetMaxPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* sizeSec */); + ADLX_RESULT (ADLX_STD_CALL *ClearPerformanceMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis */); + ADLX_RESULT (ADLX_STD_CALL *GetCurrentPerformanceMetricsHistorySize)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int* sizeSec */); + ADLX_RESULT (ADLX_STD_CALL *StartPerformanceMetricsTracking)(/* IADLXPerformanceMonitoringServices* pThis */); + ADLX_RESULT (ADLX_STD_CALL *StopPerformanceMetricsTracking)(/* IADLXPerformanceMonitoringServices* pThis */); + ADLX_RESULT (ADLX_STD_CALL *GetAllMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXAllMetricsList** ppMetricsList */); + ADLX_RESULT (ADLX_STD_CALL *GetGPUMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, adlx_int startMs, adlx_int stopMs, IADLXGPUMetricsList** ppMetricsList */); + ADLX_RESULT (ADLX_STD_CALL *GetSystemMetricsHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXSystemMetricsList** ppMetricsList */); + ADLX_RESULT (ADLX_STD_CALL *GetFPSHistory)(/* IADLXPerformanceMonitoringServices* pThis, adlx_int startMs, adlx_int stopMs, IADLXFPSList** ppMetricsList */); + ADLX_RESULT (ADLX_STD_CALL *GetCurrentAllMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXAllMetrics** ppMetrics */); + ADLX_RESULT (ADLX_STD_CALL *GetCurrentGPUMetrics)(IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, IADLXGPUMetrics** ppMetrics); // Used + ADLX_RESULT (ADLX_STD_CALL *GetCurrentSystemMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXSystemMetrics** ppMetrics */); + ADLX_RESULT (ADLX_STD_CALL *GetCurrentFPS)(/* IADLXPerformanceMonitoringServices* pThis, IADLXFPS** ppMetrics */); + ADLX_RESULT (ADLX_STD_CALL *GetSupportedGPUMetrics)(IADLXPerformanceMonitoringServices* pThis, IADLXGPU* pGPU, IADLXGPUMetricsSupport** ppMetricsSupported); // Used + ADLX_RESULT (ADLX_STD_CALL *GetSupportedSystemMetrics)(/* IADLXPerformanceMonitoringServices* pThis, IADLXSystemMetricsSupport** ppMetricsSupported */); +}IADLXPerformanceMonitoringServicesVtbl; +struct IADLXPerformanceMonitoringServices { const IADLXPerformanceMonitoringServicesVtbl *pVtbl; }; + +typedef struct IADLXGPUMetricsSupportVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL* Acquire)(/* IADLXGPUMetricsSupport* pThis */); + adlx_long (ADLX_STD_CALL* Release)(IADLXGPUMetricsSupport* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL* QueryInterface)(/* IADLXGPUMetricsSupport* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXGPUMetricsSupport + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUUsage)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUClockSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVRAMClockSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUTemperature)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUHotspotTemperature)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUPower)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUTotalBoardPower)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUFanSpeed)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVRAM)(IADLXGPUMetricsSupport* pThis, adlx_bool* supported); // Used + ADLX_RESULT (ADLX_STD_CALL* IsSupportedGPUVoltage)(/* IADLXGPUMetricsSupport* pThis, adlx_bool* supported */); + + ADLX_RESULT (ADLX_STD_CALL* GetGPUUsageRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUClockSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUVRAMClockSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUTemperatureRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUHotspotTemperatureRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUPowerRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUFanSpeedRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUVRAMRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUVoltageRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); + ADLX_RESULT (ADLX_STD_CALL* GetGPUTotalBoardPowerRange)(/* IADLXGPUMetricsSupport* pThis, adlx_int* minValue, adlx_int* maxValue */); +} IADLXGPUMetricsSupportVtbl; +struct IADLXGPUMetricsSupport { const IADLXGPUMetricsSupportVtbl *pVtbl; }; + +typedef struct IADLXGPUMetricsVtbl +{ + //IADLXInterface + adlx_long (ADLX_STD_CALL* Acquire)(/* IADLXGPUMetrics* pThis */); + adlx_long (ADLX_STD_CALL* Release)(IADLXGPUMetrics* pThis); // Used + ADLX_RESULT (ADLX_STD_CALL* QueryInterface)(/* IADLXGPUMetrics* pThis, const wchar_t* interfaceId, void** ppInterface */); + + //IADLXGPUMetrics + ADLX_RESULT (ADLX_STD_CALL* TimeStamp)(/* IADLXGPUMetrics* pThis, adlx_int64* ms */); + ADLX_RESULT (ADLX_STD_CALL* GPUUsage)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUClockSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUVRAMClockSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUTemperature)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUHotspotTemperature)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUPower)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUTotalBoardPower)(/* IADLXGPUMetrics* pThis, adlx_double* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUFanSpeed)(/* IADLXGPUMetrics* pThis, adlx_int* data */); + ADLX_RESULT (ADLX_STD_CALL* GPUVRAM)(IADLXGPUMetrics* pThis, adlx_int* data); // Used + ADLX_RESULT (ADLX_STD_CALL* GPUVoltage)(/* IADLXGPUMetrics* pThis, adlx_int* data */); +} IADLXGPUMetricsVtbl; +struct IADLXGPUMetrics { const IADLXGPUMetricsVtbl *pVtbl; }; + +struct { + void *handle; + ADLX_RESULT (*ADLXInitialize)(adlx_uint64 version, IADLXSystem** ppSystem); + ADLX_RESULT (*ADLXInitializeWithIncompatibleDriver)(adlx_uint64 version, IADLXSystem** ppSystem); + ADLX_RESULT (*ADLXQueryVersion)(const char** version); + ADLX_RESULT (*ADLXTerminate)(); + IADLXSystem *sys; +} adlx { NULL, NULL, NULL, NULL, NULL, NULL }; +static std::mutex ggml_adlx_lock; + +extern "C" { + +int ggml_hip_mgmt_init() { + std::lock_guard lock(ggml_adlx_lock); + if (adlx.handle != NULL) { + // Already initialized + return 0; + } + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + fs::path libPath = fs::path("\\Windows") / fs::path("System32") / fs::path("amdadlx64.dll"); + + adlx.handle = (void*)LoadLibraryW(libPath.wstring().c_str()); + if (adlx.handle == NULL) { + return ADLX_NOT_FOUND; + } + + adlx.ADLXInitialize = (ADLX_RESULT (*)(adlx_uint64 version, IADLXSystem **ppSystem)) GetProcAddress((HMODULE)(adlx.handle), "ADLXInitialize"); + adlx.ADLXInitializeWithIncompatibleDriver = (ADLX_RESULT (*)(adlx_uint64 version, IADLXSystem **ppSystem)) GetProcAddress((HMODULE)(adlx.handle), "ADLXInitializeWithIncompatibleDriver"); + adlx.ADLXTerminate = (ADLX_RESULT (*)()) GetProcAddress((HMODULE)(adlx.handle), "ADLXTerminate"); + adlx.ADLXQueryVersion = (ADLX_RESULT (*)(const char **version)) GetProcAddress((HMODULE)(adlx.handle), "ADLXQueryVersion"); + if (adlx.ADLXInitialize == NULL || adlx.ADLXInitializeWithIncompatibleDriver == NULL || adlx.ADLXTerminate == NULL) { + GGML_LOG_INFO("%s unable to locate required symbols in amdadlx64.dll, falling back to hip free memory reporting", __func__); + FreeLibrary((HMODULE)(adlx.handle)); + adlx.handle = NULL; + return ADLX_NOT_FOUND; + } + + SetErrorMode(old_mode); + + // Aid in troubleshooting... + if (adlx.ADLXQueryVersion != NULL) { + const char *version = NULL; + ADLX_RESULT status = adlx.ADLXQueryVersion(&version); + if (ADLX_SUCCEEDED(status)) { + GGML_LOG_DEBUG("%s located ADLX version %s\n", __func__, version); + } + } + + ADLX_RESULT status = adlx.ADLXInitialize(ADLX_FULL_VERSION, &adlx.sys); + if (ADLX_FAILED(status)) { + // GGML_LOG_DEBUG("%s failed to initialize ADLX error=%d - attempting with incompatible driver...\n", __func__, status); + // Try with the incompatible driver + status = adlx.ADLXInitializeWithIncompatibleDriver(ADLX_FULL_VERSION, &adlx.sys); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s failed to initialize ADLX error=%d\n", __func__, status); + FreeLibrary((HMODULE)(adlx.handle)); + adlx.handle = NULL; + adlx.sys = NULL; + return status; + } + // GGML_LOG_DEBUG("%s initialized ADLX with incpomatible driver\n", __func__); + } + return ADLX_OK; +} + +void ggml_hip_mgmt_release() { + std::lock_guard lock(ggml_adlx_lock); + if (adlx.handle == NULL) { + // Already free + return; + } + ADLX_RESULT status = adlx.ADLXTerminate(); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s failed to terminate Adlx %d\n", __func__, status); + // Unload anyway... + } + FreeLibrary((HMODULE)(adlx.handle)); + adlx.handle = NULL; +} + +#define adlx_gdm_cleanup \ + if (gpuMetricsSupport != NULL) gpuMetricsSupport->pVtbl->Release(gpuMetricsSupport); \ + if (gpuMetrics != NULL) gpuMetrics->pVtbl->Release(gpuMetrics); \ + if (perfMonitoringServices != NULL) perfMonitoringServices->pVtbl->Release(perfMonitoringServices); \ + if (gpus != NULL) gpus->pVtbl->Release(gpus); \ + if (gpu != NULL) gpu->pVtbl->Release(gpu) + +int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total) { + std::lock_guard lock(ggml_adlx_lock); + if (adlx.handle == NULL) { + GGML_LOG_INFO("%s ADLX was not initialized\n", __func__); + return ADLX_ADL_INIT_ERROR; + } + IADLXGPUMetricsSupport *gpuMetricsSupport = NULL; + IADLXPerformanceMonitoringServices *perfMonitoringServices = NULL; + IADLXGPUList* gpus = NULL; + IADLXGPU* gpu = NULL; + IADLXGPUMetrics *gpuMetrics = NULL; + ADLX_RESULT status; + // The "UniqueID" exposed in ADLX is the PCI Bus and Device IDs + adlx_int target = (pci_bus_id << 8) | (pci_device_id & 0xff); + + status = adlx.sys->pVtbl->GetPerformanceMonitoringServices(adlx.sys, &perfMonitoringServices); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GetPerformanceMonitoringServices failed %d\n", __func__, status); + return status; + } + + status = adlx.sys->pVtbl->GetGPUs(adlx.sys, &gpus); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GetGPUs failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + + // Get GPU list + for (adlx_uint crt = gpus->pVtbl->Begin(gpus); crt != gpus->pVtbl->End(gpus); ++crt) + { + status = gpus->pVtbl->At_GPUList(gpus, crt, &gpu); + if (ADLX_FAILED(status)) + { + GGML_LOG_INFO("%s %d] At_GPUList failed %d\n", __func__, crt, status); + continue; + } + adlx_int id; + status = gpu->pVtbl->UniqueId(gpu, &id); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s %d] UniqueId lookup failed %d\n", __func__, crt, status); + gpu->pVtbl->Release(gpu); + gpu = NULL; + continue; + } + if (id != target) { + GGML_LOG_DEBUG("%s %d] GPU UniqueId: %x does not match target %02x %02x\n", __func__, crt, id, pci_bus_id, pci_device_id); + gpu->pVtbl->Release(gpu); + gpu = NULL; + continue; + } + // Any failures at this point should cause a fall-back to other APIs + status = perfMonitoringServices->pVtbl->GetSupportedGPUMetrics(perfMonitoringServices, gpu, &gpuMetricsSupport); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GetSupportedGPUMetrics failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + status = perfMonitoringServices->pVtbl->GetCurrentGPUMetrics(perfMonitoringServices, gpu, &gpuMetrics); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GetCurrentGPUMetrics failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + + adlx_bool supported = false; + status = gpuMetricsSupport->pVtbl->IsSupportedGPUVRAM(gpuMetricsSupport, &supported); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s IsSupportedGPUVRAM failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + + adlx_uint totalVRAM = 0; + status = gpu->pVtbl->TotalVRAM(gpu, &totalVRAM); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s TotalVRAM failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + + adlx_int usedVRAM = 0; + status = gpuMetrics->pVtbl->GPUVRAM(gpuMetrics, &usedVRAM); + if (ADLX_FAILED(status)) { + GGML_LOG_INFO("%s GPUVRAM failed %d\n", __func__, status); + adlx_gdm_cleanup; + return status; + } + *total = size_t(totalVRAM) * 1024 * 1024; + *free = size_t(totalVRAM-usedVRAM) * 1024 * 1024; + + adlx_gdm_cleanup; + return ADLX_OK; + } + adlx_gdm_cleanup; + return ADLX_NOT_FOUND; +} + +} // extern "C" + +#else // #ifdef _WIN32 + +extern "C" { + +// TODO Linux implementation of accurate VRAM reporting +int ggml_hip_mgmt_init() { + return -1; +} +void ggml_hip_mgmt_release() {} +int ggml_hip_get_device_memory(int pci_bus_id, int pci_device_id, size_t *free, size_t *total) { + return -1; +} + +} // extern "C" + +#endif // #ifdef _WIN32 \ No newline at end of file diff --git a/ml/backend/ggml/ggml/src/mem_nvml.cpp b/ml/backend/ggml/ggml/src/mem_nvml.cpp new file mode 100644 index 00000000..f473a2a2 --- /dev/null +++ b/ml/backend/ggml/ggml/src/mem_nvml.cpp @@ -0,0 +1,274 @@ +// NVIDIA Management Library (NVML) +// +// https://developer.nvidia.com/management-library-nvml +// +// This library provides accurate VRAM reporting for NVIDIA GPUs, particularly +// on Windows, where the cuda library provides inaccurate VRAM usage metrics. The +// runtime DLL is installed with every driver on Windows, and most Linux +// systems, and the headers are included in the standard CUDA SDK install. As +// such, we can include the header here to simplify the code. + + +#include "ggml-impl.h" +#include +#include +#include +#include + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +#else +# include +# include +# include +# include +#endif + +namespace fs = std::filesystem; + +// Minimal definitions to avoid including the nvml.h header +typedef enum nvmlReturn_enum +{ + // cppcheck-suppress * + NVML_SUCCESS = 0, //!< The operation was successful + NVML_ERROR_UNINITIALIZED = 1, //!< NVML was not first initialized with nvmlInit() + NVML_ERROR_INVALID_ARGUMENT = 2, //!< A supplied argument is invalid + NVML_ERROR_NOT_SUPPORTED = 3, //!< The requested operation is not available on target device + NVML_ERROR_NO_PERMISSION = 4, //!< The current user does not have permission for operation + NVML_ERROR_ALREADY_INITIALIZED = 5, //!< Deprecated: Multiple initializations are now allowed through ref counting + NVML_ERROR_NOT_FOUND = 6, //!< A query to find an object was unsuccessful + NVML_ERROR_INSUFFICIENT_SIZE = 7, //!< An input argument is not large enough + NVML_ERROR_INSUFFICIENT_POWER = 8, //!< A device's external power cables are not properly attached + NVML_ERROR_DRIVER_NOT_LOADED = 9, //!< NVIDIA driver is not loaded + NVML_ERROR_TIMEOUT = 10, //!< User provided timeout passed + NVML_ERROR_IRQ_ISSUE = 11, //!< NVIDIA Kernel detected an interrupt issue with a GPU + NVML_ERROR_LIBRARY_NOT_FOUND = 12, //!< NVML Shared Library couldn't be found or loaded + NVML_ERROR_FUNCTION_NOT_FOUND = 13, //!< Local version of NVML doesn't implement this function + NVML_ERROR_CORRUPTED_INFOROM = 14, //!< infoROM is corrupted + NVML_ERROR_GPU_IS_LOST = 15, //!< The GPU has fallen off the bus or has otherwise become inaccessible + NVML_ERROR_RESET_REQUIRED = 16, //!< The GPU requires a reset before it can be used again + NVML_ERROR_OPERATING_SYSTEM = 17, //!< The GPU control device has been blocked by the operating system/cgroups + NVML_ERROR_LIB_RM_VERSION_MISMATCH = 18, //!< RM detects a driver/library version mismatch + NVML_ERROR_IN_USE = 19, //!< An operation cannot be performed because the GPU is currently in use + NVML_ERROR_MEMORY = 20, //!< Insufficient memory + NVML_ERROR_NO_DATA = 21, //!< No data + NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22, //!< The requested vgpu operation is not available on target device, becasue ECC is enabled + NVML_ERROR_INSUFFICIENT_RESOURCES = 23, //!< Ran out of critical resources, other than memory + NVML_ERROR_FREQ_NOT_SUPPORTED = 24, //!< Ran out of critical resources, other than memory + NVML_ERROR_ARGUMENT_VERSION_MISMATCH = 25, //!< The provided version is invalid/unsupported + NVML_ERROR_DEPRECATED = 26, //!< The requested functionality has been deprecated + NVML_ERROR_NOT_READY = 27, //!< The system is not ready for the request + NVML_ERROR_GPU_NOT_FOUND = 28, //!< No GPUs were found + NVML_ERROR_INVALID_STATE = 29, //!< Resource not in correct state to perform requested operation + NVML_ERROR_UNKNOWN = 999 //!< An internal driver error occurred +} nvmlReturn_t; +typedef struct nvmlDevice_st* nvmlDevice_t; +typedef struct nvmlMemory_st +{ + unsigned long long total; //!< Total physical device memory (in bytes) + unsigned long long free; //!< Unallocated device memory (in bytes) + unsigned long long used; //!< Sum of Reserved and Allocated device memory (in bytes). + //!< Note that the driver/GPU always sets aside a small amount of memory for bookkeeping +} nvmlMemory_t; +// end nvml.h definitions + +struct { + void *handle; + nvmlReturn_t (*nvmlInit_v2)(void); + nvmlReturn_t (*nvmlShutdown)(void); + nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *); + nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *); + nvmlReturn_t (*nvmlDeviceGetName)(nvmlDevice_t, char *, unsigned int); + const char * (*nvmlErrorString)(nvmlReturn_t result); +} nvml { NULL, NULL, NULL, NULL, NULL, NULL, NULL }; +static std::mutex ggml_nvml_lock; + +extern "C" { + +#ifndef _WIN32 +// Helper function to get available memory from /proc/meminfo on Linux +// Returns MemAvailable as calculated by the kernel +static size_t get_mem_available() { + std::ifstream meminfo("/proc/meminfo"); + if (!meminfo.is_open()) { + return 0; + } + + std::string line; + while (std::getline(meminfo, line)) { + if (line.find("MemAvailable:") == 0) { + size_t available_kb; + sscanf(line.c_str(), "MemAvailable: %zu kB", &available_kb); + // Convert from kB to bytes + return available_kb * 1024; + } + } + + return 0; +} +#endif + +int ggml_nvml_init() { + std::lock_guard lock(ggml_nvml_lock); + if (nvml.handle != NULL) { + // Already initialized + return 0; + } +#ifdef _WIN32 + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + fs::path libPath[2]; + const char * programDir = std::getenv("ProgramW6432"); + if (programDir == NULL) { + libPath[0] = fs::path("Program Files") / fs::path("NVIDIA Corporation") / fs::path("NVSMI") / fs::path("NVML.dll"); + } else { + libPath[0] = fs::path(programDir) / fs::path("NVIDIA Corporation") / fs::path("NVSMI") / fs::path("NVML.dll"); + } + libPath[1] = fs::path("\\Windows") / fs::path("System32") / fs::path("NVML.dll"); + + for (int i = 0; i < 2; i++) { + nvml.handle = (void*)LoadLibraryW(libPath[i].wstring().c_str()); + if (nvml.handle != NULL) { + break; + } + } + if (nvml.handle == NULL) { + return NVML_ERROR_NOT_FOUND; + } + + nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlInit_v2"); + nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown"); + nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID"); + nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo"); + nvml.nvmlDeviceGetName = (nvmlReturn_t (*)(nvmlDevice_t, char *, unsigned int)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetName"); + nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) GetProcAddress((HMODULE)(nvml.handle), "nvmlErrorString"); + if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlDeviceGetName == NULL || nvml.nvmlErrorString == NULL) { + GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__); + FreeLibrary((HMODULE)(nvml.handle)); + nvml.handle = NULL; + return NVML_ERROR_NOT_FOUND; + } + + SetErrorMode(old_mode); + + nvmlReturn_t status = nvml.nvmlInit_v2(); + if (status != NVML_SUCCESS) { + GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status)); + FreeLibrary((HMODULE)(nvml.handle)); + nvml.handle = NULL; + return status; + } +#else + constexpr std::array libPaths = { + "/usr/lib/wsl/lib/libnvidia-ml.so.1", // Favor WSL2 path if present + "libnvidia-ml.so.1" // On a non-WSL2 system, it should be in the path + }; + for (const char* path : libPaths) { + nvml.handle = dlopen(path, RTLD_LAZY); + if (nvml.handle) break; + } + if (nvml.handle == NULL) { + GGML_LOG_INFO("%s unable to load libnvidia-ml: %s\n", __func__, dlerror()); + return NVML_ERROR_NOT_FOUND; + } + nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlInit_v2"); + nvml.nvmlShutdown = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlShutdown"); + nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) dlsym(nvml.handle, "nvmlDeviceGetHandleByUUID"); + nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) dlsym(nvml.handle, "nvmlDeviceGetMemoryInfo"); + nvml.nvmlDeviceGetName = (nvmlReturn_t (*)(nvmlDevice_t, char *, unsigned int)) dlsym(nvml.handle, "nvmlDeviceGetName"); + nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) dlsym(nvml.handle, "nvmlErrorString"); + if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlDeviceGetName == NULL) { + GGML_LOG_INFO("%s unable to locate required symbols in libnvidia-ml.so", __func__); + dlclose(nvml.handle); + nvml.handle = NULL; + return NVML_ERROR_NOT_FOUND; + } + nvmlReturn_t status = nvml.nvmlInit_v2(); + if (status != NVML_SUCCESS) { + GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status)); + dlclose(nvml.handle); + nvml.handle = NULL; + return status; + } +#endif + return NVML_SUCCESS; +} + +void ggml_nvml_release() { + std::lock_guard lock(ggml_nvml_lock); + if (nvml.handle == NULL) { + // Already free + return; + } + nvmlReturn_enum status = nvml.nvmlShutdown(); + if (status != NVML_SUCCESS) { + GGML_LOG_INFO("%s failed to shutdown NVML: %s\n", __func__, nvml.nvmlErrorString(status)); + } +#ifdef _WIN32 + FreeLibrary((HMODULE)(nvml.handle)); +#else + dlclose(nvml.handle); +#endif + nvml.handle = NULL; +} + +int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) { + std::lock_guard lock(ggml_nvml_lock); + if (nvml.handle == NULL) { + return NVML_ERROR_UNINITIALIZED; + } + nvmlDevice_t device; + auto status = nvml.nvmlDeviceGetHandleByUUID(uuid, &device); + if (status != NVML_SUCCESS) { + return status; + } + nvmlMemory_t memInfo = {0}; + status = nvml.nvmlDeviceGetMemoryInfo(device, &memInfo); + + if (status == NVML_SUCCESS) { + // NVML working correctly, use its values + *free = memInfo.free; + *total = memInfo.total; + return NVML_SUCCESS; + } + +#ifndef _WIN32 + // Handle NVML_ERROR_NOT_SUPPORTED - this indicates NVML doesn't support + // reporting framebuffer memory (e.g., unified memory GPUs where FB memory is 0) + if (status == NVML_ERROR_NOT_SUPPORTED) { + // Use system memory from /proc/meminfo + size_t mem_available = get_mem_available(); + size_t mem_total = 0; + + // Read MemTotal + std::ifstream meminfo("/proc/meminfo"); + if (meminfo.is_open()) { + std::string line; + while (std::getline(meminfo, line)) { + if (line.find("MemTotal:") == 0) { + size_t total_kb; + sscanf(line.c_str(), "MemTotal: %zu kB", &total_kb); + mem_total = total_kb * 1024; + break; + } + } + } + + if (mem_total > 0) { + *total = mem_total; + *free = mem_available; + GGML_LOG_INFO("%s NVML not supported for memory query, using system memory (total=%zu, available=%zu)\n", + __func__, mem_total, mem_available); + return NVML_SUCCESS; + } + } +#endif + + return status; +} + +} \ No newline at end of file diff --git a/ml/device.go b/ml/device.go new file mode 100644 index 00000000..6569d87b --- /dev/null +++ b/ml/device.go @@ -0,0 +1,338 @@ +package ml + +import ( + "context" + "encoding/binary" + "fmt" + "hash/maphash" + "log/slog" + "slices" + "sort" + "strconv" + "strings" + + "github.com/ollama/ollama/format" +) + +// GPULayers is a set of layers to be allocated on a single GPU +type GPULayers struct { + DeviceID + + // Layers is a set of layer indicies to load + Layers []int +} + +func (g GPULayers) String() string { + if len(g.Layers) == 0 { + return "" + } + + slices.Sort(g.Layers) + + contiguous := true + base := g.Layers[0] + for i := range g.Layers { + if g.Layers[i] != base+i { + contiguous = false + break + } + } + + if contiguous { + return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1]) + } else { + return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers) + } +} + +// GPULayersList is a set of layer allocations across multiple GPUs +type GPULayersList []GPULayers + +func (l GPULayersList) String() string { + if l.Sum() > 0 { + return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l)) + } else { + return fmt.Sprintf("%v", []GPULayers(l)) + } +} + +// Sum is the total number of layers assigned across all GPUs +func (l GPULayersList) Sum() int { + var sum int + + for _, g := range l { + sum += len(g.Layers) + } + + return sum +} + +var h maphash.Hash + +// Hash is an identifier of this layer assignment +func (l GPULayersList) Hash() uint64 { + h.Reset() + for _, g := range l { + if len(g.Layers) > 0 { + h.WriteString(g.ID + g.Library) + for _, l := range g.Layers { + binary.Write(&h, binary.NativeEndian, int64(l)) + } + } + } + + return h.Sum64() +} + +// ErrNoMem is returned when panicing due to insufficient memory. It includes +// the attempted memory allocation. +type ErrNoMem struct { + BackendMemory +} + +func (e ErrNoMem) Error() string { + return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory) +} + +// Minimal unique device identification +type DeviceID struct { + // ID is an identifier for the device for matching with system + // management libraries. The ID is only unique for other devices + // using the same Library. + // This ID represents a "post filtered" view of the enumerated devices + // if the ID is numeric + ID string `json:"id"` + + // Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.) + Library string `json:"backend,omitempty"` +} + +// DeviceMemory provides a breakdown of the memory needed +// per device, such as a CPU or GPU. +type DeviceMemory struct { + DeviceID + + // Name is the name of the device as labeled by the backend. It + // may not be persistent across instances of the runner. + Name string + + // Weights is the per-layer memory needed for the model weights. + Weights []uint64 + + // Cache is the per-layer memory needed for the KV cache. + Cache []uint64 + + // Graph is the size of the compute graph. It is not per-layer. + Graph uint64 +} + +func sumMemory(mem []uint64) uint64 { + var sum uint64 + + for _, m := range mem { + sum += m + } + + return sum +} + +// Size returns the total size of the memory required by this device +func (m DeviceMemory) Size() uint64 { + return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph +} + +func memoryPresent(mem []uint64) bool { + return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 }) +} + +func (m DeviceMemory) LogValue() slog.Value { + var attrs []slog.Attr + if memoryPresent(m.Weights) { + attrs = append(attrs, slog.Any("Weights", m.Weights)) + } + + if memoryPresent(m.Cache) { + attrs = append(attrs, slog.Any("Cache", m.Cache)) + } + + if m.Graph != 0 { + attrs = append(attrs, slog.Any("Graph", m.Graph)) + } + + if len(attrs) > 0 && m.ID != "" { + attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...) + } + + return slog.GroupValue(attrs...) +} + +// BackendMemory provides the amount of memory required to load the model +// per device based on the BackendParams. In some cases, not all required +// allocations will be known at this point. However, the size of the most recent +// allocation is guaranteed to be provided so that if it failed, the caller can +// accommodate that to make forward progress. +type BackendMemory struct { + // InputWeights are always located on the CPU and cannot be moved + InputWeights uint64 + + // CPU model components are located in system memory. This does not + // include unified memory allocated through the GPU. + CPU DeviceMemory + + // GPU model components are located on one or more GPUs. + GPUs []DeviceMemory +} + +func (m BackendMemory) LogValue() slog.Value { + var attrs []slog.Attr + if m.InputWeights != 0 { + attrs = append(attrs, slog.Any("InputWeights", m.InputWeights)) + } + + attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU)) + for _, g := range m.GPUs { + attrs = append(attrs, slog.Any(g.Name, g)) + } + + return slog.GroupValue(attrs...) +} + +// Log prints a high level summary of the memory +func (m BackendMemory) Log(level slog.Level) { + var total uint64 + + for _, gpu := range m.GPUs { + if sum := sumMemory(gpu.Weights); sum > 0 { + slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum)) + total += sum + } + } + if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 { + slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) + total += sum + } + + for _, gpu := range m.GPUs { + if sum := sumMemory(gpu.Cache); sum > 0 { + slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum)) + total += sum + } + } + if sum := sumMemory(m.CPU.Cache); sum > 0 { + slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) + total += sum + } + + for _, gpu := range m.GPUs { + if sum := gpu.Graph; sum > 0 { + slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum)) + total += sum + } + } + if sum := m.CPU.Graph; sum > 0 { + slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) + total += sum + } + + if total > 0 { + slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total)) + } +} + +type DeviceInfo struct { + DeviceID + + // Name is the name of the device as labeled by the backend. It + // may not be persistent across instances of the runner. + Name string `json:"name"` + + // Description is the longer user-friendly identification of the device + Description string `json:"description"` + + // FilterID is populated with the unfiltered device ID if a numeric ID is used + // so the device can be included. + FilteredID string `json:"filtered_id,omitempty"` + + // Integrated is set true for integrated GPUs, false for Discrete GPUs + Integrated bool `json:"integration,omitempty"` + + // PCIID is the bus, device and domain ID of the device for deduplication + // when discovered by multiple backends + PCIID string `json:"pci_id,omitempty"` + + // TotalMemory is the total amount of memory the device can use for loading models + TotalMemory uint64 `json:"total_memory"` + + // FreeMemory is the amount of memory currently available on the device for loading models + FreeMemory uint64 `json:"free_memory,omitempty"` + + // ComputeMajor is the major version of capabilities of the device + // if unsupported by the backend, -1 will be returned + ComputeMajor int + + // ComputeMinor is the minor version of capabilities of the device + // if unsupported by the backend, -1 will be returned + ComputeMinor int + + // Driver Information + DriverMajor int `json:"driver_major,omitempty"` + DriverMinor int `json:"driver_minor,omitempty"` + + // Where backends were loaded from + LibraryPath []string +} + +func (d DeviceInfo) Compute() string { + // AMD gfx is encoded into the major minor in hex form + if strings.EqualFold(d.Library, "ROCm") { + return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor) + } + return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor) +} + +func (d DeviceInfo) Driver() string { + return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor) +} + +type DeviceComparison int + +const ( + UniqueDevice DeviceComparison = iota + SameBackendDevice // The device is the same, and the library/backend is the same + DuplicateDevice // The same physical device but different library/backend (overlapping device) +) + +func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison { + if a.PCIID != b.PCIID { + return UniqueDevice + } + if a.Library == b.Library { + return SameBackendDevice + } + return DuplicateDevice +} + +// For a SameBackendDevice, return true if b is better than a +// e.g. newer GPU library version +func (a DeviceInfo) IsBetter(b DeviceInfo) bool { + aLib := a.LibraryPath[len(a.LibraryPath)-1] + bLib := b.LibraryPath[len(b.LibraryPath)-1] + if aLib == bLib { + return false + } + aLibSplit := strings.SplitN(aLib, "_", 2) + bLibSplit := strings.SplitN(bLib, "_", 2) + if len(aLibSplit) < 2 || len(bLibSplit) < 2 { + return false + } + if aLibSplit[0] != bLibSplit[0] { + slog.Debug("unexpected libraries", "a", aLib, "b", bLib) + return false + } + if aLibSplit[1] == bLibSplit[1] { + return false + } + cmp := []string{aLibSplit[1], bLibSplit[1]} + sort.Sort(sort.Reverse(sort.StringSlice(cmp))) + return cmp[0] == bLibSplit[1] +} diff --git a/ml/nn/pooling/pooling_test.go b/ml/nn/pooling/pooling_test.go index c8001945..e2772746 100644 --- a/ml/nn/pooling/pooling_test.go +++ b/ml/nn/pooling/pooling_test.go @@ -3,11 +3,9 @@ package pooling_test import ( "bytes" "os" - "slices" "testing" "github.com/google/go-cmp/cmp" - "github.com/ollama/ollama/discover" fsggml "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/backend/ggml" @@ -32,20 +30,7 @@ func setup(tb testing.TB, n int) ml.Backend { tb.Fatal(err) } - var gpuLayers ml.GPULayersList - if gpus := discover.GetGPUInfo(); len(gpus) > 0 { - gpuLayers = append(gpuLayers, ml.GPULayers{ - ID: gpus[0].ID, - Layers: slices.Collect(func(yield func(int) bool) { - for i := range n { - if !yield(i) { - return - } - } - }), - }) - } - b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers}) + b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true}) if err != nil { tb.Fatal(err) } diff --git a/model/bytepairencoding_test.go b/model/bytepairencoding_test.go index 39e5ab45..15cb56ca 100644 --- a/model/bytepairencoding_test.go +++ b/model/bytepairencoding_test.go @@ -251,7 +251,7 @@ func BenchmarkBytePairEncoding(b *testing.B) { bts := bts[:n] b.Run("encode"+strconv.Itoa(n), func(b *testing.B) { b.ResetTimer() - for range b.N { + for b.Loop() { _, err := tokenizer.Encode(string(bts), true) if err != nil { b.Fatal(err) @@ -266,7 +266,7 @@ func BenchmarkBytePairEncoding(b *testing.B) { } b.ResetTimer() - for range b.N { + for b.Loop() { _, err := tokenizer.Decode(ids) if err != nil { b.Fatal(err) @@ -276,7 +276,7 @@ func BenchmarkBytePairEncoding(b *testing.B) { b.Run("split"+strconv.Itoa(n), func(b *testing.B) { b.ResetTimer() - for range b.N { + for b.Loop() { slices.Collect(tokenizer.split(string(bts))) } }) diff --git a/model/models/deepseek2/model.go b/model/models/deepseek2/model.go index 7b88711b..7e57f72d 100644 --- a/model/models/deepseek2/model.go +++ b/model/models/deepseek2/model.go @@ -150,7 +150,9 @@ func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml } func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor { - scores = scores.Add(ctx, moe.ExpProbsBias) + if moe.ExpProbsBias != nil { + scores = scores.Add(ctx, moe.ExpProbsBias) + } topKIndices := scores.TopK(ctx, opts.numExpertsUsed) return topKIndices } diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index d0e9a026..1333151b 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -65,7 +65,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac cache.(*kvcache.WrapperCache).SetLayerType(layerType) // inputPerLayer = inputsPerLayer[:, i, :] - inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2)) + inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2)).Contiguous(ctx) hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions) } diff --git a/model/models/llama4/process_image.go b/model/models/llama4/process_image.go index 916f6f90..0b3fab53 100644 --- a/model/models/llama4/process_image.go +++ b/model/models/llama4/process_image.go @@ -73,7 +73,7 @@ func (p ImageProcessor) bestResolution(img image.Point, possibleResolutions []im for i, res := range possibleResolutions { scaleW := float64(res.X) / float64(w) scaleH := float64(res.Y) / float64(h) - scale := math.Min(scaleW, scaleH) + scale := min(scaleW, scaleH) scales[i] = scale } @@ -124,11 +124,11 @@ func (p ImageProcessor) maxResolution(imageRes, targetRes image.Point) image.Poi if scaleW < scaleH { newRes = image.Point{ targetRes.X, - int(math.Min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))), + int(min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))), } } else { newRes = image.Point{ - int(math.Min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))), + int(min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))), targetRes.Y, } } diff --git a/model/models/mllama/process_image.go b/model/models/mllama/process_image.go index 8e60508f..7ab3de9f 100644 --- a/model/models/mllama/process_image.go +++ b/model/models/mllama/process_image.go @@ -53,7 +53,7 @@ func (p ImageProcessor) fitToCanvas(imageSize, canvasSize image.Point) image.Poi tw := min(max(imageSize.X, p.imageSize), canvasSize.X) th := min(max(imageSize.Y, p.imageSize), canvasSize.Y) - r := math.Min( + r := min( float64(tw)/float64(imageSize.X), float64(th)/float64(imageSize.Y), ) @@ -89,10 +89,10 @@ func (p ImageProcessor) optimalTiledCanvas(imageSize image.Point) image.Point { if minUpscale == 0 { minUpscale = s } else { - minUpscale = math.Min(minUpscale, s) + minUpscale = min(minUpscale, s) } } else { - maxDownscale = math.Max(maxDownscale, s) + maxDownscale = max(maxDownscale, s) } } diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index cc58e4a2..9fd6e313 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -15,11 +15,17 @@ import ( ) type Options struct { - hiddenSize, numHeads, numKVHeads int - eps float32 - ropeBase, ropeScale float32 + hiddenSize, + numHeads, + numKVHeads, + keyLength, + valueLength int - keyLength, valueLength int + eps, + ropeBase, + ropeScale float32 + ropeType string + originalContextLength int numExperts, numExpertsUsed int normTopKProb bool @@ -29,6 +35,19 @@ func (o Options) headDim() int { return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) } +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + opts := []func(*rope.Options){rope.WithTypeNeoX()} + if o.ropeType == "yarn" { + attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale)))) + opts = append(opts, + rope.WithOriginalContextLength(o.originalContextLength), + rope.WithExtrapolationFactor(1.), + rope.WithAttentionFactor(attnFactor), + ) + } + return fast.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...) +} + type Attention struct { Query *nn.Linear `gguf:"attn_q"` QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` @@ -52,8 +71,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, query = sa.QueryNorm.Forward(ctx, query, opts.eps) key = sa.KeyNorm.Forward(ctx, key, opts.eps) - query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) - key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) @@ -183,7 +202,7 @@ func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil + return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil } var _ model.Model = (*Model)(nil) @@ -216,17 +235,19 @@ func New(c fs.Config) (model.Model, error) { ), Layers: layers, Options: &Options{ - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - keyLength: int(c.Uint("attention.key_length")), - valueLength: int(c.Uint("attention.value_length")), - eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.scaling.factor", 1), - numExperts: int(c.Uint("expert_count")), - numExpertsUsed: int(c.Uint("expert_used_count")), - normTopKProb: c.Bool("norm_top_k_prob", true), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + keyLength: int(c.Uint("attention.key_length")), + valueLength: int(c.Uint("attention.value_length")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeType: c.String("rope.scaling.type"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.scaling.factor", 1), + originalContextLength: int(c.Uint("rope.scaling.original_context_length")), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + normTopKProb: c.Bool("norm_top_k_prob", true), }, } diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go index a1d4e812..4374f3e2 100644 --- a/model/parsers/parsers.go +++ b/model/parsers/parsers.go @@ -16,11 +16,38 @@ type Parser interface { HasThinkingSupport() bool } +type ParserConstructor func() Parser + +type ParserRegistry struct { + constructors map[string]ParserConstructor +} + +func (r *ParserRegistry) Register(name string, constructor ParserConstructor) { + r.constructors[name] = constructor +} + +var registry = ParserRegistry{ + constructors: make(map[string]ParserConstructor), +} + +func Register(name string, constructor ParserConstructor) { + registry.Register(name, constructor) +} + func ParserForName(name string) Parser { + if parser, ok := registry.constructors[name]; ok { + return parser() + } switch name { case "qwen3-coder": parser := &Qwen3CoderParser{} return parser + case "qwen3-vl-instruct": + parser := &Qwen3VLParser{hasThinkingSupport: false} + return parser + case "qwen3-vl-thinking": + parser := &Qwen3VLParser{hasThinkingSupport: true} + return parser case "passthrough": return &PassthroughParser{} case "harmony": diff --git a/model/parsers/parsers_test.go b/model/parsers/parsers_test.go new file mode 100644 index 00000000..8a64a235 --- /dev/null +++ b/model/parsers/parsers_test.go @@ -0,0 +1,97 @@ +package parsers + +import ( + "testing" + + "github.com/ollama/ollama/api" +) + +type mockParser struct { + name string +} + +func (m *mockParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { + return tools +} + +func (m *mockParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + return "mock:" + s, "", nil, nil +} + +func (m *mockParser) HasToolSupport() bool { + return false +} + +func (m *mockParser) HasThinkingSupport() bool { + return false +} + +func TestRegisterCustomParser(t *testing.T) { + // Register a custom parser + Register("custom-parser", func() Parser { + return &mockParser{name: "custom"} + }) + + // Retrieve it + parser := ParserForName("custom-parser") + if parser == nil { + t.Fatal("expected parser to be registered") + } + + // Test it works + content, _, _, err := parser.Add("test", false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "mock:test" { + t.Errorf("expected 'mock:test', got %q", content) + } +} + +func TestBuiltInParsersStillWork(t *testing.T) { + tests := []struct { + name string + }{ + {"passthrough"}, + {"qwen3-coder"}, + {"harmony"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := ParserForName(tt.name) + if parser == nil { + t.Fatalf("expected built-in parser %q to exist", tt.name) + } + }) + } +} + +func TestOverrideBuiltInParser(t *testing.T) { + // Override a built-in parser + Register("passthrough", func() Parser { + return &mockParser{name: "override"} + }) + + // Should get the override + parser := ParserForName("passthrough") + if parser == nil { + t.Fatal("expected parser to exist") + } + + // Test it's the override + content, _, _, err := parser.Add("test", false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "mock:test" { + t.Errorf("expected 'mock:test' from override, got %q", content) + } +} + +func TestUnknownParserReturnsNil(t *testing.T) { + parser := ParserForName("nonexistent-parser") + if parser != nil { + t.Error("expected nil for unknown parser") + } +} diff --git a/model/parsers/qwen3coder.go b/model/parsers/qwen3coder.go index f44d7c8e..bfa9762c 100644 --- a/model/parsers/qwen3coder.go +++ b/model/parsers/qwen3coder.go @@ -150,7 +150,9 @@ func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) { ambiguous := p.acc.String()[ambiguousStart:] p.acc.Reset() p.acc.WriteString(ambiguous) - events = append(events, qwenEventContent{content: unambiguous}) + if len(unambiguous) > 0 { + events = append(events, qwenEventContent{content: unambiguous}) + } return events, false } else { // we found content that is entirely not a tool call. We should withhold @@ -274,7 +276,14 @@ func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, er var paramType api.PropertyType if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil { if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok { - paramType = prop.Type + // Handle anyOf by collecting all types from the union + if len(prop.AnyOf) > 0 { + for _, anyOfProp := range prop.AnyOf { + paramType = append(paramType, anyOfProp.Type...) + } + } else { + paramType = prop.Type + } } } diff --git a/model/parsers/qwen3coder_test.go b/model/parsers/qwen3coder_test.go index c77fe2d9..e4246abc 100644 --- a/model/parsers/qwen3coder_test.go +++ b/model/parsers/qwen3coder_test.go @@ -103,6 +103,21 @@ func TestQwenParserStreaming(t *testing.T) { }, }, }, + { + desc: "unambiguous empty: partial tool open at buffer start", + steps: []step{ + { + input: "abc", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + }, + }, + }, + }, { desc: "trailing whitespace between tool call and content", steps: []step{ @@ -962,6 +977,21 @@ func TestQwenToolCallValueParsing(t *testing.T) { raw: "123", want: 123, // Integer has higher precedence than string }, + { + desc: "anyOf array or string - with array of objects", + paramType: api.PropertyType{"array", "string"}, + raw: `[{"content": "task 1", "status": "pending", "priority": "high", "id": "1"}, {"content": "task 2", "status": "completed", "priority": "low", "id": "2"}]`, + want: []any{ + map[string]any{"content": "task 1", "status": "pending", "priority": "high", "id": "1"}, + map[string]any{"content": "task 2", "status": "completed", "priority": "low", "id": "2"}, + }, + }, + { + desc: "anyOf array or string - with plain string", + paramType: api.PropertyType{"array", "string"}, + raw: "Error: could not load data", + want: "Error: could not load data", + }, } for _, tc := range cases { diff --git a/model/parsers/qwen3vl.go b/model/parsers/qwen3vl.go new file mode 100644 index 00000000..75ee6abe --- /dev/null +++ b/model/parsers/qwen3vl.go @@ -0,0 +1,236 @@ +package parsers + +import ( + "context" + "encoding/json" + "log/slog" + "strings" + "unicode" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/logutil" +) + +// TODO: call the init function +const ( + CollectingThinkingContent qwenParserState = iota + CollectingContent + CollectingToolContent +) + +const ( + thinkingCloseTag = "" +) + +type Qwen3VLParser struct { + state qwenParserState + buffer strings.Builder + tools []api.Tool + hasThinkingSupport bool +} + +func (p *Qwen3VLParser) HasToolSupport() bool { + return true +} + +func (p *Qwen3VLParser) HasThinkingSupport() bool { + return p.hasThinkingSupport +} + +func (p *Qwen3VLParser) setInitialState(lastMessage *api.Message) { + prefill := lastMessage != nil && lastMessage.Role == "assistant" + if !p.HasThinkingSupport() { + p.state = CollectingContent + return + } + + if prefill && lastMessage.Content != "" { + p.state = CollectingContent + return + } + + p.state = CollectingThinkingContent +} + +func (p *Qwen3VLParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { + p.tools = tools + p.setInitialState(lastMessage) + return tools +} + +type qwenEventThinkingContent struct { + content string +} + +func (qwenEventThinkingContent) isQwenEvent() {} + +func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + events := p.parseEvents() + + var toolCalls []api.ToolCall + var contentSb strings.Builder + var thinkingSb strings.Builder + for _, event := range events { + switch event := event.(type) { + case qwenEventRawToolCall: + toolCall, err := parseJSONToolCall(event, p.tools) + if err != nil { + slog.Warn("qwen tool call parsing failed", "error", err) + return "", "", nil, err + } + toolCalls = append(toolCalls, toolCall) + case qwenEventThinkingContent: + thinkingSb.WriteString(event.content) + case qwenEventContent: + // TODO(drifkin): if the same turn contains multiple interleaved content + // events, we naively append them together here. + contentSb.WriteString(event.content) + } + } + + return contentSb.String(), thinkingSb.String(), toolCalls, nil +} + +func (p *Qwen3VLParser) parseEvents() []qwenEvent { + var all []qwenEvent + + keepLooping := true + for keepLooping { + var events []qwenEvent + events, keepLooping = p.eat() + if len(events) > 0 { + all = append(all, events...) + } + } + + if len(all) > 0 { + slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "buffer", p.buffer.String()) + } + + return all +} + +func emitContentBeforeTag(p *Qwen3VLParser, events []qwenEvent, tag string) []qwenEvent { + split := strings.SplitN(p.buffer.String(), tag, 2) + before := split[0] + before = strings.TrimRightFunc(before, unicode.IsSpace) + if len(before) > 0 { + events = append(events, qwenEventContent{content: before}) + } + after := split[1] + p.buffer.Reset() + p.buffer.WriteString(after) + return events +} + +func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) { + var events []qwenEvent + + switch p.state { + case CollectingContent: + if strings.Contains(p.buffer.String(), toolOpenTag) { + events = emitContentBeforeTag(p, events, toolOpenTag) + p.state = CollectingToolContent + return events, true + } else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 { + beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen] + trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen + + unambiguous := p.buffer.String()[:ambiguousStart] + ambiguous := p.buffer.String()[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, qwenEventContent{content: unambiguous}) + } + return events, false + } else { + whitespaceLen := trailingWhitespaceLen(p.buffer.String()) + ambiguousStart := len(p.buffer.String()) - whitespaceLen + + unambiguous := p.buffer.String()[:ambiguousStart] + ambiguous := p.buffer.String()[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, qwenEventContent{content: unambiguous}) + } + return events, false + } + case CollectingToolContent: + if strings.Contains(p.buffer.String(), toolCloseTag) { + split := strings.SplitN(p.buffer.String(), toolCloseTag, 2) + before := split[0] // do we also need to do it to tool calls? + if len(before) == 0 { + slog.Warn("qwen tool call closing tag found but no content before it") + } + + after := strings.TrimLeftFunc(split[1], unicode.IsSpace) + events = append(events, qwenEventRawToolCall{raw: before}) + p.buffer.Reset() + p.buffer.WriteString(after) + p.state = CollectingContent + return events, true + } else { + return events, false + } + case CollectingThinkingContent: + if strings.Contains(p.buffer.String(), thinkingCloseTag) { + split := strings.SplitN(p.buffer.String(), thinkingCloseTag, 2) + // before := split[0] + before := strings.TrimRightFunc(split[0], unicode.IsSpace) + if len(before) == 0 { + slog.Warn("qwen tool call closing tag found but no content before it") + } + after := strings.TrimLeftFunc(split[1], unicode.IsSpace) + if len(before) > 0 { + events = append(events, qwenEventThinkingContent{content: before}) + } + p.buffer.Reset() + p.buffer.WriteString(after) + p.state = CollectingContent + return events, true + } else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 { + beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen] + trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen + + unambiguous := p.buffer.String()[:ambiguousStart] + ambiguous := p.buffer.String()[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, qwenEventThinkingContent{content: unambiguous}) + } + return events, false + } else { + whitespaceLen := trailingWhitespaceLen(p.buffer.String()) + ambiguousStart := len(p.buffer.String()) - whitespaceLen + + unambiguous := p.buffer.String()[:ambiguousStart] + ambiguous := p.buffer.String()[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, qwenEventThinkingContent{content: unambiguous}) + } + return events, false + } + default: + panic("unreachable") + } +} + +func parseJSONToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, error) { + var toolCallFunction api.ToolCallFunction + if err := json.Unmarshal([]byte(raw.raw), &toolCallFunction); err != nil { + return api.ToolCall{}, err + } + + toolCall := api.ToolCall{} + toolCall.Function = toolCallFunction + + return toolCall, nil +} diff --git a/model/parsers/qwen3vl_nonthinking_test.go b/model/parsers/qwen3vl_nonthinking_test.go new file mode 100644 index 00000000..74392946 --- /dev/null +++ b/model/parsers/qwen3vl_nonthinking_test.go @@ -0,0 +1,655 @@ +package parsers + +import ( + "reflect" + "testing" + + "github.com/ollama/ollama/api" +) + +func TestQwen3VLNonThinkingParserStreaming(t *testing.T) { + type step struct { + input string + wantEvents []qwenEvent + } + + cases := []struct { + desc string + steps []step + only bool + }{ + { + desc: "simple thinking", + steps: []step{ + {input: "abc", wantEvents: []qwenEvent{qwenEventContent{content: "abc"}}}, + }, + }, + { + desc: "simple trip thinking", + steps: []step{ + {input: "abc", wantEvents: []qwenEvent{qwenEventContent{content: "abc"}}}, + }, + }, + { + desc: "thinking with split tags", + steps: []step{ + {input: "abc", wantEvents: []qwenEvent{qwenEventContent{content: "abc"}}}, + {input: "", wantEvents: []qwenEvent{qwenEventContent{content: ""}}}, + }, + }, + { + desc: "multiple think tags", + steps: []step{ + {input: "abcactually, is not thinking", wantEvents: []qwenEvent{qwenEventContent{content: "abcactually, is not thinking"}}}, + }, + }, + { + desc: "thinking and tool call", + steps: []step{ + { + input: "I'm thinkingI'm tool calling", + wantEvents: []qwenEvent{ + qwenEventContent{content: "I'm thinking"}, + qwenEventRawToolCall{raw: "I'm tool calling"}, + }, + }, + }, + }, + { + desc: "nested thinking (outside thinking, inside thinking)", + steps: []step{ + { + input: "I'm thinkingI'm nested thinking", + wantEvents: []qwenEvent{ + qwenEventContent{content: "I'm thinkingI'm nested thinking"}, + }, + }, + }, + }, + { + desc: "interleaved thinking", + steps: []step{ + { + input: "I'm thinkingI'm actually content", + wantEvents: []qwenEvent{ + qwenEventContent{content: "I'm thinkingI'm actually content"}, + }, + }, + }, + }, + { + desc: "nested thinking and tool call (outside thinking, inside tool call)", + steps: []step{ + { + input: "I'm thinkingI'm nested tool call", + wantEvents: []qwenEvent{ + qwenEventContent{content: "I'm thinking"}, + qwenEventRawToolCall{raw: "I'm nested tool call"}, + qwenEventContent{content: ""}, + }, + }, + }, + }, + { + desc: "nested thinking and tool call (outside tool call, inside thinking)", + steps: []step{ + { + input: "I'm nested tool callI'm thinking", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "I'm nested tool callI'm thinking"}, + }, + }, + }, + }, + { + desc: "interleaved thinking and tool call", + steps: []step{ + { + input: "I'm thinkingI'm NOT a nested tool callI'm nested tool call 2", + wantEvents: []qwenEvent{ + qwenEventContent{content: "I'm thinking"}, + qwenEventRawToolCall{raw: "I'm NOT a nested tool call"}, + qwenEventRawToolCall{raw: "I'm nested tool call 2"}, + qwenEventContent{content: ""}, + }, + }, + }, + }, + { + desc: "emit unambiguous before partial tool open (trailing ws)", + steps: []step{ + { + input: "abc\u00a0\nabc", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + }, + }, + }, + }, + { + desc: "partial thinking tag fakeout", + steps: []step{ + { + input: "abcunfinished<", // when something is ambiguious, we dont emit anything + wantEvents: []qwenEvent{qwenEventContent{content: "abcunfinished"}}, + }, + }, + }, + { + desc: "test with split tool and content", + steps: []step{ + { + input: "abcunfinished def", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "unfinished"}, + qwenEventContent{content: "def"}, + }, + }, + }, + }, + } + anyOnlies := false + for _, tc := range cases { + if tc.only { + anyOnlies = true + } + } + + for _, tc := range cases { + if anyOnlies && !tc.only { + continue + } + + t.Run(tc.desc, func(t *testing.T) { + parser := Qwen3VLParser{hasThinkingSupport: false} + parser.Init([]api.Tool{}, nil) + + for i, step := range tc.steps { + parser.buffer.WriteString(step.input) + gotEvents := parser.parseEvents() + + if len(gotEvents) == 0 && len(step.wantEvents) == 0 { + // avoid deep equal on empty vs. nil slices + continue + } + + if !reflect.DeepEqual(gotEvents, step.wantEvents) { + t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents) + } + } + }) + } +} + +func TestQwenOldParserStreaming(t *testing.T) { + type step struct { + input string + wantEvents []qwenEvent + } + + cases := []struct { + desc string + steps []step + only bool + }{ + { + desc: "simple message streamed word by word", + steps: []step{ + { + input: "hi", + wantEvents: []qwenEvent{qwenEventContent{content: "hi"}}, + }, + { + input: " there", + wantEvents: []qwenEvent{qwenEventContent{content: " there"}}, + }, + }, + }, + { + desc: "content before tool call", + steps: []step{ + { + input: "hi there", + wantEvents: []qwenEvent{qwenEventContent{content: "hi there"}}, + }, + }, + }, + { + desc: "multiple tool calls in one message", + steps: []step{ + { + input: "before1in tool callafter1in tool call 2after2", + wantEvents: []qwenEvent{ + qwenEventContent{content: "before1"}, + qwenEventRawToolCall{raw: "in tool call"}, + qwenEventContent{content: "after1"}, + qwenEventRawToolCall{raw: "in tool call 2"}, + qwenEventContent{content: "after2"}, + }, + }, + }, + }, + { + desc: "tool calls with split tags", + steps: []step{ + { + input: "beforein tool callaf", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "in tool call"}, + qwenEventContent{content: "af"}, + }, + }, + { + input: "ter", + wantEvents: []qwenEvent{ + qwenEventContent{content: "ter"}, + }, + }, + }, + }, + { + desc: "trailing whitespace between content and tool call", + steps: []step{ + { + input: "abc\ndef", + wantEvents: []qwenEvent{ + qwenEventContent{content: "abc"}, + qwenEventRawToolCall{raw: "def"}, + }, + }, + }, + }, + { + desc: "trailing whitespace between tool call and content", + steps: []step{ + { + input: "abc\ndef", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + qwenEventContent{content: "def"}, + }, + }, + }, + }, + { + desc: "empty content before tool call", + steps: []step{ + { + input: "\nabc", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + }, + }, + }, + }, + { + desc: "partial tool open tag fakeout", + steps: []step{ + { + input: "abc\ntestمرحبا", + wantEvents: []qwenEvent{ + qwenEventContent{content: "你好 🌍"}, + qwenEventRawToolCall{raw: "test"}, + qwenEventContent{content: "مرحبا"}, + }, + }, + }, + }, + { + desc: "arabic text handling", + steps: []step{ + { + input: "مرحبا بالعالم", + wantEvents: []qwenEvent{qwenEventContent{content: "مرحبا بالعالم"}}, + }, + }, + }, + { + desc: "emoji passthrough", + steps: []step{ + { + input: "✅", + wantEvents: []qwenEvent{qwenEventContent{content: "✅"}}, + }, + }, + }, + { + desc: "emoji after tool call", + steps: []step{ + { + input: "test完成 ✅", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "test"}, + qwenEventContent{content: "完成 ✅"}, + }, + }, + }, + }, + { + desc: "unicode streaming with whitespace handling", + steps: []step{ + { + input: "مرحبا", + wantEvents: []qwenEvent{ + qwenEventContent{content: "مرحبا"}, + }, + }, + { + input: " \n", + wantEvents: []qwenEvent{}, + }, + { + input: "世界", + wantEvents: []qwenEvent{ + qwenEventContent{content: " \n世界"}, + }, + }, + }, + }, + { + desc: "non-breaking space withheld across chunks", + steps: []step{ + { + input: "Hello\u00a0", + wantEvents: []qwenEvent{ + qwenEventContent{content: "Hello"}, + }, + }, + { + input: "world", + wantEvents: []qwenEvent{ + qwenEventContent{content: "\u00a0world"}, + }, + }, + }, + }, + { + desc: "ideographic space before partial tool", + steps: []step{ + { + input: "Hello\u3000abc", + wantEvents: []qwenEvent{}, + }, + { + input: "def", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + qwenEventContent{content: "def"}, + }, + }, + }, + }, + { + desc: "ideographic space before partial tool fakeout", + steps: []step{ + { + input: "Hello\u3000abc", + wantEvents: []qwenEvent{ + qwenEventContent{content: "\u3000abc"}, + }, + }, + }, + }, + { + desc: "unicode with partial tool tag", + steps: []step{ + { + input: "测试🎯 b and a < b\""}}`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "exec", + Arguments: map[string]any{ + "command": "ls && echo \"a > b and a < b\"", + }, + }, + }, + }, + { + name: "unicode in function names and parameters", + tools: []api.Tool{}, + rawToolCall: `{"name": "获取天气", "arguments": {"城市": "北京", "message": "Hello! 你好! 🌟 مرحبا"}}`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "获取天气", + Arguments: map[string]any{ + "城市": "北京", + "message": "Hello! 你好! 🌟 مرحبا", + }, + }, + }, + }, + } + + for i, step := range steps { + gotToolCall, err := parseJSONToolCall(qwenEventRawToolCall{raw: step.rawToolCall}, step.tools) + if err != nil { + t.Errorf("step %d (%s): %v", i, step.name, err) + } + if !reflect.DeepEqual(gotToolCall, step.wantToolCall) { + t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall) + } + } +} diff --git a/model/parsers/qwen3vl_thinking_test.go b/model/parsers/qwen3vl_thinking_test.go new file mode 100644 index 00000000..d85a60fd --- /dev/null +++ b/model/parsers/qwen3vl_thinking_test.go @@ -0,0 +1,548 @@ +package parsers + +import ( + "reflect" + "testing" + + "github.com/ollama/ollama/api" +) + +func TestQwen3VLThinkingParserStreaming(t *testing.T) { + type step struct { + input string + wantEvents []qwenEvent + } + + cases := []struct { + desc string + steps []step + only bool + }{ + { + desc: "simple thinking", + steps: []step{ + {input: "abc", wantEvents: []qwenEvent{qwenEventThinkingContent{content: "abc"}}}, + }, + }, + { + desc: "simple trip thinking", + steps: []step{ + {input: "abc", wantEvents: []qwenEvent{qwenEventThinkingContent{content: "abc"}}}, + }, + }, + { + desc: "thinking with split tags", + steps: []step{ + {input: "abc", wantEvents: []qwenEvent{qwenEventThinkingContent{content: "abc"}}}, + {input: "", wantEvents: []qwenEvent{}}, + }, + }, + { + desc: "multiple think tags", + steps: []step{ + {input: "abcactually, is not thinking", wantEvents: []qwenEvent{qwenEventThinkingContent{content: "abcactually, is not thinking"}}}, + }, + }, + { + desc: "thinking and tool call", + steps: []step{ + { + input: "I'm thinkingI'm tool calling", + wantEvents: []qwenEvent{ + qwenEventThinkingContent{content: "I'm thinking"}, + qwenEventRawToolCall{raw: "I'm tool calling"}, + }, + }, + }, + }, + { + desc: "thinking and content", + steps: []step{ + { + input: "I'm thinkingI'm content", + wantEvents: []qwenEvent{ + qwenEventThinkingContent{content: "I'm thinking"}, + qwenEventContent{content: "I'm content"}, + }, + }, + }, + }, + { + desc: "thinking and tool call and content", + }, + { + desc: "nested thinking (outside thinking, inside thinking)", + steps: []step{ + { + input: "I'm thinkingI'm nested thinking", + wantEvents: []qwenEvent{ + qwenEventThinkingContent{content: "I'm thinkingI'm nested thinking"}, + qwenEventContent{content: ""}, + }, + }, + }, + }, + { + desc: "interleaved thinking", + steps: []step{ + { + input: "I'm thinkingI'm actually content", + wantEvents: []qwenEvent{ + qwenEventThinkingContent{content: "I'm thinking"}, + qwenEventContent{content: "I'm actually content"}, + }, + }, + }, + }, + { + desc: "nested thinking and tool call (outside thinking, inside tool call)", + steps: []step{ + { + input: "I'm thinkingI'm nested tool call", + wantEvents: []qwenEvent{qwenEventThinkingContent{content: "I'm thinkingI'm nested tool call"}}, + }, + }, + }, + { + desc: "nested thinking and tool call (outside tool call, inside thinking)", + steps: []step{ + { + input: "I'm nested tool callI'm thinking", + wantEvents: []qwenEvent{ + qwenEventThinkingContent{content: "I'm nested tool callI'm thinking"}, + qwenEventContent{content: ""}, + }, + }, + }, + }, + { + desc: "interleaved thinking and tool call", + steps: []step{ + { + input: "I'm thinkingI'm NOT a nested tool callI'm nested tool call 2", + wantEvents: []qwenEvent{ + qwenEventThinkingContent{content: "I'm thinkingI'm NOT a nested tool call"}, + qwenEventContent{content: ""}, + qwenEventRawToolCall{raw: "I'm nested tool call 2"}, + qwenEventContent{content: ""}, + }, + }, + }, + }, + { + desc: "partial thinking tag fakeout", + steps: []step{ + { + input: "abcunfinishedunfinished"}}, + }, + }, + }, + { + desc: "test with split thinking and content", + steps: []step{ + { + input: "abcunfinishedunfinished"}}, + }, + { + input: "ink> def", + wantEvents: []qwenEvent{ + qwenEventContent{content: "def"}, + }, + }, + }, + }, + { + desc: "thinking with no tags", + steps: []step{ + { + input: "Hello I am thinking", + wantEvents: []qwenEvent{ + qwenEventThinkingContent{content: "Hello I am thinking"}, + }, + }, + { + input: "Hello I am thinking some more", + wantEvents: []qwenEvent{ + qwenEventThinkingContent{content: "Hello I am thinking some more"}, + }, + }, + { + input: "Hello I am think NOT", + wantEvents: []qwenEvent{ + qwenEventThinkingContent{content: "Hello I am think"}, + qwenEventContent{content: "NOT"}, + }, + }, + }, + }, + } + anyOnlies := false + for _, tc := range cases { + if tc.only { + anyOnlies = true + } + } + + for _, tc := range cases { + if anyOnlies && !tc.only { + continue + } + + t.Run(tc.desc, func(t *testing.T) { + parser := Qwen3VLParser{hasThinkingSupport: true} + parser.Init([]api.Tool{}, nil) + // parser.state = CollectingThinkingContent + + for i, step := range tc.steps { + parser.buffer.WriteString(step.input) + gotEvents := parser.parseEvents() + + if len(gotEvents) == 0 && len(step.wantEvents) == 0 { + // avoid deep equal on empty vs. nil slices + continue + } + + if !reflect.DeepEqual(gotEvents, step.wantEvents) { + t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents) + } + } + }) + } +} + +func TestQwen3VLThinkingToolParser(t *testing.T) { + type step struct { + name string + rawToolCall string + tools []api.Tool + wantToolCall api.ToolCall + } + + steps := []step{ + { + name: "simple tool call", + tools: []api.Tool{}, + rawToolCall: `{"name": "get-current-weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get-current-weather", + Arguments: map[string]any{ + "location": "San Francisco, CA", + "unit": "fahrenheit", + }, + }, + }, + }, + { + name: "names with spaces", + tools: []api.Tool{}, + rawToolCall: `{"name": "get current temperature", "arguments": {"location with spaces": "San Francisco", "unit with spaces": "celsius"}}`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get current temperature", + Arguments: map[string]any{ + "location with spaces": "San Francisco", + "unit with spaces": "celsius", + }, + }, + }, + }, + { + name: "names with quotes", + tools: []api.Tool{}, + rawToolCall: `{"name": "\"get current temperature\"", "arguments": {"\"location with spaces\"": "San Francisco", "\"unit with spaces\"": "\"celsius\""}}`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "\"get current temperature\"", + Arguments: map[string]any{ + "\"location with spaces\"": "San Francisco", + "\"unit with spaces\"": "\"celsius\"", + }, + }, + }, + }, + { + name: "tool call with typed parameters (json types)", + tools: []api.Tool{}, + rawToolCall: `{"name": "calculate", "arguments": {"x": 3.14, "y": 42, "enabled": true, "items": ["a", "b", "c"]}}`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "calculate", + Arguments: map[string]any{ + "x": 3.14, + "y": float64(42), + "enabled": true, + "items": []any{"a", "b", "c"}, + }, + }, + }, + }, + { + name: "ampersands in parameter values", + tools: []api.Tool{}, + rawToolCall: `{"name": "exec", "arguments": {"command": "ls && echo \"done\""}}`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "exec", + Arguments: map[string]any{ + "command": "ls && echo \"done\"", + }, + }, + }, + }, + { + name: "angle brackets in parameter values", + tools: []api.Tool{}, + rawToolCall: `{"name": "exec", "arguments": {"command": "ls && echo \"a > b and a < b\""}}`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "exec", + Arguments: map[string]any{ + "command": "ls && echo \"a > b and a < b\"", + }, + }, + }, + }, + { + name: "unicode in function names and parameters", + tools: []api.Tool{}, + rawToolCall: `{"name": "获取天气", "arguments": {"城市": "北京", "message": "Hello! 你好! 🌟 مرحبا"}}`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "获取天气", + Arguments: map[string]any{ + "城市": "北京", + "message": "Hello! 你好! 🌟 مرحبا", + }, + }, + }, + }, + } + + for i, step := range steps { + gotToolCall, err := parseJSONToolCall(qwenEventRawToolCall{raw: step.rawToolCall}, step.tools) + if err != nil { + t.Errorf("step %d (%s): %v", i, step.name, err) + } + if !reflect.DeepEqual(gotToolCall, step.wantToolCall) { + t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall) + } + } +} + +func TestQwen3VLParserState(t *testing.T) { + cases := []struct { + desc string + hasThinking bool + last *api.Message + wantState qwenParserState + }{ + { + desc: "no thinking support => CollectingContent", + hasThinking: false, + last: nil, + wantState: CollectingContent, + }, + { + desc: "thinking support, no last message => CollectingThinkingContent", + hasThinking: true, + last: nil, + wantState: CollectingThinkingContent, + }, + { + desc: "thinking support, last assistant with empty content => CollectingThinkingContent", + hasThinking: true, + last: &api.Message{Role: "assistant", Content: ""}, + wantState: CollectingThinkingContent, + }, + { + desc: "thinking support, last assistant with content => CollectingContent", + hasThinking: true, + last: &api.Message{Role: "assistant", Content: "hello"}, + wantState: CollectingContent, + }, + { + desc: "thinking support, last is user => CollectingThinkingContent", + hasThinking: true, + last: &api.Message{Role: "user", Content: "hi"}, + wantState: CollectingThinkingContent, + }, + } + + for _, tc := range cases { + parser := Qwen3VLParser{hasThinkingSupport: tc.hasThinking} + parser.Init(nil, tc.last) + if parser.state != tc.wantState { + t.Errorf("%s: got state %v, want %v", tc.desc, parser.state, tc.wantState) + } + } +} + +func TestQwen3VLThinkingParserWithThinkingPrefill(t *testing.T) { + type step struct { + input string + wantEvents []qwenEvent + } + + cases := []struct { + desc string + steps []step + only bool + }{ + { + desc: "thinking prefill", + steps: []step{ + {input: "abc", wantEvents: []qwenEvent{qwenEventThinkingContent{content: "abc"}}}, + }, + }, + { + desc: "thinking prefill with content", + steps: []step{ + {input: "abc def", wantEvents: []qwenEvent{qwenEventContent{content: "def"}}}, + }, + }, + { + desc: "thinking prefill with fakeout", + steps: []step{ + {input: "abc", wantEvents: []qwenEvent{}}, + }, + }, + { + desc: "thinking prefill with spaces", + steps: []step{ + {input: " starting content", wantEvents: []qwenEvent{qwenEventContent{content: "starting content"}}}, + }, + }, + } + last := &api.Message{Role: "assistant", Thinking: "i am thinking"} // so if there is thinking the test is still thinking + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + parser := Qwen3VLParser{hasThinkingSupport: true} + parser.Init([]api.Tool{}, last) + + for i, step := range tc.steps { + parser.buffer.WriteString(step.input) + gotEvents := parser.parseEvents() + + if len(gotEvents) == 0 && len(step.wantEvents) == 0 { + // avoid deep equal on empty vs. nil slices + continue + } + + if !reflect.DeepEqual(gotEvents, step.wantEvents) { + t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents) + } + } + }) + } +} + +func TestQwen3VLThinkingParserWithNonThinkingPrefill(t *testing.T) { + type step struct { + input string + wantEvents []qwenEvent + } + + cases := []struct { + desc string + steps []step + only bool + }{ + { + desc: "thinking prefill", + steps: []step{ + {input: "abc", wantEvents: []qwenEvent{qwenEventContent{content: "abc"}}}, + }, + }, + { + desc: "thinking prefill with content", + steps: []step{ + {input: "abc def", wantEvents: []qwenEvent{qwenEventContent{content: "ink> def"}}}, + }, + }, + { + desc: "thinking prefill with fakeout", + steps: []step{ + {input: "abc", wantEvents: []qwenEvent{qwenEventContent{content: ">"}}}, + }, + }, + { + desc: "thinking prefill with spaces", + steps: []step{ + {input: " starting content", wantEvents: []qwenEvent{qwenEventContent{content: " starting content"}}}, + }, + }, + } + last := &api.Message{Role: "assistant", Thinking: "i am thinking", Content: "i am content"} // so if there is thinking the test is still thinking + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + parser := Qwen3VLParser{hasThinkingSupport: true} + parser.Init([]api.Tool{}, last) + + for i, step := range tc.steps { + parser.buffer.WriteString(step.input) + gotEvents := parser.parseEvents() + + if len(gotEvents) == 0 && len(step.wantEvents) == 0 { + // avoid deep equal on empty vs. nil slices + continue + } + + if !reflect.DeepEqual(gotEvents, step.wantEvents) { + t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents) + } + } + }) + } +} + +func TestQwen3VLThinkingParserStreamingAssistantPrefillContent(t *testing.T) { + // last message is assistant with content ⇒ start in CollectingContent + last := &api.Message{Role: "assistant", Content: "has content"} + parser := Qwen3VLParser{hasThinkingSupport: true} + parser.Init([]api.Tool{}, last) + + type step struct { + input string + wantEvents []qwenEvent + } + + steps := []step{ + {input: "abc", wantEvents: []qwenEvent{qwenEventContent{content: "abc"}}}, + {input: "{\"name\": \"x\", \"arguments\": {}}", wantEvents: []qwenEvent{qwenEventRawToolCall{raw: "{\"name\": \"x\", \"arguments\": {}}"}}}, + } + + for i, s := range steps { + parser.buffer.WriteString(s.input) + gotEvents := parser.parseEvents() + if len(gotEvents) == 0 && len(s.wantEvents) == 0 { + continue + } + if !reflect.DeepEqual(gotEvents, s.wantEvents) { + t.Fatalf("step %d: input %q: got %#v, want %#v", i, s.input, gotEvents, s.wantEvents) + } + } +} diff --git a/model/renderers/qwen3coder.go b/model/renderers/qwen3coder.go index df3b3a45..18853019 100644 --- a/model/renderers/qwen3coder.go +++ b/model/renderers/qwen3coder.go @@ -55,7 +55,9 @@ func renderAdditionalKeys(obj any, handledKeys map[string]bool) string { return sb.String() } -func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) { +type Qwen3CoderRenderer struct{} + +func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) { var sb strings.Builder // filter out system messages and choose the first (if any) to win @@ -99,9 +101,7 @@ func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkVa sb.WriteString("\n" + name + "") if len(prop.Type) > 0 { - // TODO(!!!)(drifkin): we should match the reference implementation for - // more complex types here instead of using this format - sb.WriteString("\n" + prop.ToTypeScriptType() + "") + sb.WriteString("\n" + formatToolDefinitionType(prop.Type) + "") } if prop.Description != "" { @@ -215,3 +215,24 @@ func formatToolCallArgument(value any) string { return fmt.Sprintf("%v", value) } + +func formatToolDefinitionType(tp api.PropertyType) string { + if len(tp) == 0 { + return "[]" + } + + if len(tp) == 1 { + return tp[0] + } + + // TODO(drifkin): it would be nice to format the JSON here similarly to + // python's default json.dumps behavior (spaces after commas and colons). + // This would let us be byte-for-byte compatible with the reference + // implementation for most common inputs + jsonBytes, err := json.Marshal(tp) + if err != nil { + return "[]" + } + + return string(jsonBytes) +} diff --git a/model/renderers/qwen3coder_test.go b/model/renderers/qwen3coder_test.go index 4aaa066d..1addee9e 100644 --- a/model/renderers/qwen3coder_test.go +++ b/model/renderers/qwen3coder_test.go @@ -288,7 +288,7 @@ call tool<|im_end|> } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - rendered, err := Qwen3CoderRenderer(tt.msgs, tt.tools, nil) + rendered, err := (&Qwen3CoderRenderer{}).Render(tt.msgs, tt.tools, nil) if err != nil { t.Fatal(err) } @@ -336,3 +336,35 @@ func TestFormatToolCallArgument(t *testing.T) { }) } } + +func TestQwen3ToolDefinitionTypes(t *testing.T) { + tests := []struct { + name string + propertyType api.PropertyType + expected string + }{ + { + name: "simple", + propertyType: api.PropertyType{"string"}, + expected: "string", + }, + { + name: "multiple", + propertyType: api.PropertyType{"string", "number"}, + expected: "[\"string\",\"number\"]", + }, + { + name: "empty", + propertyType: api.PropertyType{}, + expected: "[]", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatToolDefinitionType(tt.propertyType) + if got != tt.expected { + t.Errorf("formatToolDefinitionType() = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/model/renderers/qwen3vl.go b/model/renderers/qwen3vl.go new file mode 100644 index 00000000..8ea4abbb --- /dev/null +++ b/model/renderers/qwen3vl.go @@ -0,0 +1,175 @@ +package renderers + +import ( + "encoding/json" + "strings" + + "github.com/ollama/ollama/api" +) + +func marshalWithSpaces(v any) ([]byte, error) { + b, err := json.Marshal(v) + if err != nil { + return nil, err + } + + out := make([]byte, 0, len(b)+len(b)/8) + inStr, esc := false, false + for _, c := range b { + if inStr { + out = append(out, c) + if esc { + esc = false + continue + } + if c == '\\' { + esc = true + continue + } + if c == '"' { + inStr = false + } + continue + } + switch c { + case '"': + inStr = true + out = append(out, c) + case ':': + out = append(out, ':', ' ') + case ',': + out = append(out, ',', ' ') + default: + out = append(out, c) + } + } + return out, nil +} + +type Qwen3VLRenderer struct { + isThinking bool + + useImgTags bool +} + +func (r *Qwen3VLRenderer) renderContent(content api.Message) string { + // This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go + var subSb strings.Builder + for range content.Images { + // TODO: (jmorganca): how to render this is different for different + // model backends, and so we should eventually parameterize this or + // only output a placeholder such as [img] + if r.useImgTags { + subSb.WriteString("[img]") + } else { + subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>") + } + } + // TODO: support videos + + subSb.WriteString(content.Content) + return subSb.String() +} + +func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) { + var sb strings.Builder + + if len(tools) > 0 { + sb.WriteString(imStartTag + "system\n") + if len(messages) > 0 && messages[0].Role == "system" { + sb.WriteString(messages[0].Content + "\n\n") + } + sb.WriteString("# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n") + for _, tool := range tools { + sb.WriteString("\n") + if b, err := marshalWithSpaces(tool); err == nil { + sb.Write(b) + } + } + sb.WriteString("\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n") + } else if len(messages) > 0 && messages[0].Role == "system" { + sb.WriteString("<|im_start|>system\n" + messages[0].Content + "<|im_end|>\n") + } + multiStepTool := true + lastQueryIndex := len(messages) - 1 // so this is the last user message + + for i := len(messages) - 1; i >= 0; i-- { + message := messages[i] + if multiStepTool && message.Role == "user" { + // Check if content starts with and ends with + content := r.renderContent(message) + if !(strings.HasPrefix(content, "") && strings.HasSuffix(content, "")) { + multiStepTool = false + lastQueryIndex = i + } + } + } + + for i, message := range messages { + content := r.renderContent(message) + + lastMessage := i == len(messages)-1 + prefill := lastMessage && message.Role == "assistant" + + if message.Role == "user" || message.Role == "system" && i != 0 { + sb.WriteString("<|im_start|>" + message.Role + "\n" + content + "<|im_end|>\n") + } else if message.Role == "assistant" { + contentReasoning := "" + + if r.isThinking { + if message.Thinking != "" { + contentReasoning = message.Thinking + } + } + + if r.isThinking && i > lastQueryIndex { + if i == len(messages)-1 || contentReasoning != "" { + sb.WriteString("<|im_start|>" + message.Role + "\n\n" + strings.Trim(contentReasoning, "\n")) // do we want to add a new line here? + if content != "" { + sb.WriteString("\n\n\n" + strings.TrimLeft(content, "\n")) + } + } else { + sb.WriteString("<|im_start|>" + message.Role + "\n" + content) + } + } else { + sb.WriteString("<|im_start|>" + message.Role + "\n" + content) + } + + if len(message.ToolCalls) > 0 { + for j, toolCall := range message.ToolCalls { + if j > 0 || content != "" { + sb.WriteString("\n") + } + + sb.WriteString("\n{\"name\": \"" + toolCall.Function.Name + "\", \"arguments\": ") + if b, err := marshalWithSpaces(toolCall.Function.Arguments); err == nil { + sb.Write(b) + } + sb.WriteString("}\n") + } + } + + if !prefill { + sb.WriteString("<|im_end|>\n") + } + } else if message.Role == "tool" { + if i == 0 || messages[i-1].Role != "tool" { + sb.WriteString("<|im_start|>user") + } + sb.WriteString("\n\n" + message.Content + "\n") + if i == len(messages)-1 || messages[i+1].Role != "tool" { + sb.WriteString("<|im_end|>\n") + } + } + + // prefill at the end + if lastMessage && !prefill { + sb.WriteString("<|im_start|>assistant\n") + if r.isThinking { + sb.WriteString("\n") + } + } + } + + return sb.String(), nil +} diff --git a/model/renderers/qwen3vl_nonthinking_test.go b/model/renderers/qwen3vl_nonthinking_test.go new file mode 100644 index 00000000..d3377e39 --- /dev/null +++ b/model/renderers/qwen3vl_nonthinking_test.go @@ -0,0 +1,521 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" +) + +func TestQwen3VLNonThinkingRenderer(t *testing.T) { + tests := []struct { + name string + msgs []api.Message + images []api.ImageData + tools []api.Tool + useImgTags bool + expected string + }{ + { + name: "prefill", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Tell me something interesting."}, + {Role: "assistant", Content: "I'll tell you something interesting about cats"}, + }, + expected: `<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +Tell me something interesting.<|im_end|> +<|im_start|>assistant +I'll tell you something interesting about cats`, + }, + { + name: "basic", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello, how are you?"}, + }, + expected: `<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +Hello, how are you?<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "With thinking, end assistant.", + msgs: []api.Message{ + {Role: "user", Content: "Tell me a story in two sentences."}, + {Role: "assistant", Content: "abcTo make this story interesting, I will speak in poetry."}, // does the thinking even work? + }, + expected: `<|im_start|>user +Tell me a story in two sentences.<|im_end|> +<|im_start|>assistant +abcTo make this story interesting, I will speak in poetry.`, + }, + { + name: "Multiple thinking", + msgs: []api.Message{ + {Role: "user", Content: "Tell me a story in two sentences."}, + {Role: "assistant", Content: "abcTo make this story interesting, I will speak in poetry.And I will speak in poetry after the first sentence."}, + }, + expected: `<|im_start|>user +Tell me a story in two sentences.<|im_end|> +<|im_start|>assistant +abcTo make this story interesting, I will speak in poetry.And I will speak in poetry after the first sentence.`, // NOTE: the second thinking tag is not captured + }, + { + name: "Multiple thinking, multiple messages.", + msgs: []api.Message{ + {Role: "user", Content: "Tell me a story in two sentences."}, + {Role: "assistant", Content: "abcTo make this story interesting, I will speak in poetry.And I will speak in poetry after the first sentence."}, + {Role: "user", Content: "What is the weather like in San Francisco? I will check the weather in San Francisco for you."}, + {Role: "assistant", Content: "I'll check the weather in San Francisco for you.Speak poetry after the first sentence.Speak poetry after the second sentence."}, + }, + expected: `<|im_start|>user +Tell me a story in two sentences.<|im_end|> +<|im_start|>assistant +abcTo make this story interesting, I will speak in poetry.And I will speak in poetry after the first sentence.<|im_end|> +<|im_start|>user +What is the weather like in San Francisco? I will check the weather in San Francisco for you.<|im_end|> +<|im_start|>assistant +I'll check the weather in San Francisco for you.Speak poetry after the first sentence.Speak poetry after the second sentence.`, + }, + { + name: "Image", + msgs: []api.Message{ + {Role: "user", Content: "Describe this image.", Images: []api.ImageData{api.ImageData("img2")}}, + {Role: "assistant", Content: "Let me analyze this image."}, + }, + expected: `<|im_start|>user +<|vision_start|><|image_pad|><|vision_end|>Describe this image.<|im_end|> +<|im_start|>assistant +Let me analyze this image.`, + }, + { + name: "Image with image tags", + msgs: []api.Message{ + {Role: "user", Content: "Describe this image.", Images: []api.ImageData{api.ImageData("img2")}}, + {Role: "assistant", Content: "Let me analyze this image."}, + }, + useImgTags: true, + expected: `<|im_start|>user +[img]Describe this image.<|im_end|> +<|im_start|>assistant +Let me analyze this image.`, + }, + { + name: "Multiple images", + msgs: []api.Message{ + {Role: "user", Content: "Describe these images.", Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")}}, + }, + expected: `<|im_start|>user +<|vision_start|><|image_pad|><|vision_end|><|vision_start|><|image_pad|><|vision_end|>Describe these images.<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "Multiple images with image tags", + msgs: []api.Message{ + {Role: "user", Content: "Describe these images.", Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")}}, + {Role: "assistant", Content: "Let me analyze this image."}, + }, + useImgTags: true, + expected: `<|im_start|>user +[img][img]Describe these images.<|im_end|> +<|im_start|>assistant +Let me analyze this image.`, + }, + // // NOTE: solved with #12518: https://github.com/ollama/ollama/compare/main...drifkin/stable-tool-args + // { + // name: "with tools and response", + // msgs: []api.Message{ + // {Role: "system", Content: "You are a helpful assistant with access to tools."}, + // {Role: "user", Content: "What's the weather like in New York?"}, + // { + // Role: "assistant", + // Content: "I'll check the weather in New York for you.", + // ToolCalls: []api.ToolCall{ + // { + // Function: api.ToolCallFunction{ + // Name: "get-current-weather", + // Arguments: map[string]any{ + // "location": "New York", + // "unit": "fahrenheit", + // }, + // }, + // }, + // }, + // }, + // {Role: "tool", Content: "80", ToolName: "get-current-weather"}, + // {Role: "user", Content: "That sounds nice! What about San Francisco?"}, + // }, + // tools: []api.Tool{ + // { + // Type: "function", + // Function: api.ToolFunction{ + // Name: "get-current-weather", + // Description: "Get the current weather for a location", + // Parameters: api.ToolFunctionParameters{ + // Type: "object", + // Required: []string{"location"}, + // Properties: map[string]api.ToolProperty{ + // "location": { + // Type: api.PropertyType{"string"}, + // Description: "The city and state, e.g. San Francisco, CA", + // }, + // "unit": { + // Type: api.PropertyType{"string"}, + // Enum: []any{"celsius", "fahrenheit"}, + // Description: "The temperature unit", + // }, + // }, + // }, + // }, + // }, + // }, + // expected: `<|im_start|>system + // You are a helpful assistant with access to tools. + + // # Tools + + // You may call one or more functions to assist with the user query. + + // You are provided with function signatures within XML tags: + // + // {"type": "function", "function": {"name": "get-current-weather", "description": "Get the current weather for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit"}}, "required": ["location"]}}} + // + + // For each function call, return a json object with function name and arguments within XML tags: + // + // {"name": , "arguments": } + // <|im_end|> + // <|im_start|>user + // What's the weather like in New York?<|im_end|> + // <|im_start|>assistant + // I'll check the weather in New York for you. + // + // {"name": "get-current-weather", "arguments": {"location": "New York", "unit": "fahrenheit"}} + // <|im_end|> + // <|im_start|>user + // + // 80 + // <|im_end|> + // <|im_start|>user + // That sounds nice! What about San Francisco?<|im_end|> + // <|im_start|>assistant + // `, + // }, + // // NOTE: solved with #12518: https://github.com/ollama/ollama/compare/main...drifkin/stable-tool-args + // { + // name: "With tools and response, multiple tool calls", + // msgs: []api.Message{ + // { + // Role: "system", + // Content: "You are a helpful assistant with access to tools.", + // }, + // { + // Role: "user", + // Content: "Call two tools for me: add and multiply.", + // }, + // { + // Role: "assistant", + // Content: "Sure, I'll call both tools for you.", + // ToolCalls: []api.ToolCall{ + // { + // Function: api.ToolCallFunction{ + // Name: "add", + // Arguments: map[string]any{ + // "a": 2, + // "b": 3, + // }, + // }, + // }, + // { + // Function: api.ToolCallFunction{ + // Name: "multiply", + // Arguments: map[string]any{ + // "x": 4, + // "y": 5, + // }, + // }, + // }, + // }, + // }, + // { + // Role: "tool", + // Content: "5", + // ToolName: "add", + // }, + // { + // Role: "tool", + // Content: "20", + // ToolName: "multiply", + // }, + // { + // Role: "user", + // Content: "Thanks! What are the results?", + // }, + // }, + // tools: []api.Tool{ + // { + // Type: "function", + // Function: api.ToolFunction{ + // Name: "add", + // Description: "Add two numbers", + // Parameters: api.ToolFunctionParameters{ + // Type: "object", + // Required: []string{"a", "b"}, + // Properties: map[string]api.ToolProperty{ + // "a": {Type: api.PropertyType{"integer"}, Description: "First number"}, + // "b": {Type: api.PropertyType{"integer"}, Description: "Second number"}, + // }, + // }, + // }, + // }, + // { + // Type: "function", + // Function: api.ToolFunction{ + // Name: "multiply", + // Description: "Multiply two numbers", + // Parameters: api.ToolFunctionParameters{ + // Type: "object", + // Required: []string{"x", "y"}, + // Properties: map[string]api.ToolProperty{ + // "x": {Type: api.PropertyType{"integer"}, Description: "First factor"}, + // "y": {Type: api.PropertyType{"integer"}, Description: "Second factor"}, + // }, + // }, + // }, + // }, + // }, + // expected: `<|im_start|>system + // You are a helpful assistant with access to tools. + + // # Tools + + // You may call one or more functions to assist with the user query. + + // You are provided with function signatures within XML tags: + // + // {"type": "function", "function": {"name": "add", "description": "Add two numbers", "parameters": {"type": "object", "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, "required": ["a", "b"]}}} + // {"type": "function", "function": {"name": "multiply", "description": "Multiply two numbers", "parameters": {"type": "object", "properties": {"x": {"description": "First factor"}, "y": {"description": "Second factor"}}, "required": ["x", "y"]}}} + // + + // For each function call, return a json object with function name and arguments within XML tags: + // + // {"name": , "arguments": } + // <|im_end|> + // <|im_start|>user + // Call two tools for me: add and multiply.<|im_end|> + // <|im_start|>assistant + // Sure, I'll call both tools for you. + // + // {"name": "add", "arguments": {"a": 2, "b": 3}} + // + // + // {"name": "multiply", "arguments": {"x": 4, "y": 5}} + // <|im_end|> + // <|im_start|>user + // + // 5 + // + // + // 20 + // <|im_end|> + // <|im_start|>user + // Thanks! What are the results?<|im_end|> + // <|im_start|>assistant + // `, + // }, + { + name: "user tool_response block preserved", + msgs: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "I'll check.", + ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}}, + }, + }, + {Role: "user", Content: "\n18\n"}, + {Role: "user", Content: "Thanks!"}, + }, + expected: `<|im_start|>user +What's the weather?<|im_end|> +<|im_start|>assistant +I'll check. + +{"name": "get-current-weather", "arguments": {"location": "Paris", "unit": "celsius"}} +<|im_end|> +<|im_start|>user + +18 +<|im_end|> +<|im_start|>user +Thanks!<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "assistant with multiple tool calls and content", + msgs: []api.Message{ + {Role: "user", Content: "Hi"}, + { + Role: "assistant", + Content: "before", + ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "add", Arguments: map[string]any{"a": 2, "b": 3}}}, + {Function: api.ToolCallFunction{Name: "mul", Arguments: map[string]any{"x": 4, "y": 5}}}, + }, + }, + }, + expected: `<|im_start|>user +Hi<|im_end|> +<|im_start|>assistant +before + +{"name": "add", "arguments": {"a": 2, "b": 3}} + + +{"name": "mul", "arguments": {"x": 4, "y": 5}} +`, + }, + { + name: "consecutive tool responses grouped", + msgs: []api.Message{ + {Role: "user", Content: "Compute results"}, + {Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "job", Arguments: map[string]any{"n": 1}}}}}, + {Role: "tool", Content: "5", ToolName: "job"}, + {Role: "tool", Content: "6", ToolName: "job"}, + }, + expected: `<|im_start|>user +Compute results<|im_end|> +<|im_start|>assistant +ok + +{"name": "job", "arguments": {"n": 1}} +<|im_end|> +<|im_start|>user + +5 + + +6 +<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "last message is tool then prefill", + msgs: []api.Message{ + {Role: "user", Content: "run"}, + {Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "exec", Arguments: map[string]any{"cmd": "ls"}}}}}, + {Role: "tool", Content: "done", ToolName: "exec"}, + }, + expected: `<|im_start|>user +run<|im_end|> +<|im_start|>assistant +ok + +{"name": "exec", "arguments": {"cmd": "ls"}} +<|im_end|> +<|im_start|>user + +done +<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "user with multiple images", + msgs: []api.Message{ + {Role: "user", Content: "Describe.", Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")}}, + }, + expected: `<|im_start|>user +<|vision_start|><|image_pad|><|vision_end|><|vision_start|><|image_pad|><|vision_end|>Describe.<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "user tool_response, no whitespace", + msgs: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "I'll check.", + ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}}, + }, + }, + {Role: "user", Content: "\n18\n"}, + {Role: "user", Content: "Thanks!"}, + }, + expected: `<|im_start|>user +What's the weather?<|im_end|> +<|im_start|>assistant +I'll check. + +{"name": "get-current-weather", "arguments": {"location": "Paris", "unit": "celsius"}} +<|im_end|> +<|im_start|>user + +18 +<|im_end|> +<|im_start|>user +Thanks!<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "user tool_response with surrounding whitespace", + msgs: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "I'll check.", + ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}}, + }, + }, + {Role: "user", Content: "\n\n\n\n\n18\n extra\n\n\n\n\n\n"}, + }, + expected: `<|im_start|>user +What's the weather?<|im_end|> +<|im_start|>assistant +I'll check. + +{"name": "get-current-weather", "arguments": {"location": "Paris", "unit": "celsius"}} +<|im_end|> +<|im_start|>user + + + + + +18 + extra + + + + + +<|im_end|> +<|im_start|>assistant +`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rendered, err := (&Qwen3VLRenderer{isThinking: false, useImgTags: tt.useImgTags}).Render(tt.msgs, tt.tools, nil) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(rendered, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +} diff --git a/model/renderers/qwen3vl_test.go b/model/renderers/qwen3vl_test.go new file mode 100644 index 00000000..6810a7c9 --- /dev/null +++ b/model/renderers/qwen3vl_test.go @@ -0,0 +1,346 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +// TODO(drifkin): this will be moved to utils in the near future and used by other renderers as well +func TestMarshalWithSpaces(t *testing.T) { + tests := []struct { + name string + input any + expected string + }{ + // basic formatting tests + { + name: "simple object", + input: map[string]any{"key": "value"}, + expected: `{"key": "value"}`, + }, + { + name: "simple array", + input: []any{"a", "b", "c"}, + expected: `["a", "b", "c"]`, + }, + // escaped quotes + { + name: "escaped quote in string", + input: map[string]any{"text": `quote"inside`}, + expected: `{"text": "quote\"inside"}`, + }, + { + name: "multiple escaped quotes", + input: map[string]any{"text": `say "hello" and "goodbye"`}, + expected: `{"text": "say \"hello\" and \"goodbye\""}`, + }, + // escaped backslashes + { + name: "escaped backslash", + input: map[string]any{"path": `C:\windows\system32`}, + expected: `{"path": "C:\\windows\\system32"}`, + }, + { + name: "double backslash", + input: map[string]any{"text": `test\\more`}, + expected: `{"text": "test\\\\more"}`, + }, + { + name: "backslash before quote", + input: map[string]any{"text": `end with \"`}, + expected: `{"text": "end with \\\""}`, + }, + // standard JSON escape sequences + { + name: "newline in string", + input: map[string]any{"text": "line1\nline2"}, + expected: `{"text": "line1\nline2"}`, + }, + { + name: "tab in string", + input: map[string]any{"text": "before\tafter"}, + expected: `{"text": "before\tafter"}`, + }, + { + name: "carriage return", + input: map[string]any{"text": "before\rafter"}, + expected: `{"text": "before\rafter"}`, + }, + { + name: "multiple escape sequences", + input: map[string]any{"text": "line1\nline2\ttab\rcarriage"}, + expected: `{"text": "line1\nline2\ttab\rcarriage"}`, + }, + // strings containing colons and commas (no spaces should be added inside) + { + name: "colon in string", + input: map[string]any{"url": "http://example.com"}, + expected: `{"url": "http://example.com"}`, + }, + { + name: "comma in string", + input: map[string]any{"list": "apple, banana, cherry"}, + expected: `{"list": "apple, banana, cherry"}`, + }, + { + name: "colon and comma in string", + input: map[string]any{"data": "key:value, key2:value2"}, + expected: `{"data": "key:value, key2:value2"}`, + }, + // unicode characters + { + name: "emoji", + input: map[string]any{"emoji": "😀🎉✨"}, + expected: `{"emoji": "😀🎉✨"}`, + }, + { + name: "chinese characters", + input: map[string]any{"text": "你好世界"}, + expected: `{"text": "你好世界"}`, + }, + { + name: "arabic characters", + input: map[string]any{"text": "مرحبا"}, + expected: `{"text": "مرحبا"}`, + }, + { + name: "mixed unicode and ascii", + input: map[string]any{"text": "Hello 世界! 😀"}, + expected: `{"text": "Hello 世界! 😀"}`, + }, + { + name: "unicode with special symbols", + input: map[string]any{"text": "®©™€£¥"}, + expected: `{"text": "®©™€£¥"}`, + }, + // complex combinations - strings that look like JSON + { + name: "json string inside value", + input: map[string]any{"nested": `{"key":"value"}`}, + expected: `{"nested": "{\"key\":\"value\"}"}`, + }, + { + name: "json array inside value", + input: map[string]any{"array": `["a","b","c"]`}, + expected: `{"array": "[\"a\",\"b\",\"c\"]"}`, + }, + // edge cases + { + name: "empty string", + input: map[string]any{"empty": ""}, + expected: `{"empty": ""}`, + }, + { + name: "empty object", + input: map[string]any{}, + expected: `{}`, + }, + { + name: "empty array", + input: []any{}, + expected: `[]`, + }, + { + name: "numbers", + input: map[string]any{"int": 42, "float": 3.14}, + expected: `{"float": 3.14, "int": 42}`, + }, + { + name: "boolean", + input: map[string]any{"bool": true, "other": false}, + expected: `{"bool": true, "other": false}`, + }, + { + name: "null value", + input: map[string]any{"value": nil}, + expected: `{"value": null}`, + }, + // nested structures with complex strings + { + name: "nested object with escapes", + input: map[string]any{ + "outer": map[string]any{ + "path": `C:\folder\file.txt`, + "quote": `He said "hi"`, + }, + }, + expected: `{"outer": {"path": "C:\\folder\\file.txt", "quote": "He said \"hi\""}}`, + }, + { + name: "array with unicode and escapes", + input: []any{ + "normal", + "with\nnewline", + "with\"quote", + "emoji😀", + "colon:comma,", + }, + expected: `["normal", "with\nnewline", "with\"quote", "emoji😀", "colon:comma,"]`, + }, + { + name: "backslash at positions before special chars", + input: map[string]any{"text": `a\b:c\d,e`}, + expected: `{"text": "a\\b:c\\d,e"}`, + }, + { + name: "multiple backslashes before quote", + input: map[string]any{"text": `ends\\"`}, + expected: `{"text": "ends\\\\\""}`, + }, + { + name: "unicode with escapes", + input: map[string]any{"text": "Hello\n世界\t😀"}, + expected: `{"text": "Hello\n世界\t😀"}`, + }, + + // Real-world tool call example + { + name: "tool call arguments", + input: map[string]any{ + "location": "San Francisco, CA", + "unit": "fahrenheit", + "format": "json", + }, + expected: `{"format": "json", "location": "San Francisco, CA", "unit": "fahrenheit"}`, + }, + { + name: "complex tool arguments with escapes", + input: map[string]any{ + "query": `SELECT * FROM "users" WHERE name = 'O'Brien'`, + "description": "Fetch user\ndata from DB", + "path": `C:\data\users.db`, + }, + expected: `{"description": "Fetch user\ndata from DB", "path": "C:\\data\\users.db", "query": "SELECT * FROM \"users\" WHERE name = 'O'Brien'"}`, + }, + { + name: "unicode immediately adjacent to JSON structure chars", + input: map[string]any{"😀key": "😀value", "test": "😀:😀,😀"}, + expected: `{"test": "😀:😀,😀", "😀key": "😀value"}`, + }, + { + name: "long unicode string stress test", + input: map[string]any{"text": "😀😁😂😃😄😅😆😇😈😉😊😋😌😍😎😏😐😑😒😓😔😕😖😗😘😙😚😛😜😝😞😟"}, + expected: `{"text": "😀😁😂😃😄😅😆😇😈😉😊😋😌😍😎😏😐😑😒😓😔😕😖😗😘😙😚😛😜😝😞😟"}`, + }, + { + name: "deeply nested with unicode everywhere", + input: map[string]any{ + "😀": map[string]any{ + "你好": []any{"مرحبا", "®©™", "∑∫∂√"}, + }, + }, + expected: `{"😀": {"你好": ["مرحبا", "®©™", "∑∫∂√"]}}`, + }, + { + name: "unicode with all JSON special chars interleaved", + input: map[string]any{"k😀:k": "v😀,v", "a:😀": "b,😀", "😀": ":,😀,:"}, + expected: `{"a:😀": "b,😀", "k😀:k": "v😀,v", "😀": ":,😀,:"}`, + }, + { + name: "combining diacritics and RTL text", + input: map[string]any{"hebrew": "עִבְרִית", "combined": "é̀ñ", "mixed": "test:עִבְרִית,é̀ñ"}, + expected: `{"combined": "é̀ñ", "hebrew": "עִבְרִית", "mixed": "test:עִבְרִית,é̀ñ"}`, + }, + { + name: "pathological case: unicode + escapes + special chars", + input: map[string]any{"😀": "test\n😀\"quote😀\\backslash😀:colon😀,comma😀"}, + expected: `{"😀": "test\n😀\"quote😀\\backslash😀:colon😀,comma😀"}`, + }, + + // all JSON structural characters inside strings + { + name: "braces and brackets in strings", + input: map[string]any{"text": "test{with}braces[and]brackets"}, + expected: `{"text": "test{with}braces[and]brackets"}`, + }, + { + name: "braces and brackets with colons and commas", + input: map[string]any{"code": "{key:value,[1,2,3]}"}, + expected: `{"code": "{key:value,[1,2,3]}"}`, + }, + { + name: "json-like string with all structural chars", + input: map[string]any{"schema": `{"type":"object","properties":{"name":{"type":"string"},"items":{"type":"array"}}}`}, + expected: `{"schema": "{\"type\":\"object\",\"properties\":{\"name\":{\"type\":\"string\"},\"items\":{\"type\":\"array\"}}}"}`, + }, + + // forward slash tests (JSON allows \/ as an escape sequence) + { + name: "forward slash in URL", + input: map[string]any{"url": "https://example.com/path/to/resource"}, + expected: `{"url": "https://example.com/path/to/resource"}`, + }, + { + name: "regex pattern with slashes", + input: map[string]any{"regex": "/[a-z]+/gi"}, + expected: `{"regex": "/[a-z]+/gi"}`, + }, + + // all JSON escape sequences + { + name: "backspace escape", + input: map[string]any{"text": "before\bafter"}, + expected: `{"text": "before\bafter"}`, + }, + { + name: "form feed escape", + input: map[string]any{"text": "before\fafter"}, + expected: `{"text": "before\fafter"}`, + }, + { + name: "all standard escapes combined", + input: map[string]any{"text": "\"\\\b\f\n\r\t"}, + expected: `{"text": "\"\\\b\f\n\r\t"}`, + }, + + // unicode escape sequences + { + name: "string that forces unicode escapes", + input: map[string]any{"control": "\u0000\u0001\u001f"}, + expected: `{"control": "\u0000\u0001\u001f"}`, + }, + + // empty objects and arrays nested with strings + { + name: "nested empty structures with string values", + input: map[string]any{"empty_obj": map[string]any{}, "empty_arr": []any{}, "text": "{}[]"}, + expected: `{"empty_arr": [], "empty_obj": {}, "text": "{}[]"}`, + }, + + // complex nesting with all structural characters + { + name: "deeply nested with all char types", + input: map[string]any{ + "level1": map[string]any{ + "array": []any{ + map[string]any{"nested": "value:with,special{chars}[here]"}, + []any{"a", "b", "c"}, + }, + }, + }, + expected: `{"level1": {"array": [{"nested": "value:with,special{chars}[here]"}, ["a", "b", "c"]]}}`, + }, + + // string containing escaped structural characters + { + name: "string with multiple escape sequences and structural chars", + input: map[string]any{"data": "test\"quote\"{brace}[bracket]:colon,comma\\backslash/slash"}, + expected: `{"data": "test\"quote\"{brace}[bracket]:colon,comma\\backslash/slash"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := marshalWithSpaces(tt.input) + if err != nil { + t.Fatalf("marshalWithSpaces failed: %v", err) + } + + resultStr := string(result) + if diff := cmp.Diff(resultStr, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +} diff --git a/model/renderers/qwen3vl_thinking_test.go b/model/renderers/qwen3vl_thinking_test.go new file mode 100644 index 00000000..eb53e6a9 --- /dev/null +++ b/model/renderers/qwen3vl_thinking_test.go @@ -0,0 +1,372 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" +) + +func TestQwen3VLThinkingRenderer(t *testing.T) { + tests := []struct { + name string + msgs []api.Message + images []api.ImageData + tools []api.Tool + expected string + }{ + { + name: "basic", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello, how are you?"}, + }, + expected: `<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +Hello, how are you?<|im_end|> +<|im_start|>assistant + +`, + }, + { + name: "With thinking, end assistant.", + msgs: []api.Message{ + {Role: "user", Content: "Tell me a story in two sentences."}, + {Role: "assistant", Content: "abc", Thinking: "To make this story interesting, I will speak in poetry."}, + }, + expected: `<|im_start|>user +Tell me a story in two sentences.<|im_end|> +<|im_start|>assistant + +To make this story interesting, I will speak in poetry. + + +abc`, + }, + { + name: "With thinking, end assistant.", + msgs: []api.Message{ + {Role: "user", Content: "Tell me a story in two sentences."}, + {Role: "assistant", Thinking: "To make this story interesting, I will speak in poetry."}, + }, + expected: `<|im_start|>user +Tell me a story in two sentences.<|im_end|> +<|im_start|>assistant + +To make this story interesting, I will speak in poetry.`, + }, + { + name: "Multiple thinking", + msgs: []api.Message{ + {Role: "user", Content: "Tell me a story in two sentences."}, + {Role: "assistant", Content: "abc", Thinking: "To make this story interesting, I will speak in poetry.And I will speak in poetry after the first sentence."}, + }, + expected: `<|im_start|>user +Tell me a story in two sentences.<|im_end|> +<|im_start|>assistant + +To make this story interesting, I will speak in poetry.And I will speak in poetry after the first sentence. + + +abc`, // NOTE: the second thinking tag is not captured + }, + { + name: "Multiple thinking, multiple messages.", + msgs: []api.Message{ + {Role: "user", Content: "Tell me a story in two sentences."}, + {Role: "assistant", Thinking: "To make this story interesting, I will speak in poetry.", Content: "abc"}, + {Role: "user", Content: "What is the weather like in San Francisco?"}, + {Role: "assistant", Thinking: "Speak poetry after the first sentence.Speak poetry after the second sentence."}, + }, + expected: `<|im_start|>user +Tell me a story in two sentences.<|im_end|> +<|im_start|>assistant +abc<|im_end|> +<|im_start|>user +What is the weather like in San Francisco?<|im_end|> +<|im_start|>assistant + +Speak poetry after the first sentence.Speak poetry after the second sentence.`, + }, + // NOTE: Servers automatically prepend a [img-] tag + // { + // name: "Image", + // msgs: []api.Message{ + // {Role: "user", Content: "Describe this image.", Images: []api.ImageData{api.ImageData(IMAGE2_BASE64)}}, + // }, + // expected: `<|im_start|>user + // [img-0]Describe this image.<|im_end|> + // <|im_start|>assistant + // + // `, + // }, + + // NOTE: Servers automatically prepend a [img-] tag + // { + // name: "Multiple images", + // msgs: []api.Message{ + // {Role: "user", Content: "Describe these images.", Images: []api.ImageData{api.ImageData(IMAGE1_BASE64), api.ImageData(IMAGE2_BASE64)}}, + // }, + // expected: `<|im_start|>user + // [img-0][img-1]Describe these images.<|im_end|> + // <|im_start|>assistant + // + // `, + // }, + + // NOTE: solved with #12518: https://github.com/ollama/ollama/compare/main...drifkin/stable-tool-args + // { + // name: "with tools and response", + // msgs: []api.Message{ + // {Role: "system", Content: "You are a helpful assistant with access to tools."}, + // {Role: "user", Content: "What's the weather like in New York?"}, + // { + // Role: "assistant", + // Content: "I'll check the weather in New York for you.", + // ToolCalls: []api.ToolCall{ + // { + // Function: api.ToolCallFunction{ + // Name: "get-current-weather", + // Arguments: map[string]any{ + // "location": "New York", + // "unit": "fahrenheit", + // }, + // }, + // }, + // }, + // }, + // {Role: "tool", Content: "80", ToolName: "get-current-weather"}, + // {Role: "user", Content: "That sounds nice! What about San Francisco?"}, + // }, + // tools: []api.Tool{ + // { + // Type: "function", + // Function: api.ToolFunction{ + // Name: "get-current-weather", + // Description: "Get the current weather for a location", + // Parameters: api.ToolFunctionParameters{ + // Type: "object", + // Required: []string{"location"}, + // Properties: map[string]api.ToolProperty{ + // "location": { + // Type: api.PropertyType{"string"}, + // Description: "The city and state, e.g. San Francisco, CA", + // }, + // "unit": { + // Type: api.PropertyType{"string"}, + // Enum: []any{"celsius", "fahrenheit"}, + // Description: "The temperature unit", + // }, + // }, + // }, + // }, + // }, + // }, + // expected: `<|im_start|>system + // You are a helpful assistant with access to tools. + + // # Tools + + // You may call one or more functions to assist with the user query. + + // You are provided with function signatures within XML tags: + // + // {"type": "function", "function": {"name": "get-current-weather", "description": "Get the current weather for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The temperature unit"}}, "required": ["location"]}}} + // + + // For each function call, return a json object with function name and arguments within XML tags: + // + // {"name": , "arguments": } + // <|im_end|> + // <|im_start|>user + // What's the weather like in New York?<|im_end|> + // <|im_start|>assistant + // I'll check the weather in New York for you. + // + // {"name": "get-current-weather", "arguments": {"location": "New York", "unit": "fahrenheit"}} + // <|im_end|> + // <|im_start|>user + // + // 80 + // <|im_end|> + // <|im_start|>user + // That sounds nice! What about San Francisco?<|im_end|> + // <|im_start|>assistant + // + // `, + // }, + + // NOTE: solved with #12518: https://github.com/ollama/ollama/compare/main...drifkin/stable-tool-args + // { + // name: "With tools and response, multiple tool calls", + // msgs: []api.Message{ + // { + // Role: "system", + // Content: "You are a helpful assistant with access to tools.", + // }, + // { + // Role: "user", + // Content: "Call two tools for me: add and multiply.", + // }, + // { + // Role: "assistant", + // Content: "Sure, I'll call both tools for you.", + // ToolCalls: []api.ToolCall{ + // { + // Function: api.ToolCallFunction{ + // Name: "add", + // Arguments: map[string]any{ + // "a": 2, + // "b": 3, + // }, + // }, + // }, + // { + // Function: api.ToolCallFunction{ + // Name: "multiply", + // Arguments: map[string]any{ + // "x": 4, + // "y": 5, + // }, + // }, + // }, + // }, + // }, + // { + // Role: "tool", + // Content: "5", + // ToolName: "add", + // }, + // { + // Role: "tool", + // Content: "20", + // ToolName: "multiply", + // }, + // { + // Role: "user", + // Content: "Thanks! What are the results?", + // }, + // }, + // tools: []api.Tool{ + // { + // Type: "function", + // Function: api.ToolFunction{ + // Name: "add", + // Description: "Add two numbers", + // Parameters: api.ToolFunctionParameters{ + // Type: "object", + // Required: []string{"a", "b"}, + // Properties: map[string]api.ToolProperty{ + // "a": {Type: api.PropertyType{"integer"}, Description: "First number"}, + // "b": {Type: api.PropertyType{"integer"}, Description: "Second number"}, + // }, + // }, + // }, + // }, + // { + // Type: "function", + // Function: api.ToolFunction{ + // Name: "multiply", + // Description: "Multiply two numbers", + // Parameters: api.ToolFunctionParameters{ + // Type: "object", + // Required: []string{"x", "y"}, + // Properties: map[string]api.ToolProperty{ + // "x": {Type: api.PropertyType{"integer"}, Description: "First factor"}, + // "y": {Type: api.PropertyType{"integer"}, Description: "Second factor"}, + // }, + // }, + // }, + // }, + // }, + // expected: `<|im_start|>system + // You are a helpful assistant with access to tools. + + // # Tools + + // You may call one or more functions to assist with the user query. + + // You are provided with function signatures within XML tags: + // + // {"type": "function", "function": {"name": "add", "description": "Add two numbers", "parameters": {"type": "object", "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, "required": ["a", "b"]}}} + // {"type": "function", "function": {"name": "multiply", "description": "Multiply two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "integer"}, "y": {"type": "integer"}}, "required": ["x", "y"]}}} + // + + // For each function call, return a json object with function name and arguments within XML tags: + // + // {"name": , "arguments": } + // <|im_end|> + // <|im_start|>user + // Call two tools for me: add and multiply.<|im_end|> + // <|im_start|>assistant + // Sure, I'll call both tools for you. + // + // {"name": "add", "arguments": {"a": 2, "b": 3}} + // + // + // {"name": "multiply", "arguments": {"x": 4, "y": 5}} + // <|im_end|> + // <|im_start|>user + // + // 5 + // + // + // 20 + // <|im_end|> + // <|im_start|>user + // Thanks! What are the results?<|im_end|> + // <|im_start|>assistant + // + // `, + // }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rendered, err := (&Qwen3VLRenderer{isThinking: true}).Render(tt.msgs, tt.tools, nil) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(rendered, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestFormatToolCallArgumentThinkingVL(t *testing.T) { + tests := []struct { + name string + arg any + expected string + }{ + { + name: "string", + arg: "foo", + expected: "foo", + }, + { + name: "map", + arg: map[string]any{"foo": "bar"}, + expected: "{\"foo\":\"bar\"}", + }, + { + name: "number", + arg: 1, + expected: "1", + }, + { + name: "boolean", + arg: true, + expected: "true", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatToolCallArgument(tt.arg) + if got != tt.expected { + t.Errorf("formatToolCallArgument(%v) = %v, want %v", tt.arg, got, tt.expected) + } + }) + } +} diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go index 2dfb51e4..d995579c 100644 --- a/model/renderers/renderer.go +++ b/model/renderers/renderer.go @@ -6,20 +6,56 @@ import ( "github.com/ollama/ollama/api" ) -type rendererFunc func([]api.Message, []api.Tool, *api.ThinkValue) (string, error) +type Renderer interface { + Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) +} + +type ( + RendererConstructor func() Renderer + RendererRegistry struct { + renderers map[string]RendererConstructor + } +) + +// RenderImgTags is a global flag that tells renderers to use [img] tags +// for images. This is set by the Ollama server package on init, or left as +// false for other environments where renderers are used +var RenderImgTags bool + +func (r *RendererRegistry) Register(name string, renderer RendererConstructor) { + r.renderers[name] = renderer +} + +var registry = RendererRegistry{ + renderers: make(map[string]RendererConstructor), +} + +func Register(name string, renderer RendererConstructor) { + registry.Register(name, renderer) +} func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { renderer := rendererForName(name) if renderer == nil { return "", fmt.Errorf("unknown renderer %q", name) } - return renderer(msgs, tools, think) + return renderer.Render(msgs, tools, think) } -func rendererForName(name string) rendererFunc { +func rendererForName(name string) Renderer { + if constructor, ok := registry.renderers[name]; ok { + return constructor() + } switch name { case "qwen3-coder": - return Qwen3CoderRenderer + renderer := &Qwen3CoderRenderer{} + return renderer + case "qwen3-vl-instruct": + renderer := &Qwen3VLRenderer{isThinking: false, useImgTags: RenderImgTags} + return renderer + case "qwen3-vl-thinking": + renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags} + return renderer default: return nil } diff --git a/model/renderers/renderer_test.go b/model/renderers/renderer_test.go new file mode 100644 index 00000000..8625634c --- /dev/null +++ b/model/renderers/renderer_test.go @@ -0,0 +1,67 @@ +package renderers + +import ( + "testing" + + "github.com/ollama/ollama/api" +) + +type mockRenderer struct{} + +func (m *mockRenderer) Render(msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { + return "mock-output", nil +} + +func TestRegisterCustomRenderer(t *testing.T) { + // Register a custom renderer + Register("custom-renderer", func() Renderer { + return &mockRenderer{} + }) + + // Retrieve and use it + result, err := RenderWithRenderer("custom-renderer", nil, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "mock-output" { + t.Errorf("expected 'mock-output', got %q", result) + } +} + +func TestBuiltInRendererStillWorks(t *testing.T) { + // Test that qwen3-coder still works + messages := []api.Message{ + {Role: "user", Content: "Hello"}, + } + + result, err := RenderWithRenderer("qwen3-coder", messages, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == "" { + t.Error("expected non-empty result from qwen3-coder renderer") + } +} + +func TestOverrideBuiltInRenderer(t *testing.T) { + // Override the built-in renderer + Register("qwen3-coder", func() Renderer { + return &mockRenderer{} + }) + + // Should get the override + result, err := RenderWithRenderer("qwen3-coder", nil, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "mock-output" { + t.Errorf("expected 'mock-output' from override, got %q", result) + } +} + +func TestUnknownRendererReturnsError(t *testing.T) { + _, err := RenderWithRenderer("nonexistent-renderer", nil, nil, nil) + if err == nil { + t.Error("expected error for unknown renderer") + } +} diff --git a/openai/openai.go b/openai/openai.go index 7ef5ac6d..23e9522f 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -1,21 +1,18 @@ -// openai package provides middleware for partial compatibility with the OpenAI REST API +// openai package provides core transformation logic for partial compatibility with the OpenAI REST API package openai import ( - "bytes" "encoding/base64" "encoding/json" "errors" "fmt" - "io" "log/slog" "math/rand" "net/http" + "slices" "strings" "time" - "github.com/gin-gonic/gin" - "github.com/ollama/ollama/api" "github.com/ollama/ollama/types/model" ) @@ -86,7 +83,7 @@ type StreamOptions struct { } type Reasoning struct { - Effort *string `json:"effort,omitempty"` + Effort string `json:"effort,omitempty"` } type ChatCompletionRequest struct { @@ -220,11 +217,12 @@ func NewError(code int, message string) ErrorResponse { return ErrorResponse{Error{Type: etype, Message: message}} } -func toUsage(r api.ChatResponse) Usage { +// ToUsage converts an api.ChatResponse to Usage +func ToUsage(r api.ChatResponse) Usage { return Usage{ - PromptTokens: r.PromptEvalCount, - CompletionTokens: r.EvalCount, - TotalTokens: r.PromptEvalCount + r.EvalCount, + PromptTokens: r.Metrics.PromptEvalCount, + CompletionTokens: r.Metrics.EvalCount, + TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount, } } @@ -237,7 +235,8 @@ func toolCallId() string { return "call_" + strings.ToLower(string(b)) } -func toToolCalls(tc []api.ToolCall) []ToolCall { +// ToToolCalls converts api.ToolCall to OpenAI ToolCall format +func ToToolCalls(tc []api.ToolCall) []ToolCall { toolCalls := make([]ToolCall, len(tc)) for i, tc := range tc { toolCalls[i].ID = toolCallId() @@ -256,8 +255,9 @@ func toToolCalls(tc []api.ToolCall) []ToolCall { return toolCalls } -func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { - toolCalls := toToolCalls(r.Message.ToolCalls) +// ToChatCompletion converts an api.ChatResponse to ChatCompletion +func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion { + toolCalls := ToToolCalls(r.Message.ToolCalls) return ChatCompletion{ Id: id, Object: "chat.completion", @@ -276,13 +276,14 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { } return nil }(r.DoneReason), - }}, Usage: toUsage(r), + }}, Usage: ToUsage(r), DebugInfo: r.DebugInfo, } } -func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk { - toolCalls := toToolCalls(r.Message.ToolCalls) +// ToChunk converts an api.ChatResponse to ChatCompletionChunk +func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk { + toolCalls := ToToolCalls(r.Message.ToolCalls) return ChatCompletionChunk{ Id: id, Object: "chat.completion.chunk", @@ -305,15 +306,17 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu } } -func toUsageGenerate(r api.GenerateResponse) Usage { +// ToUsageGenerate converts an api.GenerateResponse to Usage +func ToUsageGenerate(r api.GenerateResponse) Usage { return Usage{ - PromptTokens: r.PromptEvalCount, - CompletionTokens: r.EvalCount, - TotalTokens: r.PromptEvalCount + r.EvalCount, + PromptTokens: r.Metrics.PromptEvalCount, + CompletionTokens: r.Metrics.EvalCount, + TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount, } } -func toCompletion(id string, r api.GenerateResponse) Completion { +// ToCompletion converts an api.GenerateResponse to Completion +func ToCompletion(id string, r api.GenerateResponse) Completion { return Completion{ Id: id, Object: "text_completion", @@ -330,11 +333,12 @@ func toCompletion(id string, r api.GenerateResponse) Completion { return nil }(r.DoneReason), }}, - Usage: toUsageGenerate(r), + Usage: ToUsageGenerate(r), } } -func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk { +// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk +func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk { return CompletionChunk{ Id: id, Object: "text_completion", @@ -354,7 +358,8 @@ func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk { } } -func toListCompletion(r api.ListResponse) ListCompletion { +// ToListCompletion converts an api.ListResponse to ListCompletion +func ToListCompletion(r api.ListResponse) ListCompletion { var data []Model for _, m := range r.Models { data = append(data, Model{ @@ -371,7 +376,8 @@ func toListCompletion(r api.ListResponse) ListCompletion { } } -func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { +// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList +func ToEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { if r.Embeddings != nil { var data []Embedding for i, e := range r.Embeddings { @@ -396,7 +402,8 @@ func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { return EmbeddingList{} } -func toModel(r api.ShowResponse, m string) Model { +// ToModel converts an api.ShowResponse to Model +func ToModel(r api.ShowResponse, m string) Model { return Model{ Id: m, Object: "model", @@ -405,7 +412,8 @@ func toModel(r api.ShowResponse, m string) Model { } } -func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { +// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest +func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { var messages []api.Message for _, msg := range r.Messages { toolName := "" @@ -417,7 +425,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { } switch content := msg.Content.(type) { case string: - toolCalls, err := fromCompletionToolCall(msg.ToolCalls) + toolCalls, err := FromCompletionToolCall(msg.ToolCalls) if err != nil { return nil, err } @@ -449,6 +457,11 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { types := []string{"jpeg", "jpg", "png", "webp"} valid := false + // support blank mime type to match api/chat taking just unadorned base64 + if strings.HasPrefix(url, "data:;base64,") { + url = strings.TrimPrefix(url, "data:;base64,") + valid = true + } for _, t := range types { prefix := "data:image/" + t + ";base64," if strings.HasPrefix(url, prefix) { @@ -475,7 +488,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { // since we might have added multiple messages above, if we have tools // calls we'll add them to the last message if len(messages) > 0 && len(msg.ToolCalls) > 0 { - toolCalls, err := fromCompletionToolCall(msg.ToolCalls) + toolCalls, err := FromCompletionToolCall(msg.ToolCalls) if err != nil { return nil, err } @@ -560,13 +573,23 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { } var think *api.ThinkValue + var effort string + if r.Reasoning != nil { - think = &api.ThinkValue{ - Value: *r.Reasoning.Effort, - } + effort = r.Reasoning.Effort } else if r.ReasoningEffort != nil { - think = &api.ThinkValue{ - Value: *r.ReasoningEffort, + effort = *r.ReasoningEffort + } + + if effort != "" { + if !slices.Contains([]string{"high", "medium", "low", "none"}, effort) { + return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", effort) + } + + if effort == "none" { + think = &api.ThinkValue{Value: false} + } else { + think = &api.ThinkValue{Value: effort} } } @@ -596,7 +619,8 @@ func nameFromToolCallID(messages []Message, toolCallID string) string { return "" } -func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) { +// FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall +func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) { apiToolCalls := make([]api.ToolCall, len(toolCalls)) for i, tc := range toolCalls { apiToolCalls[i].Function.Name = tc.Function.Name @@ -609,7 +633,8 @@ func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) { return apiToolCalls, nil } -func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { +// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest +func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { options := make(map[string]any) switch stop := r.Stop.(type) { @@ -660,413 +685,3 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { DebugRenderOnly: r.DebugRenderOnly, }, nil } - -type BaseWriter struct { - gin.ResponseWriter -} - -type ChatWriter struct { - stream bool - streamOptions *StreamOptions - id string - toolCallSent bool - BaseWriter -} - -type CompleteWriter struct { - stream bool - streamOptions *StreamOptions - id string - BaseWriter -} - -type ListWriter struct { - BaseWriter -} - -type RetrieveWriter struct { - BaseWriter - model string -} - -type EmbedWriter struct { - BaseWriter - model string -} - -func (w *BaseWriter) writeError(data []byte) (int, error) { - var serr api.StatusError - err := json.Unmarshal(data, &serr) - if err != nil { - return 0, err - } - - w.ResponseWriter.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error())) - if err != nil { - return 0, err - } - - return len(data), nil -} - -func (w *ChatWriter) writeResponse(data []byte) (int, error) { - var chatResponse api.ChatResponse - err := json.Unmarshal(data, &chatResponse) - if err != nil { - return 0, err - } - - // chat chunk - if w.stream { - c := toChunk(w.id, chatResponse, w.toolCallSent) - d, err := json.Marshal(c) - if err != nil { - return 0, err - } - if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 { - w.toolCallSent = true - } - - w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") - _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) - if err != nil { - return 0, err - } - - if chatResponse.Done { - if w.streamOptions != nil && w.streamOptions.IncludeUsage { - u := toUsage(chatResponse) - c.Usage = &u - c.Choices = []ChunkChoice{} - d, err := json.Marshal(c) - if err != nil { - return 0, err - } - _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) - if err != nil { - return 0, err - } - } - _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) - if err != nil { - return 0, err - } - } - - return len(data), nil - } - - // chat completion - w.ResponseWriter.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse)) - if err != nil { - return 0, err - } - - return len(data), nil -} - -func (w *ChatWriter) Write(data []byte) (int, error) { - code := w.ResponseWriter.Status() - if code != http.StatusOK { - return w.writeError(data) - } - - return w.writeResponse(data) -} - -func (w *CompleteWriter) writeResponse(data []byte) (int, error) { - var generateResponse api.GenerateResponse - err := json.Unmarshal(data, &generateResponse) - if err != nil { - return 0, err - } - - // completion chunk - if w.stream { - c := toCompleteChunk(w.id, generateResponse) - if w.streamOptions != nil && w.streamOptions.IncludeUsage { - c.Usage = &Usage{} - } - d, err := json.Marshal(c) - if err != nil { - return 0, err - } - - w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") - _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) - if err != nil { - return 0, err - } - - if generateResponse.Done { - if w.streamOptions != nil && w.streamOptions.IncludeUsage { - u := toUsageGenerate(generateResponse) - c.Usage = &u - c.Choices = []CompleteChunkChoice{} - d, err := json.Marshal(c) - if err != nil { - return 0, err - } - _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) - if err != nil { - return 0, err - } - } - _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) - if err != nil { - return 0, err - } - } - - return len(data), nil - } - - // completion - w.ResponseWriter.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse)) - if err != nil { - return 0, err - } - - return len(data), nil -} - -func (w *CompleteWriter) Write(data []byte) (int, error) { - code := w.ResponseWriter.Status() - if code != http.StatusOK { - return w.writeError(data) - } - - return w.writeResponse(data) -} - -func (w *ListWriter) writeResponse(data []byte) (int, error) { - var listResponse api.ListResponse - err := json.Unmarshal(data, &listResponse) - if err != nil { - return 0, err - } - - w.ResponseWriter.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse)) - if err != nil { - return 0, err - } - - return len(data), nil -} - -func (w *ListWriter) Write(data []byte) (int, error) { - code := w.ResponseWriter.Status() - if code != http.StatusOK { - return w.writeError(data) - } - - return w.writeResponse(data) -} - -func (w *RetrieveWriter) writeResponse(data []byte) (int, error) { - var showResponse api.ShowResponse - err := json.Unmarshal(data, &showResponse) - if err != nil { - return 0, err - } - - // retrieve completion - w.ResponseWriter.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model)) - if err != nil { - return 0, err - } - - return len(data), nil -} - -func (w *RetrieveWriter) Write(data []byte) (int, error) { - code := w.ResponseWriter.Status() - if code != http.StatusOK { - return w.writeError(data) - } - - return w.writeResponse(data) -} - -func (w *EmbedWriter) writeResponse(data []byte) (int, error) { - var embedResponse api.EmbedResponse - err := json.Unmarshal(data, &embedResponse) - if err != nil { - return 0, err - } - - w.ResponseWriter.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse)) - if err != nil { - return 0, err - } - - return len(data), nil -} - -func (w *EmbedWriter) Write(data []byte) (int, error) { - code := w.ResponseWriter.Status() - if code != http.StatusOK { - return w.writeError(data) - } - - return w.writeResponse(data) -} - -func ListMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - w := &ListWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - } - - c.Writer = w - - c.Next() - } -} - -func RetrieveMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) - return - } - - c.Request.Body = io.NopCloser(&b) - - // response writer - w := &RetrieveWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - model: c.Param("model"), - } - - c.Writer = w - - c.Next() - } -} - -func CompletionsMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - var req CompletionRequest - err := c.ShouldBindJSON(&req) - if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) - return - } - - var b bytes.Buffer - genReq, err := fromCompleteRequest(req) - if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) - return - } - - if err := json.NewEncoder(&b).Encode(genReq); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) - return - } - - c.Request.Body = io.NopCloser(&b) - - w := &CompleteWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - stream: req.Stream, - id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), - streamOptions: req.StreamOptions, - } - - c.Writer = w - c.Next() - } -} - -func EmbeddingsMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - var req EmbedRequest - err := c.ShouldBindJSON(&req) - if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) - return - } - - if req.Input == "" { - req.Input = []string{""} - } - - if req.Input == nil { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input")) - return - } - - if v, ok := req.Input.([]any); ok && len(v) == 0 { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input")) - return - } - - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) - return - } - - c.Request.Body = io.NopCloser(&b) - - w := &EmbedWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - model: req.Model, - } - - c.Writer = w - - c.Next() - } -} - -func ChatMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - var req ChatCompletionRequest - err := c.ShouldBindJSON(&req) - if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) - return - } - - if len(req.Messages) == 0 { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'")) - return - } - - var b bytes.Buffer - - chatReq, err := fromChatRequest(req) - if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) - return - } - - if err := json.NewEncoder(&b).Encode(chatReq); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) - return - } - - c.Request.Body = io.NopCloser(&b) - - w := &ChatWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - stream: req.Stream, - id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), - streamOptions: req.StreamOptions, - } - - c.Writer = w - - c.Next() - } -} diff --git a/openai/openai_test.go b/openai/openai_test.go index 0d7f016b..0f1a877f 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -1,19 +1,8 @@ package openai import ( - "bytes" "encoding/base64" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "reflect" - "strings" "testing" - "time" - - "github.com/gin-gonic/gin" - "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" ) @@ -23,905 +12,139 @@ const ( image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` ) -var ( - False = false - True = true -) +func TestFromChatRequest_Basic(t *testing.T) { + req := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{ + {Role: "user", Content: "Hello"}, + }, + } -func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { - return func(c *gin.Context) { - bodyBytes, _ := io.ReadAll(c.Request.Body) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - err := json.Unmarshal(bodyBytes, capturedRequest) - if err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request") - } - c.Next() + result, err := FromChatRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Model != "test-model" { + t.Errorf("expected model 'test-model', got %q", result.Model) + } + + if len(result.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(result.Messages)) + } + + if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" { + t.Errorf("unexpected message: %+v", result.Messages[0]) } } -func TestChatMiddleware(t *testing.T) { - type testCase struct { - name string - body string - req api.ChatRequest - err ErrorResponse - } +func TestFromChatRequest_WithImage(t *testing.T) { + imgData, _ := base64.StdEncoding.DecodeString(image) - var capturedRequest *api.ChatRequest - - testCases := []testCase{ - { - name: "chat handler", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "Hello"} - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "Hello", + req := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{ + { + Role: "user", + Content: []any{ + map[string]any{"type": "text", "text": "Hello"}, + map[string]any{ + "type": "image_url", + "image_url": map[string]any{"url": prefix + image}, }, }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "chat handler with options", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "Hello"} - ], - "stream": true, - "max_tokens": 999, - "seed": 123, - "stop": ["\n", "stop"], - "temperature": 3.0, - "frequency_penalty": 4.0, - "presence_penalty": 5.0, - "top_p": 6.0, - "response_format": {"type": "json_object"} - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "Hello", - }, - }, - Options: map[string]any{ - "num_predict": 999.0, // float because JSON doesn't distinguish between float and int - "seed": 123.0, - "stop": []any{"\n", "stop"}, - "temperature": 3.0, - "frequency_penalty": 4.0, - "presence_penalty": 5.0, - "top_p": 6.0, - }, - Format: json.RawMessage(`"json"`), - Stream: &True, - }, - }, - { - name: "chat handler with streaming usage", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "Hello"} - ], - "stream": true, - "stream_options": {"include_usage": true}, - "max_tokens": 999, - "seed": 123, - "stop": ["\n", "stop"], - "temperature": 3.0, - "frequency_penalty": 4.0, - "presence_penalty": 5.0, - "top_p": 6.0, - "response_format": {"type": "json_object"} - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "Hello", - }, - }, - Options: map[string]any{ - "num_predict": 999.0, // float because JSON doesn't distinguish between float and int - "seed": 123.0, - "stop": []any{"\n", "stop"}, - "temperature": 3.0, - "frequency_penalty": 4.0, - "presence_penalty": 5.0, - "top_p": 6.0, - }, - Format: json.RawMessage(`"json"`), - Stream: &True, - }, - }, - { - name: "chat handler with image content", - body: `{ - "model": "test-model", - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "Hello" - }, - { - "type": "image_url", - "image_url": { - "url": "` + prefix + image + `" - } - } - ] - } - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "Hello", - }, - { - Role: "user", - Images: []api.ImageData{ - func() []byte { - img, _ := base64.StdEncoding.DecodeString(image) - return img - }(), - }, - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "chat handler with tools", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "What's the weather like in Paris Today?"}, - {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "What's the weather like in Paris Today?", - }, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: map[string]any{ - "location": "Paris, France", - "format": "celsius", - }, - }, - }, - }, - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "chat handler with tools and content", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "What's the weather like in Paris Today?"}, - {"role": "assistant", "content": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "What's the weather like in Paris Today?", - }, - { - Role: "assistant", - Content: "Let's see what the weather is like in Paris", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: map[string]any{ - "location": "Paris, France", - "format": "celsius", - }, - }, - }, - }, - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "chat handler with tools and empty content", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "What's the weather like in Paris Today?"}, - {"role": "assistant", "content": "", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "What's the weather like in Paris Today?", - }, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: map[string]any{ - "location": "Paris, France", - "format": "celsius", - }, - }, - }, - }, - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "chat handler with tools and thinking content", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "What's the weather like in Paris Today?"}, - {"role": "assistant", "reasoning": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "What's the weather like in Paris Today?", - }, - { - Role: "assistant", - Thinking: "Let's see what the weather is like in Paris", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: map[string]any{ - "location": "Paris, France", - "format": "celsius", - }, - }, - }, - }, - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "tool response with call ID", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "What's the weather like in Paris Today?"}, - {"role": "assistant", "tool_calls": [{"id": "id_abc", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}, - {"role": "tool", "tool_call_id": "id_abc", "content": "The weather in Paris is 20 degrees Celsius"} - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "What's the weather like in Paris Today?", - }, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: map[string]any{ - "location": "Paris, France", - "format": "celsius", - }, - }, - }, - }, - }, - { - Role: "tool", - Content: "The weather in Paris is 20 degrees Celsius", - ToolName: "get_current_weather", - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "tool response with name", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "What's the weather like in Paris Today?"}, - {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}, - {"role": "tool", "name": "get_current_weather", "content": "The weather in Paris is 20 degrees Celsius"} - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "What's the weather like in Paris Today?", - }, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: map[string]any{ - "location": "Paris, France", - "format": "celsius", - }, - }, - }, - }, - }, - { - Role: "tool", - Content: "The weather in Paris is 20 degrees Celsius", - ToolName: "get_current_weather", - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "chat handler with streaming tools", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "What's the weather like in Paris?"} - ], - "stream": true, - "tools": [{ - "type": "function", - "function": { - "name": "get_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "required": ["location"], - "properties": { - "location": { - "type": "string", - "description": "The city and state" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } - } - } - } - }] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "What's the weather like in Paris?", - }, - }, - Tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "get_weather", - Description: "Get the current weather", - Parameters: struct { - Type string `json:"type"` - Defs any `json:"$defs,omitempty"` - Items any `json:"items,omitempty"` - Required []string `json:"required"` - Properties map[string]api.ToolProperty `json:"properties"` - }{ - Type: "object", - Required: []string{"location"}, - Properties: map[string]api.ToolProperty{ - "location": { - Type: api.PropertyType{"string"}, - Description: "The city and state", - }, - "unit": { - Type: api.PropertyType{"string"}, - Enum: []any{"celsius", "fahrenheit"}, - }, - }, - }, - }, - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &True, - }, - }, - { - name: "chat handler error forwarding", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": 2} - ] - }`, - err: ErrorResponse{ - Error: Error{ - Message: "invalid message content type: float64", - Type: "invalid_request_error", - }, }, }, } - endpoint := func(c *gin.Context) { - c.Status(http.StatusOK) + result, err := FromChatRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) } - gin.SetMode(gin.TestMode) - router := gin.New() - router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest)) - router.Handle(http.MethodPost, "/api/chat", endpoint) + if len(result.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(result.Messages)) + } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body)) - req.Header.Set("Content-Type", "application/json") + if result.Messages[0].Content != "Hello" { + t.Errorf("expected first message content 'Hello', got %q", result.Messages[0].Content) + } - defer func() { capturedRequest = nil }() + if len(result.Messages[1].Images) != 1 { + t.Fatalf("expected 1 image, got %d", len(result.Messages[1].Images)) + } - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - - var errResp ErrorResponse - if resp.Code != http.StatusOK { - if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { - t.Fatal(err) - } - return - } - if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" { - t.Fatalf("requests did not match: %+v", diff) - } - if diff := cmp.Diff(tc.err, errResp); diff != "" { - t.Fatalf("errors did not match for %s:\n%s", tc.name, diff) - } - }) + if string(result.Messages[1].Images[0]) != string(imgData) { + t.Error("image data mismatch") } } -func TestCompletionsMiddleware(t *testing.T) { - type testCase struct { - name string - body string - req api.GenerateRequest - err ErrorResponse +func TestFromCompleteRequest_Basic(t *testing.T) { + temp := float32(0.8) + req := CompletionRequest{ + Model: "test-model", + Prompt: "Hello", + Temperature: &temp, } - var capturedRequest *api.GenerateRequest - - testCases := []testCase{ - { - name: "completions handler", - body: `{ - "model": "test-model", - "prompt": "Hello", - "temperature": 0.8, - "stop": ["\n", "stop"], - "suffix": "suffix" - }`, - req: api.GenerateRequest{ - Model: "test-model", - Prompt: "Hello", - Options: map[string]any{ - "frequency_penalty": 0.0, - "presence_penalty": 0.0, - "temperature": 0.8, - "top_p": 1.0, - "stop": []any{"\n", "stop"}, - }, - Suffix: "suffix", - Stream: &False, - }, - }, - { - name: "completions handler stream", - body: `{ - "model": "test-model", - "prompt": "Hello", - "stream": true, - "temperature": 0.8, - "stop": ["\n", "stop"], - "suffix": "suffix" - }`, - req: api.GenerateRequest{ - Model: "test-model", - Prompt: "Hello", - Options: map[string]any{ - "frequency_penalty": 0.0, - "presence_penalty": 0.0, - "temperature": 0.8, - "top_p": 1.0, - "stop": []any{"\n", "stop"}, - }, - Suffix: "suffix", - Stream: &True, - }, - }, - { - name: "completions handler stream with usage", - body: `{ - "model": "test-model", - "prompt": "Hello", - "stream": true, - "stream_options": {"include_usage": true}, - "temperature": 0.8, - "stop": ["\n", "stop"], - "suffix": "suffix" - }`, - req: api.GenerateRequest{ - Model: "test-model", - Prompt: "Hello", - Options: map[string]any{ - "frequency_penalty": 0.0, - "presence_penalty": 0.0, - "temperature": 0.8, - "top_p": 1.0, - "stop": []any{"\n", "stop"}, - }, - Suffix: "suffix", - Stream: &True, - }, - }, - { - name: "completions handler error forwarding", - body: `{ - "model": "test-model", - "prompt": "Hello", - "temperature": null, - "stop": [1, 2], - "suffix": "suffix" - }`, - err: ErrorResponse{ - Error: Error{ - Message: "invalid type for 'stop' field: float64", - Type: "invalid_request_error", - }, - }, - }, + result, err := FromCompleteRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) } - endpoint := func(c *gin.Context) { - c.Status(http.StatusOK) + if result.Model != "test-model" { + t.Errorf("expected model 'test-model', got %q", result.Model) } - gin.SetMode(gin.TestMode) - router := gin.New() - router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest)) - router.Handle(http.MethodPost, "/api/generate", endpoint) + if result.Prompt != "Hello" { + t.Errorf("expected prompt 'Hello', got %q", result.Prompt) + } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body)) - req.Header.Set("Content-Type", "application/json") - - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - - var errResp ErrorResponse - if resp.Code != http.StatusOK { - if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { - t.Fatal(err) - } - } - - if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { - t.Fatal("requests did not match") - } - - if !reflect.DeepEqual(tc.err, errResp) { - t.Fatal("errors did not match") - } - - capturedRequest = nil - }) + if tempVal, ok := result.Options["temperature"].(float32); !ok || tempVal != 0.8 { + t.Errorf("expected temperature 0.8, got %v", result.Options["temperature"]) } } -func TestEmbeddingsMiddleware(t *testing.T) { - type testCase struct { - name string - body string - req api.EmbedRequest - err ErrorResponse - } - - var capturedRequest *api.EmbedRequest - - testCases := []testCase{ - { - name: "embed handler single input", - body: `{ - "input": "Hello", - "model": "test-model" - }`, - req: api.EmbedRequest{ - Input: "Hello", - Model: "test-model", - }, - }, - { - name: "embed handler batch input", - body: `{ - "input": ["Hello", "World"], - "model": "test-model" - }`, - req: api.EmbedRequest{ - Input: []any{"Hello", "World"}, - Model: "test-model", - }, - }, - { - name: "embed handler error forwarding", - body: `{ - "model": "test-model" - }`, - err: ErrorResponse{ - Error: Error{ - Message: "invalid input", - Type: "invalid_request_error", - }, - }, +func TestToUsage(t *testing.T) { + resp := api.ChatResponse{ + Metrics: api.Metrics{ + PromptEvalCount: 10, + EvalCount: 20, }, } - endpoint := func(c *gin.Context) { - c.Status(http.StatusOK) + usage := ToUsage(resp) + + if usage.PromptTokens != 10 { + t.Errorf("expected PromptTokens 10, got %d", usage.PromptTokens) } - gin.SetMode(gin.TestMode) - router := gin.New() - router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest)) - router.Handle(http.MethodPost, "/api/embed", endpoint) + if usage.CompletionTokens != 20 { + t.Errorf("expected CompletionTokens 20, got %d", usage.CompletionTokens) + } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body)) - req.Header.Set("Content-Type", "application/json") - - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - - var errResp ErrorResponse - if resp.Code != http.StatusOK { - if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { - t.Fatal(err) - } - } - - if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { - t.Fatal("requests did not match") - } - - if !reflect.DeepEqual(tc.err, errResp) { - t.Fatal("errors did not match") - } - - capturedRequest = nil - }) + if usage.TotalTokens != 30 { + t.Errorf("expected TotalTokens 30, got %d", usage.TotalTokens) } } -func TestListMiddleware(t *testing.T) { - type testCase struct { - name string - endpoint func(c *gin.Context) - resp string +func TestNewError(t *testing.T) { + tests := []struct { + code int + want string + }{ + {400, "invalid_request_error"}, + {404, "not_found_error"}, + {500, "api_error"}, } - testCases := []testCase{ - { - name: "list handler", - endpoint: func(c *gin.Context) { - c.JSON(http.StatusOK, api.ListResponse{ - Models: []api.ListModelResponse{ - { - Name: "test-model", - ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), - }, - }, - }) - }, - resp: `{ - "object": "list", - "data": [ - { - "id": "test-model", - "object": "model", - "created": 1686935002, - "owned_by": "library" - } - ] - }`, - }, - { - name: "list handler empty output", - endpoint: func(c *gin.Context) { - c.JSON(http.StatusOK, api.ListResponse{}) - }, - resp: `{ - "object": "list", - "data": null - }`, - }, - } - - gin.SetMode(gin.TestMode) - - for _, tc := range testCases { - router := gin.New() - router.Use(ListMiddleware()) - router.Handle(http.MethodGet, "/api/tags", tc.endpoint) - req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil) - - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - - var expected, actual map[string]any - err := json.Unmarshal([]byte(tc.resp), &expected) - if err != nil { - t.Fatalf("failed to unmarshal expected response: %v", err) + for _, tt := range tests { + result := NewError(tt.code, "test message") + if result.Error.Type != tt.want { + t.Errorf("NewError(%d) type = %q, want %q", tt.code, result.Error.Type, tt.want) } - - err = json.Unmarshal(resp.Body.Bytes(), &actual) - if err != nil { - t.Fatalf("failed to unmarshal actual response: %v", err) - } - - if !reflect.DeepEqual(expected, actual) { - t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) - } - } -} - -func TestRetrieveMiddleware(t *testing.T) { - type testCase struct { - name string - endpoint func(c *gin.Context) - resp string - } - - testCases := []testCase{ - { - name: "retrieve handler", - endpoint: func(c *gin.Context) { - c.JSON(http.StatusOK, api.ShowResponse{ - ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), - }) - }, - resp: `{ - "id":"test-model", - "object":"model", - "created":1686935002, - "owned_by":"library"} - `, - }, - { - name: "retrieve handler error forwarding", - endpoint: func(c *gin.Context) { - c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"}) - }, - resp: `{ - "error": { - "code": null, - "message": "model not found", - "param": null, - "type": "api_error" - } - }`, - }, - } - - gin.SetMode(gin.TestMode) - - for _, tc := range testCases { - router := gin.New() - router.Use(RetrieveMiddleware()) - router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint) - req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil) - - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - - var expected, actual map[string]any - err := json.Unmarshal([]byte(tc.resp), &expected) - if err != nil { - t.Fatalf("failed to unmarshal expected response: %v", err) - } - - err = json.Unmarshal(resp.Body.Bytes(), &actual) - if err != nil { - t.Fatalf("failed to unmarshal actual response: %v", err) - } - - if !reflect.DeepEqual(expected, actual) { - t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) + if result.Error.Message != "test message" { + t.Errorf("NewError(%d) message = %q, want %q", tt.code, result.Error.Message, "test message") } } } diff --git a/runner/llamarunner/image.go b/runner/llamarunner/image.go index cc0153ae..9fc97081 100644 --- a/runner/llamarunner/image.go +++ b/runner/llamarunner/image.go @@ -56,7 +56,7 @@ func (c *ImageContext) Free(modelPath string) { } } -func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte) ([][]float32, error) { +func (c *ImageContext) MultimodalTokenize(llamaContext *llama.Context, data []byte) ([]llama.MtmdChunk, error) { if c == nil { return nil, nil } @@ -70,10 +70,10 @@ func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte) ([][]f c.mu.Lock() defer c.mu.Unlock() - embed, err := c.findImage(hash) + chunks, err := c.findImage(hash) if err != nil { if c.mtmd != nil { - embed, err = c.mtmd.NewEmbed(llamaContext, data) + chunks, err = c.mtmd.MultimodalTokenize(llamaContext, data) if err != nil { return nil, err } @@ -81,10 +81,10 @@ func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte) ([][]f return nil, errors.New("received image but vision model not loaded") } - c.addImage(hash, embed) + c.addImage(hash, chunks) } - return embed, nil + return chunks, nil } func (c *ImageContext) BatchSize(configuredBatchSize int) int { @@ -102,7 +102,7 @@ func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int { type imageCache struct { key uint64 - val [][]float32 + val []llama.MtmdChunk lastUsed time.Time } @@ -114,7 +114,7 @@ func (c *ImageContext) hashImage(image []byte) uint64 { var errImageNotFound = errors.New("image not found in cache") -func (c *ImageContext) findImage(hash uint64) ([][]float32, error) { +func (c *ImageContext) findImage(hash uint64) ([]llama.MtmdChunk, error) { for i := range c.images { if c.images[i].key == hash { slog.Debug("loading image embeddings from cache", "entry", i) @@ -126,7 +126,7 @@ func (c *ImageContext) findImage(hash uint64) ([][]float32, error) { return nil, errImageNotFound } -func (c *ImageContext) addImage(hash uint64, embed [][]float32) { +func (c *ImageContext) addImage(hash uint64, embed []llama.MtmdChunk) { best := time.Now() var bestImage int diff --git a/runner/llamarunner/image_test.go b/runner/llamarunner/image_test.go index 2e1efaec..f7d98a47 100644 --- a/runner/llamarunner/image_test.go +++ b/runner/llamarunner/image_test.go @@ -3,16 +3,18 @@ package llamarunner import ( "reflect" "testing" + + "github.com/ollama/ollama/llama" ) func TestImageCache(t *testing.T) { cache := ImageContext{images: make([]imageCache, 4)} - valA := [][]float32{{0.1, 0.2}, {0.3}} - valB := [][]float32{{0.4}, {0.5}, {0.6}} - valC := [][]float32{{0.7}} - valD := [][]float32{{0.8}} - valE := [][]float32{{0.9}} + valA := []llama.MtmdChunk{{Embed: []float32{0.1, 0.2}}, {Embed: []float32{0.3}}} + valB := []llama.MtmdChunk{{Embed: []float32{0.4}}, {Embed: []float32{0.5}}, {Embed: []float32{0.6}}} + valC := []llama.MtmdChunk{{Embed: []float32{0.7}}} + valD := []llama.MtmdChunk{{Embed: []float32{0.8}}} + valE := []llama.MtmdChunk{{Embed: []float32{0.9}}} // Empty cache result, err := cache.findImage(0x5adb61d31933a946) diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 791492bb..163aaa62 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -79,13 +79,16 @@ type Sequence struct { // true if an embedding are to be returned instead of text generation embeddingOnly bool + // shift if context window is exceeded + shift bool + doneReason llm.DoneReason // Metrics - startProcessingTime time.Time - startGenerationTime time.Time - numDecoded int - numPromptInputs int + processingDuration time.Duration + generationDuration time.Duration + numDecoded int + numPromptInputs int } type NewSequenceParams struct { @@ -94,13 +97,15 @@ type NewSequenceParams struct { numKeep int samplingParams *llama.SamplingParams embedding bool + shift bool + truncate bool } +var errorInputTooLong = errors.New("the input length exceeds the context length") + func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() - startTime := time.Now() - inputs, err := s.inputs(prompt, images) if err != nil { return nil, fmt.Errorf("failed to process inputs: %w", err) @@ -121,6 +126,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe if len(inputs) > s.cache.numCtx { discard := len(inputs) - s.cache.numCtx + if !params.truncate { + return nil, errorInputTooLong + } + newInputs := inputs[:params.numKeep] newInputs = append(newInputs, inputs[params.numKeep+discard:]...) @@ -142,18 +151,18 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe } return &Sequence{ - inputs: inputs, - numPromptInputs: len(inputs), - startProcessingTime: startTime, - numPredict: params.numPredict, - pendingResponses: make([]string, 0), - responses: make(chan string, 100), - quit: make(chan bool, 1), - embedding: make(chan []float32, 1), - samplingCtx: sc, - embeddingOnly: params.embedding, - stop: params.stop, - numKeep: params.numKeep, + inputs: inputs, + numPromptInputs: len(inputs), + numPredict: params.numPredict, + pendingResponses: make([]string, 0), + responses: make(chan string, 100), + quit: make(chan bool, 1), + embedding: make(chan []float32, 1), + samplingCtx: sc, + embeddingOnly: params.embedding, + stop: params.stop, + numKeep: params.numKeep, + shift: params.shift, }, nil } @@ -200,13 +209,19 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) return nil, fmt.Errorf("invalid image index: %d", n) } - embed, err := s.image.NewEmbed(s.lc, images[imageIndex].Data) + chunks, err := s.image.MultimodalTokenize(s.lc, images[imageIndex].Data) if err != nil { return nil, err } - for _, e := range embed { - inputs = append(inputs, input{embed: e}) + for _, c := range chunks { + if len(c.Embed) != 0 { + inputs = append(inputs, input{embed: c.Embed}) + } else { + for _, t := range c.Tokens { + inputs = append(inputs, input{token: t}) + } + } } } } @@ -388,6 +403,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) for i, input := range seq.inputs { if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx { if len(seq.pendingInputs) == 0 { + if !seq.shift { + s.removeSequence(seqIdx, llm.DoneReasonLength) + break + } + err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) if err != nil { var reprocess *ErrReprocessInputs @@ -438,8 +458,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) return nil } - err := s.lc.Decode(batch) - if err != nil { + t := time.Now() + if err := s.lc.Decode(batch); err != nil { return fmt.Errorf("failed to decode batch: %w", err) } @@ -459,9 +479,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } - seq.numDecoded += 1 - if seq.numDecoded == 1 { - seq.startGenerationTime = time.Now() + s.lc.Synchronize() + seq.numDecoded++ + if seq.numDecoded > 1 { + seq.generationDuration += time.Since(t) + } else { + seq.processingDuration += time.Since(t) } // if done processing the prompt, generate an embedding and return @@ -583,8 +606,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { numKeep: req.Options.NumKeep, samplingParams: &samplingParams, embedding: false, + shift: req.Shift, + truncate: req.Truncate, }) if err != nil { + if errors.Is(err, errorInputTooLong) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) return } @@ -646,9 +675,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { Done: true, DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, - PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), + PromptEvalDuration: seq.processingDuration, EvalCount: seq.numDecoded, - EvalDuration: time.Since(seq.startGenerationTime), + EvalDuration: seq.generationDuration, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) } @@ -812,7 +841,7 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) { numGPU := 0 for i := range gpuIDs { for _, layers := range req.GPULayers { - if gpuIDs[i] == layers.ID { + if gpuIDs[i] == layers.DeviceID { tensorSplit[i] = float32(len(layers.Layers)) numGPU += len(layers.Layers) } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 480cfc19..af212ece 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -28,6 +28,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" @@ -87,13 +88,17 @@ type Sequence struct { // true if an embedding are to be returned instead of text generation embeddingOnly bool + // shift if context window is exceeded + shift bool + doneReason llm.DoneReason // Metrics - startProcessingTime time.Time - startGenerationTime time.Time - numPredicted int - numPromptInputs int + startedAt, lastUpdatedAt time.Time + processingDuration time.Duration + samplingDuration time.Duration + numPredicted int + numPromptInputs int } type NewSequenceParams struct { @@ -102,13 +107,15 @@ type NewSequenceParams struct { numKeep int32 sampler sample.Sampler embedding bool + shift bool + truncate bool } +var errorInputTooLong = errors.New("the input length exceeds the context length") + func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() - startTime := time.Now() - inputs, ctxs, mmStore, err := s.inputs(prompt, images) if err != nil { return nil, fmt.Errorf("failed to process inputs: %w", err) @@ -125,6 +132,11 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe if int32(len(inputs)) > s.cache.numCtx { discard := int32(len(inputs)) - s.cache.numCtx + + if !params.truncate { + return nil, errorInputTooLong + } + promptStart := params.numKeep + discard // If we need to truncate in the middle of a unbreakable batch, remove the entire batch @@ -163,20 +175,20 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe // TODO(jessegross): Ingest cached history for grammar return &Sequence{ - ctxs: ctxs, - mmStore: mmStore, - inputs: inputs, - numPromptInputs: len(inputs), - startProcessingTime: startTime, - numPredict: params.numPredict, - pendingResponses: make([]string, 0), - responses: make(chan string, 100), - quit: make(chan bool, 1), - embedding: make(chan []float32, 1), - sampler: params.sampler, - embeddingOnly: params.embedding, - stop: params.stop, - numKeep: params.numKeep, + ctxs: ctxs, + mmStore: mmStore, + inputs: inputs, + numPromptInputs: len(inputs), + numPredict: params.numPredict, + pendingResponses: make([]string, 0), + responses: make(chan string, 100), + quit: make(chan bool, 1), + embedding: make(chan []float32, 1), + sampler: params.sampler, + embeddingOnly: params.embedding, + stop: params.stop, + numKeep: params.numKeep, + shift: params.shift, }, nil } @@ -322,9 +334,6 @@ type Server struct { // TODO (jmorganca): make this n_batch batchSize int - // Used to signal a hard failure during async processing which will panic the runner - hardErrCh chan error - // Simple counter used only for trace logging batches batchID int @@ -407,25 +416,25 @@ func (s *Server) run(ctx context.Context) { supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone - var activeBatch batchState + var previousBatch batchState for { select { case <-ctx.Done(): return - case err := <-s.hardErrCh: - panic(err) default: var err error - activeBatch, err = s.forwardBatch(activeBatch) + nextBatch, err := s.forwardBatch(previousBatch) if err != nil { panic(err) } if supportsAsync { - go s.computeBatch(activeBatch) + go s.computeBatch(nextBatch) } else { - s.computeBatch(activeBatch) + s.computeBatch(nextBatch) } + + previousBatch = nextBatch } } } @@ -521,6 +530,12 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er break } + if !seq.shift { + s.removeSequence(seqIdx, llm.DoneReasonLength) + nextBatch.seqs[seqIdx] = nil + break + } + err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) if err != nil { var reprocess *ErrReprocessInputs @@ -561,6 +576,13 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er seq.inputs = seq.inputs[len(seq.pendingInputs):] } + startedAt := time.Now() + for i := range nextBatch.seqs { + if nextBatch.seqs[i] != nil && nextBatch.seqs[i].startedAt.IsZero() { + nextBatch.seqs[i].startedAt = startedAt + } + } + if resumeSeq != -1 { s.nextSeq = resumeSeq } else { @@ -655,9 +677,7 @@ func (s *Server) computeBatch(activeBatch batchState) { // don't sample prompt processing if len(seq.inputs) != 0 { if !s.cache.enabled { - s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch") - s.mu.Unlock() - return + panic("caching disabled but unable to fit entire input in a batch") } continue } @@ -681,6 +701,7 @@ func (s *Server) computeBatch(activeBatch batchState) { activeBatch.modelOutput) outputs := activeBatch.modelOutput.Floats() + t := time.Now() logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id) @@ -693,8 +714,10 @@ func (s *Server) computeBatch(activeBatch batchState) { continue } + seq.lastUpdatedAt = t if seq.numPredicted == 1 { - seq.startGenerationTime = time.Now() + seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt) + seq.startedAt = seq.lastUpdatedAt } // if done processing the prompt, generate an embedding and return @@ -709,8 +732,7 @@ func (s *Server) computeBatch(activeBatch batchState) { logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches) token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) if err != nil { - s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) - return + panic("failed to sample token") } nextBatchTokens[i].Token = token @@ -727,8 +749,7 @@ func (s *Server) computeBatch(activeBatch batchState) { piece, err := s.model.(model.TextProcessor).Decode([]int32{token}) if err != nil { - s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err) - return + panic("failed to decode token") } seq.pendingResponses = append(seq.pendingResponses, piece) @@ -773,6 +794,13 @@ func (s *Server) computeBatch(activeBatch batchState) { s.removeSequence(i, llm.DoneReasonConnectionClosed) } } + + samplingDuration := time.Since(t) + for i, seq := range s.seqs { + if seq != nil && nextBatchTokens[i] != nil { + s.seqs[i].samplingDuration += samplingDuration + } + } } func (s *Server) completion(w http.ResponseWriter, r *http.Request) { @@ -823,8 +851,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { numKeep: int32(req.Options.NumKeep), sampler: sampler, embedding: false, + shift: req.Shift, + truncate: req.Truncate, }) if err != nil { + if errors.Is(err, errorInputTooLong) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) return } @@ -886,9 +920,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { Done: true, DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, - PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), + PromptEvalDuration: seq.processingDuration, EvalCount: seq.numPredicted, - EvalDuration: time.Since(seq.startGenerationTime), + EvalDuration: seq.lastUpdatedAt.Sub(seq.startedAt) - seq.samplingDuration, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) } @@ -1235,6 +1269,52 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) { } } +// info is the handler called by the Ollama server to report information +// about the GPU devices in use by this runner +func (s *Server) info(w http.ResponseWriter, r *http.Request) { + s.loadMu.Lock() + defer s.loadMu.Unlock() + + w.Header().Set("Content-Type", "application/json") + + m := s.model + + if m == nil { + startLoad := time.Now() + + // Dummy load to get the backend wired up + f, err := os.CreateTemp("", "*.bin") + if err != nil { + http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError) + return + } + defer f.Close() + defer os.Remove(f.Name()) + + if err := ggml.WriteGGUF(f, ggml.KV{ + "general.architecture": "llama", + "tokenizer.ggml.model": "gpt2", + }, nil); err != nil { + http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError) + return + } + + m, err = model.New(f.Name(), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}}) + if err != nil { + http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError) + return + } + slog.Debug("dummy model load took", "duration", time.Since(startLoad)) + } + + startDevices := time.Now() + infos := m.Backend().BackendDevices() + slog.Debug("gathering device infos took", "duration", time.Since(startDevices)) + if err := json.NewEncoder(w).Encode(&infos); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } +} + func Execute(args []string) error { fs := flag.NewFlagSet("runner", flag.ExitOnError) mpath := fs.String("model", "", "Path to model binary file") @@ -1257,7 +1337,6 @@ func Execute(args []string) error { server := &Server{ modelPath: *mpath, status: llm.ServerStatusLaunched, - hardErrCh: make(chan error, 1), } server.cond = sync.NewCond(&server.mu) @@ -1275,6 +1354,7 @@ func Execute(args []string) error { mux := http.NewServeMux() // TODO: support embeddings + mux.HandleFunc("GET /info", server.info) mux.HandleFunc("POST /load", server.load) mux.HandleFunc("POST /embedding", server.embeddings) mux.HandleFunc("POST /completion", server.completion) diff --git a/scripts/build_linux.sh b/scripts/build_linux.sh index 618722d1..8287c11c 100755 --- a/scripts/build_linux.sh +++ b/scripts/build_linux.sh @@ -13,12 +13,13 @@ set -eu . $(dirname $0)/env.sh mkdir -p dist +NOVULKAN=${NOVULKAN:-""} docker buildx build \ --output type=local,dest=./dist/ \ --platform=${PLATFORM} \ ${OLLAMA_COMMON_BUILD_ARGS} \ - --target archive \ + --target archive${NOVULKAN} \ -f Dockerfile \ . diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 37fe8796..f1cd3fea 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -165,7 +165,7 @@ function buildROCm() { $env:HIPCXX="${env:HIP_PATH}\bin\clang++.exe" $env:HIP_PLATFORM="amd" $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" - & cmake --fresh --preset "ROCm 6" -G Ninja ` + & cmake --fresh --preset "ROCm 6" -G Ninja -DOLLAMA_RUNNER_DIR="rocm" ` -DCMAKE_C_COMPILER=clang ` -DCMAKE_CXX_COMPILER=clang++ ` -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" ` @@ -179,10 +179,23 @@ function buildROCm() { if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --install build --component "HIP" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue } } } +function buildVulkan(){ + if ($env:VULKAN_SDK) { + write-host "Building Vulkan backend libraries" + & cmake --fresh --preset Vulkan --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="vulkan" + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build --preset Vulkan --config Release --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build --component Vulkan --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } +} + function buildOllama() { mkdir -Force -path "${script:DIST_DIR}\" write-host "Building ollama CLI" @@ -279,7 +292,7 @@ function distZip() { write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip" Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64-rocm") { - Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" + Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama\rocm" } } @@ -296,6 +309,7 @@ try { buildCUDA12 buildCUDA13 buildROCm + buildVulkan buildOllama buildApp gatherDependencies @@ -314,4 +328,4 @@ try { } finally { set-location $script:SRC_DIR $env:PKG_VERSION="" -} +} \ No newline at end of file diff --git a/server/images.go b/server/images.go index 9466b7fb..d3bd9ffa 100644 --- a/server/images.go +++ b/server/images.go @@ -105,12 +105,16 @@ func (m *Model) Capabilities() []model.Capability { builtinParser := parsers.ParserForName(m.Config.Parser) // Check for tools capability - if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) { + v, err := m.Template.Vars() + if err != nil { + slog.Warn("model template contains errors", "error", err) + } + if slices.Contains(v, "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) { capabilities = append(capabilities, model.CapabilityTools) } // Check for insert capability - if slices.Contains(m.Template.Vars(), "suffix") { + if slices.Contains(v, "suffix") { capabilities = append(capabilities, model.CapabilityInsert) } diff --git a/server/prompt.go b/server/prompt.go index 56bc6303..21759198 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -20,7 +20,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error) // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // latest message and 2) system messages -func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (prompt string, images []llm.ImageData, _ error) { +func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, _ error) { var system []api.Message // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent @@ -59,7 +59,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } } - if ctxLen > opts.NumCtx { + if truncate && ctxLen > opts.NumCtx { slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:])) break } else { diff --git a/server/prompt_test.go b/server/prompt_test.go index 659e6408..3bd62115 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -27,16 +27,18 @@ func TestChatPrompt(t *testing.T) { visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} cases := []struct { - name string - model Model - limit int - msgs []api.Message + name string + model Model + limit int + truncate bool + msgs []api.Message expect }{ { - name: "messages", - model: visionModel, - limit: 64, + name: "messages", + model: visionModel, + limit: 64, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, {Role: "assistant", Content: "I-I'm a what?"}, @@ -47,9 +49,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate messages", - model: visionModel, - limit: 1, + name: "truncate messages", + model: visionModel, + limit: 1, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, {Role: "assistant", Content: "I-I'm a what?"}, @@ -60,9 +63,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate messages with image", - model: visionModel, - limit: 64, + name: "truncate messages with image", + model: visionModel, + limit: 64, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, {Role: "assistant", Content: "I-I'm a what?"}, @@ -76,9 +80,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate messages with images", - model: visionModel, - limit: 64, + name: "truncate messages with images", + model: visionModel, + limit: 64, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}}, {Role: "assistant", Content: "I-I'm a what?"}, @@ -92,9 +97,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "messages with images", - model: visionModel, - limit: 2048, + name: "messages with images", + model: visionModel, + limit: 2048, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}}, {Role: "assistant", Content: "I-I'm a what?"}, @@ -109,9 +115,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "message with image tag", - model: visionModel, - limit: 2048, + name: "message with image tag", + model: visionModel, + limit: 2048, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}}, {Role: "assistant", Content: "I-I'm a what?"}, @@ -126,9 +133,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "messages with interleaved images", - model: visionModel, - limit: 2048, + name: "messages with interleaved images", + model: visionModel, + limit: 2048, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, {Role: "user", Images: []api.ImageData{[]byte("something")}}, @@ -145,9 +153,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate message with interleaved images", - model: visionModel, - limit: 1024, + name: "truncate message with interleaved images", + model: visionModel, + limit: 1024, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, {Role: "user", Images: []api.ImageData{[]byte("something")}}, @@ -163,9 +172,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "message with system prompt", - model: visionModel, - limit: 2048, + name: "message with system prompt", + model: visionModel, + limit: 2048, + truncate: true, msgs: []api.Message{ {Role: "system", Content: "You are the Test Who Lived."}, {Role: "user", Content: "You're a test, Harry!"}, @@ -177,9 +187,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "out of order system", - model: visionModel, - limit: 2048, + name: "out of order system", + model: visionModel, + limit: 2048, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, {Role: "assistant", Content: "I-I'm a what?"}, @@ -191,9 +202,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "multiple images same prompt", - model: visionModel, - limit: 2048, + name: "multiple images same prompt", + model: visionModel, + limit: 2048, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "Compare these two pictures of hotdogs", Images: []api.ImageData{[]byte("one hotdog"), []byte("two hotdogs")}}, }, @@ -202,6 +214,20 @@ func TestChatPrompt(t *testing.T) { images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")}, }, }, + { + name: "no truncate with limit exceeded", + model: visionModel, + limit: 10, + truncate: false, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!"}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."}, + }, + expect: expect{ + prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ", + }, + }, } for _, tt := range cases { @@ -209,7 +235,7 @@ func TestChatPrompt(t *testing.T) { model := tt.model opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} think := false - prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}) + prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate) if tt.error == nil && err != nil { t.Fatal(err) } else if tt.error != nil && err != tt.error { diff --git a/server/routes.go b/server/routes.go index 21a1b2b3..80c00cb6 100644 --- a/server/routes.go +++ b/server/routes.go @@ -37,8 +37,9 @@ import ( "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/middleware" "github.com/ollama/ollama/model/parsers" - "github.com/ollama/ollama/openai" + "github.com/ollama/ollama/model/renderers" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" @@ -91,6 +92,9 @@ func init() { } gin.SetMode(mode) + + // Tell renderers to use [img] tags + renderers.RenderImgTags = true } var ( @@ -330,12 +334,18 @@ func (s *Server) GenerateHandler(c *gin.Context) { if req.Suffix != "" { caps = append(caps, model.CapabilityInsert) } - if req.Think != nil && req.Think.Bool() { + + modelCaps := m.Capabilities() + if slices.Contains(modelCaps, model.CapabilityThinking) { caps = append(caps, model.CapabilityThinking) - // TODO(drifkin): consider adding a warning if it's false and the model - // doesn't support thinking. It's not strictly required, but it can be a - // hint that the user is on an older qwen3/r1 model that doesn't have an - // updated template supporting thinking + if req.Think == nil { + req.Think = &api.ThinkValue{Value: true} + } + } else { + if req.Think != nil && req.Think.Bool() { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)}) + return + } } r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) @@ -397,12 +407,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { msgs = append(msgs, m.Messages...) } + userMsg := api.Message{Role: "user", Content: req.Prompt} for _, i := range images { - imgPrompt := "" - msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)}) + userMsg.Images = append(userMsg.Images, i.Data) } - - values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt}) + values.Messages = append(msgs, userMsg) } values.Think = req.Think != nil && req.Think.Bool() @@ -423,12 +432,31 @@ func (s *Server) GenerateHandler(c *gin.Context) { b.WriteString(s) } - if err := tmpl.Execute(&b, values); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + // check that we're in the `api/chat`-like flow, and if so, generate the + // prompt the same way + // TEMP(drifkin): we should really just detect the chat-like flow and call + // the real chat handler, but doing this as a stopgap to get renderer + // support for generate + if values.Messages != nil && values.Suffix == "" && req.Template == "" { + prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + // TEMP(drifkin): req.Context will be removed very soon, but we're temporarily supporting it in this flow here + if req.Context != nil { + b.WriteString(prompt) + prompt = b.String() + } + } else { + // legacy flow + if err := tmpl.Execute(&b, values); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } - prompt = b.String() + prompt = b.String() + } } // If debug mode is enabled, return the rendered template instead of calling the model @@ -464,10 +492,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { var sb strings.Builder defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, + Shift: req.Shift == nil || *req.Shift, + Truncate: req.Truncate == nil || *req.Truncate, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, @@ -529,7 +559,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { ch <- res }); err != nil { - ch <- gin.H{"error": err.Error()} + var serr api.StatusError + if errors.As(err, &serr) { + ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode} + } else { + ch <- gin.H{"error": err.Error()} + } } }() @@ -549,7 +584,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { msg = "unexpected error format in response" } - c.JSON(http.StatusInternalServerError, gin.H{"error": msg}) + status, ok := t["status"].(int) + if !ok { + status = http.StatusInternalServerError + } + + c.JSON(status, gin.H{"error": msg}) return default: c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"}) @@ -1449,11 +1489,11 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/embeddings", s.EmbeddingsHandler) // Inference (OpenAI compatibility) - r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) - r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) - r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler) - r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler) - r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler) + r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler) + r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler) + r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler) + r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler) + r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler) if rc != nil { // wrap old with new @@ -1557,8 +1597,8 @@ func Serve(ln net.Listener) error { // At startup we retrieve GPU information so we can get log messages before loading a model // This will log warnings to the log in case we have problems with detected GPUs - gpus := discover.GetGPUInfo() - gpus.LogDetails() + gpus := discover.GPUDevices(ctx, nil) + discover.LogDetails(gpus) var totalVRAM uint64 for _, gpu := range gpus { @@ -1614,6 +1654,30 @@ func streamResponse(c *gin.Context, ch chan any) { return false } + // errors are provided as a gin.H with an "error" field and + // an optional "status" field. For errors that are streamed + // before any content, we need to set the status code and + // content type for the error. + if h, ok := val.(gin.H); ok { + if e, ok := h["error"].(string); ok { + status, ok := h["status"].(int) + if !ok { + status = http.StatusInternalServerError + } + + if !c.Writer.Written() { + c.Header("Content-Type", "application/json") + c.JSON(status, gin.H{"error": e}) + } else { + if err := json.NewEncoder(c.Writer).Encode(gin.H{"error": e}); err != nil { + slog.Error("streamResponse failed to encode json error", "error", err) + } + } + + return false + } + } + bts, err := json.Marshal(val) if err != nil { slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err)) @@ -1777,7 +1841,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } // expire the runner - if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { + if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { s.sched.expireRunner(m) c.JSON(http.StatusOK, api.ChatResponse{ @@ -1871,8 +1935,18 @@ func (s *Server) ChatHandler(c *gin.Context) { if len(req.Tools) > 0 { caps = append(caps, model.CapabilityTools) } - if req.Think != nil && req.Think.Bool() { + + modelCaps := m.Capabilities() + if slices.Contains(modelCaps, model.CapabilityThinking) { caps = append(caps, model.CapabilityThinking) + if req.Think == nil { + req.Think = &api.ThinkValue{Value: true} + } + } else { + if req.Think != nil && req.Think.Bool() { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)}) + return + } } r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) @@ -1923,7 +1997,8 @@ func (s *Server) ChatHandler(c *gin.Context) { } } - prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think) + truncate := req.Truncate == nil || *req.Truncate + prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate) if err != nil { slog.Error("chat prompt error", "error", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -1967,88 +2042,174 @@ func (s *Server) ChatHandler(c *gin.Context) { toolParser = tools.NewParser(m.Template.Template, req.Tools) } + type structuredOutputsState int + const ( + structuredOutputsState_None structuredOutputsState = iota + structuredOutputsState_ReadyToApply + structuredOutputsState_Applying + ) + ch := make(chan any) go func() { defer close(ch) - if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, - }, func(r llm.CompletionResponse) { - res := api.ChatResponse{ - Model: req.Model, - CreatedAt: time.Now().UTC(), - Message: api.Message{Role: "assistant", Content: r.Content}, - Done: r.Done, - Metrics: api.Metrics{ - PromptEvalCount: r.PromptEvalCount, - PromptEvalDuration: r.PromptEvalDuration, - EvalCount: r.EvalCount, - EvalDuration: r.EvalDuration, - }, - } - if r.Done { - res.DoneReason = r.DoneReason.String() - res.TotalDuration = time.Since(checkpointStart) - res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + structuredOutputsState := structuredOutputsState_None + + for { + var tb strings.Builder + + currentFormat := req.Format + // structured outputs via double request is enabled when: + // 1. the model supports the thinking capability and + // 2. it uses a built-in parser or our generic thinking parser + + // Note that the current approach does not work for (potential future) + // non-thinking models that emit anything before actual content. This + // current approach uses the transition from parsed thinking content to + // parsed non-thinking content as the signal to turn constraining on + + if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) { + currentFormat = nil } - if builtinParser != nil { - slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content) - - content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done) - if err != nil { - ch <- gin.H{"error": err.Error()} - return + // sets up new context given parent context per request + ctx, cancel := context.WithCancel(c.Request.Context()) + err := r.Completion(ctx, llm.CompletionRequest{ + Prompt: prompt, + Images: images, + Format: currentFormat, + Options: opts, + Shift: req.Shift == nil || *req.Shift, + Truncate: truncate, + }, func(r llm.CompletionResponse) { + res := api.ChatResponse{ + Model: req.Model, + CreatedAt: time.Now().UTC(), + Message: api.Message{Role: "assistant", Content: r.Content}, + Done: r.Done, + Metrics: api.Metrics{ + PromptEvalCount: r.PromptEvalCount, + PromptEvalDuration: r.PromptEvalDuration, + EvalCount: r.EvalCount, + EvalDuration: r.EvalDuration, + }, + } + if r.Done { + res.DoneReason = r.DoneReason.String() + res.TotalDuration = time.Since(checkpointStart) + res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } - res.Message.Content = content - res.Message.Thinking = thinking - res.Message.ToolCalls = toolCalls + if builtinParser != nil { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content) - if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done { - slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done) - ch <- res - } else { - slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser) - } + content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done) + if err != nil { + ch <- gin.H{"error": err.Error()} + return + } - return - } - - if thinkingState != nil { - thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content) - if thinkingContent == "" && remainingContent == "" && !r.Done { - // need to accumulate more to decide what to send - return - } - res.Message.Content = remainingContent - res.Message.Thinking = thinkingContent - } - - if len(req.Tools) > 0 { - toolCalls, content := toolParser.Add(res.Message.Content) - if len(content) > 0 { res.Message.Content = content - } else if len(toolCalls) > 0 { + res.Message.Thinking = thinking res.Message.ToolCalls = toolCalls - res.Message.Content = "" - } else if res.Message.Thinking != "" { - // don't return - } else { - if r.Done { - res.Message.Content = toolParser.Content() + + tb.WriteString(thinking) + // we are now receiving content from the model - we should start applying structured outputs + if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && res.Message.Content != "" { + structuredOutputsState = structuredOutputsState_ReadyToApply + cancel() + return + } + + if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done) ch <- res + } else { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser) + } + return + } + + if thinkingState != nil { + thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content) + if thinkingContent == "" && remainingContent == "" && !r.Done { + // need to accumulate more to decide what to send + return + } + res.Message.Thinking = thinkingContent + tb.WriteString(thinkingContent) + // emit the collected thinking text before restarting with structured outputs and clear unstructured content + // to avoid leaking mixed tokens like "Hello" + if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && remainingContent != "" { + structuredOutputsState = structuredOutputsState_ReadyToApply + res.Message.Content = "" + ch <- res + cancel() + return + } + res.Message.Content = remainingContent + } + + if len(req.Tools) > 0 { + toolCalls, content := toolParser.Add(res.Message.Content) + if len(content) > 0 { + res.Message.Content = content + } else if len(toolCalls) > 0 { + res.Message.ToolCalls = toolCalls + res.Message.Content = "" + } else if res.Message.Thinking != "" { + // don't return + } else { + if r.Done { + res.Message.Content = toolParser.Content() + ch <- res + } + return + } + } + + ch <- res + }) + if err != nil { + if structuredOutputsState == structuredOutputsState_ReadyToApply && strings.Contains(err.Error(), "context canceled") && c.Request.Context().Err() == nil { + // only ignores error if it's a context cancellation due to setting structured outputs + } else { + var serr api.StatusError + if errors.As(err, &serr) { + ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode} + } else { + ch <- gin.H{"error": err.Error()} } return } } - ch <- res - }); err != nil { - ch <- gin.H{"error": err.Error()} + // ignored structured outputs cancellation falls through to here, start a new request with the structured outputs and updated prompt. use the + if structuredOutputsState == structuredOutputsState_ReadyToApply { + structuredOutputsState = structuredOutputsState_Applying + msg := api.Message{ + Role: "assistant", + Thinking: tb.String(), + } + + msgs = append(msgs, msg) + prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate) + if err != nil { + slog.Error("chat prompt error applying structured outputs", "error", err) + ch <- gin.H{"error": err.Error()} + return + } + // force constraining by terminating thinking header, the parser is already at this state + // when the last message is thinking, the rendered for gpt-oss cannot disambiguate between having the + // model continue thinking or ending thinking and outputting the final message. + // TODO(parthsareen): consider adding prefill disambiguation logic to the renderer for structured outputs. + if shouldUseHarmony(m) || (builtinParser != nil && m.Config.Parser == "harmony") { + prompt += "<|end|><|start|>assistant<|channel|>final<|message|>" + } + continue + } + + break } }() @@ -2072,7 +2233,12 @@ func (s *Server) ChatHandler(c *gin.Context) { msg = "unexpected error format in response" } - c.JSON(http.StatusInternalServerError, gin.H{"error": msg}) + status, ok := t["status"].(int) + if !ok { + status = http.StatusInternalServerError + } + + c.JSON(status, gin.H{"error": msg}) return default: c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"}) diff --git a/server/routes_debug_test.go b/server/routes_debug_test.go index 6507284e..466951a1 100644 --- a/server/routes_debug_test.go +++ b/server/routes_debug_test.go @@ -30,15 +30,15 @@ func TestGenerateDebugRenderOnly(t *testing.T) { s := Server{ sched: &Scheduler{ - pendingReqCh: make(chan *LlmRequest, 1), - finishedReqCh: make(chan *LlmRequest, 1), - expiredCh: make(chan *runnerRef, 1), - unloadedCh: make(chan any, 1), - loaded: make(map[string]*runnerRef), - newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, - reschedDelay: 250 * time.Millisecond, + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, + waitForRecovery: 250 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { // add small delay to simulate loading time.Sleep(time.Millisecond) @@ -146,7 +146,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) { DebugRenderOnly: true, }, expectDebug: true, - expectTemplate: "[img-0]\n\nDescribe this image", + expectTemplate: "[img-0]Describe this image", expectNumImages: 1, }, { @@ -223,15 +223,15 @@ func TestChatDebugRenderOnly(t *testing.T) { s := Server{ sched: &Scheduler{ - pendingReqCh: make(chan *LlmRequest, 1), - finishedReqCh: make(chan *LlmRequest, 1), - expiredCh: make(chan *runnerRef, 1), - unloadedCh: make(chan any, 1), - loaded: make(map[string]*runnerRef), - newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, - reschedDelay: 250 * time.Millisecond, + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, + waitForRecovery: 250 * time.Millisecond, loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { // add small delay to simulate loading time.Sleep(time.Millisecond) diff --git a/server/routes_generate_renderer_test.go b/server/routes_generate_renderer_test.go new file mode 100644 index 00000000..ea18b1e5 --- /dev/null +++ b/server/routes_generate_renderer_test.go @@ -0,0 +1,313 @@ +package server + +import ( + "bytes" + "encoding/json" + "net/http" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/discover" + "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/llm" +) + +// TestGenerateWithBuiltinRenderer tests that api/generate uses built-in renderers +// when in chat-like flow (messages present, no suffix, no template) +func TestGenerateWithBuiltinRenderer(t *testing.T) { + gin.SetMode(gin.TestMode) + + mock := mockRunner{ + CompletionResponse: llm.CompletionResponse{ + Done: true, + DoneReason: llm.DoneReasonStop, + PromptEvalCount: 1, + PromptEvalDuration: 1, + EvalCount: 1, + EvalDuration: 1, + }, + } + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: getGpuFn, + getCpuFn: getCpuFn, + waitForRecovery: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { + time.Sleep(time.Millisecond) + req.successCh <- &runnerRef{ + llama: &mock, + } + return false + }, + }, + } + + go s.sched.Run(t.Context()) + + // Create a model with a built-in renderer (qwen3-coder) + _, digest := createBinFile(t, ggml.KV{ + "general.architecture": "qwen3", + "qwen3.block_count": uint32(1), + "qwen3.context_length": uint32(8192), + "qwen3.embedding_length": uint32(4096), + "qwen3.attention.head_count": uint32(32), + "qwen3.attention.head_count_kv": uint32(8), + "tokenizer.ggml.tokens": []string{""}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []*ggml.Tensor{ + {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + }) + + // Create a model with the qwen3-coder renderer + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "test-renderer", + Files: map[string]string{"file.gguf": digest}, + Renderer: "qwen3-coder", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + mock.CompletionResponse.Content = "Hi!" + + t.Run("chat-like flow uses renderer", func(t *testing.T) { + // Test that when using messages (chat-like flow), the built-in renderer is used + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-renderer", + Prompt: "Write a hello world function", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + // The qwen3-coder renderer produces output with <|im_start|> and <|im_end|> tags + // When messages are built internally from prompt, it should use the renderer + if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") { + t.Errorf("expected prompt to contain <|im_start|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt) + } + + if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_end|>") { + t.Errorf("expected prompt to contain <|im_end|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt) + } + }) + + t.Run("chat-like flow with system message uses renderer", func(t *testing.T) { + // Test that system messages work with the renderer + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-renderer", + Prompt: "Write a hello world function", + System: "You are a helpful coding assistant.", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + // Should contain the system message and use renderer format + if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>system") { + t.Errorf("expected prompt to contain system message with renderer format, got: %s", mock.CompletionRequest.Prompt) + } + + if !strings.Contains(mock.CompletionRequest.Prompt, "You are a helpful coding assistant.") { + t.Errorf("expected prompt to contain system message content, got: %s", mock.CompletionRequest.Prompt) + } + }) + + t.Run("custom template bypasses renderer", func(t *testing.T) { + // Test that providing a custom template uses the legacy flow + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-renderer", + Prompt: "Write a hello world function", + Template: "{{ .Prompt }}", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + // Should NOT use the renderer format when custom template is provided + if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") { + t.Errorf("expected prompt to NOT use renderer when custom template provided, got: %s", mock.CompletionRequest.Prompt) + } + + // Should just be the raw prompt from the template + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Write a hello world function"); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + // Create a model with suffix support for the next test + w = createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "test-suffix-renderer", + From: "test-renderer", + Template: `{{- if .Suffix }}
 {{ .Prompt }} {{ .Suffix }} 
+{{- else }}{{ .Prompt }}
+{{- end }}`,
+	})
+
+	if w.Code != http.StatusOK {
+		t.Fatalf("expected status 200, got %d", w.Code)
+	}
+
+	t.Run("suffix bypasses renderer", func(t *testing.T) {
+		// Test that providing a suffix uses the legacy flow
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-suffix-renderer",
+			Prompt: "def add(",
+			Suffix: "    return c",
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		// Should NOT use the renderer format when suffix is provided
+		if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
+			t.Errorf("expected prompt to NOT use renderer when suffix provided, got: %s", mock.CompletionRequest.Prompt)
+		}
+
+		// Should use the suffix template format
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "
 def add(     return c "); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+}
+
+// TestGenerateWithDebugRenderOnly tests that debug_render_only works with built-in renderers
+func TestGenerateWithDebugRenderOnly(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
+	mock := mockRunner{
+		CompletionResponse: llm.CompletionResponse{
+			Done:               true,
+			DoneReason:         llm.DoneReasonStop,
+			PromptEvalCount:    1,
+			PromptEvalDuration: 1,
+			EvalCount:          1,
+			EvalDuration:       1,
+		},
+	}
+
+	s := Server{
+		sched: &Scheduler{
+			pendingReqCh:    make(chan *LlmRequest, 1),
+			finishedReqCh:   make(chan *LlmRequest, 1),
+			expiredCh:       make(chan *runnerRef, 1),
+			unloadedCh:      make(chan any, 1),
+			loaded:          make(map[string]*runnerRef),
+			newServerFn:     newMockServer(&mock),
+			getGpuFn:        getGpuFn,
+			getCpuFn:        getCpuFn,
+			waitForRecovery: 250 * time.Millisecond,
+			loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
+				time.Sleep(time.Millisecond)
+				req.successCh <- &runnerRef{
+					llama: &mock,
+				}
+				return false
+			},
+		},
+	}
+
+	go s.sched.Run(t.Context())
+
+	// Create a model with a built-in renderer
+	_, digest := createBinFile(t, ggml.KV{
+		"general.architecture":          "qwen3",
+		"qwen3.block_count":             uint32(1),
+		"qwen3.context_length":          uint32(8192),
+		"qwen3.embedding_length":        uint32(4096),
+		"qwen3.attention.head_count":    uint32(32),
+		"qwen3.attention.head_count_kv": uint32(8),
+		"tokenizer.ggml.tokens":         []string{""},
+		"tokenizer.ggml.scores":         []float32{0},
+		"tokenizer.ggml.token_type":     []int32{0},
+	}, []*ggml.Tensor{
+		{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+		{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+		{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+		{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+		{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+		{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+		{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+		{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+		{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+		{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+		{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+	})
+
+	w := createRequest(t, s.CreateHandler, api.CreateRequest{
+		Model:    "test-debug-renderer",
+		Files:    map[string]string{"file.gguf": digest},
+		Renderer: "qwen3-coder",
+		Stream:   &stream,
+	})
+
+	if w.Code != http.StatusOK {
+		t.Fatalf("expected status 200, got %d", w.Code)
+	}
+
+	t.Run("debug_render_only with renderer", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:           "test-debug-renderer",
+			Prompt:          "Write a hello world function",
+			System:          "You are a coding assistant",
+			DebugRenderOnly: true,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		var resp api.GenerateResponse
+		if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
+			t.Fatal(err)
+		}
+
+		if resp.DebugInfo == nil {
+			t.Fatalf("expected debug info, got nil")
+		}
+
+		// Verify that the rendered template uses the built-in renderer
+		if !strings.Contains(resp.DebugInfo.RenderedTemplate, "<|im_start|>") {
+			t.Errorf("expected rendered template to use qwen3-coder renderer format, got: %s", resp.DebugInfo.RenderedTemplate)
+		}
+
+		if !strings.Contains(resp.DebugInfo.RenderedTemplate, "You are a coding assistant") {
+			t.Errorf("expected rendered template to contain system message, got: %s", resp.DebugInfo.RenderedTemplate)
+		}
+
+		if !strings.Contains(resp.DebugInfo.RenderedTemplate, "Write a hello world function") {
+			t.Errorf("expected rendered template to contain prompt, got: %s", resp.DebugInfo.RenderedTemplate)
+		}
+	})
+}
diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go
index a3b83fc1..75d4f012 100644
--- a/server/routes_generate_test.go
+++ b/server/routes_generate_test.go
@@ -68,15 +68,15 @@ func TestGenerateChat(t *testing.T) {
 
 	s := Server{
 		sched: &Scheduler{
-			pendingReqCh:  make(chan *LlmRequest, 1),
-			finishedReqCh: make(chan *LlmRequest, 1),
-			expiredCh:     make(chan *runnerRef, 1),
-			unloadedCh:    make(chan any, 1),
-			loaded:        make(map[string]*runnerRef),
-			newServerFn:   newMockServer(&mock),
-			getGpuFn:      discover.GetGPUInfo,
-			getCpuFn:      discover.GetCPUInfo,
-			reschedDelay:  250 * time.Millisecond,
+			pendingReqCh:    make(chan *LlmRequest, 1),
+			finishedReqCh:   make(chan *LlmRequest, 1),
+			expiredCh:       make(chan *runnerRef, 1),
+			unloadedCh:      make(chan any, 1),
+			loaded:          make(map[string]*runnerRef),
+			newServerFn:     newMockServer(&mock),
+			getGpuFn:        getGpuFn,
+			getCpuFn:        getCpuFn,
+			waitForRecovery: 250 * time.Millisecond,
 			loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
 				// add small delay to simulate loading
 				time.Sleep(time.Millisecond)
@@ -158,11 +158,26 @@ func TestGenerateChat(t *testing.T) {
 			t.Errorf("expected status 400, got %d", w.Code)
 		}
 
-		if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
+		if diff := cmp.Diff(w.Body.String(), `{"error":"\"test\" does not support thinking"}`); diff != "" {
 			t.Errorf("mismatch (-got +want):\n%s", diff)
 		}
 	})
 
+	t.Run("model can't think but think set false", func(t *testing.T) {
+		think := false
+		w := createRequest(t, s.ChatHandler, api.ChatRequest{
+			Model: "test",
+			Messages: []api.Message{
+				{Role: "user", Content: "Hello!"},
+			},
+			Think: &api.ThinkValue{Value: think},
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+	})
+
 	t.Run("missing model", func(t *testing.T) {
 		w := createRequest(t, s.ChatHandler, api.ChatRequest{})
 		if w.Code != http.StatusBadRequest {
@@ -594,6 +609,58 @@ func TestGenerateChat(t *testing.T) {
 			t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
 		}
 	})
+
+	t.Run("status error non-streaming", func(t *testing.T) {
+		mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
+			return api.StatusError{
+				StatusCode:   http.StatusServiceUnavailable,
+				Status:       "Service Unavailable",
+				ErrorMessage: "model is overloaded",
+			}
+		}
+
+		stream := false
+		w := createRequest(t, s.ChatHandler, api.ChatRequest{
+			Model: "test",
+			Messages: []api.Message{
+				{Role: "user", Content: "Hello!"},
+			},
+			Stream: &stream,
+		})
+
+		if w.Code != http.StatusServiceUnavailable {
+			t.Errorf("expected status 503, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), `{"error":"model is overloaded"}`); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("status error streaming", func(t *testing.T) {
+		mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
+			return api.StatusError{
+				StatusCode:   http.StatusTooManyRequests,
+				Status:       "Too Many Requests",
+				ErrorMessage: "rate limit exceeded",
+			}
+		}
+
+		w := createRequest(t, s.ChatHandler, api.ChatRequest{
+			Model: "test",
+			Messages: []api.Message{
+				{Role: "user", Content: "Hello!"},
+			},
+		})
+
+		if w.Code != http.StatusTooManyRequests {
+			t.Errorf("expected status 429, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), `{"error":"rate limit exceeded"}`); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
 }
 
 func TestGenerate(t *testing.T) {
@@ -612,15 +679,15 @@ func TestGenerate(t *testing.T) {
 
 	s := Server{
 		sched: &Scheduler{
-			pendingReqCh:  make(chan *LlmRequest, 1),
-			finishedReqCh: make(chan *LlmRequest, 1),
-			expiredCh:     make(chan *runnerRef, 1),
-			unloadedCh:    make(chan any, 1),
-			loaded:        make(map[string]*runnerRef),
-			newServerFn:   newMockServer(&mock),
-			getGpuFn:      discover.GetGPUInfo,
-			getCpuFn:      discover.GetCPUInfo,
-			reschedDelay:  250 * time.Millisecond,
+			pendingReqCh:    make(chan *LlmRequest, 1),
+			finishedReqCh:   make(chan *LlmRequest, 1),
+			expiredCh:       make(chan *runnerRef, 1),
+			unloadedCh:      make(chan any, 1),
+			loaded:          make(map[string]*runnerRef),
+			newServerFn:     newMockServer(&mock),
+			getGpuFn:        getGpuFn,
+			getCpuFn:        getCpuFn,
+			waitForRecovery: 250 * time.Millisecond,
 			loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
 				// add small delay to simulate loading
 				time.Sleep(time.Millisecond)
@@ -968,6 +1035,55 @@ func TestGenerate(t *testing.T) {
 			t.Errorf("mismatch (-got +want):\n%s", diff)
 		}
 	})
+
+	t.Run("status error non-streaming", func(t *testing.T) {
+		mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
+			return api.StatusError{
+				StatusCode:   http.StatusServiceUnavailable,
+				Status:       "Service Unavailable",
+				ErrorMessage: "model is overloaded",
+			}
+		}
+
+		streamRequest := false
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test",
+			Prompt: "Hello!",
+			Stream: &streamRequest,
+		})
+
+		if w.Code != http.StatusServiceUnavailable {
+			t.Errorf("expected status 503, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), `{"error":"model is overloaded"}`); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("status error streaming", func(t *testing.T) {
+		mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
+			return api.StatusError{
+				StatusCode:   http.StatusTooManyRequests,
+				Status:       "Too Many Requests",
+				ErrorMessage: "rate limit exceeded",
+			}
+		}
+
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test",
+			Prompt: "Hello!",
+			Stream: &stream,
+		})
+
+		if w.Code != http.StatusTooManyRequests {
+			t.Errorf("expected status 429, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), `{"error":"rate limit exceeded"}`); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
 }
 
 func TestChatWithPromptEndingInThinkTag(t *testing.T) {
@@ -988,15 +1104,15 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
 
 		s := &Server{
 			sched: &Scheduler{
-				pendingReqCh:  make(chan *LlmRequest, 1),
-				finishedReqCh: make(chan *LlmRequest, 1),
-				expiredCh:     make(chan *runnerRef, 1),
-				unloadedCh:    make(chan any, 1),
-				loaded:        make(map[string]*runnerRef),
-				newServerFn:   newMockServer(mock),
-				getGpuFn:      discover.GetGPUInfo,
-				getCpuFn:      discover.GetCPUInfo,
-				reschedDelay:  250 * time.Millisecond,
+				pendingReqCh:    make(chan *LlmRequest, 1),
+				finishedReqCh:   make(chan *LlmRequest, 1),
+				expiredCh:       make(chan *runnerRef, 1),
+				unloadedCh:      make(chan any, 1),
+				loaded:          make(map[string]*runnerRef),
+				newServerFn:     newMockServer(mock),
+				getGpuFn:        getGpuFn,
+				getCpuFn:        getCpuFn,
+				waitForRecovery: 250 * time.Millisecond,
 				loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
 					time.Sleep(time.Millisecond)
 					req.successCh <- &runnerRef{llama: mock}
@@ -1120,13 +1236,6 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
 		"The answer is 4.",
 		true)
 
-	testChatRequest(t, "thinking disabled but template still adds think tag",
-		"Simple question",
-		" My thoughts  The answer.",
-		"",
-		" My thoughts  The answer.",
-		false)
-
 	// Test streaming response with template-added 
 	t.Run("streaming with thinking", func(t *testing.T) {
 		var wg sync.WaitGroup
@@ -1198,4 +1307,238 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
 			t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got)
 		}
 	})
+
+	t.Run("structured outputs restart non-stream", func(t *testing.T) {
+		var (
+			requestsMu sync.Mutex
+			requests   []llm.CompletionRequest
+			wg         sync.WaitGroup
+		)
+
+		wg.Add(2)
+
+		format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
+
+		mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
+			defer wg.Done()
+
+			requestsMu.Lock()
+			requests = append(requests, r)
+			callNum := len(requests)
+			requestsMu.Unlock()
+
+			switch callNum {
+			case 1:
+				fn(llm.CompletionResponse{
+					Content:            " I am thinking through this problem.  {\"answer\":\"42\"}",
+					Done:               false,
+					PromptEvalCount:    1,
+					PromptEvalDuration: 1,
+				})
+
+				select {
+				case <-ctx.Done():
+					return ctx.Err()
+				case <-time.After(time.Second):
+					t.Fatalf("timeout waiting for structured outputs cancellation")
+					return nil
+				}
+			case 2:
+				fn(llm.CompletionResponse{
+					Content:            `{"answer":"42"}`,
+					Done:               true,
+					DoneReason:         llm.DoneReasonStop,
+					PromptEvalCount:    1,
+					PromptEvalDuration: 1,
+					EvalCount:          1,
+					EvalDuration:       1,
+				})
+				return nil
+			default:
+				t.Fatalf("unexpected number of completion calls: %d", callNum)
+				return nil
+			}
+		}
+
+		think := true
+		streamRequest := false
+		w := createRequest(t, s.ChatHandler, api.ChatRequest{
+			Model:    "test-thinking",
+			Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
+			Think:    &api.ThinkValue{Value: think},
+			Stream:   &streamRequest,
+			Format:   format,
+		})
+
+		wg.Wait()
+		mock.CompletionFn = nil
+
+		if w.Code != http.StatusOK {
+			t.Fatalf("expected status 200, got %d", w.Code)
+		}
+
+		if len(requests) != 2 {
+			t.Fatalf("expected two completion calls, got %d", len(requests))
+		}
+
+		if requests[0].Format != nil {
+			t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
+		}
+
+		if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
+			t.Errorf("expected second completion format to match original format")
+		}
+
+		var resp api.ChatResponse
+		if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
+			t.Fatal(err)
+		}
+
+		if resp.Message.Thinking != "I am thinking through this problem. " {
+			t.Errorf("expected thinking %q, got %q", "I am thinking through this problem. ", resp.Message.Thinking)
+		}
+
+		if resp.Message.Content != `{"answer":"42"}` {
+			t.Errorf("expected content %q, got %q", `{"answer":"42"}`, resp.Message.Content)
+		}
+
+		if !resp.Done {
+			t.Errorf("expected response to be done")
+		}
+
+		if resp.DoneReason != "stop" {
+			t.Errorf("expected done reason stop, got %s", resp.DoneReason)
+		}
+	})
+
+	t.Run("structured outputs restart streaming", func(t *testing.T) {
+		var (
+			requestsMu sync.Mutex
+			requests   []llm.CompletionRequest
+			wg         sync.WaitGroup
+		)
+
+		wg.Add(2)
+
+		format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
+
+		mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
+			defer wg.Done()
+
+			requestsMu.Lock()
+			requests = append(requests, r)
+			callNum := len(requests)
+			requestsMu.Unlock()
+
+			switch callNum {
+			case 1:
+				fn(llm.CompletionResponse{
+					Content:            " I am thinking through this problem.  {\"answer\":\"42\"}",
+					Done:               false,
+					PromptEvalCount:    1,
+					PromptEvalDuration: 1,
+				})
+
+				select {
+				case <-ctx.Done():
+					return ctx.Err()
+				case <-time.After(time.Second):
+					t.Fatalf("timeout waiting for structured outputs cancellation")
+					return nil
+				}
+			case 2:
+				fn(llm.CompletionResponse{
+					Content:            `{"answer":"42"}`,
+					Done:               true,
+					DoneReason:         llm.DoneReasonStop,
+					PromptEvalCount:    1,
+					PromptEvalDuration: 1,
+					EvalCount:          1,
+					EvalDuration:       1,
+				})
+				return nil
+			default:
+				t.Fatalf("unexpected number of completion calls: %d", callNum)
+				return nil
+			}
+		}
+
+		think := true
+		streamRequest := true
+		w := createRequest(t, s.ChatHandler, api.ChatRequest{
+			Model:    "test-thinking",
+			Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
+			Think:    &api.ThinkValue{Value: think},
+			Stream:   &streamRequest,
+			Format:   format,
+		})
+
+		wg.Wait()
+		mock.CompletionFn = nil
+
+		if w.Code != http.StatusOK {
+			t.Fatalf("expected status 200, got %d", w.Code)
+		}
+
+		if len(requests) != 2 {
+			t.Fatalf("expected two completion calls, got %d", len(requests))
+		}
+
+		if requests[0].Format != nil {
+			t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
+		}
+
+		if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
+			t.Errorf("expected second completion format to match original format")
+		}
+
+		decoder := json.NewDecoder(w.Body)
+		var events []api.ChatResponse
+		for {
+			var event api.ChatResponse
+			if err := decoder.Decode(&event); err == io.EOF {
+				break
+			} else if err != nil {
+				t.Fatal(err)
+			}
+			events = append(events, event)
+			if event.Done {
+				break
+			}
+		}
+
+		if len(events) < 2 {
+			t.Fatalf("expected at least two streaming events, got %d", len(events))
+		}
+
+		first := events[0]
+		if first.Message.Thinking != "I am thinking through this problem. " {
+			t.Errorf("expected first event thinking %q, got %q", "I am thinking through this problem. ", first.Message.Thinking)
+		}
+
+		if first.Message.Content != "" {
+			t.Errorf("expected first event content to be empty, got %q", first.Message.Content)
+		}
+
+		if first.Done {
+			t.Error("expected first event to be non-terminal")
+		}
+
+		last := events[len(events)-1]
+		if last.Message.Thinking != "" {
+			t.Errorf("expected final event thinking to be empty, got %q", last.Message.Thinking)
+		}
+
+		if last.Message.Content != `{"answer":"42"}` {
+			t.Errorf("expected final event content %q, got %q", `{"answer":"42"}`, last.Message.Content)
+		}
+
+		if !last.Done {
+			t.Error("expected final event to be done")
+		}
+
+		if last.DoneReason != "stop" {
+			t.Errorf("expected final done reason stop, got %s", last.DoneReason)
+		}
+	})
 }
diff --git a/server/routes_harmony_streaming_test.go b/server/routes_harmony_streaming_test.go
index b1ede4e3..caf2cf6d 100644
--- a/server/routes_harmony_streaming_test.go
+++ b/server/routes_harmony_streaming_test.go
@@ -268,15 +268,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
 
 			s := Server{
 				sched: &Scheduler{
-					pendingReqCh:  make(chan *LlmRequest, 1),
-					finishedReqCh: make(chan *LlmRequest, 1),
-					expiredCh:     make(chan *runnerRef, 1),
-					unloadedCh:    make(chan any, 1),
-					loaded:        make(map[string]*runnerRef),
-					newServerFn:   newMockServer(&mock),
-					getGpuFn:      discover.GetGPUInfo,
-					getCpuFn:      discover.GetCPUInfo,
-					reschedDelay:  100 * time.Millisecond,
+					pendingReqCh:    make(chan *LlmRequest, 1),
+					finishedReqCh:   make(chan *LlmRequest, 1),
+					expiredCh:       make(chan *runnerRef, 1),
+					unloadedCh:      make(chan any, 1),
+					loaded:          make(map[string]*runnerRef),
+					newServerFn:     newMockServer(&mock),
+					getGpuFn:        getGpuFn,
+					getCpuFn:        getCpuFn,
+					waitForRecovery: 100 * time.Millisecond,
 					loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
 						req.successCh <- &runnerRef{
 							llama: &mock,
@@ -419,15 +419,15 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
 
 	s := Server{
 		sched: &Scheduler{
-			pendingReqCh:  make(chan *LlmRequest, 1),
-			finishedReqCh: make(chan *LlmRequest, 1),
-			expiredCh:     make(chan *runnerRef, 1),
-			unloadedCh:    make(chan any, 1),
-			loaded:        make(map[string]*runnerRef),
-			newServerFn:   newMockServer(&mock),
-			getGpuFn:      discover.GetGPUInfo,
-			getCpuFn:      discover.GetCPUInfo,
-			reschedDelay:  100 * time.Millisecond,
+			pendingReqCh:    make(chan *LlmRequest, 1),
+			finishedReqCh:   make(chan *LlmRequest, 1),
+			expiredCh:       make(chan *runnerRef, 1),
+			unloadedCh:      make(chan any, 1),
+			loaded:          make(map[string]*runnerRef),
+			newServerFn:     newMockServer(&mock),
+			getGpuFn:        getGpuFn,
+			getCpuFn:        getCpuFn,
+			waitForRecovery: 100 * time.Millisecond,
 			loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
 				req.successCh <- &runnerRef{
 					llama: &mock,
@@ -601,15 +601,15 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
 
 			s := Server{
 				sched: &Scheduler{
-					pendingReqCh:  make(chan *LlmRequest, 1),
-					finishedReqCh: make(chan *LlmRequest, 1),
-					expiredCh:     make(chan *runnerRef, 1),
-					unloadedCh:    make(chan any, 1),
-					loaded:        make(map[string]*runnerRef),
-					newServerFn:   newMockServer(&mock),
-					getGpuFn:      discover.GetGPUInfo,
-					getCpuFn:      discover.GetCPUInfo,
-					reschedDelay:  250 * time.Millisecond,
+					pendingReqCh:    make(chan *LlmRequest, 1),
+					finishedReqCh:   make(chan *LlmRequest, 1),
+					expiredCh:       make(chan *runnerRef, 1),
+					unloadedCh:      make(chan any, 1),
+					loaded:          make(map[string]*runnerRef),
+					newServerFn:     newMockServer(&mock),
+					getGpuFn:        getGpuFn,
+					getCpuFn:        getCpuFn,
+					waitForRecovery: 250 * time.Millisecond,
 					loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
 						req.successCh <- &runnerRef{
 							llama: &mock,
diff --git a/server/sched.go b/server/sched.go
index 74aa406a..7c639953 100644
--- a/server/sched.go
+++ b/server/sched.go
@@ -21,6 +21,8 @@ import (
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/fs/ggml"
 	"github.com/ollama/ollama/llm"
+	"github.com/ollama/ollama/logutil"
+	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/types/model"
 )
 
@@ -50,11 +52,13 @@ type Scheduler struct {
 	activeLoading llm.LlamaServer
 	loaded        map[string]*runnerRef
 
-	loadFn       func(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool
-	newServerFn  func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
-	getGpuFn     func() discover.GpuInfoList
-	getCpuFn     func() discover.GpuInfoList
-	reschedDelay time.Duration
+	loadFn      func(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool
+	newServerFn func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
+	getGpuFn    func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList
+	getCpuFn    func() discover.GpuInfo
+
+	// waitForRecovery sets the limit for how long to wait for memory usage to recover after unload before scheduling the next model
+	waitForRecovery time.Duration
 }
 
 // Default automatic value for number of models we allow per GPU
@@ -67,15 +71,15 @@ var ErrMaxQueue = errors.New("server busy, please try again.  maximum pending re
 func InitScheduler(ctx context.Context) *Scheduler {
 	maxQueue := envconfig.MaxQueue()
 	sched := &Scheduler{
-		pendingReqCh:  make(chan *LlmRequest, maxQueue),
-		finishedReqCh: make(chan *LlmRequest, maxQueue),
-		expiredCh:     make(chan *runnerRef, maxQueue),
-		unloadedCh:    make(chan any, maxQueue),
-		loaded:        make(map[string]*runnerRef),
-		newServerFn:   llm.NewLlamaServer,
-		getGpuFn:      discover.GetGPUInfo,
-		getCpuFn:      discover.GetCPUInfo,
-		reschedDelay:  250 * time.Millisecond,
+		pendingReqCh:    make(chan *LlmRequest, maxQueue),
+		finishedReqCh:   make(chan *LlmRequest, maxQueue),
+		expiredCh:       make(chan *runnerRef, maxQueue),
+		unloadedCh:      make(chan any, maxQueue),
+		loaded:          make(map[string]*runnerRef),
+		newServerFn:     llm.NewLlamaServer,
+		getGpuFn:        discover.GetGPUInfo,
+		getCpuFn:        discover.GetCPUInfo,
+		waitForRecovery: 5 * time.Second,
 	}
 	sched.loadFn = sched.load
 	return sched
@@ -148,7 +152,12 @@ func (s *Scheduler) processPending(ctx context.Context) {
 				s.loadedMu.Lock()
 				runner := s.loaded[pending.model.ModelPath]
 				loadedCount := len(s.loaded)
+				runnersSnapshot := make([]discover.FilteredRunnerDiscovery, 0, len(s.loaded))
+				for _, r := range s.loaded {
+					runnersSnapshot = append(runnersSnapshot, r)
+				}
 				s.loadedMu.Unlock()
+
 				if runner != nil {
 					if runner.needsReload(ctx, pending) {
 						slog.Debug("reloading", "runner", runner)
@@ -166,9 +175,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
 					// Get a refreshed GPU list
 					var gpus discover.GpuInfoList
 					if pending.opts.NumGPU == 0 {
-						gpus = s.getCpuFn()
+						gpus = discover.GpuInfoList{s.getCpuFn()}
 					} else {
-						gpus = s.getGpuFn()
+						gpus = s.getGpuFn(ctx, runnersSnapshot)
 					}
 
 					if envconfig.MaxRunners() <= 0 {
@@ -223,8 +232,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
 				}
 
 				if runnerToExpire == nil {
-					// Shouildn't happen
-					slog.Error("runner to expire was nil!")
+					// While we were performing load calculations, the loaded runner(s) unloaded in parallel
+					// so findRunnerToUnload returned no runners.  We'll try again and the loadedCount should be zero
+					slog.Debug("runner to expire was nil, retrying")
 					continue
 				}
 				// Trigger an expiration to unload once it's done
@@ -343,7 +353,11 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
 				runner.refMu.Unlock()
 			} else {
 				slog.Debug("starting background wait for VRAM recovery", "runner", runner)
-				finished := runner.waitForVRAMRecovery()
+				runnersSnapshot := make([]discover.FilteredRunnerDiscovery, 0, len(s.loaded))
+				for _, r := range s.loaded {
+					runnersSnapshot = append(runnersSnapshot, r)
+				}
+				finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
 				runner.unload()
 				delete(s.loaded, runner.modelPath)
 				s.loadedMu.Unlock()
@@ -429,7 +443,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
 
 	s.loadedMu.Unlock()
 
-	err := llama.Load(req.ctx, gpus, requireFull)
+	gpuIDs, err := llama.Load(req.ctx, gpus, requireFull)
 	if err != nil {
 		if errors.Is(err, llm.ErrLoadRequiredFull) {
 			return true
@@ -448,7 +462,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
 		llama:           llama,
 		Options:         &req.opts,
 		sessionDuration: sessionDuration,
-		gpus:            gpus,
+		gpus:            gpuIDs,
 		vramSize:        llama.VRAMSize(),
 		totalSize:       llama.TotalSize(),
 		loading:         true,
@@ -497,11 +511,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
 }
 
 func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
-	type predKey struct {
-		Library string
-		ID      string
-	}
-	predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
+	predMap := map[ml.DeviceID]uint64{} // Sum up the total predicted usage per GPU for all runners
 	s.loadedMu.Lock()
 	runners := make([]*runnerRef, 0, len(s.loaded))
 	for _, r := range s.loaded {
@@ -512,7 +522,7 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
 		r.refMu.Lock()
 		if r.llama != nil {
 			for _, gpu := range allGpus {
-				predMap[predKey{gpu.Library, gpu.ID}] += r.llama.VRAMByGPU(gpu.ID)
+				predMap[gpu.DeviceID] += r.llama.VRAMByGPU(gpu.DeviceID)
 			}
 		} else {
 			slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
@@ -522,7 +532,7 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
 
 	// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
 	for i := range allGpus {
-		if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok {
+		if p, ok := predMap[allGpus[i].DeviceID]; ok {
 			slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory))
 			if p > allGpus[i].TotalMemory {
 				// Shouldn't happen
@@ -546,8 +556,8 @@ type runnerRef struct {
 
 	llama     llm.LlamaServer
 	pid       int
-	loading   bool                 // True only during initial load, then false forever
-	gpus      discover.GpuInfoList // Recorded at time of provisioning
+	loading   bool          // True only during initial load, then false forever
+	gpus      []ml.DeviceID // Recorded at time of provisioning
 	vramSize  uint64
 	totalSize uint64
 
@@ -571,7 +581,6 @@ func (runner *runnerRef) unload() {
 		runner.llama.Close()
 	}
 	runner.model = nil
-	runner.llama = nil
 	runner.Options = nil
 	runner.gpus = nil
 }
@@ -618,14 +627,14 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
 // a before and after GPU memory allocation.  The returned channel
 // will be notified when we're done waiting, or have timed out and should
 // proceed anyway
-func (runner *runnerRef) waitForVRAMRecovery() chan any {
+func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []discover.FilteredRunnerDiscovery) chan any {
 	finished := make(chan any, 1)
 
 	// CPU or Metal don't need checking, so no waiting required
 	// windows can page VRAM, only cuda currently can report accurate used vram usage
 	if len(runner.gpus) == 0 ||
-		(len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "metal")) ||
-		(runtime.GOOS == "windows" && runner.gpus[0].Library != "cuda") {
+		(len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "Metal")) ||
+		(runtime.GOOS == "windows" && runner.gpus[0].Library != "CUDA") {
 		finished <- struct{}{}
 		slog.Debug("no need to wait for VRAM recovery", "runner", runner)
 		return finished
@@ -633,33 +642,41 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
 	start := time.Now()
 
 	// Establish a baseline before we unload
-	gpusBefore := discover.GetGPUInfo()
+	gpusBefore := s.getGpuFn(context.Background(), runners)
 	var totalMemoryBefore, freeMemoryBefore uint64
 	for _, gpu := range gpusBefore {
 		totalMemoryBefore += gpu.TotalMemory
 		freeMemoryBefore += gpu.FreeMemory
 	}
+	totalMemoryNow := totalMemoryBefore
+	freeMemoryNow := freeMemoryBefore
+
 	go func() {
-		expiresAt := start.Add(5 * time.Second) // typical convergence is 0.5-1.5s
+		// typical convergence is 0.5-1.5s - If it takes too long to discover and converge, let the scheduler estimate VRAM usage
+		ctx, cancel := context.WithTimeout(context.Background(), s.waitForRecovery)
+		defer cancel()
 		ticker := time.NewTicker(250 * time.Millisecond)
 		defer ticker.Stop()
 		for {
-			<-ticker.C
-			if time.Now().After(expiresAt) {
-				slog.Warn("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "runner", runner)
-				finished <- struct{}{}
-			}
-
-			// Query GPUs, look for free to go back up
-			gpusNow := discover.GetGPUInfo()
-			var totalMemoryNow, freeMemoryNow uint64
-			for _, gpu := range gpusNow {
-				totalMemoryNow += gpu.TotalMemory
-				freeMemoryNow += gpu.FreeMemory
-			}
-			// If we're within ~80% of the estimated memory usage recovered, bail out
-			if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.8 {
-				slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "runner", runner)
+			select {
+			case <-ticker.C:
+				// Query GPUs, look for free to go back up
+				gpusNow := s.getGpuFn(ctx, runners)
+				totalMemoryNow = 0
+				freeMemoryNow = 0
+				for _, gpu := range gpusNow {
+					totalMemoryNow += gpu.TotalMemory
+					freeMemoryNow += gpu.FreeMemory
+				}
+				logutil.Trace("gpu VRAM convergence", "percent", int(max(float32(freeMemoryNow-freeMemoryBefore), 0.0)/float32(runner.vramSize)*100))
+				// If we're within ~75% of the estimated memory usage recovered, bail out
+				if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.75 {
+					slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
+					finished <- struct{}{}
+					return
+				}
+			case <-ctx.Done():
+				slog.Debug("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
 				finished <- struct{}{}
 				return
 			}
@@ -678,8 +695,7 @@ func (runner *runnerRef) LogValue() slog.Value {
 	}
 	if len(runner.gpus) > 0 {
 		attrs = append(attrs,
-			slog.String("inference", runner.gpus[0].Library),
-			slog.Int("devices", len(runner.gpus)),
+			slog.Any("inference", runner.gpus),
 		)
 	}
 	attrs = append(attrs,
@@ -695,6 +711,32 @@ func (runner *runnerRef) LogValue() slog.Value {
 	return slog.GroupValue(attrs...)
 }
 
+// Implements discover.RunnerDiscovery
+func (runner *runnerRef) GetPort() int {
+	if runner.llama != nil {
+		return runner.llama.GetPort()
+	}
+	return -1
+}
+
+func (runner *runnerRef) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
+	if runner.llama != nil {
+		return runner.llama.GetDeviceInfos(ctx)
+	}
+	return nil
+}
+
+func (runner *runnerRef) GetActiveDeviceIDs() []ml.DeviceID {
+	return runner.gpus
+}
+
+func (runner *runnerRef) HasExited() bool {
+	if runner.llama != nil {
+		return runner.llama.HasExited()
+	}
+	return true
+}
+
 type ByDurationAndName []*runnerRef
 
 func (a ByDurationAndName) Len() int      { return len(a) }
diff --git a/server/sched_test.go b/server/sched_test.go
index 0acd5911..66d43338 100644
--- a/server/sched_test.go
+++ b/server/sched_test.go
@@ -17,6 +17,7 @@ import (
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/fs/ggml"
 	"github.com/ollama/ollama/llm"
+	"github.com/ollama/ollama/ml"
 )
 
 func TestMain(m *testing.M) {
@@ -25,7 +26,7 @@ func TestMain(m *testing.M) {
 	os.Exit(m.Run())
 }
 
-func TestInitScheduler(t *testing.T) {
+func TestSchedInit(t *testing.T) {
 	ctx, done := context.WithCancel(t.Context())
 	defer done()
 	s := InitScheduler(ctx)
@@ -34,10 +35,11 @@ func TestInitScheduler(t *testing.T) {
 	s.loadedMu.Unlock()
 }
 
-func TestLoad(t *testing.T) {
+func TestSchedLoad(t *testing.T) {
 	ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
 	defer done()
 	s := InitScheduler(ctx)
+	s.waitForRecovery = 10 * time.Millisecond
 	var f *ggml.GGML // value not used in tests
 	req := &LlmRequest{
 		ctx:             ctx,
@@ -61,7 +63,7 @@ func TestLoad(t *testing.T) {
 	err := <-req.errCh
 	require.Contains(t, err.Error(), "this model may be incompatible")
 
-	server := &mockLlm{vramSize: 10, vramByGPU: map[string]uint64{}}
+	server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}}
 	s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
 		server.modelPath = model
 		return server, nil
@@ -109,7 +111,7 @@ func (scenario *reqBundle) newServer(gpus discover.GpuInfoList, model string, f
 	return scenario.srv, nil
 }
 
-func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vramSize uint64, duration *api.Duration) *reqBundle {
+func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vramSize uint64, duration *api.Duration, vramByGPU map[ml.DeviceID]uint64) *reqBundle {
 	b := &reqBundle{}
 	b.ctx, b.ctxDone = context.WithCancel(ctx)
 	t.Helper()
@@ -146,32 +148,35 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vra
 		successCh:       make(chan *runnerRef, 1),
 		errCh:           make(chan error, 1),
 	}
-	b.srv = &mockLlm{vramSize: vramSize, vramByGPU: map[string]uint64{"": vramSize}}
+	b.srv = &mockLlm{vramSize: vramSize, vramByGPU: vramByGPU}
 	return b
 }
 
-func getGpuFn() discover.GpuInfoList {
-	g := discover.GpuInfo{Library: "metal"}
+func getGpuFn(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
+	slog.Info("test getGpuFn called", "runners", runners)
+	g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
 	g.TotalMemory = 24 * format.GigaByte
 	g.FreeMemory = 12 * format.GigaByte
 	return []discover.GpuInfo{g}
 }
 
-func getCpuFn() discover.GpuInfoList {
-	g := discover.GpuInfo{Library: "cpu"}
+func getCpuFn() discover.GpuInfo {
+	slog.Info("test getCpuFn called")
+	g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "cpu"}}
 	g.TotalMemory = 32 * format.GigaByte
 	g.FreeMemory = 26 * format.GigaByte
-	return []discover.GpuInfo{g}
+	return g
 }
 
-func TestRequestsSameModelSameRequest(t *testing.T) {
+func TestSchedRequestsSameModelSameRequest(t *testing.T) {
 	ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
 	defer done()
 	s := InitScheduler(ctx)
+	s.waitForRecovery = 10 * time.Millisecond
 	s.getGpuFn = getGpuFn
 	s.getCpuFn = getCpuFn
-	a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
-	b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0})
+	a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}, nil)
+	b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0}, nil)
 	b.req.model = a.req.model
 	b.f = a.f
 
@@ -207,14 +212,15 @@ func TestRequestsSameModelSameRequest(t *testing.T) {
 	}
 }
 
-func TestRequestsSimpleReloadSameModel(t *testing.T) {
-	ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
+func TestSchedRequestsSimpleReloadSameModel(t *testing.T) {
+	ctx, done := context.WithTimeout(t.Context(), 5000*time.Millisecond)
 	defer done()
 	s := InitScheduler(ctx)
+	s.waitForRecovery = 10 * time.Millisecond
 	s.getGpuFn = getGpuFn
 	s.getCpuFn = getCpuFn
-	a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
-	b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond})
+	a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}, nil)
+	b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond}, nil)
 	tmpModel := *a.req.model
 	b.req.model = &tmpModel
 	b.f = a.f
@@ -243,6 +249,15 @@ func TestRequestsSimpleReloadSameModel(t *testing.T) {
 	// finish first two requests, so model can reload
 	time.Sleep(1 * time.Millisecond)
 	a.ctxDone()
+	// Report recovered VRAM usage
+	time.Sleep(1 * time.Millisecond)
+	s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
+		slog.Info("XXX altered getGpuFn called")
+		g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
+		g.TotalMemory = 24 * format.GigaByte
+		g.FreeMemory = 24 * format.GigaByte
+		return []discover.GpuInfo{g}
+	}
 	select {
 	case resp := <-b.req.successCh:
 		require.Equal(t, resp.llama, b.srv)
@@ -255,19 +270,23 @@ func TestRequestsSimpleReloadSameModel(t *testing.T) {
 	}
 }
 
-func TestRequestsMultipleLoadedModels(t *testing.T) {
+func TestSchedRequestsMultipleLoadedModels(t *testing.T) {
 	ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
 	defer done()
 	s := InitScheduler(ctx)
-	s.getGpuFn = getGpuFn
-	s.getCpuFn = getCpuFn
+	s.waitForRecovery = 10 * time.Millisecond
+	s.getGpuFn = getGpuFn // 1 metal GPU
+	s.getCpuFn = getCpuFn // 1 CPU
 
 	// Multiple loaded models
-	a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil)
-	b := newScenarioRequest(t, ctx, "ollama-model-3b", 10*format.GigaByte, nil)
-	c := newScenarioRequest(t, ctx, "ollama-model-4a", 10*format.GigaByte, nil)
-	c.req.opts.NumGPU = 0                                                       // CPU load, will be allowed
-	d := newScenarioRequest(t, ctx, "ollama-model-3c", 10*format.GigaByte, nil) // Needs prior unloaded
+	a := newScenarioRequest(t, ctx, "model-a-1g-gpu", 1*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 1 * format.GigaByte})
+	a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
+	b := newScenarioRequest(t, ctx, "model-b-10g-gpu", 10*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 10 * format.GigaByte})
+	b.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
+	c := newScenarioRequest(t, ctx, "model-c-10g-cpu", 10*format.GigaByte, nil, nil /* No GPU load */)
+	c.req.opts.NumGPU = 0                                                                                                                         // CPU load, will be allowed
+	b.req.sessionDuration = &api.Duration{Duration: 10 * time.Millisecond}                                                                        // longer than b to cause the scheduler to favor unloading b over c
+	d := newScenarioRequest(t, ctx, "model-d-10g-gpu", 13*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 13 * format.GigaByte}) // Needs prior unloaded
 
 	t.Setenv("OLLAMA_MAX_LOADED_MODELS", "1")
 	s.newServerFn = a.newServer
@@ -338,7 +357,16 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
 	s.loadedMu.Lock()
 	require.Len(t, s.loaded, 2)
 	s.loadedMu.Unlock()
+	// Mark b done so it can unload
 	b.ctxDone()
+	// Report recovered VRAM usage so scheduler will finish waiting and unload
+	time.Sleep(1 * time.Millisecond)
+	s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
+		g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
+		g.TotalMemory = 24 * format.GigaByte
+		g.FreeMemory = 24 * format.GigaByte
+		return []discover.GpuInfo{g}
+	}
 	select {
 	case resp := <-d.req.successCh:
 		require.Equal(t, resp.llama, d.srv)
@@ -347,20 +375,34 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
 	case <-ctx.Done():
 		t.Fatal("timeout")
 	}
+	// Wait for b to close
+closeWait:
+	for {
+		select {
+		case <-ctx.Done():
+			t.Fatal("timeout")
+		default:
+			if b.srv.closeCalled {
+				break closeWait
+			}
+			time.Sleep(1 * time.Millisecond)
+		}
+	}
 	s.loadedMu.Lock()
 	require.Len(t, s.loaded, 2)
 	s.loadedMu.Unlock()
 }
 
-func TestGetRunner(t *testing.T) {
+func TestSchedGetRunner(t *testing.T) {
 	ctx, done := context.WithTimeout(t.Context(), 3*time.Second)
 	defer done()
 
-	a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
-	b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond})
-	c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond})
+	a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond}, nil)
+	b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond}, nil)
+	c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond}, nil)
 	t.Setenv("OLLAMA_MAX_QUEUE", "1")
 	s := InitScheduler(ctx)
+	s.waitForRecovery = 10 * time.Millisecond
 	s.getGpuFn = getGpuFn
 	s.getCpuFn = getCpuFn
 	s.newServerFn = a.newServer
@@ -405,10 +447,11 @@ func TestGetRunner(t *testing.T) {
 	b.ctxDone()
 }
 
-func TestExpireRunner(t *testing.T) {
+func TestSchedExpireRunner(t *testing.T) {
 	ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
 	defer done()
 	s := InitScheduler(ctx)
+	s.waitForRecovery = 10 * time.Millisecond
 	req := &LlmRequest{
 		ctx:             ctx,
 		model:           &Model{ModelPath: "foo"},
@@ -420,7 +463,7 @@ func TestExpireRunner(t *testing.T) {
 
 	var f *ggml.GGML
 	gpus := discover.GpuInfoList{}
-	server := &mockLlm{vramSize: 10, vramByGPU: map[string]uint64{}}
+	server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}}
 	s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
 		server.modelPath = model
 		return server, nil
@@ -453,15 +496,16 @@ func TestExpireRunner(t *testing.T) {
 }
 
 // TODO - add one scenario that triggers the bogus finished event with positive ref count
-func TestPrematureExpired(t *testing.T) {
+func TestSchedPrematureExpired(t *testing.T) {
 	ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
 	defer done()
 
 	// Same model, same request
-	scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil)
+	scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil, nil)
 	s := InitScheduler(ctx)
-	s.getGpuFn = func() discover.GpuInfoList {
-		g := discover.GpuInfo{Library: "metal"}
+	s.waitForRecovery = 10 * time.Millisecond
+	s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
+		g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
 		g.TotalMemory = 24 * format.GigaByte
 		g.FreeMemory = 12 * format.GigaByte
 		return []discover.GpuInfo{g}
@@ -500,7 +544,7 @@ func TestPrematureExpired(t *testing.T) {
 	time.Sleep(5 * time.Millisecond)
 }
 
-func TestUseLoadedRunner(t *testing.T) {
+func TestSchedUseLoadedRunner(t *testing.T) {
 	ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
 	req := &LlmRequest{
 		ctx:             ctx,
@@ -509,7 +553,7 @@ func TestUseLoadedRunner(t *testing.T) {
 		sessionDuration: &api.Duration{Duration: 2},
 	}
 	finished := make(chan *LlmRequest)
-	llm1 := &mockLlm{vramByGPU: map[string]uint64{}}
+	llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
 	r1 := &runnerRef{llama: llm1, sessionDuration: 1, numParallel: 1}
 	req.useLoadedRunner(r1, finished)
 	require.Equal(t, uint(1), r1.refCount)
@@ -527,29 +571,40 @@ func TestUseLoadedRunner(t *testing.T) {
 	require.Equal(t, req, fin)
 }
 
-func TestUpdateFreeSpace(t *testing.T) {
+func TestSchedUpdateFreeSpace(t *testing.T) {
 	ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
 	defer done()
 	gpus := discover.GpuInfoList{
 		{
-			Library: "a",
-			ID:      "1",
+			DeviceID: ml.DeviceID{
+				ID: "1",
+			},
 		},
 		{
-			Library: "a",
-			ID:      "2",
+			DeviceID: ml.DeviceID{
+				ID: "2",
+			},
 		},
 	}
 	gpus[0].TotalMemory = 1000
 	gpus[0].FreeMemory = 900
 	gpus[1].TotalMemory = 2000
 	gpus[1].FreeMemory = 1900
-	llm1 := &mockLlm{vramByGPU: map[string]uint64{"1": 50, "2": 50}}
-	llm2 := &mockLlm{vramByGPU: map[string]uint64{"1": 125, "2": 75}}
-	r1 := &runnerRef{llama: llm1, gpus: gpus, numParallel: 1}
-	r2 := &runnerRef{llama: llm2, gpus: gpus, numParallel: 1}
+	gpuIDs := []ml.DeviceID{
+		{
+			ID: "1",
+		},
+		{
+			ID: "2",
+		},
+	}
+	llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{{ID: "1"}: 50, {ID: "2"}: 50}}
+	llm2 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{{ID: "1"}: 125, {ID: "2"}: 75}}
+	r1 := &runnerRef{llama: llm1, gpus: gpuIDs, numParallel: 1}
+	r2 := &runnerRef{llama: llm2, gpus: gpuIDs, numParallel: 1}
 
 	s := InitScheduler(ctx)
+	s.waitForRecovery = 10 * time.Millisecond
 	s.loadedMu.Lock()
 	s.loaded["a"] = r1
 	s.loaded["b"] = r2
@@ -560,7 +615,7 @@ func TestUpdateFreeSpace(t *testing.T) {
 	require.Equal(t, uint64(2000-50-75), gpus[1].FreeMemory)
 }
 
-func TestFindRunnerToUnload(t *testing.T) {
+func TestSchedFindRunnerToUnload(t *testing.T) {
 	ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
 	defer done()
 
@@ -568,6 +623,7 @@ func TestFindRunnerToUnload(t *testing.T) {
 	r2 := &runnerRef{sessionDuration: 2, numParallel: 1}
 
 	s := InitScheduler(ctx)
+	s.waitForRecovery = 10 * time.Millisecond
 	s.loadedMu.Lock()
 	s.loaded["a"] = r1
 	s.loaded["b"] = r2
@@ -580,11 +636,11 @@ func TestFindRunnerToUnload(t *testing.T) {
 	require.Equal(t, r1, resp)
 }
 
-func TestNeedsReload(t *testing.T) {
+func TestSchedNeedsReload(t *testing.T) {
 	ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
 	defer done()
 
-	llm := &mockLlm{vramByGPU: map[string]uint64{}}
+	llm := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
 	do := api.DefaultOptions()
 	runner := &runnerRef{
 		model: &Model{
@@ -627,13 +683,14 @@ func TestNeedsReload(t *testing.T) {
 	require.False(t, resp)
 }
 
-func TestUnloadAllRunners(t *testing.T) {
+func TestSchedUnloadAllRunners(t *testing.T) {
 	ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
 	defer done()
 
-	llm1 := &mockLlm{vramByGPU: map[string]uint64{}}
-	llm2 := &mockLlm{vramByGPU: map[string]uint64{}}
+	llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
+	llm2 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
 	s := InitScheduler(ctx)
+	s.waitForRecovery = 10 * time.Millisecond
 	s.unloadAllRunners()
 
 	r1 := &runnerRef{llama: llm1, numParallel: 1}
@@ -649,8 +706,8 @@ func TestUnloadAllRunners(t *testing.T) {
 	require.True(t, llm2.closeCalled)
 }
 
-func TestUnload(t *testing.T) {
-	llm1 := &mockLlm{vramByGPU: map[string]uint64{}}
+func TestSchedUnload(t *testing.T) {
+	llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
 	r1 := &runnerRef{llama: llm1, numParallel: 1}
 	r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}, numParallel: 1}
 	r1.unload()
@@ -659,13 +716,14 @@ func TestUnload(t *testing.T) {
 	require.Nil(t, r2.model)
 }
 
-func TestAlreadyCanceled(t *testing.T) {
+func TestSchedAlreadyCanceled(t *testing.T) {
 	ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
 	defer done()
 	dctx, done2 := context.WithCancel(ctx)
 	done2()
-	scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0})
+	scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0}, nil)
 	s := InitScheduler(ctx)
+	s.waitForRecovery = 10 * time.Millisecond
 	slog.Info("scenario1a")
 	s.pendingReqCh <- scenario1a.req
 	require.Len(t, s.pendingReqCh, 1)
@@ -691,24 +749,28 @@ type mockLlm struct {
 	closeCalled       bool
 	vramSize          uint64
 	totalSize         uint64
-	vramByGPU         map[string]uint64
+	vramByGPU         map[ml.DeviceID]uint64
 }
 
 func (s *mockLlm) ModelPath() string {
 	return s.modelPath
 }
 
-func (s *mockLlm) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error {
+func (s *mockLlm) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) ([]ml.DeviceID, error) {
 	if requireFull {
 		for _, g := range gpus {
 			if g.FreeMemory >= s.vramSize {
-				return nil
+				return []ml.DeviceID{g.DeviceID}, nil
 			}
 		}
 
-		return llm.ErrLoadRequiredFull
+		return nil, llm.ErrLoadRequiredFull
 	}
-	return nil
+	gpuIDs := make([]ml.DeviceID, len(gpus))
+	for i := range gpus {
+		gpuIDs[i] = gpus[i].DeviceID
+	}
+	return gpuIDs, nil
 }
 func (s *mockLlm) Ping(ctx context.Context) error             { return s.pingResp }
 func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp }
@@ -732,7 +794,11 @@ func (s *mockLlm) Close() error {
 	s.closeCalled = true
 	return s.closeResp
 }
-func (s *mockLlm) VRAMSize() uint64              { return s.vramSize }
-func (s *mockLlm) TotalSize() uint64             { return s.totalSize }
-func (s *mockLlm) VRAMByGPU(gpuid string) uint64 { return s.vramByGPU[gpuid] }
-func (s *mockLlm) Pid() int                      { return -1 }
+func (s *mockLlm) VRAMSize() uint64                                   { return s.vramSize }
+func (s *mockLlm) TotalSize() uint64                                  { return s.totalSize }
+func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64                    { return s.vramByGPU[id] }
+func (s *mockLlm) Pid() int                                           { return -1 }
+func (s *mockLlm) GetPort() int                                       { return -1 }
+func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
+func (s *mockLlm) HasExited() bool                                    { return false }
+func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID                  { return nil }
diff --git a/template/template.go b/template/template.go
index f2775b91..c90190d7 100644
--- a/template/template.go
+++ b/template/template.go
@@ -148,7 +148,12 @@ func Parse(s string) (*Template, error) {
 	}
 
 	t := Template{Template: tmpl, raw: s}
-	if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
+	vars, err := t.Vars()
+	if err != nil {
+		return nil, err
+	}
+
+	if !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
 		// touch up the template and append {{ .Response }}
 		tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
 	}
@@ -160,11 +165,15 @@ func (t *Template) String() string {
 	return t.raw
 }
 
-func (t *Template) Vars() []string {
+func (t *Template) Vars() ([]string, error) {
 	var vars []string
 	for _, tt := range t.Templates() {
 		for _, n := range tt.Root.Nodes {
-			vars = append(vars, Identifiers(n)...)
+			v, err := Identifiers(n)
+			if err != nil {
+				return vars, err
+			}
+			vars = append(vars, v...)
 		}
 	}
 
@@ -173,7 +182,7 @@ func (t *Template) Vars() []string {
 		set[strings.ToLower(n)] = struct{}{}
 	}
 
-	return slices.Sorted(maps.Keys(set))
+	return slices.Sorted(maps.Keys(set)), nil
 }
 
 func (t *Template) Contains(s string) bool {
@@ -244,6 +253,10 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
 
 func (t *Template) Execute(w io.Writer, v Values) error {
 	system, messages := collate(v.Messages)
+	vars, err := t.Vars()
+	if err != nil {
+		return err
+	}
 	if v.Prompt != "" && v.Suffix != "" {
 		return t.Template.Execute(w, map[string]any{
 			"Prompt":     v.Prompt,
@@ -253,7 +266,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 			"ThinkLevel": v.ThinkLevel,
 			"IsThinkSet": v.IsThinkSet,
 		})
-	} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
+	} else if !v.forceLegacy && slices.Contains(vars, "messages") {
 		return t.Template.Execute(w, map[string]any{
 			"System":     system,
 			"Messages":   messages,
@@ -329,7 +342,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 		return err
 	}
 
-	_, err := io.Copy(w, &b)
+	_, err = io.Copy(w, &b)
 	return err
 }
 
@@ -358,27 +371,47 @@ func collate(msgs []api.Message) (string, []*api.Message) {
 }
 
 // Identifiers walks the node tree returning any identifiers it finds along the way
-func Identifiers(n parse.Node) []string {
+func Identifiers(n parse.Node) ([]string, error) {
 	switch n := n.(type) {
 	case *parse.ListNode:
 		var names []string
 		for _, n := range n.Nodes {
-			names = append(names, Identifiers(n)...)
+			i, err := Identifiers(n)
+			if err != nil {
+				return names, err
+			}
+			names = append(names, i...)
 		}
 
-		return names
+		return names, nil
 	case *parse.TemplateNode:
+		if n.Pipe == nil {
+			return nil, errors.New("undefined template specified")
+		}
 		return Identifiers(n.Pipe)
 	case *parse.ActionNode:
+		if n.Pipe == nil {
+			return nil, errors.New("undefined action in template")
+		}
 		return Identifiers(n.Pipe)
 	case *parse.BranchNode:
-		names := Identifiers(n.Pipe)
+		if n.Pipe == nil {
+			return nil, errors.New("undefined branch")
+		}
+		names, err := Identifiers(n.Pipe)
+		if err != nil {
+			return names, err
+		}
 		for _, n := range []*parse.ListNode{n.List, n.ElseList} {
 			if n != nil {
-				names = append(names, Identifiers(n)...)
+				i, err := Identifiers(n)
+				if err != nil {
+					return names, err
+				}
+				names = append(names, i...)
 			}
 		}
-		return names
+		return names, nil
 	case *parse.IfNode:
 		return Identifiers(&n.BranchNode)
 	case *parse.RangeNode:
@@ -389,17 +422,21 @@ func Identifiers(n parse.Node) []string {
 		var names []string
 		for _, c := range n.Cmds {
 			for _, a := range c.Args {
-				names = append(names, Identifiers(a)...)
+				i, err := Identifiers(a)
+				if err != nil {
+					return names, err
+				}
+				names = append(names, i...)
 			}
 		}
-		return names
+		return names, nil
 	case *parse.FieldNode:
-		return n.Ident
+		return n.Ident, nil
 	case *parse.VariableNode:
-		return n.Ident
+		return n.Ident, nil
 	}
 
-	return nil
+	return nil, nil
 }
 
 // deleteNode walks the node list and deletes nodes that match the predicate
diff --git a/template/template_test.go b/template/template_test.go
index 3d4eb991..45101e5a 100644
--- a/template/template_test.go
+++ b/template/template_test.go
@@ -154,24 +154,55 @@ func TestTemplate(t *testing.T) {
 }
 
 func TestParse(t *testing.T) {
-	cases := []struct {
+	validCases := []struct {
+		name     string
 		template string
 		vars     []string
 	}{
-		{"{{ .Prompt }}", []string{"prompt", "response"}},
-		{"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
-		{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
-		{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
-		{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
-		{"{{ range .Messages }}{{ if eq .Role \"tool\" }}Tool Result: {{ .ToolName }} {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role", "toolname"}},
-		{`{{- range .Messages }}
+		{
+			name:     "PromptOnly",
+			template: "{{ .Prompt }}",
+			vars:     []string{"prompt", "response"},
+		},
+		{
+			name:     "SystemAndPrompt",
+			template: "{{ .System }} {{ .Prompt }}",
+			vars:     []string{"prompt", "response", "system"},
+		},
+		{
+			name:     "PromptResponseSystem",
+			template: "{{ .System }} {{ .Prompt }} {{ .Response }}",
+			vars:     []string{"prompt", "response", "system"},
+		},
+		{
+			name:     "ToolsBlock",
+			template: "{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}",
+			vars:     []string{"prompt", "response", "system", "tools"},
+		},
+		{
+			name:     "MessagesRange",
+			template: "{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}",
+			vars:     []string{"content", "messages", "role"},
+		},
+		{
+			name:     "ToolResultConditional",
+			template: "{{ range .Messages }}{{ if eq .Role \"tool\" }}Tool Result: {{ .ToolName }} {{ .Content }}{{ end }}{{ end }}",
+			vars:     []string{"content", "messages", "role", "toolname"},
+		},
+		{
+			name: "MultilineSystemUserAssistant",
+			template: `{{- range .Messages }}
 {{- if eq .Role "system" }}SYSTEM:
 {{- else if eq .Role "user" }}USER:
 {{- else if eq .Role "assistant" }}ASSISTANT:
-{{- else if eq .Role "tool" }}TOOL: 
+{{- else if eq .Role "tool" }}TOOL:
 {{- end }} {{ .Content }}
-{{- end }}`, []string{"content", "messages", "role"}},
-		{`{{- if .Messages }}
+{{- end }}`,
+			vars: []string{"content", "messages", "role"},
+		},
+		{
+			name: "ChatMLLike",
+			template: `{{- if .Messages }}
 {{- range .Messages }}<|im_start|>{{ .Role }}
 {{ .Content }}<|im_end|>
 {{ end }}<|im_start|>assistant
@@ -182,18 +213,60 @@ func TestParse(t *testing.T) {
 {{ .Prompt }}<|im_end|>
 {{ end }}<|im_start|>assistant
 {{ .Response }}<|im_end|>
-{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
+{{- end -}}`,
+			vars: []string{"content", "messages", "prompt", "response", "role", "system"},
+		},
 	}
 
-	for _, tt := range cases {
-		t.Run("", func(t *testing.T) {
+	for _, tt := range validCases {
+		tt := tt
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+
 			tmpl, err := Parse(tt.template)
 			if err != nil {
-				t.Fatal(err)
+				t.Fatalf("Parse returned unexpected error: %v", err)
 			}
 
-			if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" {
-				t.Errorf("mismatch (-got +want):\n%s", diff)
+			gotVars, err := tmpl.Vars()
+			if err != nil {
+				t.Fatalf("Vars returned unexpected error: %v", err)
+			}
+
+			if diff := cmp.Diff(gotVars, tt.vars); diff != "" {
+				t.Errorf("Vars mismatch (-got +want):\n%s", diff)
+			}
+		})
+	}
+}
+
+func TestParseError(t *testing.T) {
+	invalidCases := []struct {
+		name     string
+		template string
+		errorStr string
+	}{
+		{
+			"TemplateNotClosed",
+			"{{ .Prompt ",
+			"unclosed action",
+		},
+		{
+			"Template",
+			`{{define "x"}}{{template "x"}}{{end}}{{template "x"}}`,
+			"undefined template specified",
+		},
+	}
+
+	for _, tt := range invalidCases {
+		t.Run(tt.name, func(t *testing.T) {
+			_, err := Parse(tt.template)
+			if err == nil {
+				t.Fatalf("expected Parse to return an error for an invalid template, got nil")
+			}
+
+			if !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.errorStr)) {
+				t.Errorf("unexpected error message.\n got: %q\n want substring (case‑insensitive): %q", err.Error(), tt.errorStr)
 			}
 		})
 	}