Merge branch 'ollama:main' into main
40
.github/workflows/release.yaml
vendored
@@ -104,6 +104,13 @@ jobs:
|
|||||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||||
rocm-version: '6.2'
|
rocm-version: '6.2'
|
||||||
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||||
|
runner_dir: 'rocm'
|
||||||
|
- os: windows
|
||||||
|
arch: amd64
|
||||||
|
preset: Vulkan
|
||||||
|
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||||
|
flags: ''
|
||||||
|
runner_dir: 'vulkan'
|
||||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||||
environment: release
|
environment: release
|
||||||
env:
|
env:
|
||||||
@@ -113,13 +120,14 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
choco install -y --no-progress ccache ninja
|
choco install -y --no-progress ccache ninja
|
||||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ')
|
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
|
||||||
id: cache-install
|
id: cache-install
|
||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
key: ${{ matrix.install }}
|
||||||
- if: startsWith(matrix.preset, 'CUDA ')
|
- if: startsWith(matrix.preset, 'CUDA ')
|
||||||
name: Install CUDA ${{ matrix.cuda-version }}
|
name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
@@ -149,6 +157,18 @@ jobs:
|
|||||||
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
- if: matrix.preset == 'Vulkan'
|
||||||
|
name: Install Vulkan ${{ matrix.rocm-version }}
|
||||||
|
run: |
|
||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
|
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||||
|
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
|
||||||
|
}
|
||||||
|
|
||||||
|
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
||||||
|
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
|
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
||||||
- if: matrix.preset == 'CPU'
|
- if: matrix.preset == 'CPU'
|
||||||
run: |
|
run: |
|
||||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
@@ -159,6 +179,7 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
key: ${{ matrix.install }}
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
@@ -171,7 +192,7 @@ jobs:
|
|||||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
||||||
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
||||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip
|
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
||||||
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||||
env:
|
env:
|
||||||
CMAKE_GENERATOR: Ninja
|
CMAKE_GENERATOR: Ninja
|
||||||
@@ -312,13 +333,13 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- os: linux
|
- os: linux
|
||||||
arch: amd64
|
arch: amd64
|
||||||
target: archive_novulkan
|
target: archive
|
||||||
- os: linux
|
- os: linux
|
||||||
arch: amd64
|
arch: amd64
|
||||||
target: rocm
|
target: rocm
|
||||||
- os: linux
|
- os: linux
|
||||||
arch: arm64
|
arch: arm64
|
||||||
target: archive_novulkan
|
target: archive
|
||||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||||
environment: release
|
environment: release
|
||||||
needs: setup-environment
|
needs: setup-environment
|
||||||
@@ -374,14 +395,12 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- os: linux
|
- os: linux
|
||||||
arch: arm64
|
arch: arm64
|
||||||
target: novulkan
|
|
||||||
build-args: |
|
build-args: |
|
||||||
CGO_CFLAGS
|
CGO_CFLAGS
|
||||||
CGO_CXXFLAGS
|
CGO_CXXFLAGS
|
||||||
GOFLAGS
|
GOFLAGS
|
||||||
- os: linux
|
- os: linux
|
||||||
arch: amd64
|
arch: amd64
|
||||||
target: novulkan
|
|
||||||
build-args: |
|
build-args: |
|
||||||
CGO_CFLAGS
|
CGO_CFLAGS
|
||||||
CGO_CXXFLAGS
|
CGO_CXXFLAGS
|
||||||
@@ -394,14 +413,6 @@ jobs:
|
|||||||
CGO_CXXFLAGS
|
CGO_CXXFLAGS
|
||||||
GOFLAGS
|
GOFLAGS
|
||||||
FLAVOR=rocm
|
FLAVOR=rocm
|
||||||
- os: linux
|
|
||||||
arch: amd64
|
|
||||||
suffix: '-vulkan'
|
|
||||||
target: default
|
|
||||||
build-args: |
|
|
||||||
CGO_CFLAGS
|
|
||||||
CGO_CXXFLAGS
|
|
||||||
GOFLAGS
|
|
||||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||||
environment: release
|
environment: release
|
||||||
needs: setup-environment
|
needs: setup-environment
|
||||||
@@ -419,7 +430,6 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
platforms: ${{ matrix.os }}/${{ matrix.arch }}
|
platforms: ${{ matrix.os }}/${{ matrix.arch }}
|
||||||
target: ${{ matrix.preset }}
|
|
||||||
build-args: ${{ matrix.build-args }}
|
build-args: ${{ matrix.build-args }}
|
||||||
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
|
||||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||||
|
|||||||
1
.github/workflows/test.yaml
vendored
@@ -172,6 +172,7 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
C:\Program Files\AMD\ROCm
|
C:\Program Files\AMD\ROCm
|
||||||
|
C:\VulkanSDK
|
||||||
key: ${{ matrix.install }}
|
key: ${{ matrix.install }}
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
|
|||||||
27
Dockerfile
@@ -159,32 +159,7 @@ ARG VULKANVERSION
|
|||||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||||
COPY --from=build /bin/ollama /bin/ollama
|
COPY --from=build /bin/ollama /bin/ollama
|
||||||
|
|
||||||
# Temporary opt-out stages for Vulkan
|
FROM ubuntu:24.04
|
||||||
FROM --platform=linux/amd64 scratch AS amd64_novulkan
|
|
||||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
|
||||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
|
||||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
|
||||||
FROM arm64 AS arm64_novulkan
|
|
||||||
FROM ${FLAVOR}_novulkan AS archive_novulkan
|
|
||||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
|
||||||
COPY --from=build /bin/ollama /bin/ollama
|
|
||||||
FROM ubuntu:24.04 AS novulkan
|
|
||||||
RUN apt-get update \
|
|
||||||
&& apt-get install -y ca-certificates \
|
|
||||||
&& apt-get clean \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
COPY --from=archive_novulkan /bin /usr/bin
|
|
||||||
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
|
||||||
COPY --from=archive_novulkan /lib/ollama /usr/lib/ollama
|
|
||||||
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
|
||||||
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
|
||||||
ENV NVIDIA_VISIBLE_DEVICES=all
|
|
||||||
ENV OLLAMA_HOST=0.0.0.0:11434
|
|
||||||
EXPOSE 11434
|
|
||||||
ENTRYPOINT ["/bin/ollama"]
|
|
||||||
CMD ["serve"]
|
|
||||||
|
|
||||||
FROM ubuntu:24.04 AS default
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y ca-certificates libvulkan1 \
|
&& apt-get install -y ca-certificates libvulkan1 \
|
||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
|
|||||||
@@ -321,6 +321,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [LibreChat](https://github.com/danny-avila/LibreChat)
|
- [LibreChat](https://github.com/danny-avila/LibreChat)
|
||||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||||
|
- [AI-UI](https://github.com/bajahaw/ai-ui)
|
||||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||||
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
||||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||||
@@ -387,7 +388,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
|
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
|
||||||
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
|
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
|
||||||
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
||||||
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
|
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VS Code extension for multi-file/whole-repo coding
|
||||||
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
||||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
||||||
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
||||||
@@ -419,7 +420,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
|
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
|
||||||
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
|
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
|
||||||
- [Ollama Chat WebUI for Docker ](https://github.com/oslook/ollama-webui) (Support for local docker deployment, lightweight ollama webui)
|
- [Ollama Chat WebUI for Docker ](https://github.com/oslook/ollama-webui) (Support for local docker deployment, lightweight ollama webui)
|
||||||
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VSCode extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
|
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VS Code extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
|
||||||
- [MinimalNextOllamaChat](https://github.com/anilkay/MinimalNextOllamaChat) (Minimal Web UI for Chat and Model Control)
|
- [MinimalNextOllamaChat](https://github.com/anilkay/MinimalNextOllamaChat) (Minimal Web UI for Chat and Model Control)
|
||||||
- [Chipper](https://github.com/TilmanGriesel/chipper) AI interface for tinkerers (Ollama, Haystack RAG, Python)
|
- [Chipper](https://github.com/TilmanGriesel/chipper) AI interface for tinkerers (Ollama, Haystack RAG, Python)
|
||||||
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
|
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
|
||||||
@@ -662,5 +663,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
|
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
|
||||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||||
|
|
||||||
## Security
|
### Security
|
||||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
||||||
|
|||||||
45
api/types.go
@@ -117,6 +117,14 @@ type GenerateRequest struct {
|
|||||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||||
// template instead of calling the model.
|
// template instead of calling the model.
|
||||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||||
|
|
||||||
|
// Logprobs specifies whether to return log probabilities of the output tokens.
|
||||||
|
Logprobs bool `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
|
// TopLogprobs is the number of most likely tokens to return at each token position,
|
||||||
|
// each with an associated log probability. Only applies when Logprobs is true.
|
||||||
|
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
|
||||||
|
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatRequest describes a request sent by [Client.Chat].
|
// ChatRequest describes a request sent by [Client.Chat].
|
||||||
@@ -159,6 +167,14 @@ type ChatRequest struct {
|
|||||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||||
// template instead of calling the model.
|
// template instead of calling the model.
|
||||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||||
|
|
||||||
|
// Logprobs specifies whether to return log probabilities of the output tokens.
|
||||||
|
Logprobs bool `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
|
// TopLogprobs is the number of most likely tokens to return at each token position,
|
||||||
|
// each with an associated log probability. Only applies when Logprobs is true.
|
||||||
|
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
|
||||||
|
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tools []Tool
|
type Tools []Tool
|
||||||
@@ -343,6 +359,27 @@ func (t *ToolFunction) String() string {
|
|||||||
return string(bts)
|
return string(bts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TokenLogprob represents log probability information for a single token alternative.
|
||||||
|
type TokenLogprob struct {
|
||||||
|
// Token is the text representation of the token.
|
||||||
|
Token string `json:"token"`
|
||||||
|
|
||||||
|
// Logprob is the log probability of this token.
|
||||||
|
Logprob float64 `json:"logprob"`
|
||||||
|
|
||||||
|
// Bytes contains the raw byte representation of the token
|
||||||
|
Bytes []int `json:"bytes,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logprob contains log probability information for a generated token.
|
||||||
|
type Logprob struct {
|
||||||
|
TokenLogprob
|
||||||
|
|
||||||
|
// TopLogprobs contains the most likely tokens and their log probabilities
|
||||||
|
// at this position, if requested via TopLogprobs parameter.
|
||||||
|
TopLogprobs []TokenLogprob `json:"top_logprobs,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||||
// similar to [GenerateResponse].
|
// similar to [GenerateResponse].
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
@@ -369,6 +406,10 @@ type ChatResponse struct {
|
|||||||
|
|
||||||
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||||
|
|
||||||
|
// Logprobs contains log probability information for the generated tokens,
|
||||||
|
// if requested via the Logprobs parameter.
|
||||||
|
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
Metrics
|
Metrics
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -677,6 +718,10 @@ type GenerateResponse struct {
|
|||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
|
||||||
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||||
|
|
||||||
|
// Logprobs contains log probability information for the generated tokens,
|
||||||
|
// if requested via the Logprobs parameter.
|
||||||
|
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelDetails provides details about a model.
|
// ModelDetails provides details about a model.
|
||||||
|
|||||||
@@ -48,16 +48,6 @@ The `-dev` flag enables:
|
|||||||
- CORS headers for cross-origin requests
|
- CORS headers for cross-origin requests
|
||||||
- Hot-reload support for UI development
|
- Hot-reload support for UI development
|
||||||
|
|
||||||
#### Run Storybook
|
|
||||||
|
|
||||||
Inside the `ui/app` directory, run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
npm run storybook
|
|
||||||
```
|
|
||||||
|
|
||||||
For now we're writing stories as siblings of the component they're testing. So for example, `src/components/Message.stories.tsx` is the story for `src/components/Message.tsx`.
|
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
1512
app/ui/app/package-lock.json
generated
@@ -34,6 +34,7 @@
|
|||||||
"rehype-raw": "^7.0.0",
|
"rehype-raw": "^7.0.0",
|
||||||
"rehype-sanitize": "^6.0.0",
|
"rehype-sanitize": "^6.0.0",
|
||||||
"remark-math": "^6.0.0",
|
"remark-math": "^6.0.0",
|
||||||
|
"streamdown": "^1.4.0",
|
||||||
"unist-builder": "^4.0.0",
|
"unist-builder": "^4.0.0",
|
||||||
"unist-util-parents": "^3.0.0"
|
"unist-util-parents": "^3.0.0"
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -205,6 +205,13 @@ export async function* sendMessage(
|
|||||||
data: uint8ArrayToBase64(att.data),
|
data: uint8ArrayToBase64(att.data),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
// Only send think parameter when actually requesting thinking
|
||||||
|
// Don't send false as it causes issues with some providers
|
||||||
|
const shouldSendThink =
|
||||||
|
think !== undefined &&
|
||||||
|
((typeof think === "boolean" && think) ||
|
||||||
|
(typeof think === "string" && think !== ""));
|
||||||
|
|
||||||
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
|
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
@@ -222,7 +229,7 @@ export async function* sendMessage(
|
|||||||
web_search: webSearch ?? false,
|
web_search: webSearch ?? false,
|
||||||
file_tools: fileTools ?? false,
|
file_tools: fileTools ?? false,
|
||||||
...(forceUpdate !== undefined ? { forceUpdate } : {}),
|
...(forceUpdate !== undefined ? { forceUpdate } : {}),
|
||||||
...(think !== undefined ? { think } : {}),
|
...(shouldSendThink ? { think } : {}),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
signal,
|
signal,
|
||||||
|
|||||||
@@ -1,522 +0,0 @@
|
|||||||
import { expect, test, suite } from "vitest";
|
|
||||||
import { processStreamingMarkdown } from "@/utils/processStreamingMarkdown";
|
|
||||||
|
|
||||||
suite("common llm outputs that cause issues", () => {
|
|
||||||
test("prefix of bolded list item shouldn't make a horizontal line", () => {
|
|
||||||
// we're going to go in order of incrementally adding characters. This
|
|
||||||
// happens really commonly with LLMs that like to make lists like so:
|
|
||||||
//
|
|
||||||
// * **point 1**: explanatory text
|
|
||||||
// * **point 2**: more explanatory text
|
|
||||||
//
|
|
||||||
// Partial rendering of `*` (A), followed by `* *` (B), followed by `* **`
|
|
||||||
// (C) is a total mess. (A) renders as a single bullet point in an
|
|
||||||
// otherwise empty list, (B) renders as two nested lists (and therefore
|
|
||||||
// two bullet points, styled differently by default in html), and (C)
|
|
||||||
// renders as a horizontal line because in markdown apparently `***` or `*
|
|
||||||
// * *` horizontal rules don't have as strict whitespace rules as I
|
|
||||||
// expected them to
|
|
||||||
|
|
||||||
// these are alone (i.e., they would be the first list item)
|
|
||||||
expect(processStreamingMarkdown("*")).toBe("");
|
|
||||||
expect(processStreamingMarkdown("* *")).toBe("");
|
|
||||||
expect(processStreamingMarkdown("* **")).toBe("");
|
|
||||||
// expect(processStreamingMarkdown("* **b")).toBe("* **b**");
|
|
||||||
|
|
||||||
// with a list item before them
|
|
||||||
expect(
|
|
||||||
processStreamingMarkdown(
|
|
||||||
// prettier-ignore
|
|
||||||
[
|
|
||||||
"* abc",
|
|
||||||
"*"
|
|
||||||
].join("\n"),
|
|
||||||
),
|
|
||||||
).toBe("* abc");
|
|
||||||
|
|
||||||
expect(
|
|
||||||
processStreamingMarkdown(
|
|
||||||
// prettier-ignore
|
|
||||||
[
|
|
||||||
"* abc",
|
|
||||||
"* *"
|
|
||||||
].join("\n"),
|
|
||||||
),
|
|
||||||
).toBe("* abc");
|
|
||||||
|
|
||||||
expect(
|
|
||||||
processStreamingMarkdown(
|
|
||||||
// prettier-ignore
|
|
||||||
[
|
|
||||||
"* abc",
|
|
||||||
"* **"
|
|
||||||
].join("\n"),
|
|
||||||
),
|
|
||||||
).toBe("* abc");
|
|
||||||
});
|
|
||||||
|
|
||||||
test("bolded list items with text should be rendered properly", () => {
|
|
||||||
expect(processStreamingMarkdown("* **abc**")).toBe("* **abc**");
|
|
||||||
});
|
|
||||||
|
|
||||||
test("partially bolded list items should be autoclosed", () => {
|
|
||||||
expect(processStreamingMarkdown("* **abc")).toBe("* **abc**");
|
|
||||||
});
|
|
||||||
|
|
||||||
suite(
|
|
||||||
"partially bolded list items should be autoclosed, even if the last node isn't a text node",
|
|
||||||
() => {
|
|
||||||
test("inline code", () => {
|
|
||||||
expect(
|
|
||||||
processStreamingMarkdown("* **Asynchronous Function `async`*"),
|
|
||||||
).toBe("* **Asynchronous Function `async`**");
|
|
||||||
});
|
|
||||||
},
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
suite("autoclosing bold", () => {
|
|
||||||
suite("endings with no asterisks", () => {
|
|
||||||
test("should autoclose bold", () => {
|
|
||||||
expect(processStreamingMarkdown("**abc")).toBe("**abc**");
|
|
||||||
expect(processStreamingMarkdown("abc **abc")).toBe("abc **abc**");
|
|
||||||
});
|
|
||||||
|
|
||||||
suite("should autoclose, even if the last node isn't a text node", () => {
|
|
||||||
test("inline code", () => {
|
|
||||||
expect(
|
|
||||||
processStreamingMarkdown("* **Asynchronous Function `async`"),
|
|
||||||
).toBe("* **Asynchronous Function `async`**");
|
|
||||||
});
|
|
||||||
|
|
||||||
test("opening ** is at the end of the text", () => {
|
|
||||||
expect(processStreamingMarkdown("abc **`def` jhk [lmn](opq)")).toBe(
|
|
||||||
"abc **`def` jhk [lmn](opq)**",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
test("if there's a space after the **, it should NOT be autoclosed", () => {
|
|
||||||
expect(processStreamingMarkdown("abc ** `def` jhk [lmn](opq)")).toBe(
|
|
||||||
"abc \\*\\* `def` jhk [lmn](opq)",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
test("should autoclose bold, even if the last node isn't a text node", () => {
|
|
||||||
expect(
|
|
||||||
processStreamingMarkdown("* **Asynchronous Function ( `async`"),
|
|
||||||
).toBe("* **Asynchronous Function ( `async`**");
|
|
||||||
});
|
|
||||||
|
|
||||||
test("whitespace fakeouts should not be modified", () => {
|
|
||||||
expect(processStreamingMarkdown("** abc")).toBe("\\*\\* abc");
|
|
||||||
});
|
|
||||||
|
|
||||||
// TODO(drifkin): arguably this should just be removed entirely, but empty
|
|
||||||
// isn't so bad
|
|
||||||
test("should handle empty bolded items", () => {
|
|
||||||
expect(processStreamingMarkdown("**")).toBe("");
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
suite("partially closed bolded items", () => {
|
|
||||||
test("simple partial", () => {
|
|
||||||
expect(processStreamingMarkdown("**abc*")).toBe("**abc**");
|
|
||||||
});
|
|
||||||
|
|
||||||
test("partial with non-text node at end", () => {
|
|
||||||
expect(processStreamingMarkdown("**abc`def`*")).toBe("**abc`def`**");
|
|
||||||
});
|
|
||||||
|
|
||||||
test("partial with multiply nested ending nodes", () => {
|
|
||||||
expect(processStreamingMarkdown("**abc[abc](`def`)*")).toBe(
|
|
||||||
"**abc[abc](`def`)**",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
test("normal emphasis should not be affected", () => {
|
|
||||||
expect(processStreamingMarkdown("*abc*")).toBe("*abc*");
|
|
||||||
});
|
|
||||||
|
|
||||||
test("normal emphasis with nested code should not be affected", () => {
|
|
||||||
expect(processStreamingMarkdown("*`abc`*")).toBe("*`abc`*");
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
test.skip("shouldn't autoclose immediately if there's a space before the closing *", () => {
|
|
||||||
expect(processStreamingMarkdown("**abc *")).toBe("**abc**");
|
|
||||||
});
|
|
||||||
|
|
||||||
// skipping for now because this requires partial link completion as well
|
|
||||||
suite.skip("nested blocks that each need autoclosing", () => {
|
|
||||||
test("emph nested in link nested in strong nested in list item", () => {
|
|
||||||
expect(processStreamingMarkdown("* **[abc **def")).toBe(
|
|
||||||
"* **[abc **def**]()**",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
test("* **[ab *`def`", () => {
|
|
||||||
expect(processStreamingMarkdown("* **[ab *`def`")).toBe(
|
|
||||||
"* **[ab *`def`*]()**",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
suite("numbered list items", () => {
|
|
||||||
test("should remove trailing numbers", () => {
|
|
||||||
expect(processStreamingMarkdown("1. First\n2")).toBe("1. First");
|
|
||||||
});
|
|
||||||
|
|
||||||
test("should remove trailing numbers with breaks before", () => {
|
|
||||||
expect(processStreamingMarkdown("1. First \n2")).toBe("1. First");
|
|
||||||
});
|
|
||||||
|
|
||||||
test("should remove trailing numbers that form a new paragraph", () => {
|
|
||||||
expect(processStreamingMarkdown("1. First\n\n2")).toBe("1. First");
|
|
||||||
});
|
|
||||||
|
|
||||||
test("but should leave list items separated by two newlines", () => {
|
|
||||||
expect(processStreamingMarkdown("1. First\n\n2. S")).toBe(
|
|
||||||
"1. First\n\n2. S",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
// TODO(drifkin):slop tests ahead, some are decent, but need to manually go
|
|
||||||
// through them as I implement
|
|
||||||
/*
|
|
||||||
describe("StreamingMarkdownContent - processStreamingMarkdown", () => {
|
|
||||||
describe("Ambiguous endings removal", () => {
|
|
||||||
it("should remove list markers at the end", () => {
|
|
||||||
expect(processStreamingMarkdown("Some text\n* ")).toBe("Some text");
|
|
||||||
expect(processStreamingMarkdown("Some text\n*")).toBe("Some text");
|
|
||||||
expect(processStreamingMarkdown("* Item 1\n- ")).toBe("* Item 1");
|
|
||||||
expect(processStreamingMarkdown("* Item 1\n-")).toBe("* Item 1");
|
|
||||||
expect(processStreamingMarkdown("Text\n+ ")).toBe("Text");
|
|
||||||
expect(processStreamingMarkdown("Text\n+")).toBe("Text");
|
|
||||||
expect(processStreamingMarkdown("1. First\n2. ")).toBe("1. First");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should remove heading markers at the end", () => {
|
|
||||||
expect(processStreamingMarkdown("Some text\n# ")).toBe("Some text");
|
|
||||||
expect(processStreamingMarkdown("Some text\n#")).toBe("Some text\n#"); // # without space is not removed
|
|
||||||
expect(processStreamingMarkdown("# Title\n## ")).toBe("# Title");
|
|
||||||
expect(processStreamingMarkdown("# Title\n##")).toBe("# Title\n##"); // ## without space is not removed
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should remove ambiguous bold markers at the end", () => {
|
|
||||||
expect(processStreamingMarkdown("Text **")).toBe("Text ");
|
|
||||||
expect(processStreamingMarkdown("Some text\n**")).toBe("Some text");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should remove code block markers at the end", () => {
|
|
||||||
expect(processStreamingMarkdown("Text\n```")).toBe("Text");
|
|
||||||
expect(processStreamingMarkdown("```")).toBe("");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should remove single backtick at the end", () => {
|
|
||||||
expect(processStreamingMarkdown("Text `")).toBe("Text ");
|
|
||||||
expect(processStreamingMarkdown("`")).toBe("");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should remove single asterisk at the end", () => {
|
|
||||||
expect(processStreamingMarkdown("Text *")).toBe("Text ");
|
|
||||||
expect(processStreamingMarkdown("*")).toBe("");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle empty content", () => {
|
|
||||||
expect(processStreamingMarkdown("")).toBe("");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle single line removals correctly", () => {
|
|
||||||
expect(processStreamingMarkdown("* ")).toBe("");
|
|
||||||
expect(processStreamingMarkdown("# ")).toBe("");
|
|
||||||
expect(processStreamingMarkdown("**")).toBe("");
|
|
||||||
expect(processStreamingMarkdown("`")).toBe("");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("shouldn't have this regexp capture group bug", () => {
|
|
||||||
expect(
|
|
||||||
processStreamingMarkdown("Here's a shopping list:\n*"),
|
|
||||||
).not.toContain("0*");
|
|
||||||
expect(processStreamingMarkdown("Here's a shopping list:\n*")).toBe(
|
|
||||||
"Here's a shopping list:",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("List markers", () => {
|
|
||||||
it("should preserve complete list items", () => {
|
|
||||||
expect(processStreamingMarkdown("* Complete item")).toBe(
|
|
||||||
"* Complete item",
|
|
||||||
);
|
|
||||||
expect(processStreamingMarkdown("- Another item")).toBe("- Another item");
|
|
||||||
expect(processStreamingMarkdown("+ Plus item")).toBe("+ Plus item");
|
|
||||||
expect(processStreamingMarkdown("1. Numbered item")).toBe(
|
|
||||||
"1. Numbered item",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle indented list markers", () => {
|
|
||||||
expect(processStreamingMarkdown(" * ")).toBe(" ");
|
|
||||||
expect(processStreamingMarkdown(" - ")).toBe(" ");
|
|
||||||
expect(processStreamingMarkdown("\t+ ")).toBe("\t");
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("Heading markers", () => {
|
|
||||||
it("should preserve complete headings", () => {
|
|
||||||
expect(processStreamingMarkdown("# Complete Heading")).toBe(
|
|
||||||
"# Complete Heading",
|
|
||||||
);
|
|
||||||
expect(processStreamingMarkdown("## Subheading")).toBe("## Subheading");
|
|
||||||
expect(processStreamingMarkdown("### H3 Title")).toBe("### H3 Title");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should not affect # in other contexts", () => {
|
|
||||||
expect(processStreamingMarkdown("C# programming")).toBe("C# programming");
|
|
||||||
expect(processStreamingMarkdown("Issue #123")).toBe("Issue #123");
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("Bold text", () => {
|
|
||||||
it("should close incomplete bold text", () => {
|
|
||||||
expect(processStreamingMarkdown("This is **bold text")).toBe(
|
|
||||||
"This is **bold text**",
|
|
||||||
);
|
|
||||||
expect(processStreamingMarkdown("Start **bold and more")).toBe(
|
|
||||||
"Start **bold and more**",
|
|
||||||
);
|
|
||||||
expect(processStreamingMarkdown("**just bold")).toBe("**just bold**");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should not affect complete bold text", () => {
|
|
||||||
expect(processStreamingMarkdown("**complete bold**")).toBe(
|
|
||||||
"**complete bold**",
|
|
||||||
);
|
|
||||||
expect(processStreamingMarkdown("Text **bold** more")).toBe(
|
|
||||||
"Text **bold** more",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle nested bold correctly", () => {
|
|
||||||
expect(processStreamingMarkdown("**bold** and **another")).toBe(
|
|
||||||
"**bold** and **another**",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("Italic text", () => {
|
|
||||||
it("should close incomplete italic text", () => {
|
|
||||||
expect(processStreamingMarkdown("This is *italic text")).toBe(
|
|
||||||
"This is *italic text*",
|
|
||||||
);
|
|
||||||
expect(processStreamingMarkdown("Start *italic and more")).toBe(
|
|
||||||
"Start *italic and more*",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should differentiate between list markers and italic", () => {
|
|
||||||
expect(processStreamingMarkdown("* Item\n* ")).toBe("* Item");
|
|
||||||
expect(processStreamingMarkdown("Some *italic text")).toBe(
|
|
||||||
"Some *italic text*",
|
|
||||||
);
|
|
||||||
expect(processStreamingMarkdown("*just italic")).toBe("*just italic*");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should not affect complete italic text", () => {
|
|
||||||
expect(processStreamingMarkdown("*complete italic*")).toBe(
|
|
||||||
"*complete italic*",
|
|
||||||
);
|
|
||||||
expect(processStreamingMarkdown("Text *italic* more")).toBe(
|
|
||||||
"Text *italic* more",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("Code blocks", () => {
|
|
||||||
it("should close incomplete code blocks", () => {
|
|
||||||
expect(processStreamingMarkdown("```javascript\nconst x = 42;")).toBe(
|
|
||||||
"```javascript\nconst x = 42;\n```",
|
|
||||||
);
|
|
||||||
expect(processStreamingMarkdown("```\ncode here")).toBe(
|
|
||||||
"```\ncode here\n```",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should not affect complete code blocks", () => {
|
|
||||||
expect(processStreamingMarkdown("```\ncode\n```")).toBe("```\ncode\n```");
|
|
||||||
expect(processStreamingMarkdown("```js\nconst x = 1;\n```")).toBe(
|
|
||||||
"```js\nconst x = 1;\n```",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle nested code blocks correctly", () => {
|
|
||||||
expect(processStreamingMarkdown("```\ncode\n```\n```python")).toBe(
|
|
||||||
"```\ncode\n```\n```python\n```",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should not process markdown inside code blocks", () => {
|
|
||||||
expect(processStreamingMarkdown("```\n* not a list\n**not bold**")).toBe(
|
|
||||||
"```\n* not a list\n**not bold**\n```",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("Inline code", () => {
|
|
||||||
it("should close incomplete inline code", () => {
|
|
||||||
expect(processStreamingMarkdown("This is `inline code")).toBe(
|
|
||||||
"This is `inline code`",
|
|
||||||
);
|
|
||||||
expect(processStreamingMarkdown("Use `console.log")).toBe(
|
|
||||||
"Use `console.log`",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should not affect complete inline code", () => {
|
|
||||||
expect(processStreamingMarkdown("`complete code`")).toBe(
|
|
||||||
"`complete code`",
|
|
||||||
);
|
|
||||||
expect(processStreamingMarkdown("Use `code` here")).toBe(
|
|
||||||
"Use `code` here",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle multiple inline codes correctly", () => {
|
|
||||||
expect(processStreamingMarkdown("`code` and `more")).toBe(
|
|
||||||
"`code` and `more`",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should not confuse inline code with code blocks", () => {
|
|
||||||
expect(processStreamingMarkdown("```\nblock\n```\n`inline")).toBe(
|
|
||||||
"```\nblock\n```\n`inline`",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("Complex streaming scenarios", () => {
|
|
||||||
it("should handle progressive streaming of a heading", () => {
|
|
||||||
const steps = [
|
|
||||||
{ input: "#", expected: "#" }, // # alone is not removed (needs space)
|
|
||||||
{ input: "# ", expected: "" },
|
|
||||||
{ input: "# H", expected: "# H" },
|
|
||||||
{ input: "# Hello", expected: "# Hello" },
|
|
||||||
];
|
|
||||||
steps.forEach(({ input, expected }) => {
|
|
||||||
expect(processStreamingMarkdown(input)).toBe(expected);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle progressive streaming of bold text", () => {
|
|
||||||
const steps = [
|
|
||||||
{ input: "*", expected: "" },
|
|
||||||
{ input: "**", expected: "" },
|
|
||||||
{ input: "**b", expected: "**b**" },
|
|
||||||
{ input: "**bold", expected: "**bold**" },
|
|
||||||
{ input: "**bold**", expected: "**bold**" },
|
|
||||||
];
|
|
||||||
steps.forEach(({ input, expected }) => {
|
|
||||||
expect(processStreamingMarkdown(input)).toBe(expected);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle multiline content with various patterns", () => {
|
|
||||||
const multiline = `# Title
|
|
||||||
|
|
||||||
This is a paragraph with **bold text** and *italic text*.
|
|
||||||
|
|
||||||
* Item 1
|
|
||||||
* Item 2
|
|
||||||
* `;
|
|
||||||
|
|
||||||
const expected = `# Title
|
|
||||||
|
|
||||||
This is a paragraph with **bold text** and *italic text*.
|
|
||||||
|
|
||||||
* Item 1
|
|
||||||
* Item 2`;
|
|
||||||
|
|
||||||
expect(processStreamingMarkdown(multiline)).toBe(expected);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should only fix the last line", () => {
|
|
||||||
expect(processStreamingMarkdown("# Complete\n# Another\n# ")).toBe(
|
|
||||||
"# Complete\n# Another",
|
|
||||||
);
|
|
||||||
expect(processStreamingMarkdown("* Item 1\n* Item 2\n* ")).toBe(
|
|
||||||
"* Item 1\n* Item 2",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle mixed content correctly", () => {
|
|
||||||
const input = `# Header
|
|
||||||
|
|
||||||
This has **bold** text and *italic* text.
|
|
||||||
|
|
||||||
\`\`\`js
|
|
||||||
const x = 42;
|
|
||||||
\`\`\`
|
|
||||||
|
|
||||||
Now some \`inline code\` and **unclosed bold`;
|
|
||||||
|
|
||||||
const expected = `# Header
|
|
||||||
|
|
||||||
This has **bold** text and *italic* text.
|
|
||||||
|
|
||||||
\`\`\`js
|
|
||||||
const x = 42;
|
|
||||||
\`\`\`
|
|
||||||
|
|
||||||
Now some \`inline code\` and **unclosed bold**`;
|
|
||||||
|
|
||||||
expect(processStreamingMarkdown(input)).toBe(expected);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("Edge cases with escaping", () => {
|
|
||||||
it("should handle escaped asterisks (future enhancement)", () => {
|
|
||||||
// Note: Current implementation doesn't handle escaping
|
|
||||||
// This is a known limitation - escaped characters still trigger closing
|
|
||||||
expect(processStreamingMarkdown("Text \\*not italic")).toBe(
|
|
||||||
"Text \\*not italic*",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should handle escaped backticks (future enhancement)", () => {
|
|
||||||
// Note: Current implementation doesn't handle escaping
|
|
||||||
// This is a known limitation - escaped characters still trigger closing
|
|
||||||
expect(processStreamingMarkdown("Text \\`not code")).toBe(
|
|
||||||
"Text \\`not code`",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe("Code block edge cases", () => {
|
|
||||||
it("should handle triple backticks in the middle of lines", () => {
|
|
||||||
expect(processStreamingMarkdown("Text ``` in middle")).toBe(
|
|
||||||
"Text ``` in middle\n```",
|
|
||||||
);
|
|
||||||
expect(processStreamingMarkdown("```\nText ``` in code\nmore")).toBe(
|
|
||||||
"```\nText ``` in code\nmore\n```",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should properly close code blocks with language specifiers", () => {
|
|
||||||
expect(processStreamingMarkdown("```typescript")).toBe(
|
|
||||||
"```typescript\n```",
|
|
||||||
);
|
|
||||||
expect(processStreamingMarkdown("```typescript\nconst x = 1")).toBe(
|
|
||||||
"```typescript\nconst x = 1\n```",
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("should remove a completely empty partial code block", () => {
|
|
||||||
expect(processStreamingMarkdown("```\n")).toBe("");
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
*/
|
|
||||||
@@ -1,66 +1,123 @@
|
|||||||
import React from "react";
|
import React from "react";
|
||||||
import Markdown from "react-markdown";
|
import { Streamdown, defaultRemarkPlugins } from "streamdown";
|
||||||
import remarkGfm from "remark-gfm";
|
|
||||||
import remarkMath from "remark-math";
|
|
||||||
import rehypeRaw from "rehype-raw";
|
|
||||||
import rehypeSanitize, { defaultSchema } from "rehype-sanitize";
|
|
||||||
import rehypePrismPlus from "rehype-prism-plus";
|
|
||||||
import rehypeKatex from "rehype-katex";
|
|
||||||
import remarkStreamingMarkdown, {
|
|
||||||
type LastNodeInfo,
|
|
||||||
} from "@/utils/remarkStreamingMarkdown";
|
|
||||||
import type { PluggableList } from "unified";
|
|
||||||
import remarkCitationParser from "@/utils/remarkCitationParser";
|
import remarkCitationParser from "@/utils/remarkCitationParser";
|
||||||
import CopyButton from "./CopyButton";
|
import CopyButton from "./CopyButton";
|
||||||
|
import type { BundledLanguage } from "shiki";
|
||||||
|
import { highlighter } from "@/lib/highlighter";
|
||||||
|
|
||||||
interface StreamingMarkdownContentProps {
|
interface StreamingMarkdownContentProps {
|
||||||
content: string;
|
content: string;
|
||||||
isStreaming?: boolean;
|
isStreaming?: boolean;
|
||||||
size?: "sm" | "md" | "lg";
|
size?: "sm" | "md" | "lg";
|
||||||
onLastNode?: (info: LastNodeInfo) => void;
|
|
||||||
browserToolResult?: any; // TODO: proper type
|
browserToolResult?: any; // TODO: proper type
|
||||||
}
|
}
|
||||||
|
|
||||||
const CodeBlock = React.memo(
|
// Helper to extract text from React nodes
|
||||||
({ children, className, ...props }: React.HTMLAttributes<HTMLPreElement>) => {
|
const extractText = (node: React.ReactNode): string => {
|
||||||
const extractText = React.useCallback((node: React.ReactNode): string => {
|
|
||||||
if (typeof node === "string") return node;
|
if (typeof node === "string") return node;
|
||||||
if (typeof node === "number") return String(node);
|
if (typeof node === "number") return String(node);
|
||||||
if (!node) return "";
|
if (!node) return "";
|
||||||
|
|
||||||
if (React.isValidElement(node)) {
|
if (React.isValidElement(node)) {
|
||||||
if (
|
const props = node.props as any;
|
||||||
node.props &&
|
if (props?.children) {
|
||||||
typeof node.props === "object" &&
|
return extractText(props.children as React.ReactNode);
|
||||||
"children" in node.props
|
|
||||||
) {
|
|
||||||
return extractText(node.props.children as React.ReactNode);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Array.isArray(node)) {
|
if (Array.isArray(node)) {
|
||||||
return node.map(extractText).join("");
|
return node.map(extractText).join("");
|
||||||
}
|
}
|
||||||
|
|
||||||
return "";
|
return "";
|
||||||
}, []);
|
};
|
||||||
|
|
||||||
const language = className?.replace(/language-/, "") || "";
|
const CodeBlock = React.memo(
|
||||||
|
({ children }: React.HTMLAttributes<HTMLPreElement>) => {
|
||||||
|
// Extract code and language from children
|
||||||
|
const codeElement = children as React.ReactElement<{
|
||||||
|
className?: string;
|
||||||
|
children: React.ReactNode;
|
||||||
|
}>;
|
||||||
|
const language =
|
||||||
|
codeElement.props.className?.replace(/language-/, "") || "";
|
||||||
|
const codeText = extractText(codeElement.props.children);
|
||||||
|
|
||||||
|
// Synchronously highlight code using the pre-loaded highlighter
|
||||||
|
const tokens = React.useMemo(() => {
|
||||||
|
if (!highlighter) return null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
return {
|
||||||
|
light: highlighter.codeToTokensBase(codeText, {
|
||||||
|
lang: language as BundledLanguage,
|
||||||
|
theme: "one-light" as any,
|
||||||
|
}),
|
||||||
|
dark: highlighter.codeToTokensBase(codeText, {
|
||||||
|
lang: language as BundledLanguage,
|
||||||
|
theme: "one-dark" as any,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to highlight code:", error);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}, [codeText, language]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="relative bg-neutral-100 dark:bg-neutral-800 rounded-2xl overflow-hidden my-6">
|
<div className="relative bg-neutral-100 dark:bg-neutral-800 rounded-2xl overflow-hidden my-6">
|
||||||
<div className="flex justify-between select-none">
|
<div className="flex select-none">
|
||||||
|
{language && (
|
||||||
<div className="text-[13px] text-neutral-500 dark:text-neutral-400 font-mono px-4 py-2">
|
<div className="text-[13px] text-neutral-500 dark:text-neutral-400 font-mono px-4 py-2">
|
||||||
{language}
|
{language}
|
||||||
</div>
|
</div>
|
||||||
|
)}
|
||||||
<CopyButton
|
<CopyButton
|
||||||
content={extractText(children)}
|
content={codeText}
|
||||||
showLabels={true}
|
showLabels={true}
|
||||||
className="copy-button text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800"
|
className="copy-button text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800 ml-auto"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<pre className={className} {...props}>
|
{/* Light mode */}
|
||||||
{children}
|
<pre className="dark:hidden m-0 bg-neutral-100 text-sm overflow-x-auto p-4">
|
||||||
|
<code className="font-mono text-sm">
|
||||||
|
{tokens?.light
|
||||||
|
? tokens.light.map((line: any, i: number) => (
|
||||||
|
<React.Fragment key={i}>
|
||||||
|
{line.map((token: any, j: number) => (
|
||||||
|
<span
|
||||||
|
key={j}
|
||||||
|
style={{
|
||||||
|
color: token.color,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{token.content}
|
||||||
|
</span>
|
||||||
|
))}
|
||||||
|
{i < tokens.light.length - 1 && "\n"}
|
||||||
|
</React.Fragment>
|
||||||
|
))
|
||||||
|
: codeText}
|
||||||
|
</code>
|
||||||
|
</pre>
|
||||||
|
{/* Dark mode */}
|
||||||
|
<pre className="hidden dark:block m-0 bg-neutral-800 text-sm overflow-x-auto p-4">
|
||||||
|
<code className="font-mono text-sm">
|
||||||
|
{tokens?.dark
|
||||||
|
? tokens.dark.map((line: any, i: number) => (
|
||||||
|
<React.Fragment key={i}>
|
||||||
|
{line.map((token: any, j: number) => (
|
||||||
|
<span
|
||||||
|
key={j}
|
||||||
|
style={{
|
||||||
|
color: token.color,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{token.content}
|
||||||
|
</span>
|
||||||
|
))}
|
||||||
|
{i < tokens.dark.length - 1 && "\n"}
|
||||||
|
</React.Fragment>
|
||||||
|
))
|
||||||
|
: codeText}
|
||||||
|
</code>
|
||||||
</pre>
|
</pre>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
@@ -68,60 +125,14 @@ const CodeBlock = React.memo(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
||||||
React.memo(
|
React.memo(({ content, isStreaming = false, size, browserToolResult }) => {
|
||||||
({ content, isStreaming = false, size, onLastNode, browserToolResult }) => {
|
// Build the remark plugins array - keep default GFM and Math, add citations
|
||||||
// Build the remark plugins array
|
|
||||||
const remarkPlugins = React.useMemo(() => {
|
const remarkPlugins = React.useMemo(() => {
|
||||||
const plugins: PluggableList = [
|
return [
|
||||||
remarkGfm,
|
defaultRemarkPlugins.gfm,
|
||||||
[remarkMath, { singleDollarTextMath: false }],
|
defaultRemarkPlugins.math,
|
||||||
remarkCitationParser,
|
remarkCitationParser,
|
||||||
];
|
];
|
||||||
|
|
||||||
// Add streaming plugin when in streaming mode
|
|
||||||
if (isStreaming) {
|
|
||||||
plugins.push([remarkStreamingMarkdown, { debug: true, onLastNode }]);
|
|
||||||
}
|
|
||||||
|
|
||||||
return plugins;
|
|
||||||
}, [isStreaming, onLastNode]);
|
|
||||||
|
|
||||||
// Create a custom sanitization schema that allows math elements
|
|
||||||
const sanitizeSchema = React.useMemo(() => {
|
|
||||||
return {
|
|
||||||
...defaultSchema,
|
|
||||||
attributes: {
|
|
||||||
...defaultSchema.attributes,
|
|
||||||
span: [
|
|
||||||
...(defaultSchema.attributes?.span || []),
|
|
||||||
["className", /^katex/],
|
|
||||||
],
|
|
||||||
div: [
|
|
||||||
...(defaultSchema.attributes?.div || []),
|
|
||||||
["className", /^katex/],
|
|
||||||
],
|
|
||||||
"ol-citation": ["cursor", "start", "end"],
|
|
||||||
},
|
|
||||||
tagNames: [
|
|
||||||
...(defaultSchema.tagNames || []),
|
|
||||||
"math",
|
|
||||||
"mrow",
|
|
||||||
"mi",
|
|
||||||
"mo",
|
|
||||||
"mn",
|
|
||||||
"msup",
|
|
||||||
"msub",
|
|
||||||
"mfrac",
|
|
||||||
"mover",
|
|
||||||
"munder",
|
|
||||||
"msqrt",
|
|
||||||
"mroot",
|
|
||||||
"merror",
|
|
||||||
"mspace",
|
|
||||||
"mpadded",
|
|
||||||
"ol-citation",
|
|
||||||
],
|
|
||||||
};
|
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -144,6 +155,26 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
|||||||
prose-pre:my-0
|
prose-pre:my-0
|
||||||
prose-pre:max-w-full
|
prose-pre:max-w-full
|
||||||
prose-pre:pt-1
|
prose-pre:pt-1
|
||||||
|
[&_table]:border-collapse
|
||||||
|
[&_table]:w-full
|
||||||
|
[&_table]:border
|
||||||
|
[&_table]:border-neutral-200
|
||||||
|
[&_table]:rounded-lg
|
||||||
|
[&_table]:overflow-hidden
|
||||||
|
[&_th]:px-3
|
||||||
|
[&_th]:py-2
|
||||||
|
[&_th]:text-left
|
||||||
|
[&_th]:font-semibold
|
||||||
|
[&_th]:border-b
|
||||||
|
[&_th]:border-r
|
||||||
|
[&_th]:border-neutral-200
|
||||||
|
[&_th:last-child]:border-r-0
|
||||||
|
[&_td]:px-3
|
||||||
|
[&_td]:py-2
|
||||||
|
[&_td]:border-r
|
||||||
|
[&_td]:border-neutral-200
|
||||||
|
[&_td:last-child]:border-r-0
|
||||||
|
[&_tbody_tr:not(:last-child)_td]:border-b
|
||||||
[&_code:not(pre_code)]:text-neutral-700
|
[&_code:not(pre_code)]:text-neutral-700
|
||||||
[&_code:not(pre_code)]:bg-neutral-100
|
[&_code:not(pre_code)]:bg-neutral-100
|
||||||
[&_code:not(pre_code)]:font-normal
|
[&_code:not(pre_code)]:font-normal
|
||||||
@@ -160,6 +191,10 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
|||||||
dark:prose-strong:text-neutral-200
|
dark:prose-strong:text-neutral-200
|
||||||
dark:prose-pre:text-neutral-200
|
dark:prose-pre:text-neutral-200
|
||||||
dark:prose:pre:text-neutral-200
|
dark:prose:pre:text-neutral-200
|
||||||
|
dark:[&_table]:border-neutral-700
|
||||||
|
dark:[&_thead]:bg-neutral-800
|
||||||
|
dark:[&_th]:border-neutral-700
|
||||||
|
dark:[&_td]:border-neutral-700
|
||||||
dark:[&_code:not(pre_code)]:text-neutral-200
|
dark:[&_code:not(pre_code)]:text-neutral-200
|
||||||
dark:[&_code:not(pre_code)]:bg-neutral-800
|
dark:[&_code:not(pre_code)]:bg-neutral-800
|
||||||
dark:[&_code:not(pre_code)]:font-normal
|
dark:[&_code:not(pre_code)]:font-normal
|
||||||
@@ -172,23 +207,11 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
|||||||
content={content}
|
content={content}
|
||||||
isStreaming={isStreaming}
|
isStreaming={isStreaming}
|
||||||
>
|
>
|
||||||
<Markdown
|
<Streamdown
|
||||||
|
parseIncompleteMarkdown={isStreaming}
|
||||||
|
isAnimating={isStreaming}
|
||||||
remarkPlugins={remarkPlugins}
|
remarkPlugins={remarkPlugins}
|
||||||
rehypePlugins={
|
controls={false}
|
||||||
[
|
|
||||||
[rehypeRaw, { allowDangerousHtml: true }],
|
|
||||||
[rehypeSanitize, sanitizeSchema],
|
|
||||||
[rehypePrismPlus, { ignoreMissing: true }],
|
|
||||||
[
|
|
||||||
rehypeKatex,
|
|
||||||
{
|
|
||||||
errorColor: "#000000", // Black instead of red for errors
|
|
||||||
strict: false, // Be more lenient with parsing
|
|
||||||
throwOnError: false,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
] as PluggableList
|
|
||||||
}
|
|
||||||
components={{
|
components={{
|
||||||
pre: CodeBlock,
|
pre: CodeBlock,
|
||||||
table: ({
|
table: ({
|
||||||
@@ -196,38 +219,35 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
|||||||
...props
|
...props
|
||||||
}: React.HTMLAttributes<HTMLTableElement>) => (
|
}: React.HTMLAttributes<HTMLTableElement>) => (
|
||||||
<div className="overflow-x-auto max-w-full">
|
<div className="overflow-x-auto max-w-full">
|
||||||
<table {...props}>{children}</table>
|
<table
|
||||||
|
{...props}
|
||||||
|
className="border-collapse w-full border border-neutral-200 dark:border-neutral-700 rounded-lg overflow-hidden"
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</table>
|
||||||
</div>
|
</div>
|
||||||
),
|
),
|
||||||
// @ts-expect-error: custom type
|
// @ts-expect-error: custom citation type
|
||||||
"ol-citation": ({
|
"ol-citation": ({
|
||||||
cursor,
|
cursor,
|
||||||
// start,
|
|
||||||
// end,
|
|
||||||
}: {
|
}: {
|
||||||
cursor: number;
|
cursor: number;
|
||||||
start: number;
|
start: number;
|
||||||
end: number;
|
end: number;
|
||||||
}) => {
|
}) => {
|
||||||
// Check if we have a page_stack and if the cursor is valid
|
|
||||||
const pageStack = browserToolResult?.page_stack;
|
const pageStack = browserToolResult?.page_stack;
|
||||||
const hasValidPage = pageStack && cursor < pageStack.length;
|
const hasValidPage = pageStack && cursor < pageStack.length;
|
||||||
const pageUrl = hasValidPage ? pageStack[cursor] : null;
|
const pageUrl = hasValidPage ? pageStack[cursor] : null;
|
||||||
|
|
||||||
// Extract a readable title from the URL if possible
|
|
||||||
const getPageTitle = (url: string) => {
|
const getPageTitle = (url: string) => {
|
||||||
if (url.startsWith("search_results_")) {
|
if (url.startsWith("search_results_")) {
|
||||||
const searchTerm = url.substring(
|
const searchTerm = url.substring("search_results_".length);
|
||||||
"search_results_".length,
|
|
||||||
);
|
|
||||||
return `Search: ${searchTerm}`;
|
return `Search: ${searchTerm}`;
|
||||||
}
|
}
|
||||||
// For regular URLs, try to extract domain or use full URL
|
|
||||||
try {
|
try {
|
||||||
const urlObj = new URL(url);
|
const urlObj = new URL(url);
|
||||||
return urlObj.hostname;
|
return urlObj.hostname;
|
||||||
} catch {
|
} catch {
|
||||||
// If not a valid URL, return as is
|
|
||||||
return url;
|
return url;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -238,7 +258,6 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
|||||||
</span>
|
</span>
|
||||||
);
|
);
|
||||||
|
|
||||||
// If we have a valid page URL, wrap in a link
|
|
||||||
if (pageUrl && pageUrl.startsWith("http")) {
|
if (pageUrl && pageUrl.startsWith("http")) {
|
||||||
return (
|
return (
|
||||||
<a
|
<a
|
||||||
@@ -253,18 +272,16 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, just return the citation without a link
|
|
||||||
return citationElement;
|
return citationElement;
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{content}
|
{content}
|
||||||
</Markdown>
|
</Streamdown>
|
||||||
</StreamingMarkdownErrorBoundary>
|
</StreamingMarkdownErrorBoundary>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
},
|
});
|
||||||
);
|
|
||||||
|
|
||||||
interface StreamingMarkdownErrorBoundaryProps {
|
interface StreamingMarkdownErrorBoundaryProps {
|
||||||
content: string;
|
content: string;
|
||||||
|
|||||||
@@ -73,8 +73,9 @@ export default function Thinking({
|
|||||||
// Calculate max height for smooth animations
|
// Calculate max height for smooth animations
|
||||||
const getMaxHeight = () => {
|
const getMaxHeight = () => {
|
||||||
if (isCollapsed) {
|
if (isCollapsed) {
|
||||||
return finishedThinking ? "0px" : "12rem"; // 8rem = 128px (same as max-h-32)
|
return finishedThinking ? "0px" : "12rem";
|
||||||
}
|
}
|
||||||
|
// When expanded, use the content height or grow naturally
|
||||||
return contentHeight ? `${contentHeight}px` : "none";
|
return contentHeight ? `${contentHeight}px` : "none";
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -131,10 +132,11 @@ export default function Thinking({
|
|||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
ref={wrapperRef}
|
ref={wrapperRef}
|
||||||
className={`text-xs text-neutral-500 dark:text-neutral-500 rounded-md overflow-hidden
|
className={`text-xs text-neutral-500 dark:text-neutral-500 rounded-md
|
||||||
transition-[max-height,opacity] duration-300 ease-in-out relative ml-6 mt-2`}
|
transition-[max-height,opacity] duration-300 ease-in-out relative ml-6 mt-2
|
||||||
|
${isCollapsed ? "overflow-hidden" : "overflow-y-auto"}`}
|
||||||
style={{
|
style={{
|
||||||
maxHeight: getMaxHeight(),
|
maxHeight: isCollapsed ? getMaxHeight() : undefined,
|
||||||
opacity: isCollapsed && finishedThinking ? 0 : 1,
|
opacity: isCollapsed && finishedThinking ? 0 : 1,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -16,793 +16,6 @@
|
|||||||
--text-color: #ffffff;
|
--text-color: #ffffff;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@media (prefers-color-scheme: light) {
|
|
||||||
.prose {
|
|
||||||
/**
|
|
||||||
* One Light theme for prism.js
|
|
||||||
* Based on Atom's One Light theme: https://github.com/atom/atom/tree/master/packages/one-light-syntax
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* One Light colours (accurate as of commit eb064bf on 19 Feb 2021)
|
|
||||||
* From colors.less
|
|
||||||
* --mono-1: hsl(230, 8%, 24%);
|
|
||||||
* --mono-2: hsl(230, 6%, 44%);
|
|
||||||
* --mono-3: hsl(230, 4%, 64%)
|
|
||||||
* --hue-1: hsl(198, 99%, 37%);
|
|
||||||
* --hue-2: hsl(221, 87%, 60%);
|
|
||||||
* --hue-3: hsl(301, 63%, 40%);
|
|
||||||
* --hue-4: hsl(119, 34%, 47%);
|
|
||||||
* --hue-5: hsl(5, 74%, 59%);
|
|
||||||
* --hue-5-2: hsl(344, 84%, 43%);
|
|
||||||
* --hue-6: hsl(35, 99%, 36%);
|
|
||||||
* --hue-6-2: hsl(35, 99%, 40%);
|
|
||||||
* --syntax-fg: hsl(230, 8%, 24%);
|
|
||||||
* --syntax-bg: hsl(230, 1%, 98%);
|
|
||||||
* --syntax-gutter: hsl(230, 1%, 62%);
|
|
||||||
* --syntax-guide: hsla(230, 8%, 24%, 0.2);
|
|
||||||
* --syntax-accent: hsl(230, 100%, 66%);
|
|
||||||
* From syntax-variables.less
|
|
||||||
* --syntax-selection-color: hsl(230, 1%, 90%);
|
|
||||||
* --syntax-gutter-background-color-selected: hsl(230, 1%, 90%);
|
|
||||||
* --syntax-cursor-line: hsla(230, 8%, 24%, 0.05);
|
|
||||||
*/
|
|
||||||
|
|
||||||
.token.comment,
|
|
||||||
.token.prolog,
|
|
||||||
.token.cdata {
|
|
||||||
color: hsl(230, 4%, 64%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.doctype,
|
|
||||||
.token.punctuation,
|
|
||||||
.token.entity {
|
|
||||||
color: hsl(230, 8%, 24%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.attr-name,
|
|
||||||
.token.class-name,
|
|
||||||
.token.boolean,
|
|
||||||
.token.constant,
|
|
||||||
.token.number,
|
|
||||||
.token.atrule {
|
|
||||||
color: hsl(35, 99%, 36%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.keyword {
|
|
||||||
color: hsl(301, 63%, 40%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.property,
|
|
||||||
.token.tag,
|
|
||||||
.token.symbol,
|
|
||||||
.token.deleted,
|
|
||||||
.token.important {
|
|
||||||
color: hsl(5, 74%, 59%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.selector,
|
|
||||||
.token.string,
|
|
||||||
.token.char,
|
|
||||||
.token.builtin,
|
|
||||||
.token.inserted,
|
|
||||||
.token.regex,
|
|
||||||
.token.attr-value,
|
|
||||||
.token.attr-value > .token.punctuation {
|
|
||||||
color: hsl(119, 34%, 47%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.variable,
|
|
||||||
.token.operator,
|
|
||||||
.token.function {
|
|
||||||
color: hsl(221, 87%, 60%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.url {
|
|
||||||
color: hsl(198, 99%, 37%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* HTML overrides */
|
|
||||||
.token.attr-value > .token.punctuation.attr-equals,
|
|
||||||
.token.special-attr > .token.attr-value > .token.value.css {
|
|
||||||
color: hsl(230, 8%, 24%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* CSS overrides */
|
|
||||||
.language-css .token.selector {
|
|
||||||
color: hsl(5, 74%, 59%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-css .token.property {
|
|
||||||
color: hsl(230, 8%, 24%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-css .token.function,
|
|
||||||
.language-css .token.url > .token.function {
|
|
||||||
color: hsl(198, 99%, 37%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-css .token.url > .token.string.url {
|
|
||||||
color: hsl(119, 34%, 47%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-css .token.important,
|
|
||||||
.language-css .token.atrule .token.rule {
|
|
||||||
color: hsl(301, 63%, 40%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* JS overrides */
|
|
||||||
.language-javascript .token.operator {
|
|
||||||
color: hsl(301, 63%, 40%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-javascript
|
|
||||||
.token.template-string
|
|
||||||
> .token.interpolation
|
|
||||||
> .token.interpolation-punctuation.punctuation {
|
|
||||||
color: hsl(344, 84%, 43%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* JSON overrides */
|
|
||||||
.language-json .token.operator {
|
|
||||||
color: hsl(230, 8%, 24%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-json .token.null.keyword {
|
|
||||||
color: hsl(35, 99%, 36%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* MD overrides */
|
|
||||||
.language-markdown .token.url,
|
|
||||||
.language-markdown .token.url > .token.operator,
|
|
||||||
.language-markdown .token.url-reference.url > .token.string {
|
|
||||||
color: hsl(230, 8%, 24%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-markdown .token.url > .token.content {
|
|
||||||
color: hsl(221, 87%, 60%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-markdown .token.url > .token.url,
|
|
||||||
.language-markdown .token.url-reference.url {
|
|
||||||
color: hsl(198, 99%, 37%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-markdown .token.blockquote.punctuation,
|
|
||||||
.language-markdown .token.hr.punctuation {
|
|
||||||
color: hsl(230, 4%, 64%);
|
|
||||||
font-style: italic;
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-markdown .token.code-snippet {
|
|
||||||
color: hsl(119, 34%, 47%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-markdown .token.bold .token.content {
|
|
||||||
color: hsl(35, 99%, 36%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-markdown .token.italic .token.content {
|
|
||||||
color: hsl(301, 63%, 40%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-markdown .token.strike .token.content,
|
|
||||||
.language-markdown .token.strike .token.punctuation,
|
|
||||||
.language-markdown .token.list.punctuation,
|
|
||||||
.language-markdown .token.title.important > .token.punctuation {
|
|
||||||
color: hsl(5, 74%, 59%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* General */
|
|
||||||
.token.bold {
|
|
||||||
font-weight: bold;
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.comment,
|
|
||||||
.token.italic {
|
|
||||||
font-style: italic;
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.entity {
|
|
||||||
cursor: help;
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.namespace {
|
|
||||||
opacity: 0.8;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Plugin overrides */
|
|
||||||
/* Selectors should have higher specificity than those in the plugins' default stylesheets */
|
|
||||||
|
|
||||||
/* Show Invisibles plugin overrides */
|
|
||||||
.token.token.tab:not(:empty):before,
|
|
||||||
.token.token.cr:before,
|
|
||||||
.token.token.lf:before,
|
|
||||||
.token.token.space:before {
|
|
||||||
color: hsla(230, 8%, 24%, 0.2);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Toolbar plugin overrides */
|
|
||||||
/* Space out all buttons and move them away from the right edge of the code block */
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item {
|
|
||||||
margin-right: 0.4em;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Styling the buttons */
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button,
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a,
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span {
|
|
||||||
background: hsl(230, 1%, 90%);
|
|
||||||
color: hsl(230, 6%, 44%);
|
|
||||||
padding: 0.1em 0.4em;
|
|
||||||
border-radius: 0.3em;
|
|
||||||
}
|
|
||||||
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:hover,
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:focus,
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:hover,
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:focus,
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:hover,
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:focus {
|
|
||||||
background: hsl(230, 1%, 78%); /* custom: darken(--syntax-bg, 20%) */
|
|
||||||
color: hsl(230, 8%, 24%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Line Highlight plugin overrides */
|
|
||||||
/* The highlighted line itself */
|
|
||||||
.line-highlight.line-highlight {
|
|
||||||
background: hsla(230, 8%, 24%, 0.05);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Default line numbers in Line Highlight plugin */
|
|
||||||
.line-highlight.line-highlight:before,
|
|
||||||
.line-highlight.line-highlight[data-end]:after {
|
|
||||||
background: hsl(230, 1%, 90%);
|
|
||||||
color: hsl(230, 8%, 24%);
|
|
||||||
padding: 0.1em 0.6em;
|
|
||||||
border-radius: 0.3em;
|
|
||||||
box-shadow: 0 2px 0 0 rgba(0, 0, 0, 0.2); /* same as Toolbar plugin default */
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Hovering over a linkable line number (in the gutter area) */
|
|
||||||
/* Requires Line Numbers plugin as well */
|
|
||||||
pre[id].linkable-line-numbers.linkable-line-numbers
|
|
||||||
span.line-numbers-rows
|
|
||||||
> span:hover:before {
|
|
||||||
background-color: hsla(230, 8%, 24%, 0.05);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Line Numbers and Command Line plugins overrides */
|
|
||||||
/* Line separating gutter from coding area */
|
|
||||||
.line-numbers.line-numbers .line-numbers-rows,
|
|
||||||
.command-line .command-line-prompt {
|
|
||||||
border-right-color: hsla(230, 8%, 24%, 0.2);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Stuff in the gutter */
|
|
||||||
.line-numbers .line-numbers-rows > span:before,
|
|
||||||
.command-line .command-line-prompt > span:before {
|
|
||||||
color: hsl(230, 1%, 62%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Match Braces plugin overrides */
|
|
||||||
/* Note: Outline colour is inherited from the braces */
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-1,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-5,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-9 {
|
|
||||||
color: hsl(5, 74%, 59%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-2,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-6,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-10 {
|
|
||||||
color: hsl(119, 34%, 47%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-3,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-7,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-11 {
|
|
||||||
color: hsl(221, 87%, 60%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-4,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-8,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-12 {
|
|
||||||
color: hsl(301, 63%, 40%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Diff Highlight plugin overrides */
|
|
||||||
/* Taken from https://github.com/atom/github/blob/master/styles/variables.less */
|
|
||||||
pre.diff-highlight > code .token.token.deleted:not(.prefix),
|
|
||||||
pre > code.diff-highlight .token.token.deleted:not(.prefix) {
|
|
||||||
background-color: hsla(353, 100%, 66%, 0.15);
|
|
||||||
}
|
|
||||||
|
|
||||||
pre.diff-highlight > code .token.token.deleted:not(.prefix)::-moz-selection,
|
|
||||||
pre.diff-highlight
|
|
||||||
> code
|
|
||||||
.token.token.deleted:not(.prefix)
|
|
||||||
*::-moz-selection,
|
|
||||||
pre > code.diff-highlight .token.token.deleted:not(.prefix)::-moz-selection,
|
|
||||||
pre
|
|
||||||
> code.diff-highlight
|
|
||||||
.token.token.deleted:not(.prefix)
|
|
||||||
*::-moz-selection {
|
|
||||||
background-color: hsla(353, 95%, 66%, 0.25);
|
|
||||||
}
|
|
||||||
|
|
||||||
pre.diff-highlight > code .token.token.deleted:not(.prefix)::selection,
|
|
||||||
pre.diff-highlight > code .token.token.deleted:not(.prefix) *::selection,
|
|
||||||
pre > code.diff-highlight .token.token.deleted:not(.prefix)::selection,
|
|
||||||
pre > code.diff-highlight .token.token.deleted:not(.prefix) *::selection {
|
|
||||||
background-color: hsla(353, 95%, 66%, 0.25);
|
|
||||||
}
|
|
||||||
|
|
||||||
pre.diff-highlight > code .token.token.inserted:not(.prefix),
|
|
||||||
pre > code.diff-highlight .token.token.inserted:not(.prefix) {
|
|
||||||
background-color: hsla(137, 100%, 55%, 0.15);
|
|
||||||
}
|
|
||||||
|
|
||||||
pre.diff-highlight
|
|
||||||
> code
|
|
||||||
.token.token.inserted:not(.prefix)::-moz-selection,
|
|
||||||
pre.diff-highlight
|
|
||||||
> code
|
|
||||||
.token.token.inserted:not(.prefix)
|
|
||||||
*::-moz-selection,
|
|
||||||
pre
|
|
||||||
> code.diff-highlight
|
|
||||||
.token.token.inserted:not(.prefix)::-moz-selection,
|
|
||||||
pre
|
|
||||||
> code.diff-highlight
|
|
||||||
.token.token.inserted:not(.prefix)
|
|
||||||
*::-moz-selection {
|
|
||||||
background-color: hsla(135, 73%, 55%, 0.25);
|
|
||||||
}
|
|
||||||
|
|
||||||
pre.diff-highlight > code .token.token.inserted:not(.prefix)::selection,
|
|
||||||
pre.diff-highlight > code .token.token.inserted:not(.prefix) *::selection,
|
|
||||||
pre > code.diff-highlight .token.token.inserted:not(.prefix)::selection,
|
|
||||||
pre > code.diff-highlight .token.token.inserted:not(.prefix) *::selection {
|
|
||||||
background-color: hsla(135, 73%, 55%, 0.25);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Previewers plugin overrides */
|
|
||||||
/* Based on https://github.com/atom-community/atom-ide-datatip/blob/master/styles/atom-ide-datatips.less and https://github.com/atom/atom/blob/master/packages/one-light-ui */
|
|
||||||
/* Border around popup */
|
|
||||||
.prism-previewer.prism-previewer:before,
|
|
||||||
.prism-previewer-gradient.prism-previewer-gradient div {
|
|
||||||
border-color: hsl(0, 0, 95%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Angle and time should remain as circles and are hence not included */
|
|
||||||
.prism-previewer-color.prism-previewer-color:before,
|
|
||||||
.prism-previewer-gradient.prism-previewer-gradient div,
|
|
||||||
.prism-previewer-easing.prism-previewer-easing:before {
|
|
||||||
border-radius: 0.3em;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Triangles pointing to the code */
|
|
||||||
.prism-previewer.prism-previewer:after {
|
|
||||||
border-top-color: hsl(0, 0, 95%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.prism-previewer-flipped.prism-previewer-flipped.after {
|
|
||||||
border-bottom-color: hsl(0, 0, 95%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Background colour within the popup */
|
|
||||||
.prism-previewer-angle.prism-previewer-angle:before,
|
|
||||||
.prism-previewer-time.prism-previewer-time:before,
|
|
||||||
.prism-previewer-easing.prism-previewer-easing {
|
|
||||||
background: hsl(0, 0%, 100%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* For angle, this is the positive area (eg. 90deg will display one quadrant in this colour) */
|
|
||||||
/* For time, this is the alternate colour */
|
|
||||||
.prism-previewer-angle.prism-previewer-angle circle,
|
|
||||||
.prism-previewer-time.prism-previewer-time circle {
|
|
||||||
stroke: hsl(230, 8%, 24%);
|
|
||||||
stroke-opacity: 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Stroke colours of the handle, direction point, and vector itself */
|
|
||||||
.prism-previewer-easing.prism-previewer-easing circle,
|
|
||||||
.prism-previewer-easing.prism-previewer-easing path,
|
|
||||||
.prism-previewer-easing.prism-previewer-easing line {
|
|
||||||
stroke: hsl(230, 8%, 24%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Fill colour of the handle */
|
|
||||||
.prism-previewer-easing.prism-previewer-easing circle {
|
|
||||||
fill: transparent;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@media (prefers-color-scheme: dark) {
|
|
||||||
.prose {
|
|
||||||
.token.comment,
|
|
||||||
.token.prolog,
|
|
||||||
.token.cdata {
|
|
||||||
color: hsl(220, 10%, 40%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.doctype,
|
|
||||||
.token.punctuation,
|
|
||||||
.token.entity {
|
|
||||||
color: hsl(220, 14%, 71%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.attr-name,
|
|
||||||
.token.class-name,
|
|
||||||
.token.boolean,
|
|
||||||
.token.constant,
|
|
||||||
.token.number,
|
|
||||||
.token.atrule {
|
|
||||||
color: hsl(29, 54%, 61%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.keyword {
|
|
||||||
color: hsl(286, 60%, 67%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.property,
|
|
||||||
.token.tag,
|
|
||||||
.token.symbol,
|
|
||||||
.token.deleted,
|
|
||||||
.token.important {
|
|
||||||
color: hsl(355, 65%, 65%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.selector,
|
|
||||||
.token.string,
|
|
||||||
.token.char,
|
|
||||||
.token.builtin,
|
|
||||||
.token.inserted,
|
|
||||||
.token.regex,
|
|
||||||
.token.attr-value,
|
|
||||||
.token.attr-value > .token.punctuation {
|
|
||||||
color: hsl(95, 38%, 62%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.variable,
|
|
||||||
.token.operator,
|
|
||||||
.token.function {
|
|
||||||
color: hsl(207, 82%, 66%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.url {
|
|
||||||
color: hsl(187, 47%, 55%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* HTML overrides */
|
|
||||||
.token.attr-value > .token.punctuation.attr-equals,
|
|
||||||
.token.special-attr > .token.attr-value > .token.value.css {
|
|
||||||
color: hsl(220, 14%, 71%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* CSS overrides */
|
|
||||||
.language-css .token.selector {
|
|
||||||
color: hsl(355, 65%, 65%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-css .token.property {
|
|
||||||
color: hsl(220, 14%, 71%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-css .token.function,
|
|
||||||
.language-css .token.url > .token.function {
|
|
||||||
color: hsl(187, 47%, 55%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-css .token.url > .token.string.url {
|
|
||||||
color: hsl(95, 38%, 62%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-css .token.important,
|
|
||||||
.language-css .token.atrule .token.rule {
|
|
||||||
color: hsl(286, 60%, 67%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* JS overrides */
|
|
||||||
.language-javascript .token.operator {
|
|
||||||
color: hsl(286, 60%, 67%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-javascript
|
|
||||||
.token.template-string
|
|
||||||
> .token.interpolation
|
|
||||||
> .token.interpolation-punctuation.punctuation {
|
|
||||||
color: hsl(5, 48%, 51%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* JSON overrides */
|
|
||||||
.language-json .token.operator {
|
|
||||||
color: hsl(220, 14%, 71%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-json .token.null.keyword {
|
|
||||||
color: hsl(29, 54%, 61%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* MD overrides */
|
|
||||||
.language-markdown .token.url,
|
|
||||||
.language-markdown .token.url > .token.operator,
|
|
||||||
.language-markdown .token.url-reference.url > .token.string {
|
|
||||||
color: hsl(220, 14%, 71%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-markdown .token.url > .token.content {
|
|
||||||
color: hsl(207, 82%, 66%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-markdown .token.url > .token.url,
|
|
||||||
.language-markdown .token.url-reference.url {
|
|
||||||
color: hsl(187, 47%, 55%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-markdown .token.blockquote.punctuation,
|
|
||||||
.language-markdown .token.hr.punctuation {
|
|
||||||
color: hsl(220, 10%, 40%);
|
|
||||||
font-style: italic;
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-markdown .token.code-snippet {
|
|
||||||
color: hsl(95, 38%, 62%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-markdown .token.bold .token.content {
|
|
||||||
color: hsl(29, 54%, 61%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-markdown .token.italic .token.content {
|
|
||||||
color: hsl(286, 60%, 67%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.language-markdown .token.strike .token.content,
|
|
||||||
.language-markdown .token.strike .token.punctuation,
|
|
||||||
.language-markdown .token.list.punctuation,
|
|
||||||
.language-markdown .token.title.important > .token.punctuation {
|
|
||||||
color: hsl(355, 65%, 65%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* General */
|
|
||||||
.token.bold {
|
|
||||||
font-weight: bold;
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.comment,
|
|
||||||
.token.italic {
|
|
||||||
font-style: italic;
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.entity {
|
|
||||||
cursor: help;
|
|
||||||
}
|
|
||||||
|
|
||||||
.token.namespace {
|
|
||||||
opacity: 0.8;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Plugin overrides */
|
|
||||||
/* Selectors should have higher specificity than those in the plugins' default stylesheets */
|
|
||||||
|
|
||||||
/* Show Invisibles plugin overrides */
|
|
||||||
.token.token.tab:not(:empty):before,
|
|
||||||
.token.token.cr:before,
|
|
||||||
.token.token.lf:before,
|
|
||||||
.token.token.space:before {
|
|
||||||
color: hsla(220, 14%, 71%, 0.15);
|
|
||||||
text-shadow: none;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Toolbar plugin overrides */
|
|
||||||
/* Space out all buttons and move them away from the right edge of the code block */
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item {
|
|
||||||
margin-right: 0.4em;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Styling the buttons */
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button,
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a,
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span {
|
|
||||||
background: hsl(220, 13%, 26%);
|
|
||||||
color: hsl(220, 9%, 55%);
|
|
||||||
padding: 0.1em 0.4em;
|
|
||||||
border-radius: 0.3em;
|
|
||||||
}
|
|
||||||
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:hover,
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:focus,
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:hover,
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:focus,
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:hover,
|
|
||||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:focus {
|
|
||||||
background: hsl(220, 13%, 28%);
|
|
||||||
color: hsl(220, 14%, 71%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Line Highlight plugin overrides */
|
|
||||||
/* The highlighted line itself */
|
|
||||||
.line-highlight.line-highlight {
|
|
||||||
background: hsla(220, 100%, 80%, 0.04);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Default line numbers in Line Highlight plugin */
|
|
||||||
.line-highlight.line-highlight:before,
|
|
||||||
.line-highlight.line-highlight[data-end]:after {
|
|
||||||
background: hsl(220, 13%, 26%);
|
|
||||||
color: hsl(220, 14%, 71%);
|
|
||||||
padding: 0.1em 0.6em;
|
|
||||||
border-radius: 0.3em;
|
|
||||||
box-shadow: 0 2px 0 0 rgba(0, 0, 0, 0.2); /* same as Toolbar plugin default */
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Hovering over a linkable line number (in the gutter area) */
|
|
||||||
/* Requires Line Numbers plugin as well */
|
|
||||||
pre[id].linkable-line-numbers.linkable-line-numbers
|
|
||||||
span.line-numbers-rows
|
|
||||||
> span:hover:before {
|
|
||||||
background-color: hsla(220, 100%, 80%, 0.04);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Line Numbers and Command Line plugins overrides */
|
|
||||||
/* Line separating gutter from coding area */
|
|
||||||
.line-numbers.line-numbers .line-numbers-rows,
|
|
||||||
.command-line .command-line-prompt {
|
|
||||||
border-right-color: hsla(220, 14%, 71%, 0.15);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Stuff in the gutter */
|
|
||||||
.line-numbers .line-numbers-rows > span:before,
|
|
||||||
.command-line .command-line-prompt > span:before {
|
|
||||||
color: hsl(220, 14%, 45%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Match Braces plugin overrides */
|
|
||||||
/* Note: Outline colour is inherited from the braces */
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-1,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-5,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-9 {
|
|
||||||
color: hsl(355, 65%, 65%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-2,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-6,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-10 {
|
|
||||||
color: hsl(95, 38%, 62%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-3,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-7,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-11 {
|
|
||||||
color: hsl(207, 82%, 66%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-4,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-8,
|
|
||||||
.rainbow-braces .token.token.punctuation.brace-level-12 {
|
|
||||||
color: hsl(286, 60%, 67%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Diff Highlight plugin overrides */
|
|
||||||
/* Taken from https://github.com/atom/github/blob/master/styles/variables.less */
|
|
||||||
pre.diff-highlight > code .token.token.deleted:not(.prefix),
|
|
||||||
pre > code.diff-highlight .token.token.deleted:not(.prefix) {
|
|
||||||
background-color: hsla(353, 100%, 66%, 0.15);
|
|
||||||
}
|
|
||||||
|
|
||||||
pre.diff-highlight > code .token.token.deleted:not(.prefix)::-moz-selection,
|
|
||||||
pre.diff-highlight
|
|
||||||
> code
|
|
||||||
.token.token.deleted:not(.prefix)
|
|
||||||
*::-moz-selection,
|
|
||||||
pre > code.diff-highlight .token.token.deleted:not(.prefix)::-moz-selection,
|
|
||||||
pre
|
|
||||||
> code.diff-highlight
|
|
||||||
.token.token.deleted:not(.prefix)
|
|
||||||
*::-moz-selection {
|
|
||||||
background-color: hsla(353, 95%, 66%, 0.25);
|
|
||||||
}
|
|
||||||
|
|
||||||
pre.diff-highlight > code .token.token.deleted:not(.prefix)::selection,
|
|
||||||
pre.diff-highlight > code .token.token.deleted:not(.prefix) *::selection,
|
|
||||||
pre > code.diff-highlight .token.token.deleted:not(.prefix)::selection,
|
|
||||||
pre > code.diff-highlight .token.token.deleted:not(.prefix) *::selection {
|
|
||||||
background-color: hsla(353, 95%, 66%, 0.25);
|
|
||||||
}
|
|
||||||
|
|
||||||
pre.diff-highlight > code .token.token.inserted:not(.prefix),
|
|
||||||
pre > code.diff-highlight .token.token.inserted:not(.prefix) {
|
|
||||||
background-color: hsla(137, 100%, 55%, 0.15);
|
|
||||||
}
|
|
||||||
|
|
||||||
pre.diff-highlight
|
|
||||||
> code
|
|
||||||
.token.token.inserted:not(.prefix)::-moz-selection,
|
|
||||||
pre.diff-highlight
|
|
||||||
> code
|
|
||||||
.token.token.inserted:not(.prefix)
|
|
||||||
*::-moz-selection,
|
|
||||||
pre
|
|
||||||
> code.diff-highlight
|
|
||||||
.token.token.inserted:not(.prefix)::-moz-selection,
|
|
||||||
pre
|
|
||||||
> code.diff-highlight
|
|
||||||
.token.token.inserted:not(.prefix)
|
|
||||||
*::-moz-selection {
|
|
||||||
background-color: hsla(135, 73%, 55%, 0.25);
|
|
||||||
}
|
|
||||||
|
|
||||||
pre.diff-highlight > code .token.token.inserted:not(.prefix)::selection,
|
|
||||||
pre.diff-highlight > code .token.token.inserted:not(.prefix) *::selection,
|
|
||||||
pre > code.diff-highlight .token.token.inserted:not(.prefix)::selection,
|
|
||||||
pre > code.diff-highlight .token.token.inserted:not(.prefix) *::selection {
|
|
||||||
background-color: hsla(135, 73%, 55%, 0.25);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Previewers plugin overrides */
|
|
||||||
/* Based on https://github.com/atom-community/atom-ide-datatip/blob/master/styles/atom-ide-datatips.less and https://github.com/atom/atom/blob/master/packages/one-dark-ui */
|
|
||||||
/* Border around popup */
|
|
||||||
.prism-previewer.prism-previewer:before,
|
|
||||||
.prism-previewer-gradient.prism-previewer-gradient div {
|
|
||||||
border-color: hsl(224, 13%, 17%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Angle and time should remain as circles and are hence not included */
|
|
||||||
.prism-previewer-color.prism-previewer-color:before,
|
|
||||||
.prism-previewer-gradient.prism-previewer-gradient div,
|
|
||||||
.prism-previewer-easing.prism-previewer-easing:before {
|
|
||||||
border-radius: 0.3em;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Triangles pointing to the code */
|
|
||||||
.prism-previewer.prism-previewer:after {
|
|
||||||
border-top-color: hsl(224, 13%, 17%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.prism-previewer-flipped.prism-previewer-flipped.after {
|
|
||||||
border-bottom-color: hsl(224, 13%, 17%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Background colour within the popup */
|
|
||||||
.prism-previewer-angle.prism-previewer-angle:before,
|
|
||||||
.prism-previewer-time.prism-previewer-time:before,
|
|
||||||
.prism-previewer-easing.prism-previewer-easing {
|
|
||||||
background: hsl(219, 13%, 22%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* For angle, this is the positive area (eg. 90deg will display one quadrant in this colour) */
|
|
||||||
/* For time, this is the alternate colour */
|
|
||||||
.prism-previewer-angle.prism-previewer-angle circle,
|
|
||||||
.prism-previewer-time.prism-previewer-time circle {
|
|
||||||
stroke: hsl(220, 14%, 71%);
|
|
||||||
stroke-opacity: 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Stroke colours of the handle, direction point, and vector itself */
|
|
||||||
.prism-previewer-easing.prism-previewer-easing circle,
|
|
||||||
.prism-previewer-easing.prism-previewer-easing path,
|
|
||||||
.prism-previewer-easing.prism-previewer-easing line {
|
|
||||||
stroke: hsl(220, 14%, 71%);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Fill colour of the handle */
|
|
||||||
.prism-previewer-easing.prism-previewer-easing circle {
|
|
||||||
fill: transparent;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
.prose pre {
|
|
||||||
contain: layout style;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Or more aggressively */
|
|
||||||
.prose pre code {
|
|
||||||
contain: layout style paint;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* messaging-style typing indicator animation */
|
/* messaging-style typing indicator animation */
|
||||||
@keyframes typing {
|
@keyframes typing {
|
||||||
|
|||||||
156
app/ui/app/src/lib/highlighter.ts
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import { createHighlighter } from "shiki";
|
||||||
|
import type { ThemeRegistration } from "shiki";
|
||||||
|
|
||||||
|
const oneLightTheme: ThemeRegistration = {
|
||||||
|
name: "one-light",
|
||||||
|
type: "light",
|
||||||
|
colors: {
|
||||||
|
"editor.background": "#fafafa",
|
||||||
|
"editor.foreground": "#383a42",
|
||||||
|
},
|
||||||
|
tokenColors: [
|
||||||
|
{
|
||||||
|
scope: ["comment", "punctuation.definition.comment"],
|
||||||
|
settings: { foreground: "#a0a1a7" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: ["keyword", "storage.type", "storage.modifier"],
|
||||||
|
settings: { foreground: "#a626a4" },
|
||||||
|
},
|
||||||
|
{ scope: ["string", "string.quoted"], settings: { foreground: "#50a14f" } },
|
||||||
|
{
|
||||||
|
scope: ["function", "entity.name.function", "support.function"],
|
||||||
|
settings: { foreground: "#4078f2" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: [
|
||||||
|
"constant.numeric",
|
||||||
|
"constant.language",
|
||||||
|
"constant.character",
|
||||||
|
"number",
|
||||||
|
],
|
||||||
|
settings: { foreground: "#c18401" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: ["variable", "support.variable"],
|
||||||
|
settings: { foreground: "#e45649" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: ["entity.name.tag", "entity.name.type", "entity.name.class"],
|
||||||
|
settings: { foreground: "#e45649" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: ["entity.other.attribute-name"],
|
||||||
|
settings: { foreground: "#c18401" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: ["keyword.operator", "operator"],
|
||||||
|
settings: { foreground: "#a626a4" },
|
||||||
|
},
|
||||||
|
{ scope: ["punctuation"], settings: { foreground: "#383a42" } },
|
||||||
|
{
|
||||||
|
scope: ["markup.heading"],
|
||||||
|
settings: { foreground: "#e45649", fontStyle: "bold" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: ["markup.bold"],
|
||||||
|
settings: { foreground: "#c18401", fontStyle: "bold" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: ["markup.italic"],
|
||||||
|
settings: { foreground: "#a626a4", fontStyle: "italic" },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
const oneDarkTheme: ThemeRegistration = {
|
||||||
|
name: "one-dark",
|
||||||
|
type: "dark",
|
||||||
|
colors: {
|
||||||
|
"editor.background": "#282c34",
|
||||||
|
"editor.foreground": "#abb2bf",
|
||||||
|
},
|
||||||
|
tokenColors: [
|
||||||
|
{
|
||||||
|
scope: ["comment", "punctuation.definition.comment"],
|
||||||
|
settings: { foreground: "#5c6370" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: ["keyword", "storage.type", "storage.modifier"],
|
||||||
|
settings: { foreground: "#c678dd" },
|
||||||
|
},
|
||||||
|
{ scope: ["string", "string.quoted"], settings: { foreground: "#98c379" } },
|
||||||
|
{
|
||||||
|
scope: ["function", "entity.name.function", "support.function"],
|
||||||
|
settings: { foreground: "#61afef" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: [
|
||||||
|
"constant.numeric",
|
||||||
|
"constant.language",
|
||||||
|
"constant.character",
|
||||||
|
"number",
|
||||||
|
],
|
||||||
|
settings: { foreground: "#d19a66" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: ["variable", "support.variable"],
|
||||||
|
settings: { foreground: "#e06c75" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: ["entity.name.tag", "entity.name.type", "entity.name.class"],
|
||||||
|
settings: { foreground: "#e06c75" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: ["entity.other.attribute-name"],
|
||||||
|
settings: { foreground: "#d19a66" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: ["keyword.operator", "operator"],
|
||||||
|
settings: { foreground: "#c678dd" },
|
||||||
|
},
|
||||||
|
{ scope: ["punctuation"], settings: { foreground: "#abb2bf" } },
|
||||||
|
{
|
||||||
|
scope: ["markup.heading"],
|
||||||
|
settings: { foreground: "#e06c75", fontStyle: "bold" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: ["markup.bold"],
|
||||||
|
settings: { foreground: "#d19a66", fontStyle: "bold" },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
scope: ["markup.italic"],
|
||||||
|
settings: { foreground: "#c678dd", fontStyle: "italic" },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
export let highlighter: Awaited<ReturnType<typeof createHighlighter>> | null =
|
||||||
|
null;
|
||||||
|
|
||||||
|
export const highlighterPromise = createHighlighter({
|
||||||
|
themes: [oneLightTheme, oneDarkTheme],
|
||||||
|
langs: [
|
||||||
|
"javascript",
|
||||||
|
"typescript",
|
||||||
|
"python",
|
||||||
|
"bash",
|
||||||
|
"shell",
|
||||||
|
"json",
|
||||||
|
"html",
|
||||||
|
"css",
|
||||||
|
"tsx",
|
||||||
|
"jsx",
|
||||||
|
"go",
|
||||||
|
"rust",
|
||||||
|
"java",
|
||||||
|
"c",
|
||||||
|
"cpp",
|
||||||
|
"sql",
|
||||||
|
"yaml",
|
||||||
|
"markdown",
|
||||||
|
],
|
||||||
|
}).then((h) => {
|
||||||
|
highlighter = h;
|
||||||
|
return h;
|
||||||
|
});
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
import { remark } from "remark";
|
|
||||||
import remarkStringify from "remark-stringify";
|
|
||||||
import remarkStreamingMarkdown from "./remarkStreamingMarkdown";
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Process markdown content for streaming display using the remark plugin.
|
|
||||||
* This is primarily used for testing the remark plugin with string inputs/outputs.
|
|
||||||
*/
|
|
||||||
export function processStreamingMarkdown(content: string): string {
|
|
||||||
if (!content) return content;
|
|
||||||
|
|
||||||
const result = remark()
|
|
||||||
.use(remarkStreamingMarkdown, { debug: false })
|
|
||||||
.use(remarkStringify)
|
|
||||||
.processSync(content);
|
|
||||||
|
|
||||||
// remove trailing newline to keep tests cleaner
|
|
||||||
let output = result.toString();
|
|
||||||
if (output.endsWith("\n")) {
|
|
||||||
output = output.slice(0, -1);
|
|
||||||
}
|
|
||||||
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
@@ -1,447 +0,0 @@
|
|||||||
import { parents, type Proxy } from "unist-util-parents";
|
|
||||||
import type { Plugin } from "unified";
|
|
||||||
import type {
|
|
||||||
Emphasis,
|
|
||||||
Node,
|
|
||||||
Parent,
|
|
||||||
Root,
|
|
||||||
RootContent,
|
|
||||||
Text,
|
|
||||||
Strong,
|
|
||||||
PhrasingContent,
|
|
||||||
Paragraph,
|
|
||||||
} from "mdast";
|
|
||||||
import { u } from "unist-builder";
|
|
||||||
|
|
||||||
declare module "unist" {
|
|
||||||
interface Node {
|
|
||||||
/** Added by `unist-util-parents` (or your own walk). */
|
|
||||||
parent?: Proxy & Parent;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// interface SimpleTextRule {
|
|
||||||
// pattern: RegExp;
|
|
||||||
// transform: (matches: RegExpExecArray[], lastNode: Proxy) => void;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// const simpleTextRules: SimpleTextRule[] = [
|
|
||||||
// // TODO(drifkin): generalize this for `__`/`_`/`~~`/`~` etc.
|
|
||||||
// {
|
|
||||||
// pattern: /(\*\*)(?=\S|$)/g,
|
|
||||||
// transform: (matchesIterator, lastNode) => {
|
|
||||||
// const textNode = lastNode.node as Text;
|
|
||||||
|
|
||||||
// const matches = [...matchesIterator];
|
|
||||||
// const lastMatch = matches[matches.length - 1];
|
|
||||||
// const origValue = textNode.value;
|
|
||||||
// const start = lastMatch.index;
|
|
||||||
// const sep = lastMatch[1];
|
|
||||||
|
|
||||||
// const before = origValue.slice(0, start);
|
|
||||||
// const after = origValue.slice(start + sep.length);
|
|
||||||
|
|
||||||
// if (lastNode.parent) {
|
|
||||||
// const index = (lastNode.parent.node as Parent).children.indexOf(
|
|
||||||
// lastNode.node as RootContent,
|
|
||||||
// );
|
|
||||||
// const shouldRemove = before.length === 0;
|
|
||||||
// if (!shouldRemove) {
|
|
||||||
// textNode.value = before;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// const newNode = u("strong", {
|
|
||||||
// children: [u("text", { value: after })],
|
|
||||||
// });
|
|
||||||
// (lastNode.parent.node as Parent).children.splice(
|
|
||||||
// index + (shouldRemove ? 0 : 1),
|
|
||||||
// shouldRemove ? 1 : 0,
|
|
||||||
// newNode,
|
|
||||||
// );
|
|
||||||
// }
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// ];
|
|
||||||
|
|
||||||
interface Options {
|
|
||||||
debug?: boolean;
|
|
||||||
onLastNode?: (info: LastNodeInfo) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface LastNodeInfo {
|
|
||||||
path: string[];
|
|
||||||
type: string;
|
|
||||||
value?: string;
|
|
||||||
lastChars?: string;
|
|
||||||
fullNode: Node;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Removes `child` from `parent` in-place.
|
|
||||||
* @returns `true` if the child was found and removed; `false` otherwise.
|
|
||||||
*/
|
|
||||||
export function removeChildFromParent(
|
|
||||||
child: RootContent,
|
|
||||||
parent: Node,
|
|
||||||
): boolean {
|
|
||||||
if (!isParent(parent)) return false; // parent isn’t a Parent → nothing to do
|
|
||||||
|
|
||||||
const idx = parent.children.indexOf(child);
|
|
||||||
if (idx < 0) return false; // not a child → nothing to remove
|
|
||||||
|
|
||||||
parent.children.splice(idx, 1);
|
|
||||||
return true; // removal successful
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Narrow a generic `Node` to a `Parent` (i.e. one that really has children). */
|
|
||||||
function isParent(node: Node): node is Parent {
|
|
||||||
// A `Parent` always has a `children` array; make sure it's an array first.
|
|
||||||
return Array.isArray((node as Partial<Parent>).children);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Follow “last-child” pointers until you reach a leaf.
|
|
||||||
* Returns the right-most, deepest node in source order.
|
|
||||||
*/
|
|
||||||
export function findRightmostDeepestNode(root: Node): Node {
|
|
||||||
let current: Node = root;
|
|
||||||
|
|
||||||
// While the current node *is* a Parent and has at least one child…
|
|
||||||
while (isParent(current) && current.children.length > 0) {
|
|
||||||
const lastIndex = current.children.length - 1;
|
|
||||||
current = current.children[lastIndex];
|
|
||||||
}
|
|
||||||
|
|
||||||
return current; // Leaf: no further children
|
|
||||||
}
|
|
||||||
|
|
||||||
const remarkStreamingMarkdown: Plugin<[Options?], Root> = () => {
|
|
||||||
return (tree) => {
|
|
||||||
const treeWithParents = parents(tree);
|
|
||||||
const lastNode = findRightmostDeepestNode(treeWithParents) as Proxy;
|
|
||||||
|
|
||||||
const parentNode = lastNode.parent;
|
|
||||||
const grandparentNode = parentNode?.parent;
|
|
||||||
|
|
||||||
let ruleMatched = false;
|
|
||||||
|
|
||||||
// handling `* *` -> ``
|
|
||||||
//
|
|
||||||
// if the last node is part of a <list item (otherwise empty)> ->
|
|
||||||
// <list (otherwise empty)> -> <list item (last node, empty)>, then we need to
|
|
||||||
// remove everything up to and including the first list item. This happens
|
|
||||||
// when we have `* *`, which can become a bolded list item OR a horizontal
|
|
||||||
// line
|
|
||||||
if (
|
|
||||||
lastNode.type === "listItem" &&
|
|
||||||
parentNode &&
|
|
||||||
grandparentNode &&
|
|
||||||
parentNode.type === "list" &&
|
|
||||||
grandparentNode.type === "listItem" &&
|
|
||||||
parentNode.children.length === 1 &&
|
|
||||||
grandparentNode.children.length === 1
|
|
||||||
) {
|
|
||||||
ruleMatched = true;
|
|
||||||
if (grandparentNode.parent) {
|
|
||||||
removeChildFromParent(
|
|
||||||
grandparentNode.node as RootContent,
|
|
||||||
grandparentNode.parent.node,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
// Handle `*` -> ``:
|
|
||||||
//
|
|
||||||
// if the last node is just an empty list item, we need to remove it
|
|
||||||
// because it could become something else (e.g., a horizontal line)
|
|
||||||
} else if (
|
|
||||||
lastNode.type === "listItem" &&
|
|
||||||
parentNode &&
|
|
||||||
parentNode.type === "list"
|
|
||||||
) {
|
|
||||||
ruleMatched = true;
|
|
||||||
removeChildFromParent(lastNode.node as RootContent, parentNode.node);
|
|
||||||
} else if (lastNode.type === "thematicBreak") {
|
|
||||||
ruleMatched = true;
|
|
||||||
const parent = lastNode.parent;
|
|
||||||
if (parent) {
|
|
||||||
removeChildFromParent(lastNode.node as RootContent, parent.node);
|
|
||||||
}
|
|
||||||
} else if (lastNode.type === "text") {
|
|
||||||
const textNode = lastNode.node as Text;
|
|
||||||
if (textNode.value.endsWith("**")) {
|
|
||||||
ruleMatched = true;
|
|
||||||
textNode.value = textNode.value.slice(0, -2);
|
|
||||||
// if there's a newline then a number, this is very very likely a
|
|
||||||
// numbered list item. Let's just hide it until the period comes (or
|
|
||||||
// other text disambiguates it)
|
|
||||||
} else {
|
|
||||||
const match = textNode.value.match(/^([0-9]+)$/m);
|
|
||||||
if (match) {
|
|
||||||
const number = match[1];
|
|
||||||
textNode.value = textNode.value.slice(0, -number.length - 1);
|
|
||||||
ruleMatched = true;
|
|
||||||
// if the text node is now empty, then we might want to remove other
|
|
||||||
// elements, like a now-empty containing paragraph, or a break that
|
|
||||||
// might disappear once more tokens come in
|
|
||||||
if (textNode.value.length === 0) {
|
|
||||||
if (
|
|
||||||
lastNode.parent?.type === "paragraph" &&
|
|
||||||
lastNode.parent.children.length === 1
|
|
||||||
) {
|
|
||||||
// remove the whole paragraph if it's now empty (otherwise it'll
|
|
||||||
// cause an extra newline that might not last)
|
|
||||||
removeChildFromParent(
|
|
||||||
lastNode.parent.node as Paragraph,
|
|
||||||
lastNode.parent.parent?.node as Node,
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
const prev = prevSibling(lastNode);
|
|
||||||
if (prev?.type === "break") {
|
|
||||||
removeChildFromParent(
|
|
||||||
prev.node as RootContent,
|
|
||||||
lastNode.parent?.node as Node,
|
|
||||||
);
|
|
||||||
removeChildFromParent(
|
|
||||||
lastNode.node as RootContent,
|
|
||||||
lastNode.parent?.node as Node,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ruleMatched) {
|
|
||||||
return tree;
|
|
||||||
}
|
|
||||||
|
|
||||||
// we need to
|
|
||||||
// a case like
|
|
||||||
// - *def `abc` [abc **def**](abc)*
|
|
||||||
// is pretty tricky, because if we land just after def, then we actually
|
|
||||||
// have two separate tags to process at two different parents. Maybe we
|
|
||||||
// need to keep iterating up until we find a paragraph, but process each
|
|
||||||
// parent on the way up. Hmm, well actually after `def` we won't even be a proper link yet
|
|
||||||
// TODO(drifkin): it's really if the last node's parent is a paragraph, for which the following is a sub-cas where the lastNode is a text node.
|
|
||||||
// And instead of just processing simple text rules, they need to operate on the whole paragraph
|
|
||||||
// like `**[abc](def)` needs to become `**[abc](def)**`
|
|
||||||
|
|
||||||
// if we're just text at the end, then we should remove some ambiguous characters
|
|
||||||
|
|
||||||
if (lastNode.parent) {
|
|
||||||
const didChange = processParent(lastNode.parent as Parent & Proxy);
|
|
||||||
if (didChange) {
|
|
||||||
// TODO(drifkin): need to fix up the tree, but not sure lastNode will still exist? Check all the transforms to see if it's safe to find the last node again
|
|
||||||
//
|
|
||||||
// need to regen the tree w/ parents since reparenting could've happened
|
|
||||||
// treeWithParents = parents(tree);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const grandparent = lastNode.parent?.parent;
|
|
||||||
// TODO(drifkin): let's go arbitrarily high up the tree, but limiting it
|
|
||||||
// to 2 levels for now until I think more about the stop condition
|
|
||||||
if (grandparent) {
|
|
||||||
processParent(grandparent as Parent & Proxy);
|
|
||||||
}
|
|
||||||
|
|
||||||
// console.log("ruleMatched", ruleMatched);
|
|
||||||
|
|
||||||
// } else if (lastNode.parent?.type === "paragraph") {
|
|
||||||
// console.log("!!! paragraph");
|
|
||||||
// console.log("lastNode.parent", lastNode.parent);
|
|
||||||
|
|
||||||
// // Handle `**abc*` -> `**abc**`:
|
|
||||||
// // We detect this when the last child is an emphasis node, and it's preceded by a text node that ends with `*`
|
|
||||||
// const paragraph = lastNode.parent as Proxy & Paragraph;
|
|
||||||
// if (paragraph.children.length >= 2) {
|
|
||||||
// const lastChild = paragraph.children[paragraph.children.length - 1];
|
|
||||||
// if (lastChild.type === "emphasis") {
|
|
||||||
// const sibling = paragraph.children[paragraph.children.length - 2];
|
|
||||||
// if (sibling.type === "text") {
|
|
||||||
// const siblingText = sibling as Text & Proxy;
|
|
||||||
// if (siblingText.value.endsWith("*")) {
|
|
||||||
// ruleMatched = true;
|
|
||||||
// const textNode = (lastNode as Proxy).node as Text;
|
|
||||||
// textNode.value = textNode.value.slice(0, -1);
|
|
||||||
// paragraph.node.type = "strong";
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// } else if (lastNode.type === "text") {
|
|
||||||
// // Handle `**abc*` -> `**abc**`:
|
|
||||||
// //
|
|
||||||
// // this gets parsed as a text node ending in `*` followed by an emphasis
|
|
||||||
// // node. So if we're in text, we need to check if our parent is emphasis,
|
|
||||||
// // and then get our parent's sibling before it and check if it ends with
|
|
||||||
// // `*`
|
|
||||||
// const parent = lastNode.parent;
|
|
||||||
// if (parent && parent.type === "emphasis") {
|
|
||||||
// const grandparent = parent.parent;
|
|
||||||
// if (grandparent) {
|
|
||||||
// const index = (grandparent.node as Parent).children.indexOf(
|
|
||||||
// parent.node as RootContent,
|
|
||||||
// );
|
|
||||||
// if (index > 0) {
|
|
||||||
// const prevNode = grandparent.children[index - 1];
|
|
||||||
// if (
|
|
||||||
// prevNode.type === "text" &&
|
|
||||||
// (prevNode as Text).value.endsWith("*")
|
|
||||||
// ) {
|
|
||||||
// ruleMatched = true;
|
|
||||||
// const textNode = (prevNode as Proxy).node as Text;
|
|
||||||
// textNode.value = textNode.value.slice(0, -1);
|
|
||||||
// parent.node.type = "strong";
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if (!ruleMatched) {
|
|
||||||
// // if the last node is just text, then we process it in order to fix up certain unclosed items
|
|
||||||
// // e.g., `**abc` -> `**abc**`
|
|
||||||
// const textNode = lastNode.node as Text;
|
|
||||||
// for (const rule of simpleTextRules) {
|
|
||||||
// const matchesIterator = textNode.value.matchAll(rule.pattern);
|
|
||||||
// const matches = [...matchesIterator];
|
|
||||||
// if (matches.length > 0) {
|
|
||||||
// rule.transform(matches, lastNode);
|
|
||||||
// ruleMatched = true;
|
|
||||||
// break;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// } else if (!ruleMatched) {
|
|
||||||
// // console.log("no rule matched", lastNode);
|
|
||||||
// }
|
|
||||||
|
|
||||||
return tree;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
function processParent(parent: Parent & Proxy): boolean {
|
|
||||||
if (parent.type === "emphasis") {
|
|
||||||
// Handle `**abc*` -> `**abc**`:
|
|
||||||
// We detect this when we end with an emphasis node, and it's preceded by
|
|
||||||
// a text node that ends with `*`
|
|
||||||
// TODO(drifkin): the last node can be more deeply nested (e.g., a code
|
|
||||||
// literal in a link), so we probably need to walk up the tree until we
|
|
||||||
// find an emphasis node or a block? For now we'll just go up one layer to
|
|
||||||
// catch the most common cases
|
|
||||||
const emphasisNode = parent as Emphasis & Proxy;
|
|
||||||
const grandparent = emphasisNode.parent;
|
|
||||||
if (grandparent) {
|
|
||||||
const indexOfEmphasisNode = (grandparent.node as Parent).children.indexOf(
|
|
||||||
emphasisNode.node as RootContent,
|
|
||||||
);
|
|
||||||
if (indexOfEmphasisNode >= 0) {
|
|
||||||
const nodeBefore = grandparent.children[indexOfEmphasisNode - 1] as
|
|
||||||
| (Node & Proxy)
|
|
||||||
| undefined;
|
|
||||||
if (nodeBefore?.type === "text") {
|
|
||||||
const textNode = nodeBefore.node as Text;
|
|
||||||
if (textNode.value.endsWith("*")) {
|
|
||||||
const strBefore = textNode.value.slice(0, -1);
|
|
||||||
textNode.value = strBefore;
|
|
||||||
const strongNode = u("strong", {
|
|
||||||
children: emphasisNode.children,
|
|
||||||
});
|
|
||||||
(grandparent.node as Parent).children.splice(
|
|
||||||
indexOfEmphasisNode,
|
|
||||||
1,
|
|
||||||
strongNode,
|
|
||||||
);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Let's check if we have any bold items to close
|
|
||||||
for (let i = parent.children.length - 1; i >= 0; i--) {
|
|
||||||
const child = parent.children[i];
|
|
||||||
if (child.type === "text") {
|
|
||||||
const textNode = child as Text & Proxy;
|
|
||||||
const sep = "**";
|
|
||||||
const index = textNode.value.lastIndexOf(sep);
|
|
||||||
if (index >= 0) {
|
|
||||||
let isValidOpening = false;
|
|
||||||
if (index + sep.length < textNode.value.length) {
|
|
||||||
const charAfter = textNode.value[index + sep.length];
|
|
||||||
if (!isWhitespace(charAfter)) {
|
|
||||||
isValidOpening = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (i < parent.children.length - 1) {
|
|
||||||
// TODO(drifkin): I'm not sure that this check is strict enough.
|
|
||||||
// We're trying to detect cases like `**[abc]()` where the char
|
|
||||||
// after the opening ** is indeed a non-whitespace character. We're
|
|
||||||
// using the heuristic that there's another item after the current
|
|
||||||
// one, but I'm not sure if that is good enough. In a well
|
|
||||||
// constructed tree, there aren't two text nodes in a row, so this
|
|
||||||
// _seems_ good, but I should think through it more
|
|
||||||
isValidOpening = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isValidOpening) {
|
|
||||||
// TODO(drifkin): close the bold
|
|
||||||
const strBefore = textNode.value.slice(0, index);
|
|
||||||
const strAfter = textNode.value.slice(index + sep.length);
|
|
||||||
(textNode.node as Text).value = strBefore;
|
|
||||||
// TODO(drifkin): the node above could be empty in which case we probably want to delete it
|
|
||||||
const children: PhrasingContent[] = [
|
|
||||||
...(strAfter.length > 0 ? [u("text", { value: strAfter })] : []),
|
|
||||||
];
|
|
||||||
const strongNode: Strong = u("strong", {
|
|
||||||
children,
|
|
||||||
});
|
|
||||||
const nodesAfter = (parent.node as Parent).children.splice(
|
|
||||||
i + 1,
|
|
||||||
parent.children.length - i - 1,
|
|
||||||
strongNode,
|
|
||||||
);
|
|
||||||
// TODO(drifkin): this cast seems iffy, should see if we can cast the
|
|
||||||
// parent instead, which would also help us check some of our
|
|
||||||
// assumptions
|
|
||||||
strongNode.children.push(...(nodesAfter as PhrasingContent[]));
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
function prevSibling(node: Node & Proxy): (Node & Proxy) | null {
|
|
||||||
const parent = node.parent;
|
|
||||||
if (parent) {
|
|
||||||
const index = parent.children.indexOf(node);
|
|
||||||
return parent.children[index - 1] as Node & Proxy;
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
function isWhitespace(str: string) {
|
|
||||||
return str.trim() === "";
|
|
||||||
}
|
|
||||||
|
|
||||||
// function debugPrintTreeNoPos(tree: Node) {
|
|
||||||
// console.log(
|
|
||||||
// JSON.stringify(
|
|
||||||
// tree,
|
|
||||||
// (key, value) => {
|
|
||||||
// if (key === "position") {
|
|
||||||
// return undefined;
|
|
||||||
// }
|
|
||||||
// return value;
|
|
||||||
// },
|
|
||||||
// 2,
|
|
||||||
// ),
|
|
||||||
// );
|
|
||||||
// }
|
|
||||||
|
|
||||||
export default remarkStreamingMarkdown;
|
|
||||||
@@ -1794,13 +1794,14 @@ func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, ava
|
|||||||
|
|
||||||
var thinkValue *api.ThinkValue
|
var thinkValue *api.ThinkValue
|
||||||
if think != nil {
|
if think != nil {
|
||||||
|
// Only set Think if it's actually requesting thinking
|
||||||
if boolValue, ok := think.(bool); ok {
|
if boolValue, ok := think.(bool); ok {
|
||||||
thinkValue = &api.ThinkValue{
|
if boolValue {
|
||||||
Value: boolValue,
|
thinkValue = &api.ThinkValue{Value: boolValue}
|
||||||
}
|
}
|
||||||
} else if stringValue, ok := think.(string); ok {
|
} else if stringValue, ok := think.(string); ok {
|
||||||
thinkValue = &api.ThinkValue{
|
if stringValue != "" && stringValue != "none" {
|
||||||
Value: stringValue,
|
thinkValue = &api.ThinkValue{Value: stringValue}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -110,9 +110,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
|||||||
|
|
||||||
for name, mxfp4 := range mxfp4s {
|
for name, mxfp4 := range mxfp4s {
|
||||||
dims := mxfp4.blocks.Shape()
|
dims := mxfp4.blocks.Shape()
|
||||||
|
if !strings.HasSuffix(name, ".weight") {
|
||||||
|
name = name + ".weight"
|
||||||
|
}
|
||||||
if strings.Contains(name, "ffn_down_exps") {
|
if strings.Contains(name, "ffn_down_exps") {
|
||||||
out = append(out, &ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: name + ".weight",
|
Name: name,
|
||||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||||
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
|
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
|
||||||
WriterTo: mxfp4,
|
WriterTo: mxfp4,
|
||||||
@@ -121,12 +124,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
|||||||
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
|
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
|
||||||
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
|
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
|
||||||
out = append(out, &ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
|
Name: strings.Replace(name, "gate_up", "gate", 1),
|
||||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||||
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
||||||
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
|
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
|
||||||
}, &ggml.Tensor{
|
}, &ggml.Tensor{
|
||||||
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
|
Name: strings.Replace(name, "gate_up", "up", 1),
|
||||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||||
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
||||||
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),
|
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),
|
||||||
|
|||||||
@@ -2,10 +2,12 @@ package convert
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"iter"
|
"iter"
|
||||||
"path"
|
"path"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pdevine/tensor"
|
"github.com/pdevine/tensor"
|
||||||
@@ -94,6 +96,26 @@ func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []
|
|||||||
return matched
|
return matched
|
||||||
})
|
})
|
||||||
|
|
||||||
|
slices.SortStableFunc(matched, func(a, b Tensor) int {
|
||||||
|
x := strings.Split(a.Name(), ".")
|
||||||
|
y := strings.Split(b.Name(), ".")
|
||||||
|
if len(x) != len(y) {
|
||||||
|
return cmp.Compare(len(x), len(y))
|
||||||
|
}
|
||||||
|
|
||||||
|
vals := make([]int, len(x))
|
||||||
|
for i := range x {
|
||||||
|
vals[i] = strings.Compare(x[i], y[i])
|
||||||
|
m, err := strconv.ParseInt(x[i], 0, 0)
|
||||||
|
n, err2 := strconv.ParseInt(y[i], 0, 0)
|
||||||
|
if errors.Join(err, err2) == nil {
|
||||||
|
vals[i] = cmp.Compare(m, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cmp.Or(vals...)
|
||||||
|
})
|
||||||
|
|
||||||
if len(matched) > 0 {
|
if len(matched) > 0 {
|
||||||
out = append(out, &ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: merges[i].name,
|
Name: merges[i].name,
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ package convert
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"iter"
|
"iter"
|
||||||
|
"math/rand/v2"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -951,3 +953,45 @@ func TestMerge(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMergeOrder(t *testing.T) {
|
||||||
|
for range 8 {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
tensors := make([]Tensor, 16)
|
||||||
|
for i := range tensors {
|
||||||
|
tensors[i] = &fakeTensor{
|
||||||
|
name: fmt.Sprintf("layer.%d.weight", i),
|
||||||
|
shape: []uint64{1},
|
||||||
|
data: []float32{float32(i)},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rand.Shuffle(len(tensors), func(i, j int) {
|
||||||
|
tensors[i], tensors[j] = tensors[j], tensors[i]
|
||||||
|
})
|
||||||
|
|
||||||
|
matched, unmatched := mergeTensors(tensors, merge{"layer.*.weight", "layer.weight"})
|
||||||
|
if len(unmatched) != 0 {
|
||||||
|
t.Error("expected no remaining tensors, got", len(unmatched))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matched) != 1 {
|
||||||
|
t.Error("expected 1 merged tensor, got", len(matched))
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := matched[0].WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var f32s [16]float32
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.IsSorted(f32s[:]) {
|
||||||
|
t.Errorf("merged tensor data is not in order: %+v", f32s)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -94,6 +94,9 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
|||||||
continue
|
continue
|
||||||
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
|
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
|
||||||
continue
|
continue
|
||||||
|
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
|
||||||
|
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
dirs = []string{ml.LibOllamaPath, dir}
|
dirs = []string{ml.LibOllamaPath, dir}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -13,9 +13,23 @@ Embeddings turn text into numeric vectors you can store in a vector database, se
|
|||||||
|
|
||||||
## Generate embeddings
|
## Generate embeddings
|
||||||
|
|
||||||
Use `/api/embed` with a single string.
|
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
<Tab title="CLI">
|
||||||
|
Generate embeddings directly from the command line:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama run embeddinggemma "Hello world"
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also pipe text to generate embeddings:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
echo "Hello world" | ollama run embeddinggemma
|
||||||
|
```
|
||||||
|
|
||||||
|
Output is a JSON array.
|
||||||
|
|
||||||
|
</Tab>
|
||||||
<Tab title="cURL">
|
<Tab title="cURL">
|
||||||
```shell
|
```shell
|
||||||
curl -X POST http://localhost:11434/api/embed \
|
curl -X POST http://localhost:11434/api/embed \
|
||||||
|
|||||||
@@ -68,6 +68,15 @@ To run Ollama using Docker with AMD GPUs, use the `rocm` tag and the following c
|
|||||||
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama:rocm
|
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama:rocm
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Vulkan Support
|
||||||
|
|
||||||
|
Vulkan is bundled into the `ollama/ollama` image.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 -e OLLAMA_VULKAN=1 --name ollama ollama/ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Run model locally
|
## Run model locally
|
||||||
|
|
||||||
Now you can run a model:
|
Now you can run a model:
|
||||||
@@ -79,3 +88,4 @@ docker exec -it ollama ollama run llama3.2
|
|||||||
## Try different models
|
## Try different models
|
||||||
|
|
||||||
More models can be found on the [Ollama library](https://ollama.com/library).
|
More models can be found on the [Ollama library](https://ollama.com/library).
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,10 @@
|
|||||||
{
|
{
|
||||||
"source": "/api/openai",
|
"source": "/api/openai",
|
||||||
"destination": "/api/openai-compatibility"
|
"destination": "/api/openai-compatibility"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": "/api",
|
||||||
|
"destination": "/api/introduction"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"navigation": {
|
"navigation": {
|
||||||
@@ -130,7 +134,7 @@
|
|||||||
{
|
{
|
||||||
"group": "API Reference",
|
"group": "API Reference",
|
||||||
"pages": [
|
"pages": [
|
||||||
"/api/index",
|
"/api/introduction",
|
||||||
"/api/authentication",
|
"/api/authentication",
|
||||||
"/api/streaming",
|
"/api/streaming",
|
||||||
"/api/usage",
|
"/api/usage",
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
|
|||||||
|
|
||||||
## How can I use Ollama in Visual Studio Code?
|
## How can I use Ollama in Visual Studio Code?
|
||||||
|
|
||||||
There is already a large collection of plugins available for VSCode as well as other editors that leverage Ollama. See the list of [extensions & plugins](https://github.com/ollama/ollama#extensions--plugins) at the bottom of the main repository readme.
|
There is already a large collection of plugins available for VS Code as well as other editors that leverage Ollama. See the list of [extensions & plugins](https://github.com/ollama/ollama#extensions--plugins) at the bottom of the main repository readme.
|
||||||
|
|
||||||
## How do I use Ollama with GPU acceleration in Docker?
|
## How do I use Ollama with GPU acceleration in Docker?
|
||||||
|
|
||||||
|
|||||||
44
docs/gpu.mdx
@@ -52,7 +52,11 @@ sudo modprobe nvidia_uvm`
|
|||||||
|
|
||||||
## AMD Radeon
|
## AMD Radeon
|
||||||
|
|
||||||
Ollama supports the following AMD GPUs:
|
Ollama supports the following AMD GPUs via the ROCm library:
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Additional AMD GPU support is provided by the Vulkan Library - see below.
|
||||||
|
|
||||||
|
|
||||||
### Linux Support
|
### Linux Support
|
||||||
|
|
||||||
@@ -121,6 +125,42 @@ In some Linux distributions, SELinux can prevent containers from
|
|||||||
accessing the AMD GPU devices. On the host system you can run
|
accessing the AMD GPU devices. On the host system you can run
|
||||||
`sudo setsebool container_use_devices=1` to allow containers to use devices.
|
`sudo setsebool container_use_devices=1` to allow containers to use devices.
|
||||||
|
|
||||||
### Metal (Apple GPUs)
|
## Metal (Apple GPUs)
|
||||||
|
|
||||||
Ollama supports GPU acceleration on Apple devices via the Metal API.
|
Ollama supports GPU acceleration on Apple devices via the Metal API.
|
||||||
|
|
||||||
|
|
||||||
|
## Vulkan GPU Support
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
|
||||||
|
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
|
||||||
|
|
||||||
|
Additional GPU support on Windows and Linux is provided via
|
||||||
|
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
|
||||||
|
bundled with Vulkan support and require no additional setup steps. Most Linux
|
||||||
|
distributions require installing additional components, and you may have
|
||||||
|
multiple options for Vulkan drivers between Mesa and GPU Vendor specific packages
|
||||||
|
|
||||||
|
- Linux Intel GPU Instructions - https://dgpu-docs.intel.com/driver/client/overview.html
|
||||||
|
- Linux AMD GPU Instructions - https://amdgpu-install.readthedocs.io/en/latest/install-script.html#specifying-a-vulkan-implementation
|
||||||
|
|
||||||
|
For AMD GPUs on some Linux distributions, you may need to add the `ollama` user to the `render` group.
|
||||||
|
|
||||||
|
The Ollama scheduler leverages available VRAM data reported by the GPU libraries to
|
||||||
|
make optimal scheduling decisions. Vulkan requires additional capabilities or
|
||||||
|
running as root to expose this available VRAM data. If neither root access or this
|
||||||
|
capability are granted, Ollama will use approximate sizes of the models
|
||||||
|
to make best effort scheduling decisions.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo setcap cap_perfmon+ep /usr/local/bin/ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
### GPU Selection
|
||||||
|
|
||||||
|
To select specific Vulkan GPU(s), you can set the environment variable
|
||||||
|
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
|
||||||
|
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
|
||||||
|
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
|
||||||
|
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
||||||
@@ -4,7 +4,7 @@ title: VS Code
|
|||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
Install [VSCode](https://code.visualstudio.com/download).
|
Install [VS Code](https://code.visualstudio.com/download).
|
||||||
|
|
||||||
## Usage with Ollama
|
## Usage with Ollama
|
||||||
|
|
||||||
@@ -12,7 +12,7 @@ Install [VSCode](https://code.visualstudio.com/download).
|
|||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||||
<img
|
<img
|
||||||
src="/images/vscode-sidebar.png"
|
src="/images/vscode-sidebar.png"
|
||||||
alt="VSCode chat Sidebar"
|
alt="VS Code chat Sidebar"
|
||||||
width="75%"
|
width="75%"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
@@ -20,7 +20,7 @@ Install [VSCode](https://code.visualstudio.com/download).
|
|||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||||
<img
|
<img
|
||||||
src="/images/vscode-models.png"
|
src="/images/vscode-models.png"
|
||||||
alt="VSCode model picker"
|
alt="VS Code model picker"
|
||||||
width="75%"
|
width="75%"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
@@ -28,7 +28,7 @@ Install [VSCode](https://code.visualstudio.com/download).
|
|||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
||||||
<img
|
<img
|
||||||
src="/images/vscode-model-options.png"
|
src="/images/vscode-model-options.png"
|
||||||
alt="VSCode model options dropdown"
|
alt="VS Code model options dropdown"
|
||||||
width="75%"
|
width="75%"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -2,12 +2,15 @@ openapi: 3.1.0
|
|||||||
info:
|
info:
|
||||||
title: Ollama API
|
title: Ollama API
|
||||||
version: 0.1.0
|
version: 0.1.0
|
||||||
|
license:
|
||||||
|
name: MIT
|
||||||
|
url: https://opensource.org/licenses/MIT
|
||||||
description: |
|
description: |
|
||||||
OpenAPI specification for the Ollama HTTP API
|
OpenAPI specification for the Ollama HTTP API
|
||||||
|
|
||||||
servers:
|
servers:
|
||||||
- url: http://localhost:11434
|
- url: http://localhost:11434
|
||||||
description: Local Ollama instance
|
description: Ollama
|
||||||
|
security: []
|
||||||
components:
|
components:
|
||||||
securitySchemes:
|
securitySchemes:
|
||||||
bearerAuth:
|
bearerAuth:
|
||||||
@@ -93,8 +96,11 @@ components:
|
|||||||
type: boolean
|
type: boolean
|
||||||
default: true
|
default: true
|
||||||
think:
|
think:
|
||||||
type: boolean
|
oneOf:
|
||||||
description: When true, returns separate thinking output in addition to content
|
- type: boolean
|
||||||
|
- type: string
|
||||||
|
enum: [high, medium, low]
|
||||||
|
description: When true, returns separate thinking output in addition to content. Can be a boolean (true/false) or a string ("high", "medium", "low") for supported models.
|
||||||
raw:
|
raw:
|
||||||
type: boolean
|
type: boolean
|
||||||
description: When true, returns the raw response from the model without any prompt templating
|
description: When true, returns the raw response from the model without any prompt templating
|
||||||
@@ -271,8 +277,11 @@ components:
|
|||||||
type: boolean
|
type: boolean
|
||||||
default: true
|
default: true
|
||||||
think:
|
think:
|
||||||
type: boolean
|
oneOf:
|
||||||
description: When true, returns separate thinking output in addition to content
|
- type: boolean
|
||||||
|
- type: string
|
||||||
|
enum: [high, medium, low]
|
||||||
|
description: When true, returns separate thinking output in addition to content. Can be a boolean (true/false) or a string ("high", "medium", "low") for supported models.
|
||||||
keep_alive:
|
keep_alive:
|
||||||
oneOf:
|
oneOf:
|
||||||
- type: string
|
- type: string
|
||||||
@@ -310,7 +319,6 @@ components:
|
|||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
nullable: true
|
|
||||||
description: Optional base64-encoded images in the response
|
description: Optional base64-encoded images in the response
|
||||||
done:
|
done:
|
||||||
type: boolean
|
type: boolean
|
||||||
@@ -367,7 +375,6 @@ components:
|
|||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
nullable: true
|
|
||||||
description: Partial base64-encoded images, when present
|
description: Partial base64-encoded images, when present
|
||||||
done:
|
done:
|
||||||
type: boolean
|
type: boolean
|
||||||
@@ -543,6 +550,9 @@ components:
|
|||||||
license:
|
license:
|
||||||
type: string
|
type: string
|
||||||
description: The license of the model
|
description: The license of the model
|
||||||
|
modified_at:
|
||||||
|
type: string
|
||||||
|
description: Last modified timestamp in ISO 8601 format
|
||||||
details:
|
details:
|
||||||
type: object
|
type: object
|
||||||
description: High-level model details
|
description: High-level model details
|
||||||
@@ -622,6 +632,9 @@ components:
|
|||||||
size_vram:
|
size_vram:
|
||||||
type: integer
|
type: integer
|
||||||
description: VRAM usage in bytes
|
description: VRAM usage in bytes
|
||||||
|
context_length:
|
||||||
|
type: integer
|
||||||
|
description: Context length for the running model
|
||||||
PsResponse:
|
PsResponse:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -1275,6 +1288,9 @@ paths:
|
|||||||
example:
|
example:
|
||||||
source: gemma3
|
source: gemma3
|
||||||
destination: gemma3-backup
|
destination: gemma3-backup
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Model successfully copied
|
||||||
/api/pull:
|
/api/pull:
|
||||||
post:
|
post:
|
||||||
summary: Pull a model
|
summary: Pull a model
|
||||||
@@ -1382,16 +1398,7 @@ paths:
|
|||||||
model: gemma3
|
model: gemma3
|
||||||
responses:
|
responses:
|
||||||
"200":
|
"200":
|
||||||
description: Deletion status updates.
|
description: Model successfully deleted
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: "#/components/schemas/StatusResponse"
|
|
||||||
example:
|
|
||||||
status: "success"
|
|
||||||
application/x-ndjson:
|
|
||||||
schema:
|
|
||||||
$ref: "#/components/schemas/StatusEvent"
|
|
||||||
/api/version:
|
/api/version:
|
||||||
get:
|
get:
|
||||||
summary: Get version
|
summary: Get version
|
||||||
|
|||||||
@@ -196,8 +196,6 @@ var (
|
|||||||
NoPrune = Bool("OLLAMA_NOPRUNE")
|
NoPrune = Bool("OLLAMA_NOPRUNE")
|
||||||
// SchedSpread allows scheduling models across all GPUs.
|
// SchedSpread allows scheduling models across all GPUs.
|
||||||
SchedSpread = Bool("OLLAMA_SCHED_SPREAD")
|
SchedSpread = Bool("OLLAMA_SCHED_SPREAD")
|
||||||
// IntelGPU enables experimental Intel GPU detection.
|
|
||||||
IntelGPU = Bool("OLLAMA_INTEL_GPU")
|
|
||||||
// MultiUserCache optimizes prompt caching for multi-user scenarios
|
// MultiUserCache optimizes prompt caching for multi-user scenarios
|
||||||
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
|
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
|
||||||
// Enable the new Ollama engine
|
// Enable the new Ollama engine
|
||||||
@@ -206,6 +204,8 @@ var (
|
|||||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||||
// Auth enables authentication between the Ollama client and server
|
// Auth enables authentication between the Ollama client and server
|
||||||
UseAuth = Bool("OLLAMA_AUTH")
|
UseAuth = Bool("OLLAMA_AUTH")
|
||||||
|
// Enable Vulkan backend
|
||||||
|
EnableVulkan = Bool("OLLAMA_VULKAN")
|
||||||
)
|
)
|
||||||
|
|
||||||
func String(s string) func() string {
|
func String(s string) func() string {
|
||||||
@@ -314,7 +314,7 @@ func AsMap() map[string]EnvVar {
|
|||||||
ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which Vulkan devices are visible by numeric ID"}
|
ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which Vulkan devices are visible by numeric ID"}
|
||||||
ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"}
|
ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"}
|
||||||
ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"}
|
ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"}
|
||||||
ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"}
|
ret["OLLAMA_VULKAN"] = EnvVar{"OLLAMA_VULKAN", EnableVulkan(), "Enable experimental Vulkan support"}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
@@ -797,73 +797,6 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
|
||||||
if llm.KV().Uint("vision.block_count") == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, layer := range llm.Tensors().GroupLayers() {
|
|
||||||
if name == "v" || strings.HasPrefix(name, "v.") {
|
|
||||||
for _, tensor := range layer {
|
|
||||||
weights += tensor.Size()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
imageSize := uint64(llm.KV().Uint("vision.image_size"))
|
|
||||||
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
|
|
||||||
if patchSize == 0 {
|
|
||||||
slog.Warn("unknown patch size for vision model")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
|
|
||||||
|
|
||||||
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
|
|
||||||
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
|
||||||
numPatches++
|
|
||||||
}
|
|
||||||
|
|
||||||
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
|
|
||||||
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
|
|
||||||
|
|
||||||
switch llm.KV().Architecture() {
|
|
||||||
case "mllama":
|
|
||||||
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
|
||||||
|
|
||||||
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
|
|
||||||
|
|
||||||
graphSize = 4 * (8 +
|
|
||||||
imageSize*imageSize*numChannels*maxNumTiles +
|
|
||||||
embeddingLength*numPatches*maxNumTiles +
|
|
||||||
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
|
||||||
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
|
||||||
case "gemma3", "mistral3":
|
|
||||||
graphSize = 4 * (imageSize*imageSize*numChannels +
|
|
||||||
embeddingLength*patchSize +
|
|
||||||
numPatches*numPatches*headCount)
|
|
||||||
case "qwen25vl":
|
|
||||||
maxPixels := uint64(llm.KV().Uint("vision.max_pixels", 28*28*1280))
|
|
||||||
|
|
||||||
numPatches := maxPixels / (patchSize * patchSize)
|
|
||||||
|
|
||||||
graphSize = 4 * (maxPixels*numChannels + // Original image storage
|
|
||||||
// Normalized pixels
|
|
||||||
maxPixels*numChannels +
|
|
||||||
// Patches storage (numPatches * channels * patchSize^2)
|
|
||||||
numPatches*numChannels*patchSize*patchSize +
|
|
||||||
// Self-attention calculations
|
|
||||||
numPatches*numPatches*headCount +
|
|
||||||
// Additional buffer for processing
|
|
||||||
embeddingLength*numPatches)
|
|
||||||
case "llama4":
|
|
||||||
// vision graph is computed independently in the same schedule
|
|
||||||
// and is negligible compared to the worst case text graph
|
|
||||||
}
|
|
||||||
|
|
||||||
return weights, graphSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// SupportsKVCacheType checks if the requested cache type is supported
|
// SupportsKVCacheType checks if the requested cache type is supported
|
||||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||||
if cacheType == "" || cacheType == "f16" {
|
if cacheType == "" || cacheType == "f16" {
|
||||||
|
|||||||
@@ -14,6 +14,23 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func assertBytesMatchToken(t *testing.T, label, token string, ints []int) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
raw := []byte(token)
|
||||||
|
if len(ints) != len(raw) {
|
||||||
|
t.Errorf("%s expected %d bytes for token %q, got %d (%v)", label, len(raw), token, len(ints), ints)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, b := range raw {
|
||||||
|
if ints[i] != int(b) {
|
||||||
|
t.Errorf("%s byte[%d] mismatch for token %q: got %d want %d", label, i, token, ints[i], int(b))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAPIGenerate(t *testing.T) {
|
func TestAPIGenerate(t *testing.T) {
|
||||||
initialTimeout := 60 * time.Second
|
initialTimeout := 60 * time.Second
|
||||||
streamTimeout := 30 * time.Second
|
streamTimeout := 30 * time.Second
|
||||||
@@ -381,3 +398,182 @@ func TestAPIShowModel(t *testing.T) {
|
|||||||
t.Errorf("%s missing modified_at: %#v", modelName, resp)
|
t.Errorf("%s missing modified_at: %#v", modelName, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAPIGenerateLogprobs(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
if err := PullIfMissing(ctx, client, smol); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
enableLogprobs := true
|
||||||
|
noStream := false
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
logprobs *bool
|
||||||
|
topLogprobs int
|
||||||
|
expectCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no_logprobs",
|
||||||
|
logprobs: nil,
|
||||||
|
topLogprobs: 0,
|
||||||
|
expectCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "logprobs_only",
|
||||||
|
logprobs: &enableLogprobs,
|
||||||
|
topLogprobs: 0,
|
||||||
|
expectCount: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "logprobs_with_top_5",
|
||||||
|
logprobs: &enableLogprobs,
|
||||||
|
topLogprobs: 5,
|
||||||
|
expectCount: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: smol,
|
||||||
|
Prompt: "Why is the sky blue?",
|
||||||
|
Stream: &noStream,
|
||||||
|
Logprobs: test.logprobs != nil && *test.logprobs,
|
||||||
|
TopLogprobs: test.topLogprobs,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
"num_predict": 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var response api.GenerateResponse
|
||||||
|
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
|
||||||
|
if resp.Done {
|
||||||
|
response = resp
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generate failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check logprobs based on expectation
|
||||||
|
if test.expectCount == 0 {
|
||||||
|
if len(response.Logprobs) > 0 {
|
||||||
|
t.Errorf("expected no logprobs but got %d", len(response.Logprobs))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if len(response.Logprobs) == 0 {
|
||||||
|
t.Errorf("expected logprobs but got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate each logprob entry
|
||||||
|
for i, lp := range response.Logprobs {
|
||||||
|
if lp.Token == "" {
|
||||||
|
t.Errorf("logprob[%d] has empty token", i)
|
||||||
|
}
|
||||||
|
if lp.Logprob > 0 {
|
||||||
|
t.Errorf("logprob[%d] has positive logprob %f (should be <= 0)", i, lp.Logprob)
|
||||||
|
}
|
||||||
|
assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d]", i), lp.Token, lp.Bytes)
|
||||||
|
|
||||||
|
// Check top_logprobs if requested
|
||||||
|
if test.topLogprobs > 0 {
|
||||||
|
if len(lp.TopLogprobs) == 0 {
|
||||||
|
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
|
||||||
|
}
|
||||||
|
if len(lp.TopLogprobs) > test.topLogprobs {
|
||||||
|
t.Errorf("logprob[%d] has %d top_logprobs, expected max %d", i, len(lp.TopLogprobs), test.topLogprobs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify top_logprobs are sorted by probability (descending)
|
||||||
|
for j := 1; j < len(lp.TopLogprobs); j++ {
|
||||||
|
if lp.TopLogprobs[j-1].Logprob < lp.TopLogprobs[j].Logprob {
|
||||||
|
t.Errorf("logprob[%d].top_logprobs not sorted: %f < %f", i, lp.TopLogprobs[j-1].Logprob, lp.TopLogprobs[j].Logprob)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j, top := range lp.TopLogprobs {
|
||||||
|
assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d].top[%d]", i, j), top.Token, top.Bytes)
|
||||||
|
}
|
||||||
|
} else if len(lp.TopLogprobs) > 0 {
|
||||||
|
t.Errorf("logprob[%d] has top_logprobs but none were requested", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIChatLogprobs(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
if err := PullIfMissing(ctx, client, smol); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
enableLogprobs := true
|
||||||
|
noStream := false
|
||||||
|
|
||||||
|
req := api.ChatRequest{
|
||||||
|
Model: smol,
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Say hello in one word"},
|
||||||
|
},
|
||||||
|
Stream: &noStream,
|
||||||
|
Logprobs: enableLogprobs,
|
||||||
|
TopLogprobs: 3,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
"num_predict": 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var response api.ChatResponse
|
||||||
|
err := client.Chat(ctx, &req, func(resp api.ChatResponse) error {
|
||||||
|
if resp.Done {
|
||||||
|
response = resp
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("chat failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response.Logprobs) == 0 {
|
||||||
|
t.Fatal("expected logprobs in response but got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("received %d logprobs for chat response", len(response.Logprobs))
|
||||||
|
|
||||||
|
for i, lp := range response.Logprobs {
|
||||||
|
if lp.Token == "" {
|
||||||
|
t.Errorf("logprob[%d] has empty token", i)
|
||||||
|
}
|
||||||
|
if lp.Logprob > 0 {
|
||||||
|
t.Errorf("logprob[%d] has positive logprob %f", i, lp.Logprob)
|
||||||
|
}
|
||||||
|
assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d]", i), lp.Token, lp.Bytes)
|
||||||
|
if len(lp.TopLogprobs) == 0 {
|
||||||
|
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
|
||||||
|
}
|
||||||
|
if len(lp.TopLogprobs) > 3 {
|
||||||
|
t.Errorf("logprob[%d] has %d top_logprobs, expected max 3", i, len(lp.TopLogprobs))
|
||||||
|
}
|
||||||
|
for j, top := range lp.TopLogprobs {
|
||||||
|
assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d].top[%d]", i, j), top.Token, top.Bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -63,8 +63,13 @@ func BackendInit() {
|
|||||||
C.llama_backend_init()
|
C.llama_backend_init()
|
||||||
}
|
}
|
||||||
|
|
||||||
func EnumerateGPUs() []ml.DeviceID {
|
type Devices struct {
|
||||||
var ids []ml.DeviceID
|
ml.DeviceID
|
||||||
|
LlamaID uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func EnumerateGPUs() []Devices {
|
||||||
|
var ids []Devices
|
||||||
|
|
||||||
for i := range C.ggml_backend_dev_count() {
|
for i := range C.ggml_backend_dev_count() {
|
||||||
device := C.ggml_backend_dev_get(i)
|
device := C.ggml_backend_dev_get(i)
|
||||||
@@ -74,9 +79,12 @@ func EnumerateGPUs() []ml.DeviceID {
|
|||||||
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
|
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
|
||||||
var props C.struct_ggml_backend_dev_props
|
var props C.struct_ggml_backend_dev_props
|
||||||
C.ggml_backend_dev_get_props(device, &props)
|
C.ggml_backend_dev_get_props(device, &props)
|
||||||
ids = append(ids, ml.DeviceID{
|
ids = append(ids, Devices{
|
||||||
|
DeviceID: ml.DeviceID{
|
||||||
ID: C.GoString(props.id),
|
ID: C.GoString(props.id),
|
||||||
Library: C.GoString(props.library),
|
Library: C.GoString(props.library),
|
||||||
|
},
|
||||||
|
LlamaID: uint64(i),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -217,7 +225,21 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
|
|||||||
return embeddings
|
return embeddings
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLogitsIth gets the logits for the ith token
|
||||||
|
func (c *Context) GetLogitsIth(i int) []float32 {
|
||||||
|
logits := unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int32_t(i)))
|
||||||
|
if logits == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
vocabSize := c.Model().NumVocab()
|
||||||
|
result := make([]float32, vocabSize)
|
||||||
|
_ = copy(result, unsafe.Slice((*float32)(logits), vocabSize))
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
type ModelParams struct {
|
type ModelParams struct {
|
||||||
|
Devices []uint64
|
||||||
NumGpuLayers int
|
NumGpuLayers int
|
||||||
MainGpu int
|
MainGpu int
|
||||||
UseMmap bool
|
UseMmap bool
|
||||||
@@ -241,6 +263,21 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
|
|||||||
cparams.use_mmap = C.bool(params.UseMmap)
|
cparams.use_mmap = C.bool(params.UseMmap)
|
||||||
cparams.vocab_only = C.bool(params.VocabOnly)
|
cparams.vocab_only = C.bool(params.VocabOnly)
|
||||||
|
|
||||||
|
var devices []C.ggml_backend_dev_t
|
||||||
|
for _, llamaID := range params.Devices {
|
||||||
|
devices = append(devices, C.ggml_backend_dev_get(C.size_t(llamaID)))
|
||||||
|
}
|
||||||
|
if len(devices) > 0 {
|
||||||
|
devices = append(devices, C.ggml_backend_dev_t(C.NULL))
|
||||||
|
devicesData := &devices[0]
|
||||||
|
|
||||||
|
var devicesPin runtime.Pinner
|
||||||
|
devicesPin.Pin(devicesData)
|
||||||
|
defer devicesPin.Unpin()
|
||||||
|
|
||||||
|
cparams.devices = devicesData
|
||||||
|
}
|
||||||
|
|
||||||
if len(params.TensorSplit) > 0 {
|
if len(params.TensorSplit) > 0 {
|
||||||
tensorSplitData := ¶ms.TensorSplit[0]
|
tensorSplitData := ¶ms.TensorSplit[0]
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,32 @@
|
|||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Jeff Bolz <jbolz@nvidia.com>
|
||||||
|
Date: Wed, 29 Oct 2025 03:53:04 -0500
|
||||||
|
Subject: [PATCH] vulkan: Call ggml_vk_buffer_write_2d from ggml_vk_buffer_copy
|
||||||
|
(#16793)
|
||||||
|
|
||||||
|
This lets the copy to the destination device use the host-visible
|
||||||
|
vidmem optimization.
|
||||||
|
---
|
||||||
|
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 5 +----
|
||||||
|
1 file changed, 1 insertion(+), 4 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
|
index 221e29509..18b7cbccf 100644
|
||||||
|
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
|
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
|
@@ -5654,14 +5654,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
|
||||||
|
VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
|
||||||
|
// Copy device to device
|
||||||
|
ggml_vk_ensure_sync_staging_buffer(src->device, size);
|
||||||
|
- ggml_vk_ensure_sync_staging_buffer(dst->device, size);
|
||||||
|
|
||||||
|
// Copy to src staging buffer
|
||||||
|
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
|
||||||
|
- // memcpy to dst staging buffer
|
||||||
|
- memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size);
|
||||||
|
// Copy to dst buffer
|
||||||
|
- ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size);
|
||||||
|
+ ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,657 @@
|
|||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Jeff Bolz <jbolz@nvidia.com>
|
||||||
|
Date: Wed, 29 Oct 2025 08:44:29 -0500
|
||||||
|
Subject: [PATCH] vulkan: Update topk_moe fusion to handle gpt's late softmax
|
||||||
|
(#16656)
|
||||||
|
|
||||||
|
* vulkan: Update topk_moe fusion to handle gpt's late softmax
|
||||||
|
|
||||||
|
Based on #16649.
|
||||||
|
|
||||||
|
* Add ggml_check_edges
|
||||||
|
|
||||||
|
* Add sync logging to show fusion effects
|
||||||
|
|
||||||
|
* handle clamp added in #16655
|
||||||
|
|
||||||
|
* Update ggml/src/ggml-impl.h
|
||||||
|
|
||||||
|
Co-authored-by: Diego Devesa <slarengh@gmail.com>
|
||||||
|
---
|
||||||
|
ggml/src/ggml-impl.h | 16 +
|
||||||
|
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 304 +++++++++++-------
|
||||||
|
.../ggml-vulkan/vulkan-shaders/topk_moe.comp | 90 ++++--
|
||||||
|
3 files changed, 272 insertions(+), 138 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
|
||||||
|
index 639d551a2..e5c446d1d 100644
|
||||||
|
--- a/ggml/src/ggml-impl.h
|
||||||
|
+++ b/ggml/src/ggml-impl.h
|
||||||
|
@@ -693,6 +693,7 @@ GGML_API void ggml_dxgi_pdh_release();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
+#include <array>
|
||||||
|
#include <initializer_list>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
@@ -708,6 +709,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||||
|
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
+// Return true if the edges in the graph match expectations.
|
||||||
|
+inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
|
||||||
|
+ int start_idx,
|
||||||
|
+ std::initializer_list<std::array<int, 3>> edges) {
|
||||||
|
+ for (const auto & edge : edges) {
|
||||||
|
+ int dst_node = edge[0];
|
||||||
|
+ int src_idx = edge[1];
|
||||||
|
+ int src_node = edge[2];
|
||||||
|
+ if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
|
||||||
|
+ return false;
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
+ return true;
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
// expose GGUF internals for test code
|
||||||
|
GGML_API size_t gguf_type_size(enum gguf_type type);
|
||||||
|
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
||||||
|
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
|
index 53b57c179..b2855b078 100644
|
||||||
|
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
|
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
|
@@ -387,12 +387,76 @@ static constexpr uint32_t num_argsort_pipelines = 11;
|
||||||
|
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
||||||
|
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
||||||
|
|
||||||
|
-static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||||
|
- GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||||
|
- GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
||||||
|
-static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||||
|
- GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||||
|
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||||
|
+ GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||||
|
+ GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
|
||||||
|
+ GGML_OP_RESHAPE };
|
||||||
|
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||||
|
+ GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||||
|
+static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||||
|
+ GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||||
|
+ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
||||||
|
+
|
||||||
|
+//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
|
||||||
|
+//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
|
||||||
|
+//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
|
||||||
|
+//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ]
|
||||||
|
+//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ]
|
||||||
|
+//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ]
|
||||||
|
+//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ]
|
||||||
|
+//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
|
||||||
|
+//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
|
||||||
|
+//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ]
|
||||||
|
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
|
||||||
|
+ { 1, 0, 0 }, // reshape->src[0] == softmax
|
||||||
|
+ { 2, 0, 0 }, // argsort->src[0] == softmax
|
||||||
|
+ { 3, 0, 2 }, // view->src[0] == argsort
|
||||||
|
+ { 4, 0, 1 }, // get_rows->src[0] == reshape
|
||||||
|
+ { 4, 1, 3 }, // get_rows->src[1] == view
|
||||||
|
+ { 5, 0, 4 }, // reshape->src[0] == get_rows
|
||||||
|
+ { 6, 0, 5 }, // sum_rows->src[0] == reshape
|
||||||
|
+ { 7, 0, 6 }, // clamp->src[0] == sum_rows
|
||||||
|
+ { 8, 0, 5 }, // div->src[0] == reshape
|
||||||
|
+ { 8, 1, 7 }, // div->src[1] == clamp
|
||||||
|
+ { 9, 0, 8 }, // reshape->src[0] == div
|
||||||
|
+};
|
||||||
|
+
|
||||||
|
+// same as early_softmax_norm but ending after the get_rows
|
||||||
|
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
|
||||||
|
+ { 1, 0, 0 }, // reshape->src[0] == softmax
|
||||||
|
+ { 2, 0, 0 }, // argsort->src[0] == softmax
|
||||||
|
+ { 3, 0, 2 }, // view->src[0] == argsort
|
||||||
|
+ { 4, 0, 1 }, // get_rows->src[0] == reshape
|
||||||
|
+ { 4, 1, 3 }, // get_rows->src[1] == view
|
||||||
|
+};
|
||||||
|
|
||||||
|
+//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
|
||||||
|
+//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
|
||||||
|
+//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
|
||||||
|
+//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
|
||||||
|
+//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
|
||||||
|
+//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
|
||||||
|
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
|
||||||
|
+ { 1, 0, 0 }, // view->src[0] == argsort
|
||||||
|
+ { 2, 1, 1 }, // get_rows->src[1] == view
|
||||||
|
+ { 3, 0, 2 }, // reshape->src[0] == get_rows
|
||||||
|
+ { 4, 0, 3 }, // soft_max->src[0] == reshape
|
||||||
|
+ { 5, 0, 4 }, // reshape->src[0] == soft_max
|
||||||
|
+};
|
||||||
|
+
|
||||||
|
+enum topk_moe_mode {
|
||||||
|
+ TOPK_MOE_EARLY_SOFTMAX,
|
||||||
|
+ TOPK_MOE_EARLY_SOFTMAX_NORM,
|
||||||
|
+ TOPK_MOE_LATE_SOFTMAX,
|
||||||
|
+ TOPK_MOE_COUNT,
|
||||||
|
+};
|
||||||
|
+
|
||||||
|
+static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
|
||||||
|
+ topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
|
||||||
|
+ num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
|
||||||
|
+ TOPK_MOE_LATE_SOFTMAX;
|
||||||
|
+ return mode;
|
||||||
|
+}
|
||||||
|
|
||||||
|
struct vk_device_struct {
|
||||||
|
std::recursive_mutex mutex;
|
||||||
|
@@ -607,8 +671,7 @@ struct vk_device_struct {
|
||||||
|
|
||||||
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||||
|
|
||||||
|
- // [2] is {!norm, norm}
|
||||||
|
- vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
|
||||||
|
+ vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
|
||||||
|
|
||||||
|
std::vector<vk_pipeline_ref> all_pipelines;
|
||||||
|
|
||||||
|
@@ -956,6 +1019,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
|
||||||
|
struct vk_op_topk_moe_push_constants {
|
||||||
|
uint32_t n_rows;
|
||||||
|
uint32_t n_expert_used;
|
||||||
|
+ float clamp_min;
|
||||||
|
+ float clamp_max;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct vk_op_add_id_push_constants {
|
||||||
|
@@ -3806,8 +3871,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
||||||
|
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
|
||||||
|
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
|
||||||
|
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
|
||||||
|
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
|
||||||
|
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &c : compiles) {
|
||||||
|
@@ -8085,8 +8151,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
|
if (ctx->num_additional_fused_ops) {
|
||||||
|
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||||
|
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||||
|
- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
||||||
|
- return ctx->device->pipeline_topk_moe[idx][with_norm];
|
||||||
|
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||||
|
+ return ctx->device->pipeline_topk_moe[idx][mode];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||||
|
@@ -8141,6 +8207,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
case GGML_OP_ARGSORT:
|
||||||
|
+ if (ctx->num_additional_fused_ops) {
|
||||||
|
+ uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||||
|
+ GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||||
|
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||||
|
+ return ctx->device->pipeline_topk_moe[idx][mode];
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
||||||
|
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||||
|
return ctx->device->pipeline_argsort_f32[idx];
|
||||||
|
@@ -9676,10 +9749,12 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
|
||||||
|
|
||||||
|
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
|
||||||
|
|
||||||
|
- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
||||||
|
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||||
|
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
|
||||||
|
- ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
||||||
|
- ggml_tensor * ids = cgraph->nodes[node_idx + 3];
|
||||||
|
+ ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
|
||||||
|
+ (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
|
||||||
|
+ cgraph->nodes[node_idx + 5];
|
||||||
|
+ ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
|
||||||
|
|
||||||
|
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||||
|
@@ -9738,9 +9813,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||||
|
GGML_ASSERT(d_ids != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
- vk_op_topk_moe_push_constants pc;
|
||||||
|
+ vk_op_topk_moe_push_constants pc {};
|
||||||
|
pc.n_rows = n_rows;
|
||||||
|
pc.n_expert_used = n_expert_used;
|
||||||
|
+ if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
|
||||||
|
+ ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
|
||||||
|
+ pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
|
||||||
|
+ pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
|
||||||
|
+ }
|
||||||
|
|
||||||
|
GGML_ASSERT(n_expert_used <= n_experts);
|
||||||
|
|
||||||
|
@@ -11335,7 +11415,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
+
|
||||||
|
+#define ENABLE_SYNC_LOGGING 0
|
||||||
|
+
|
||||||
|
if (need_sync) {
|
||||||
|
+#if ENABLE_SYNC_LOGGING
|
||||||
|
+ std::cerr << "sync" << std::endl;
|
||||||
|
+#endif
|
||||||
|
ctx->unsynced_nodes_written.clear();
|
||||||
|
ctx->unsynced_nodes_read.clear();
|
||||||
|
ggml_vk_sync_buffers(ctx, compute_ctx);
|
||||||
|
@@ -11353,6 +11439,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
+#if ENABLE_SYNC_LOGGING
|
||||||
|
+ if (!dryrun) {
|
||||||
|
+ for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
|
||||||
|
+ auto *n = cgraph->nodes[node_idx + i];
|
||||||
|
+ std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
|
||||||
|
+ if (n->op == GGML_OP_GLU) {
|
||||||
|
+ std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
|
||||||
|
+ }
|
||||||
|
+ std::cerr << std::endl;
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
+#endif
|
||||||
|
|
||||||
|
switch (node->op) {
|
||||||
|
case GGML_OP_REPEAT:
|
||||||
|
@@ -11531,7 +11629,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_ARGSORT:
|
||||||
|
- ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
|
||||||
|
+ if (ctx->num_additional_fused_ops) {
|
||||||
|
+ ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
|
||||||
|
+ } else {
|
||||||
|
+ ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
|
||||||
|
+ }
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_SUM:
|
||||||
|
@@ -12329,30 +12431,27 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
||||||
|
- int node_idx, bool with_norm) {
|
||||||
|
+ int node_idx, topk_moe_mode mode) {
|
||||||
|
|
||||||
|
- if (with_norm) {
|
||||||
|
- if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
|
||||||
|
- return false;
|
||||||
|
- }
|
||||||
|
- for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
|
||||||
|
- if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
|
||||||
|
- return false;
|
||||||
|
- }
|
||||||
|
- }
|
||||||
|
- } else {
|
||||||
|
- if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
|
||||||
|
- return false;
|
||||||
|
- }
|
||||||
|
- for (size_t i = 0; i < topk_moe.size(); ++i) {
|
||||||
|
- if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
|
||||||
|
- return false;
|
||||||
|
- }
|
||||||
|
- }
|
||||||
|
- }
|
||||||
|
+ const ggml_tensor * softmax;
|
||||||
|
+ const ggml_tensor * weights;
|
||||||
|
|
||||||
|
- const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
|
||||||
|
- const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
||||||
|
+ switch (mode) {
|
||||||
|
+ case TOPK_MOE_EARLY_SOFTMAX_NORM:
|
||||||
|
+ softmax = cgraph->nodes[node_idx + 0];
|
||||||
|
+ weights = cgraph->nodes[node_idx + 9];
|
||||||
|
+ break;
|
||||||
|
+ case TOPK_MOE_EARLY_SOFTMAX:
|
||||||
|
+ softmax = cgraph->nodes[node_idx + 0];
|
||||||
|
+ weights = cgraph->nodes[node_idx + 4];
|
||||||
|
+ break;
|
||||||
|
+ case TOPK_MOE_LATE_SOFTMAX:
|
||||||
|
+ softmax = cgraph->nodes[node_idx + 4];
|
||||||
|
+ weights = cgraph->nodes[node_idx + 5];
|
||||||
|
+ break;
|
||||||
|
+ default:
|
||||||
|
+ return false;
|
||||||
|
+ }
|
||||||
|
|
||||||
|
const float * op_params = (const float *)softmax->op_params;
|
||||||
|
|
||||||
|
@@ -12378,60 +12477,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
- // Check that the nodes don't have any unexpected uses
|
||||||
|
- const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
|
||||||
|
- const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
|
||||||
|
- const ggml_tensor * view = cgraph->nodes[node_idx + 3];
|
||||||
|
- const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
|
||||||
|
- const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
|
||||||
|
- const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
|
||||||
|
- const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
|
||||||
|
- const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
|
||||||
|
-
|
||||||
|
- // softmax is used by reshape and argsort
|
||||||
|
- if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
|
||||||
|
- reshape1->src[0] != softmax ||
|
||||||
|
- argsort->src[0] != softmax) {
|
||||||
|
- return false;
|
||||||
|
- }
|
||||||
|
- // reshape is used by get_rows
|
||||||
|
- if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
|
||||||
|
- get_rows->src[0] != reshape1) {
|
||||||
|
- return false;
|
||||||
|
- }
|
||||||
|
- // argsort is used by view
|
||||||
|
- if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
|
||||||
|
- view->src[0] != argsort) {
|
||||||
|
- return false;
|
||||||
|
- }
|
||||||
|
- // view is written (via argsort), we can skip checking it
|
||||||
|
-
|
||||||
|
- if (with_norm) {
|
||||||
|
- // get_rows is used by reshape
|
||||||
|
- if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
|
||||||
|
- reshape5->src[0] != get_rows) {
|
||||||
|
- return false;
|
||||||
|
- }
|
||||||
|
-
|
||||||
|
- // reshape is used by sum_rows and div
|
||||||
|
- if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
|
||||||
|
- sum_rows->src[0] != reshape5 ||
|
||||||
|
- div->src[0] != reshape5) {
|
||||||
|
- return false;
|
||||||
|
- }
|
||||||
|
-
|
||||||
|
- // sum_rows is used by div
|
||||||
|
- if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
|
||||||
|
- div->src[1] != sum_rows) {
|
||||||
|
- return false;
|
||||||
|
- }
|
||||||
|
-
|
||||||
|
- // div/reshape are written
|
||||||
|
- if (reshape8->src[0] != div) {
|
||||||
|
- return false;
|
||||||
|
- }
|
||||||
|
- }
|
||||||
|
-
|
||||||
|
if (!ctx->device->subgroup_arithmetic ||
|
||||||
|
!ctx->device->subgroup_shuffle ||
|
||||||
|
!ctx->device->subgroup_require_full_support ||
|
||||||
|
@@ -12517,10 +12562,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
|
ctx->num_additional_fused_ops = num_adds - 1;
|
||||||
|
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||||
|
ctx->num_additional_fused_ops = 1;
|
||||||
|
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
||||||
|
- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
||||||
|
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
||||||
|
- ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
||||||
|
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
|
||||||
|
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
|
||||||
|
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
|
||||||
|
+ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
|
||||||
|
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
||||||
|
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
||||||
|
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
||||||
|
+ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
|
||||||
|
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
||||||
|
+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
||||||
|
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
|
||||||
|
+ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
||||||
|
@@ -12618,10 +12671,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
|
ctx->num_additional_fused_ops = num_adds - 1;
|
||||||
|
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||||
|
ctx->num_additional_fused_ops = 1;
|
||||||
|
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
||||||
|
- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
||||||
|
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
||||||
|
- ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
||||||
|
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
|
||||||
|
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
|
||||||
|
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
|
||||||
|
+ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
|
||||||
|
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
||||||
|
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
||||||
|
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
||||||
|
+ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
|
||||||
|
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
||||||
|
+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
||||||
|
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
|
||||||
|
+ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@@ -12754,25 +12815,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
||||||
|
while (first_unused < graph->n_nodes) {
|
||||||
|
std::vector<int> current_set;
|
||||||
|
|
||||||
|
- // Avoid reordering topk_moe_norm
|
||||||
|
- if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
|
||||||
|
- bool is_topk_moe_norm = true;
|
||||||
|
- for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
||||||
|
- if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
|
||||||
|
- is_topk_moe_norm = false;
|
||||||
|
+ // Check for fusion patterns and avoid reordering them
|
||||||
|
+ auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
|
||||||
|
+ if (start + (int)pattern.size() <= graph->n_nodes) {
|
||||||
|
+ bool is_pattern = true;
|
||||||
|
+ for (size_t j = 0; j < pattern.size(); ++j) {
|
||||||
|
+ if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
|
||||||
|
+ is_pattern = false;
|
||||||
|
+ }
|
||||||
|
}
|
||||||
|
+ return is_pattern;
|
||||||
|
}
|
||||||
|
- if (is_topk_moe_norm) {
|
||||||
|
- for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
||||||
|
+ return false;
|
||||||
|
+ };
|
||||||
|
+
|
||||||
|
+ auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
|
||||||
|
+ if (match_pattern(pattern, first_unused)) {
|
||||||
|
+ for (size_t j = 0; j < pattern.size(); ++j) {
|
||||||
|
new_order.push_back(graph->nodes[first_unused + j]);
|
||||||
|
used[first_unused + j] = true;
|
||||||
|
}
|
||||||
|
while (first_unused < graph->n_nodes && used[first_unused]) {
|
||||||
|
first_unused++;
|
||||||
|
}
|
||||||
|
- continue;
|
||||||
|
+ return true;
|
||||||
|
}
|
||||||
|
+ return false;
|
||||||
|
+ };
|
||||||
|
+
|
||||||
|
+ if (keep_pattern(topk_moe_early_softmax_norm)) {
|
||||||
|
+ continue;
|
||||||
|
+ }
|
||||||
|
+ if (keep_pattern(topk_moe_early_softmax)) {
|
||||||
|
+ continue;
|
||||||
|
}
|
||||||
|
+ if (keep_pattern(topk_moe_late_softmax)) {
|
||||||
|
+ continue;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
// First, grab the next unused node.
|
||||||
|
current_set.push_back(first_unused);
|
||||||
|
|
||||||
|
@@ -12790,6 +12870,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
||||||
|
if (is_empty(graph->nodes[j])) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
+ // Don't pull forward nodes from fusion patterns
|
||||||
|
+ if (match_pattern(topk_moe_early_softmax_norm, j) ||
|
||||||
|
+ match_pattern(topk_moe_early_softmax, j) ||
|
||||||
|
+ match_pattern(topk_moe_late_softmax, j)) {
|
||||||
|
+ continue;
|
||||||
|
+ }
|
||||||
|
bool ok = true;
|
||||||
|
for (int c = first_unused; c < j; ++c) {
|
||||||
|
if (!used[c] &&
|
||||||
|
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
|
||||||
|
index 9e56d5f8a..bc1c278bf 100644
|
||||||
|
--- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
|
||||||
|
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
|
||||||
|
@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
uint n_rows;
|
||||||
|
uint n_expert_used;
|
||||||
|
+ float clamp_min;
|
||||||
|
+ float clamp_max;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||||
|
@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||||
|
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||||
|
layout(constant_id = 1) const uint n_experts = 512;
|
||||||
|
layout(constant_id = 2) const bool with_norm = true;
|
||||||
|
+layout(constant_id = 3) const bool late_softmax = false;
|
||||||
|
|
||||||
|
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
||||||
|
|
||||||
|
@@ -25,53 +28,72 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];};
|
||||||
|
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
||||||
|
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
|
||||||
|
|
||||||
|
-void main() {
|
||||||
|
- const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
||||||
|
- if (row >= n_rows) {
|
||||||
|
- return;
|
||||||
|
- }
|
||||||
|
+const float INFINITY = 1.0 / 0.0;
|
||||||
|
|
||||||
|
- const uint logits_offset = n_experts * row;
|
||||||
|
- const uint weights_offset = n_expert_used * row;
|
||||||
|
- const uint ids_offset = n_experts * row;
|
||||||
|
-
|
||||||
|
- float logits_r[experts_per_thread];
|
||||||
|
-
|
||||||
|
- const float INFINITY = 1.0 / 0.0;
|
||||||
|
+// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
||||||
|
+void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
|
||||||
|
+ float max_val = -INFINITY;
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
- for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||||
|
- const uint expert = i + gl_LocalInvocationID.x;
|
||||||
|
- logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
|
||||||
|
+ for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
+ const uint idx = lane + i * WARP_SIZE;
|
||||||
|
+ const bool is_active = !use_limit || (idx < limit);
|
||||||
|
+ if (is_active) {
|
||||||
|
+ max_val = max(max_val, vals[i]);
|
||||||
|
+ }
|
||||||
|
}
|
||||||
|
|
||||||
|
- float max_val = logits_r[0];
|
||||||
|
+ max_val = subgroupMax(max_val);
|
||||||
|
+
|
||||||
|
+ float sum = 0.f;
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
- for (int i = 1; i < experts_per_thread; i++) {
|
||||||
|
- const float val = logits_r[i];
|
||||||
|
- max_val = max(val, max_val);
|
||||||
|
+ for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
+ const uint idx = lane + i * WARP_SIZE;
|
||||||
|
+ const bool is_active = !use_limit || (idx < limit);
|
||||||
|
+ if (is_active) {
|
||||||
|
+ const float val = exp(vals[i] - max_val);
|
||||||
|
+ vals[i] = val;
|
||||||
|
+ sum += val;
|
||||||
|
+ } else {
|
||||||
|
+ vals[i] = 0.f;
|
||||||
|
+ }
|
||||||
|
}
|
||||||
|
|
||||||
|
- max_val = subgroupMax(max_val);
|
||||||
|
+ sum = subgroupAdd(sum);
|
||||||
|
|
||||||
|
- float wt[experts_per_thread];
|
||||||
|
- float tmp = 0.f;
|
||||||
|
+ const float inv_sum = 1.0f / sum;
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
- const float val = logits_r[i];
|
||||||
|
- wt[i] = exp(val - max_val);
|
||||||
|
- tmp += wt[i];
|
||||||
|
+ const uint idx = lane + i * WARP_SIZE;
|
||||||
|
+ const bool is_active = !use_limit || (idx < limit);
|
||||||
|
+ if (is_active) {
|
||||||
|
+ vals[i] *= inv_sum;
|
||||||
|
+ }
|
||||||
|
}
|
||||||
|
+}
|
||||||
|
|
||||||
|
- tmp = subgroupAdd(tmp);
|
||||||
|
+void main() {
|
||||||
|
+ const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
||||||
|
+ if (row >= n_rows) {
|
||||||
|
+ return;
|
||||||
|
+ }
|
||||||
|
|
||||||
|
- const float inv_sum = 1.0f / tmp;
|
||||||
|
+ const uint logits_offset = n_experts * row;
|
||||||
|
+ const uint weights_offset = n_expert_used * row;
|
||||||
|
+ const uint ids_offset = n_experts * row;
|
||||||
|
+
|
||||||
|
+ float wt[experts_per_thread];
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
- for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
- wt[i] = wt[i] * inv_sum;
|
||||||
|
+ for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||||
|
+ const uint expert = i + gl_LocalInvocationID.x;
|
||||||
|
+ wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ if (!late_softmax) {
|
||||||
|
+ softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// at this point, each thread holds a portion of softmax,
|
||||||
|
@@ -82,6 +104,11 @@ void main() {
|
||||||
|
|
||||||
|
float output_weights[experts_per_thread];
|
||||||
|
|
||||||
|
+ [[unroll]]
|
||||||
|
+ for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
+ output_weights[i] = 0.f;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
for (int k = 0; k < n_expert_used; k++) {
|
||||||
|
float max_val = wt[0];
|
||||||
|
uint max_expert = gl_LocalInvocationID.x;
|
||||||
|
@@ -121,6 +148,7 @@ void main() {
|
||||||
|
|
||||||
|
if (with_norm) {
|
||||||
|
wt_sum = subgroupAdd(wt_sum);
|
||||||
|
+ wt_sum = clamp(wt_sum, clamp_min, clamp_max);
|
||||||
|
const float inv_sum = 1.0f / wt_sum;
|
||||||
|
|
||||||
|
[[unroll]]
|
||||||
|
@@ -129,6 +157,10 @@ void main() {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
+ if (late_softmax) {
|
||||||
|
+ softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
[[unroll]]
|
||||||
|
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||||
|
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
|
||||||
1242
llama/patches/0032-vulkan-Fuse-rope-set_rows-16769.patch
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Jeff Bolz <jbolz@nvidia.com>
|
||||||
|
Date: Thu, 30 Oct 2025 01:27:41 -0500
|
||||||
|
Subject: [PATCH] vulkan: Handle argsort with a large number of rows (#16851)
|
||||||
|
|
||||||
|
---
|
||||||
|
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++++
|
||||||
|
ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp | 16 ++++++++++++----
|
||||||
|
2 files changed, 16 insertions(+), 4 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
|
index aaf4334b5..3604ceb04 100644
|
||||||
|
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
|
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
|
@@ -1084,6 +1084,7 @@ struct vk_op_soft_max_push_constants {
|
||||||
|
|
||||||
|
struct vk_op_argsort_push_constants {
|
||||||
|
uint32_t ncols;
|
||||||
|
+ uint32_t nrows;
|
||||||
|
int32_t order;
|
||||||
|
};
|
||||||
|
|
||||||
|
@@ -8710,6 +8711,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
|
break;
|
||||||
|
case GGML_OP_ARGSORT:
|
||||||
|
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
|
||||||
|
+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||||
|
break;
|
||||||
|
case GGML_OP_IM2COL:
|
||||||
|
{
|
||||||
|
@@ -9952,9 +9954,11 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
||||||
|
int32_t * op_params = (int32_t *)dst->op_params;
|
||||||
|
|
||||||
|
uint32_t ncols = src0->ne[0];
|
||||||
|
+ uint32_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
|
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
|
||||||
|
ncols,
|
||||||
|
+ nrows,
|
||||||
|
op_params[0],
|
||||||
|
}, dryrun);
|
||||||
|
}
|
||||||
|
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
|
||||||
|
index c81b84452..c4e68bc02 100644
|
||||||
|
--- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
|
||||||
|
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
|
||||||
|
@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter {
|
||||||
|
uint ncols;
|
||||||
|
+ uint nrows;
|
||||||
|
uint order;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
|
||||||
|
dst_row[idx1] = tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
-void argsort(bool needs_bounds_check) {
|
||||||
|
+void argsort(bool needs_bounds_check, const uint row) {
|
||||||
|
// bitonic sort
|
||||||
|
const int col = int(gl_LocalInvocationID.x);
|
||||||
|
- const uint row = gl_WorkGroupID.y;
|
||||||
|
|
||||||
|
const uint row_offset = row * p.ncols;
|
||||||
|
|
||||||
|
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
if (p.ncols == BLOCK_SIZE) {
|
||||||
|
- argsort(false);
|
||||||
|
+ uint row = gl_WorkGroupID.y;
|
||||||
|
+ while (row < p.nrows) {
|
||||||
|
+ argsort(false, row);
|
||||||
|
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
+ }
|
||||||
|
} else {
|
||||||
|
- argsort(true);
|
||||||
|
+ uint row = gl_WorkGroupID.y;
|
||||||
|
+ while (row < p.nrows) {
|
||||||
|
+ argsort(true, row);
|
||||||
|
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
+ }
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Ruben Ortlam <picard12@live.de>
|
||||||
|
Date: Fri, 31 Oct 2025 08:14:49 +0100
|
||||||
|
Subject: [PATCH] vulkan: fix shmem overrun in mmq id shader (#16873)
|
||||||
|
|
||||||
|
* vulkan: fix shmem overrun in mmq id shader
|
||||||
|
|
||||||
|
* metal : fix mul_mm_id
|
||||||
|
|
||||||
|
---------
|
||||||
|
|
||||||
|
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
|
||||||
|
---
|
||||||
|
ggml/src/ggml-metal/ggml-metal-device.cpp | 2 +-
|
||||||
|
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp | 4 ++++
|
||||||
|
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl | 2 +-
|
||||||
|
tests/test-backend-ops.cpp | 3 +++
|
||||||
|
4 files changed, 9 insertions(+), 2 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||||
|
index 758116342..c78082ac3 100644
|
||||||
|
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||||
|
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||||
|
@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
|
||||||
|
char name[256];
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
|
||||||
|
- snprintf(name, 256, "%s", base);
|
||||||
|
+ snprintf(name, 256, "%s_ne02=%d", base, ne02);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
if (res) {
|
||||||
|
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
|
||||||
|
index 8b238ac4b..d955b4fc7 100644
|
||||||
|
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
|
||||||
|
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
|
||||||
|
@@ -82,9 +82,13 @@ layout (constant_id = 10) const uint WARP = 32;
|
||||||
|
|
||||||
|
#include "mul_mmq_shmem_types.glsl"
|
||||||
|
|
||||||
|
+#ifdef MUL_MAT_ID
|
||||||
|
+#define BK_STEP 1
|
||||||
|
+#else
|
||||||
|
#ifndef BK_STEP
|
||||||
|
#define BK_STEP 4
|
||||||
|
#endif
|
||||||
|
+#endif
|
||||||
|
|
||||||
|
// Shared memory cache
|
||||||
|
shared block_a_cache buf_a[BM * BK_STEP];
|
||||||
|
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
|
||||||
|
index 72fec4404..1c0f5306f 100644
|
||||||
|
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
|
||||||
|
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
|
||||||
|
@@ -27,7 +27,7 @@ struct block_a_cache {
|
||||||
|
#elif defined(DATA_A_Q8_0)
|
||||||
|
#define QUANT_R_MMQ 1
|
||||||
|
// AMD likes 4, Intel likes 1 and Nvidia likes 2
|
||||||
|
-#define BK_STEP 1
|
||||||
|
+// #define BK_STEP 1
|
||||||
|
struct block_a_cache {
|
||||||
|
int32_t qs[32/4];
|
||||||
|
FLOAT_TYPE dm;
|
||||||
|
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
|
||||||
|
index 657b6cc2f..1f8dda383 100644
|
||||||
|
--- a/tests/test-backend-ops.cpp
|
||||||
|
+++ b/tests/test-backend-ops.cpp
|
||||||
|
@@ -6722,6 +6722,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
|
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
|
||||||
|
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
|
||||||
|
|
||||||
|
+ // gpt-oss issue with Vulkan mmq_id
|
||||||
|
+ test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
|
||||||
|
+
|
||||||
|
for (ggml_type type_a : base_types) {
|
||||||
|
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
|
||||||
|
for (int n_mats : {4, 8}) {
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Masato Nakasaka <masato.nakasaka@intel.com>
|
||||||
|
Date: Fri, 31 Oct 2025 16:18:59 +0900
|
||||||
|
Subject: [PATCH] vulkan: Fix crash when FP16 mul_mat accumulation is not
|
||||||
|
supported (#16796)
|
||||||
|
|
||||||
|
* Experimenting crash fix
|
||||||
|
|
||||||
|
* added assert for aborting and fixed comment
|
||||||
|
|
||||||
|
* changed to check if a pipeline is empty or not
|
||||||
|
|
||||||
|
* Moved function in class definition
|
||||||
|
|
||||||
|
* replaced with is_empty
|
||||||
|
|
||||||
|
* Modified is_empty to check only unaligned pipelines
|
||||||
|
---
|
||||||
|
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 +++++++++++++-------
|
||||||
|
1 file changed, 13 insertions(+), 7 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
|
index 3604ceb04..80185d9f0 100644
|
||||||
|
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
|
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||||
|
@@ -146,8 +146,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
|
||||||
|
struct vk_matmul_pipeline_struct {
|
||||||
|
vk_pipeline l, m, s;
|
||||||
|
vk_pipeline a_l, a_m, a_s;
|
||||||
|
+ // Returns true when all unaligned pipelines are null.
|
||||||
|
+ // We only check for unaligned variants since one of the unaligned pipelines must exist
|
||||||
|
+ // while aligned pipelines are optional
|
||||||
|
+ bool is_empty() const {
|
||||||
|
+ return l == nullptr && m == nullptr && s == nullptr;
|
||||||
|
+ }
|
||||||
|
};
|
||||||
|
-
|
||||||
|
typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
|
||||||
|
|
||||||
|
struct vk_matmul_pipeline2 {
|
||||||
|
@@ -5080,7 +5085,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
||||||
|
if (src1_type == GGML_TYPE_Q8_1) {
|
||||||
|
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
|
||||||
|
|
||||||
|
- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
|
||||||
|
+ if (pipelines->is_empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
@@ -5229,7 +5234,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
||||||
|
if (src1_type == GGML_TYPE_Q8_1) {
|
||||||
|
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
|
||||||
|
|
||||||
|
- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
|
||||||
|
+ if (pipelines->is_empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
@@ -5264,16 +5269,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
+ vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
|
||||||
|
// XXX TODO 'prec' is not actually allowed in mul_mat_id.
|
||||||
|
bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
|
||||||
|
- bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr;
|
||||||
|
- bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr;
|
||||||
|
+ bool support_fp16acc = !mmp.f16acc->is_empty();
|
||||||
|
+ bool support_fp32acc = !mmp.f32acc->is_empty();
|
||||||
|
|
||||||
|
if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
|
||||||
|
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc;
|
||||||
|
+ return mmp.f16acc;
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(support_fp32acc);
|
||||||
|
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
|
||||||
|
+ return mmp.f32acc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
516
llm/memory.go
@@ -1,516 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"slices"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
"github.com/ollama/ollama/format"
|
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
|
||||||
"github.com/ollama/ollama/ml"
|
|
||||||
)
|
|
||||||
|
|
||||||
// pickBestFullFitByLibrary will try to find the optimal placement of the model in the available GPUs where the model fully fits
|
|
||||||
// The list of GPUs returned will always be the same brand (library)
|
|
||||||
// If the model can not be fit fully within the available GPU(s) nil is returned
|
|
||||||
func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []string, adapters []string, opts api.Options, gpus []ml.DeviceInfo, numParallel int) []ml.DeviceInfo {
|
|
||||||
for _, gl := range ml.ByLibrary(gpus) {
|
|
||||||
sgl := append(make([]ml.DeviceInfo, 0, len(gl)), gl...)
|
|
||||||
|
|
||||||
// TODO - potentially sort by performance capability, existing models loaded, etc.
|
|
||||||
// TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them
|
|
||||||
// Note: at present, this will favor most current available VRAM descending and ignoring faster GPU speed in mixed setups
|
|
||||||
sort.Sort(sort.Reverse(ml.ByFreeMemory(sgl)))
|
|
||||||
|
|
||||||
if !envconfig.SchedSpread() {
|
|
||||||
// Try to pack into as few GPUs as possible, starting from 1 GPU
|
|
||||||
for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ {
|
|
||||||
gpuSubset := sgl[:numGPUs]
|
|
||||||
ok, estimatedVRAM := predictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel)
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
slog.Info("new model will fit in available VRAM across minimum required GPUs, loading",
|
|
||||||
"model", modelPath,
|
|
||||||
"library", sgl[0].Library,
|
|
||||||
"parallel", numParallel,
|
|
||||||
"required", format.HumanBytes2(estimatedVRAM),
|
|
||||||
"gpus", numGPUs)
|
|
||||||
return gpuSubset
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// TODO future refinements
|
|
||||||
// - if multiple Libraries, see if any single GPU in any Library will fit
|
|
||||||
// - try subsets of GPUs instead of just falling back to 1 or all in a family
|
|
||||||
|
|
||||||
// Now try all the GPUS (OLLAMA_SCHED_SPREAD is set)
|
|
||||||
if ok, estimatedVRAM := predictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok {
|
|
||||||
slog.Info("new model will fit in available VRAM, loading",
|
|
||||||
"model", modelPath,
|
|
||||||
"library", sgl[0].Library,
|
|
||||||
"parallel", numParallel,
|
|
||||||
"required", format.HumanBytes2(estimatedVRAM),
|
|
||||||
"gpus", len(sgl))
|
|
||||||
return sgl
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If multiple Libraries are detected, pick the Library which loads the most layers for the model
|
|
||||||
func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []string, opts api.Options, gpus []ml.DeviceInfo, numParallel int) []ml.DeviceInfo {
|
|
||||||
byLibrary := ml.ByLibrary(gpus)
|
|
||||||
if len(byLibrary) <= 1 {
|
|
||||||
return gpus
|
|
||||||
}
|
|
||||||
var bestEstimate uint64
|
|
||||||
var bestFit int
|
|
||||||
for i, gl := range byLibrary {
|
|
||||||
_, estimatedVRAM := predictServerFit(gl, f, adapters, projectors, opts, numParallel)
|
|
||||||
if estimatedVRAM > bestEstimate {
|
|
||||||
bestEstimate = estimatedVRAM
|
|
||||||
bestFit = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return byLibrary[bestFit]
|
|
||||||
}
|
|
||||||
|
|
||||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
|
||||||
func predictServerFit(allGpus []ml.DeviceInfo, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
|
||||||
// Split up the GPUs by type and try them
|
|
||||||
var estimatedVRAM uint64
|
|
||||||
for _, gpus := range ml.ByLibrary(allGpus) {
|
|
||||||
var layerCount int
|
|
||||||
estimate := estimateGPULayers(gpus, f, projectors, opts, numParallel)
|
|
||||||
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
|
||||||
if opts.NumGPU < 0 {
|
|
||||||
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
|
|
||||||
return true, estimatedVRAM
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if layerCount > 0 && layerCount >= opts.NumGPU {
|
|
||||||
return true, estimatedVRAM
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, estimatedVRAM
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyCPUFit(f *ggml.GGML, modelPath string, projectors []string, adapters []string, opts api.Options, systemInfo ml.SystemInfo, numParallel int) bool {
|
|
||||||
estimate := estimateGPULayers(nil, f, projectors, opts, numParallel)
|
|
||||||
if estimate.TotalSize > systemInfo.FreeMemory {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
slog.Info("new model will fit in available system memory for CPU inference, loading",
|
|
||||||
"model", modelPath,
|
|
||||||
"parallel", numParallel,
|
|
||||||
"required", format.HumanBytes2(estimate.TotalSize),
|
|
||||||
)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
type MemoryEstimate struct {
|
|
||||||
// How many layers we predict we can load
|
|
||||||
Layers int
|
|
||||||
|
|
||||||
// The size of the graph which occupies the main GPU
|
|
||||||
Graph uint64
|
|
||||||
|
|
||||||
// How much VRAM will be allocated given the number of layers we predict
|
|
||||||
VRAMSize uint64
|
|
||||||
|
|
||||||
// The total size of the model if loaded into VRAM. If all layers are loaded, VRAMSize == TotalSize
|
|
||||||
TotalSize uint64
|
|
||||||
|
|
||||||
// For multi-GPU scenarios, this provides the tensor split parameter
|
|
||||||
TensorSplit []int
|
|
||||||
|
|
||||||
// For multi-GPU scenarios, this is the size in bytes per GPU
|
|
||||||
GPUSizes []uint64
|
|
||||||
|
|
||||||
// internal fields for logging purposes
|
|
||||||
inferenceLibrary string
|
|
||||||
layersRequested int
|
|
||||||
layersModel int
|
|
||||||
availableList []string
|
|
||||||
kv uint64
|
|
||||||
allocationsList []string
|
|
||||||
memoryWeights uint64
|
|
||||||
memoryLayerOutput uint64
|
|
||||||
graphFullOffload uint64
|
|
||||||
graphPartialOffload uint64
|
|
||||||
|
|
||||||
projectorWeights, projectorGraph uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
|
||||||
// The GPUs provided must all be the same Library
|
|
||||||
func estimateGPULayers(gpus []ml.DeviceInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
|
|
||||||
// Graph size for a partial offload, applies to all GPUs
|
|
||||||
var graphPartialOffload uint64
|
|
||||||
|
|
||||||
// Graph size when all layers are offloaded, applies to all GPUs
|
|
||||||
var graphFullOffload uint64
|
|
||||||
|
|
||||||
// Final graph offload once we know full or partial
|
|
||||||
var graphOffload uint64
|
|
||||||
|
|
||||||
// Projectors loaded into GPU0 only
|
|
||||||
var llamaEngineProjectorWeights uint64
|
|
||||||
|
|
||||||
// Projectors loaded with output layer
|
|
||||||
var ollamaEngineProjectorWeights uint64
|
|
||||||
var ollamaEngineProjectorGraph uint64
|
|
||||||
|
|
||||||
// Conditional output size on GPU 0
|
|
||||||
var memoryLayerOutput uint64
|
|
||||||
|
|
||||||
// The sizes of a layer
|
|
||||||
var layerSize uint64
|
|
||||||
|
|
||||||
// The sum of all the layer sizes (just for logging)
|
|
||||||
var memoryWeights uint64
|
|
||||||
|
|
||||||
// True if all the layers are loaded
|
|
||||||
var fullyLoaded bool
|
|
||||||
|
|
||||||
// Overflow that didn't fit into the GPU
|
|
||||||
var overflow uint64
|
|
||||||
|
|
||||||
overhead := envconfig.GpuOverhead()
|
|
||||||
availableList := make([]string, len(gpus))
|
|
||||||
libraries := []string{}
|
|
||||||
for i, gpu := range gpus {
|
|
||||||
availableList[i] = format.HumanBytes2(gpu.FreeMemory)
|
|
||||||
if !slices.Contains(libraries, gpu.Library) {
|
|
||||||
libraries = append(libraries, gpu.Library)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(libraries) == 0 {
|
|
||||||
libraries = []string{"cpu"}
|
|
||||||
}
|
|
||||||
slog.Debug("evaluating", "library", strings.Join(libraries, ","), "gpu_count", len(gpus), "available", availableList)
|
|
||||||
|
|
||||||
for _, projector := range projectors {
|
|
||||||
llamaEngineProjectorWeights += projectorMemoryRequirements(projector)
|
|
||||||
}
|
|
||||||
if llamaEngineProjectorWeights == 0 {
|
|
||||||
ollamaEngineProjectorWeights, ollamaEngineProjectorGraph = f.VisionGraphSize()
|
|
||||||
}
|
|
||||||
|
|
||||||
layers := f.Tensors().GroupLayers()
|
|
||||||
// add one layer worth of memory as a buffer
|
|
||||||
if blk0, ok := layers["blk.0"]; ok {
|
|
||||||
layerSize = blk0.Size()
|
|
||||||
} else {
|
|
||||||
slog.Warn("model missing blk.0 layer size")
|
|
||||||
}
|
|
||||||
|
|
||||||
useFlashAttention := envconfig.FlashAttention(f.FlashAttention()) &&
|
|
||||||
ml.FlashAttentionSupported(gpus) &&
|
|
||||||
f.SupportsFlashAttention()
|
|
||||||
|
|
||||||
var kvct string
|
|
||||||
if useFlashAttention {
|
|
||||||
requested := strings.ToLower(envconfig.KvCacheType())
|
|
||||||
if f.SupportsKVCacheType(requested) {
|
|
||||||
kvct = requested
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention)
|
|
||||||
|
|
||||||
if len(kv) > 0 {
|
|
||||||
layerSize += kv[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
var kvTotal uint64
|
|
||||||
for _, kvLayer := range kv {
|
|
||||||
kvTotal += kvLayer
|
|
||||||
}
|
|
||||||
|
|
||||||
if graphPartialOffload == 0 {
|
|
||||||
headsKV := f.KV().HeadCountKVMin()
|
|
||||||
if headsKV == 0 {
|
|
||||||
headsKV = 1
|
|
||||||
}
|
|
||||||
gqa := f.KV().HeadCountMax() / headsKV
|
|
||||||
graphPartialOffload = gqa * kvTotal / 6
|
|
||||||
}
|
|
||||||
if graphFullOffload == 0 {
|
|
||||||
graphFullOffload = graphPartialOffload
|
|
||||||
}
|
|
||||||
|
|
||||||
// on metal there's no partial offload overhead
|
|
||||||
if len(gpus) > 0 && gpus[0].Library == "Metal" {
|
|
||||||
graphPartialOffload = graphFullOffload
|
|
||||||
} else if len(gpus) > 1 {
|
|
||||||
// multigpu should always use the partial graph size
|
|
||||||
graphFullOffload = graphPartialOffload
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output layer handled at the end if we have space
|
|
||||||
if layer, ok := layers["output_norm"]; ok {
|
|
||||||
memoryLayerOutput += layer.Size()
|
|
||||||
}
|
|
||||||
if layer, ok := layers["output"]; ok {
|
|
||||||
memoryLayerOutput += layer.Size()
|
|
||||||
} else if layer, ok := layers["token_embd"]; ok {
|
|
||||||
memoryLayerOutput += layer.Size()
|
|
||||||
}
|
|
||||||
|
|
||||||
gpuZeroOverhead := llamaEngineProjectorWeights
|
|
||||||
|
|
||||||
// Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer
|
|
||||||
var layerCount int
|
|
||||||
tensorSplit := make([]int, len(gpus))
|
|
||||||
gpuAllocations := make([]uint64, len(gpus))
|
|
||||||
type gs struct {
|
|
||||||
i int
|
|
||||||
g *ml.DeviceInfo
|
|
||||||
}
|
|
||||||
gpusWithSpace := []gs{}
|
|
||||||
for i := range gpus {
|
|
||||||
var gzo uint64
|
|
||||||
if len(gpusWithSpace) == 0 {
|
|
||||||
gzo = gpuZeroOverhead
|
|
||||||
}
|
|
||||||
// Only include GPUs that can fit the graph, gpu minimum, the layer buffer and at least more layer
|
|
||||||
if gpus[i].FreeMemory < overhead+gzo+max(graphPartialOffload, graphFullOffload)+gpus[i].MinimumMemory()+2*layerSize {
|
|
||||||
slog.Debug("gpu has too little memory to allocate any layers",
|
|
||||||
"id", gpus[i].ID,
|
|
||||||
"library", gpus[i].Library,
|
|
||||||
"compute", gpus[i].Compute(),
|
|
||||||
"driver", fmt.Sprintf("%d.%d", gpus[i].DriverMajor, gpus[i].DriverMinor),
|
|
||||||
"name", gpus[i].Name,
|
|
||||||
"total", format.HumanBytes2(gpus[i].TotalMemory),
|
|
||||||
"available", format.HumanBytes2(gpus[i].FreeMemory),
|
|
||||||
"minimum_memory", gpus[i].MinimumMemory,
|
|
||||||
"layer_size", format.HumanBytes2(layerSize),
|
|
||||||
"gpu_zer_overhead", format.HumanBytes2(gzo),
|
|
||||||
"partial_offload", format.HumanBytes2(graphPartialOffload),
|
|
||||||
"full_offload", format.HumanBytes2(graphFullOffload),
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
gpusWithSpace = append(gpusWithSpace, gs{i, &gpus[i]})
|
|
||||||
gpuAllocations[i] += gpus[i].MinimumMemory() + layerSize // We hold off on graph until we know partial vs. full
|
|
||||||
}
|
|
||||||
|
|
||||||
var gpuZeroID int
|
|
||||||
if len(gpusWithSpace) > 0 {
|
|
||||||
gpuZeroID = gpusWithSpace[0].i
|
|
||||||
gpuAllocations[gpuZeroID] += gpuZeroOverhead
|
|
||||||
} else {
|
|
||||||
overflow += gpuZeroOverhead
|
|
||||||
}
|
|
||||||
|
|
||||||
// For all the layers, find where they can fit on the GPU(s)
|
|
||||||
for i := int(f.KV().BlockCount()) - 1; i >= 0; i-- {
|
|
||||||
// Some models have inconsistent layer sizes
|
|
||||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
|
||||||
layerSize = blk.Size()
|
|
||||||
layerSize += kv[i]
|
|
||||||
memoryWeights += blk.Size()
|
|
||||||
}
|
|
||||||
|
|
||||||
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
|
|
||||||
// Stop allocating on GPU(s) once we hit the users target NumGPU
|
|
||||||
overflow += layerSize
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// distribute the layers across the GPU(s) that have space
|
|
||||||
for j := len(gpusWithSpace); j > 0; j-- {
|
|
||||||
g := gpusWithSpace[i%j]
|
|
||||||
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
|
|
||||||
if g.g.FreeMemory > overhead+used+layerSize {
|
|
||||||
gpuAllocations[g.i] += layerSize
|
|
||||||
tensorSplit[g.i]++
|
|
||||||
layerCount++
|
|
||||||
break
|
|
||||||
} else {
|
|
||||||
gpusWithSpace = append(gpusWithSpace[:i%j], gpusWithSpace[i%j+1:]...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(gpusWithSpace) == 0 {
|
|
||||||
overflow += layerSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if layerCount >= int(f.KV().BlockCount()) {
|
|
||||||
fullyLoaded = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine if we need to consider output then find where it fits
|
|
||||||
memoryLastLayer := memoryLayerOutput + ollamaEngineProjectorWeights + ollamaEngineProjectorGraph
|
|
||||||
if memoryLastLayer > 0 {
|
|
||||||
if opts.NumGPU < 0 || layerCount < opts.NumGPU {
|
|
||||||
for j := len(gpusWithSpace); j > 0; j-- {
|
|
||||||
g := gpusWithSpace[layerCount%j]
|
|
||||||
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
|
|
||||||
if g.g.FreeMemory > overhead+used+memoryLastLayer {
|
|
||||||
gpuAllocations[g.i] += memoryLastLayer
|
|
||||||
tensorSplit[g.i]++
|
|
||||||
layerCount++
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if layerCount < int(f.KV().BlockCount())+1 {
|
|
||||||
fullyLoaded = false
|
|
||||||
overflow += memoryLastLayer
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the applicable (full or partial) graph allocations
|
|
||||||
for i := range gpus {
|
|
||||||
if tensorSplit[i] <= 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if fullyLoaded {
|
|
||||||
gpuAllocations[i] += graphFullOffload
|
|
||||||
} else {
|
|
||||||
gpuAllocations[i] += graphPartialOffload
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if fullyLoaded {
|
|
||||||
graphOffload = graphFullOffload
|
|
||||||
} else {
|
|
||||||
graphOffload = graphPartialOffload
|
|
||||||
}
|
|
||||||
|
|
||||||
// Summaries for the log
|
|
||||||
var memoryRequiredPartial, memoryRequiredTotal uint64
|
|
||||||
for i := range gpuAllocations {
|
|
||||||
memoryRequiredPartial += gpuAllocations[i]
|
|
||||||
}
|
|
||||||
memoryRequiredTotal = memoryRequiredPartial + overflow
|
|
||||||
|
|
||||||
allocationsList := []string{}
|
|
||||||
for _, a := range gpuAllocations {
|
|
||||||
allocationsList = append(allocationsList, format.HumanBytes2(a))
|
|
||||||
}
|
|
||||||
|
|
||||||
estimate := MemoryEstimate{
|
|
||||||
TotalSize: memoryRequiredTotal,
|
|
||||||
Layers: 0,
|
|
||||||
Graph: 0,
|
|
||||||
VRAMSize: 0,
|
|
||||||
GPUSizes: []uint64{},
|
|
||||||
|
|
||||||
inferenceLibrary: strings.Join(libraries, ","),
|
|
||||||
layersRequested: opts.NumGPU,
|
|
||||||
layersModel: int(f.KV().BlockCount()) + 1,
|
|
||||||
availableList: availableList,
|
|
||||||
kv: kvTotal,
|
|
||||||
allocationsList: allocationsList,
|
|
||||||
memoryWeights: memoryWeights,
|
|
||||||
memoryLayerOutput: memoryLayerOutput,
|
|
||||||
graphFullOffload: graphFullOffload,
|
|
||||||
graphPartialOffload: graphPartialOffload,
|
|
||||||
projectorWeights: llamaEngineProjectorWeights + ollamaEngineProjectorWeights,
|
|
||||||
projectorGraph: ollamaEngineProjectorGraph,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(gpus) == 0 {
|
|
||||||
return estimate
|
|
||||||
}
|
|
||||||
if layerCount == 0 {
|
|
||||||
slog.Debug("insufficient VRAM to load any model layers")
|
|
||||||
return estimate
|
|
||||||
}
|
|
||||||
estimate.Layers = layerCount
|
|
||||||
estimate.Graph = graphOffload
|
|
||||||
estimate.VRAMSize = memoryRequiredPartial
|
|
||||||
estimate.TotalSize = memoryRequiredTotal
|
|
||||||
estimate.TensorSplit = tensorSplit
|
|
||||||
estimate.GPUSizes = gpuAllocations
|
|
||||||
return estimate
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m MemoryEstimate) LogValue() slog.Value {
|
|
||||||
attrs := []slog.Attr{
|
|
||||||
slog.String("library", m.inferenceLibrary),
|
|
||||||
slog.Group(
|
|
||||||
"layers",
|
|
||||||
// requested number of layers to offload
|
|
||||||
"requested", m.layersRequested,
|
|
||||||
// The number of layers the model has (including output)
|
|
||||||
"model", m.layersModel,
|
|
||||||
// estimated number of layers that can be offloaded
|
|
||||||
"offload", m.Layers,
|
|
||||||
// multi-gpu split for tensors
|
|
||||||
"split", m.TensorSplit,
|
|
||||||
),
|
|
||||||
slog.Group(
|
|
||||||
"memory",
|
|
||||||
// memory available by GPU for offloading
|
|
||||||
"available", m.availableList,
|
|
||||||
"gpu_overhead", format.HumanBytes2(envconfig.GpuOverhead()),
|
|
||||||
slog.Group(
|
|
||||||
"required",
|
|
||||||
// memory required for full offloading
|
|
||||||
"full", format.HumanBytes2(m.TotalSize),
|
|
||||||
// memory required to offload layers.estimate layers
|
|
||||||
"partial", format.HumanBytes2(m.VRAMSize),
|
|
||||||
// memory of KV cache
|
|
||||||
"kv", format.HumanBytes2(m.kv),
|
|
||||||
// Allocations across the GPUs
|
|
||||||
"allocations", m.allocationsList,
|
|
||||||
),
|
|
||||||
slog.Group(
|
|
||||||
"weights",
|
|
||||||
// memory of the weights
|
|
||||||
"total", format.HumanBytes2(m.memoryWeights+m.memoryLayerOutput),
|
|
||||||
// memory of repeating layers
|
|
||||||
"repeating", format.HumanBytes2(m.memoryWeights),
|
|
||||||
// memory of non-repeating layers
|
|
||||||
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
|
|
||||||
),
|
|
||||||
slog.Group(
|
|
||||||
"graph",
|
|
||||||
// memory of graph when fully offloaded
|
|
||||||
"full", format.HumanBytes2(m.graphFullOffload),
|
|
||||||
// memory of graph when not fully offloaded
|
|
||||||
"partial", format.HumanBytes2(m.graphPartialOffload),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.projectorWeights > 0 {
|
|
||||||
attrs = append(attrs, slog.Group(
|
|
||||||
"projector",
|
|
||||||
"weights", format.HumanBytes2(m.projectorWeights),
|
|
||||||
"graph", format.HumanBytes2(m.projectorGraph),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
return slog.GroupValue(attrs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func projectorMemoryRequirements(filename string) (weights uint64) {
|
|
||||||
file, err := os.Open(filename)
|
|
||||||
if err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
ggml, err := ggml.Decode(file, 1024)
|
|
||||||
if err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, layer := range ggml.Tensors().GroupLayers() {
|
|
||||||
weights += layer.Size()
|
|
||||||
}
|
|
||||||
|
|
||||||
return weights
|
|
||||||
}
|
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/format"
|
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
|
||||||
"github.com/ollama/ollama/ml"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestEstimateGPULayers(t *testing.T) {
|
|
||||||
t.Setenv("OLLAMA_DEBUG", "1")
|
|
||||||
t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16
|
|
||||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "2048")
|
|
||||||
|
|
||||||
modelName := "dummy"
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer f.Close()
|
|
||||||
inputLayerCount := 5
|
|
||||||
|
|
||||||
tensors := []*ggml.Tensor{
|
|
||||||
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
|
||||||
{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
|
||||||
{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
|
||||||
{Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
|
||||||
{Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
|
||||||
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
|
||||||
}
|
|
||||||
assert.Len(t, tensors, inputLayerCount+1)
|
|
||||||
err = ggml.WriteGGUF(f, ggml.KV{
|
|
||||||
"general.architecture": "llama",
|
|
||||||
"llama.context_length": uint32(32),
|
|
||||||
"llama.embedding_length": uint32(4096),
|
|
||||||
"llama.block_count": uint32(inputLayerCount),
|
|
||||||
"llama.attention.head_count": uint32(32),
|
|
||||||
"llama.attention.head_count_kv": uint32(32),
|
|
||||||
"tokenizer.ggml.tokens": []string{" "},
|
|
||||||
"tokenizer.ggml.scores": []float32{0},
|
|
||||||
"tokenizer.ggml.token_type": []int32{0},
|
|
||||||
}, tensors)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ggml, err := LoadModel(f.Name(), 0)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simple CPU scenario
|
|
||||||
gpus := []ml.DeviceInfo{}
|
|
||||||
projectors := []string{}
|
|
||||||
opts := api.DefaultOptions()
|
|
||||||
t.Run("cpu", func(t *testing.T) {
|
|
||||||
estimate := estimateGPULayers(gpus, ggml, projectors, opts, 1)
|
|
||||||
assert.Equal(t, 0, estimate.Layers)
|
|
||||||
assert.Equal(t, uint64(0), estimate.Graph)
|
|
||||||
})
|
|
||||||
|
|
||||||
// derived from the dummy ggml file above
|
|
||||||
graphPartialOffload := uint64(202377216)
|
|
||||||
graphFullOffload := uint64(171968512)
|
|
||||||
layerSize := uint64(33554436)
|
|
||||||
projectorSize := uint64(0)
|
|
||||||
memoryLayerOutput := uint64(4)
|
|
||||||
|
|
||||||
// Dual CUDA scenario with asymmetry
|
|
||||||
gpuMinimumMemory := uint64(457 * format.MebiByte)
|
|
||||||
gpus = []ml.DeviceInfo{
|
|
||||||
{
|
|
||||||
DeviceID: ml.DeviceID{
|
|
||||||
Library: "CUDA",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
DeviceID: ml.DeviceID{
|
|
||||||
Library: "CUDA",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
// Nested array: GPU0 layer space, GPU1 layer space, expected gpu0, expected gpu1
|
|
||||||
for i, s := range []struct {
|
|
||||||
layer0, layer1 uint64
|
|
||||||
expect0, expect1 int
|
|
||||||
}{
|
|
||||||
{1, 1, 1, 1},
|
|
||||||
{2, 1, 2, 1},
|
|
||||||
{2, 2, 2, 2},
|
|
||||||
{1, 2, 1, 2},
|
|
||||||
{3, 3, 3, 3},
|
|
||||||
{4, 4, 3, 3},
|
|
||||||
{6, 6, 3, 3},
|
|
||||||
{0, 3, 0, 3},
|
|
||||||
} {
|
|
||||||
t.Run(fmt.Sprintf("%v", s), func(t *testing.T) {
|
|
||||||
gpus[0].FreeMemory = 0
|
|
||||||
gpus[1].FreeMemory = 0
|
|
||||||
gpus[0].FreeMemory += projectorSize
|
|
||||||
if s.layer0 > 0 {
|
|
||||||
gpus[0].FreeMemory += memoryLayerOutput
|
|
||||||
} else {
|
|
||||||
gpus[1].FreeMemory += memoryLayerOutput
|
|
||||||
}
|
|
||||||
gpus[0].FreeMemory += gpuMinimumMemory + layerSize + s.layer0*layerSize + 1
|
|
||||||
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
|
|
||||||
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
|
||||||
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
|
||||||
estimate := estimateGPULayers(gpus, ggml, projectors, opts, 1)
|
|
||||||
assert.Equal(t, s.expect0+s.expect1, estimate.Layers, "scenario %d: %v", i, s)
|
|
||||||
assert.Equal(t, []int{s.expect0, s.expect1}, estimate.TensorSplit, "scenario %d: %v", i, s)
|
|
||||||
var layerSums uint64
|
|
||||||
for _, b := range estimate.GPUSizes {
|
|
||||||
layerSums += b
|
|
||||||
}
|
|
||||||
if estimate.Layers < inputLayerCount+1 {
|
|
||||||
assert.Less(t, estimate.VRAMSize, estimate.TotalSize, "scenario %d: %v %+v", i, s, estimate)
|
|
||||||
assert.Equal(t, estimate.VRAMSize, layerSums, "scenario %d: %v %+v", i, s, estimate)
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, estimate.VRAMSize, estimate.TotalSize, "scenario %d: %v %+v", i, s, estimate)
|
|
||||||
assert.Equal(t, estimate.TotalSize, layerSums, "scenario %d: %v %+v", i, s, estimate)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
443
llm/server.go
@@ -89,20 +89,16 @@ type llmServer struct {
|
|||||||
done chan error // Channel to signal when the process exits
|
done chan error // Channel to signal when the process exits
|
||||||
status *StatusWriter
|
status *StatusWriter
|
||||||
options api.Options
|
options api.Options
|
||||||
numParallel int
|
|
||||||
modelPath string
|
modelPath string
|
||||||
|
|
||||||
loadRequest LoadRequest // Parameters used to initialize the runner
|
loadRequest LoadRequest // Parameters used to initialize the runner
|
||||||
|
mem *ml.BackendMemory // Memory allocations for this model
|
||||||
|
|
||||||
// llamaModel is an instance of the cgo llama.cpp model definition
|
// llamaModel is an instance of the cgo llama.cpp model definition
|
||||||
// nil if this server is running the new engine
|
// nil if this server is running the new engine
|
||||||
llamaModel *llama.Model
|
llamaModel *llama.Model
|
||||||
llamaModelLock *sync.Mutex
|
llamaModelLock *sync.Mutex
|
||||||
|
|
||||||
// textProcessor handles text encoding/decoding for the model in the Ollama engine
|
|
||||||
// nil if this server is running the llama.cpp based engine
|
|
||||||
textProcessor model.TextProcessor
|
|
||||||
|
|
||||||
totalLayers uint64
|
totalLayers uint64
|
||||||
loadStart time.Time // Record how long it took the model to load
|
loadStart time.Time // Record how long it took the model to load
|
||||||
loadProgress float32
|
loadProgress float32
|
||||||
@@ -114,14 +110,12 @@ type llamaServer struct {
|
|||||||
llmServer
|
llmServer
|
||||||
|
|
||||||
ggml *ggml.GGML
|
ggml *ggml.GGML
|
||||||
gpus []ml.DeviceInfo // The set of GPUs covered by the memory estimate
|
|
||||||
estimate MemoryEstimate
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ollamaServer struct {
|
type ollamaServer struct {
|
||||||
llmServer
|
llmServer
|
||||||
|
|
||||||
mem *ml.BackendMemory
|
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadModel will load a model from disk. The model must be in the GGML format.
|
// LoadModel will load a model from disk. The model must be in the GGML format.
|
||||||
@@ -245,8 +239,6 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
|||||||
loadRequest: loadRequest,
|
loadRequest: loadRequest,
|
||||||
llamaModel: llamaModel,
|
llamaModel: llamaModel,
|
||||||
llamaModelLock: &sync.Mutex{},
|
llamaModelLock: &sync.Mutex{},
|
||||||
textProcessor: textProcessor,
|
|
||||||
numParallel: numParallel,
|
|
||||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||||
totalLayers: f.KV().BlockCount() + 1,
|
totalLayers: f.KV().BlockCount() + 1,
|
||||||
loadStart: time.Now(),
|
loadStart: time.Now(),
|
||||||
@@ -281,7 +273,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if textProcessor != nil {
|
if textProcessor != nil {
|
||||||
return &ollamaServer{llmServer: s}, nil
|
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
|
||||||
} else {
|
} else {
|
||||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||||
}
|
}
|
||||||
@@ -463,76 +455,173 @@ type LoadResponse struct {
|
|||||||
|
|
||||||
var ErrLoadRequiredFull = errors.New("unable to load full model on GPU")
|
var ErrLoadRequiredFull = errors.New("unable to load full model on GPU")
|
||||||
|
|
||||||
func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||||
systemTotalMemory := systemInfo.TotalMemory
|
slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU)
|
||||||
systemFreeMemory := systemInfo.FreeMemory
|
|
||||||
systemSwapFreeMemory := systemInfo.FreeSwap
|
|
||||||
slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
|
|
||||||
|
|
||||||
if len(gpus) == 0 || s.options.NumGPU == 0 {
|
gpus := append(make([]ml.DeviceInfo, 0, len(systemGPUs)), systemGPUs...)
|
||||||
if !verifyCPUFit(s.ggml, s.modelPath, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, systemInfo, s.numParallel) {
|
|
||||||
slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate)
|
// Synthesize memory allocation information based on our estimates
|
||||||
return nil, fmt.Errorf("model requires more system memory than is currently available %w", ErrLoadRequiredFull)
|
s.mem = &ml.BackendMemory{CPU: ml.DeviceMemory{
|
||||||
|
Name: "CPU",
|
||||||
|
Weights: make([]uint64, s.totalLayers),
|
||||||
|
Cache: make([]uint64, s.totalLayers),
|
||||||
|
}, GPUs: make([]ml.DeviceMemory, len(gpus))}
|
||||||
|
|
||||||
|
for i := range s.mem.GPUs {
|
||||||
|
s.mem.GPUs[i].Name = gpus[i].Name
|
||||||
|
s.mem.GPUs[i].DeviceID = gpus[i].DeviceID
|
||||||
|
s.mem.GPUs[i].Weights = make([]uint64, s.totalLayers)
|
||||||
|
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
|
||||||
|
}
|
||||||
|
|
||||||
|
kv, graphPartialOffload, graphFullOffload := s.ggml.GraphSize(uint64(s.options.NumCtx), uint64(s.loadRequest.BatchSize),
|
||||||
|
s.loadRequest.Parallel, s.loadRequest.KvCacheType, s.loadRequest.FlashAttention)
|
||||||
|
|
||||||
|
// Use the size of one layer as a buffer
|
||||||
|
layers := s.ggml.Tensors().GroupLayers()
|
||||||
|
if blk0, ok := layers["blk.0"]; ok {
|
||||||
|
for i := range gpus {
|
||||||
|
gpus[i].FreeMemory -= blk0.Size() + kv[0]
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
g := pickBestFullFitByLibrary(s.ggml, s.modelPath, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
|
slog.Warn("model missing blk.0 layer size")
|
||||||
if g == nil {
|
|
||||||
if !requireFull {
|
|
||||||
g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
|
|
||||||
} else {
|
|
||||||
slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate)
|
|
||||||
return nil, ErrLoadRequiredFull
|
|
||||||
}
|
|
||||||
}
|
|
||||||
gpus = g
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.estimate = estimateGPULayers(gpus, s.ggml, []string{s.loadRequest.ProjectorPath}, s.options, s.numParallel)
|
// Assign all the layers to the CPU for now, they will get reassigned later
|
||||||
|
for i := range s.ggml.KV().BlockCount() {
|
||||||
|
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||||
|
s.mem.CPU.Weights[i] = blk.Size()
|
||||||
|
s.mem.CPU.Cache[i] += kv[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(gpus) >= 1 {
|
// We historically haven't included InputWeights in the model size
|
||||||
switch {
|
var outputWeights uint64
|
||||||
case s.options.NumGPU == 0:
|
if layer, ok := layers["output_norm"]; ok {
|
||||||
gpus = []ml.DeviceInfo{}
|
outputWeights += layer.Size()
|
||||||
case gpus[0].Library == "Metal" && s.estimate.VRAMSize > systemInfo.TotalMemory:
|
}
|
||||||
// disable partial offloading when model is greater than total system memory as this
|
if layer, ok := layers["output"]; ok {
|
||||||
// can lead to locking up the system
|
outputWeights += layer.Size()
|
||||||
s.options.NumGPU = 0
|
} else if layer, ok := layers["token_embd"]; ok {
|
||||||
gpus = []ml.DeviceInfo{}
|
outputWeights += layer.Size()
|
||||||
case gpus[0].Library != "Metal" && s.estimate.Layers == 0:
|
}
|
||||||
// Don't bother loading into the GPU if no layers can fit
|
s.mem.CPU.Weights[s.totalLayers-1] = outputWeights
|
||||||
gpus = []ml.DeviceInfo{}
|
|
||||||
case s.options.NumGPU < 0 && s.estimate.Layers > 0:
|
// The vision projector is always loaded on the first GPU if available.
|
||||||
s.options.NumGPU = s.estimate.Layers
|
// This can't be assigned by us, so just subtract it from free space
|
||||||
|
projectorGPU := -1
|
||||||
|
var projectorWeights uint64
|
||||||
|
if len(gpus) > 0 {
|
||||||
|
for _, projector := range s.loadRequest.LoraPath {
|
||||||
|
projectorWeights += projectorMemoryRequirements(projector)
|
||||||
|
}
|
||||||
|
|
||||||
|
// llama.cpp uses the first discrete GPU if available, otherwise the first iGPU
|
||||||
|
firstIntegrated := -1
|
||||||
|
for i := range gpus {
|
||||||
|
if !gpus[i].Integrated {
|
||||||
|
projectorGPU = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if firstIntegrated == -1 {
|
||||||
|
firstIntegrated = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if projectorGPU == -1 {
|
||||||
|
projectorGPU = firstIntegrated
|
||||||
|
}
|
||||||
|
|
||||||
|
gpus[projectorGPU].FreeMemory -= projectorWeights
|
||||||
|
}
|
||||||
|
|
||||||
|
var kvTotal uint64
|
||||||
|
for _, kvLayer := range kv {
|
||||||
|
kvTotal += kvLayer
|
||||||
|
}
|
||||||
|
|
||||||
|
if graphPartialOffload == 0 {
|
||||||
|
headsKV := s.ggml.KV().HeadCountKVMin()
|
||||||
|
if headsKV == 0 {
|
||||||
|
headsKV = 1
|
||||||
|
}
|
||||||
|
gqa := s.ggml.KV().HeadCountMax() / headsKV
|
||||||
|
graphPartialOffload = gqa * kvTotal / 6
|
||||||
|
}
|
||||||
|
if graphFullOffload == 0 {
|
||||||
|
graphFullOffload = graphPartialOffload
|
||||||
|
}
|
||||||
|
|
||||||
|
// On Metal there's no partial offload overhead
|
||||||
|
if len(gpus) > 0 && gpus[0].Library == "Metal" {
|
||||||
|
graphPartialOffload = graphFullOffload
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a layout based on the memory data that we've built. The compute graph
|
||||||
|
// for GPUs is iteratively assigned based on the number of GPUs that are required.
|
||||||
|
var gpuLayers ml.GPULayersList
|
||||||
|
for {
|
||||||
|
prevGPULayers := gpuLayers
|
||||||
|
|
||||||
|
var err error
|
||||||
|
gpuLayers, err = s.createLayout(systemInfo, gpus, s.mem, requireFull, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(gpuLayers) > len(prevGPULayers) {
|
||||||
|
for _, gl := range gpuLayers {
|
||||||
|
for i := range s.mem.GPUs {
|
||||||
|
if gl.DeviceID == s.mem.GPUs[i].DeviceID {
|
||||||
|
s.mem.GPUs[i].Graph = max(graphPartialOffload, graphFullOffload)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
s.options.NumGPU = 0
|
break
|
||||||
}
|
|
||||||
|
|
||||||
// On linux and windows, over-allocating CPU memory will almost always result in an error
|
|
||||||
// Darwin has fully dynamic swap so has no direct concept of free swap space
|
|
||||||
if runtime.GOOS != "darwin" {
|
|
||||||
systemMemoryRequired := s.estimate.TotalSize - s.estimate.VRAMSize
|
|
||||||
available := systemInfo.FreeMemory + systemInfo.FreeSwap
|
|
||||||
if systemMemoryRequired > available {
|
|
||||||
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.TotalMemory), "free", format.HumanBytes2(systemInfo.FreeMemory), "swap", format.HumanBytes2(systemInfo.FreeSwap))
|
|
||||||
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("offload", "", s.estimate)
|
// This maintains the historical assignment of graph sizes, though it isn't fully accurate
|
||||||
|
graphSize := graphFullOffload
|
||||||
|
if gpuLayers.Sum() < int(s.totalLayers) {
|
||||||
|
graphSize = graphPartialOffload
|
||||||
|
}
|
||||||
|
|
||||||
s.gpus = gpus
|
// For all layers that we have assigned to GPUs, move them in the memory data so
|
||||||
s.loadRequest.GPULayers = createGPULayers(s.estimate, s.ggml, gpus, s.options.NumGPU)
|
// that it is reported accurately
|
||||||
|
for _, gl := range gpuLayers {
|
||||||
|
for i := range s.mem.GPUs {
|
||||||
|
if gl.DeviceID == s.mem.GPUs[i].DeviceID {
|
||||||
|
for _, l := range gl.Layers {
|
||||||
|
s.mem.GPUs[i].Weights[l] = s.mem.CPU.Weights[l]
|
||||||
|
s.mem.GPUs[i].Cache[l] = s.mem.CPU.Cache[l]
|
||||||
|
|
||||||
// Mmap is only supported on the llama engine
|
s.mem.CPU.Weights[l] = 0
|
||||||
if s.textProcessor == nil {
|
s.mem.CPU.Cache[l] = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mem.GPUs[i].Graph = graphSize
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if projectorGPU > 0 && len(s.mem.GPUs[projectorGPU].Weights) > 0 {
|
||||||
|
s.mem.GPUs[projectorGPU].Weights[s.totalLayers-1] += projectorWeights
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("memory", "estimate", s.mem)
|
||||||
|
s.mem.Log(slog.LevelInfo)
|
||||||
|
|
||||||
|
// The llama engine uses mmap by default
|
||||||
s.loadRequest.UseMmap = true
|
s.loadRequest.UseMmap = true
|
||||||
|
|
||||||
// mmap has issues with partial offloading on metal
|
// mmap has issues with partial offloading on metal
|
||||||
for _, g := range gpus {
|
for _, g := range gpus {
|
||||||
if g.Library == "Metal" &&
|
if g.Library == "Metal" &&
|
||||||
uint64(s.options.NumGPU) > 0 &&
|
uint64(s.options.NumGPU) > 0 &&
|
||||||
uint64(s.options.NumGPU) < s.ggml.KV().BlockCount()+1 {
|
uint64(s.options.NumGPU) < s.totalLayers {
|
||||||
s.options.UseMMap = new(bool)
|
s.options.UseMMap = new(bool)
|
||||||
*s.options.UseMMap = false
|
*s.options.UseMMap = false
|
||||||
}
|
}
|
||||||
@@ -542,90 +631,50 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus [
|
|||||||
// Linux with a model larger than free space, mmap leads to thrashing
|
// Linux with a model larger than free space, mmap leads to thrashing
|
||||||
// For CPU loads we want the memory to be allocated, not FS cache
|
// For CPU loads we want the memory to be allocated, not FS cache
|
||||||
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
||||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) ||
|
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
|
||||||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
||||||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
||||||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
||||||
s.loadRequest.UseMmap = false
|
s.loadRequest.UseMmap = false
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.waitUntilRunnerLaunched(ctx); err != nil {
|
if err := s.waitUntilRunnerLaunched(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.loadRequest.GPULayers = gpuLayers
|
||||||
resp, err := s.initModel(ctx, s.loadRequest, LoadOperationCommit)
|
resp, err := s.initModel(ctx, s.loadRequest, LoadOperationCommit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// On the Ollama engine, we can print out a summary of the memory allocations.
|
|
||||||
// We don't have this for the llama engine but it does something similar itself.
|
|
||||||
if s.textProcessor != nil {
|
|
||||||
resp.Memory.Log(slog.LevelInfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !resp.Success {
|
if !resp.Success {
|
||||||
slog.Warn("failed to allocate memory for model", "memory", resp.Memory)
|
|
||||||
return nil, errors.New("failed to allocate memory for model")
|
return nil, errors.New("failed to allocate memory for model")
|
||||||
}
|
}
|
||||||
|
|
||||||
// The llama engine does its memory allocations together with model loading, so we
|
// The llama engine does its memory allocations together with model loading, so we
|
||||||
// need to wait until it is done to ensure that we have accurate memory data before
|
// need to wait until it is done to ensure that we have accurate memory data before
|
||||||
// loading the next model
|
// loading the next model
|
||||||
if s.textProcessor == nil {
|
|
||||||
return uniqueDeviceIDs(s.loadRequest.GPULayers), s.WaitUntilRunning(ctx)
|
return uniqueDeviceIDs(s.loadRequest.GPULayers), s.WaitUntilRunning(ctx)
|
||||||
} else {
|
|
||||||
return uniqueDeviceIDs(s.loadRequest.GPULayers), nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// createGPULayers maps from the tensor splits assigned by the memory estimates to explicit assignment
|
func projectorMemoryRequirements(filename string) (weights uint64) {
|
||||||
// of particular layers onto GPUs
|
file, err := os.Open(filename)
|
||||||
func createGPULayers(estimate MemoryEstimate, ggml *ggml.GGML, gpus []ml.DeviceInfo, numGPU int) ml.GPULayersList {
|
if err != nil {
|
||||||
if numGPU <= 0 || len(gpus) == 0 {
|
return 0
|
||||||
return nil
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
ggml, err := ggml.Decode(file, 1024)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
gpuLayers := make(ml.GPULayersList, len(gpus))
|
for _, layer := range ggml.Tensors().GroupLayers() {
|
||||||
for i := range gpuLayers {
|
weights += layer.Size()
|
||||||
gpuLayers[i].DeviceID = gpus[i].DeviceID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var sum float32
|
return weights
|
||||||
splits := make([]float32, len(estimate.TensorSplit))
|
|
||||||
// cumulative sum of all splits
|
|
||||||
for i := range splits {
|
|
||||||
sum += float32(estimate.TensorSplit[i])
|
|
||||||
splits[i] = sum
|
|
||||||
}
|
|
||||||
|
|
||||||
if sum <= 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// normalize splits
|
|
||||||
for i := range splits {
|
|
||||||
splits[i] /= sum
|
|
||||||
}
|
|
||||||
|
|
||||||
blocks := int(ggml.KV().BlockCount())
|
|
||||||
gpuRangeStart := max(0, blocks-numGPU)
|
|
||||||
gpuRangeStop := min(gpuRangeStart+numGPU, blocks+1)
|
|
||||||
for i := range blocks + 1 {
|
|
||||||
if i < gpuRangeStart || i >= gpuRangeStop {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f })
|
|
||||||
if index < 0 || index >= len(gpus) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
gpuLayers[index].Layers = append(gpuLayers[index].Layers, i)
|
|
||||||
}
|
|
||||||
|
|
||||||
return gpuLayers
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load finds the optimal layout of layers to offload on GPUs based on no initial information about the size of the model
|
// Load finds the optimal layout of layers to offload on GPUs based on no initial information about the size of the model
|
||||||
@@ -652,23 +701,6 @@ func (s *ollamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus
|
|||||||
|
|
||||||
slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU)
|
slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU)
|
||||||
|
|
||||||
systemTotalMemory := systemInfo.TotalMemory
|
|
||||||
systemFreeMemory := systemInfo.FreeMemory
|
|
||||||
systemSwapFreeMemory := systemInfo.FreeSwap
|
|
||||||
slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
|
|
||||||
|
|
||||||
for _, gpu := range gpus {
|
|
||||||
available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory()
|
|
||||||
if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory() {
|
|
||||||
available = 0
|
|
||||||
}
|
|
||||||
slog.Info("gpu memory", "id", gpu.ID, "library", gpu.Library,
|
|
||||||
"available", format.HumanBytes2(available),
|
|
||||||
"free", format.HumanBytes2(gpu.FreeMemory),
|
|
||||||
"minimum", format.HumanBytes2(gpu.MinimumMemory()),
|
|
||||||
"overhead", format.HumanBytes2(envconfig.GpuOverhead()))
|
|
||||||
}
|
|
||||||
|
|
||||||
pastAllocations := make(map[uint64]struct{})
|
pastAllocations := make(map[uint64]struct{})
|
||||||
var backoff float32
|
var backoff float32
|
||||||
|
|
||||||
@@ -834,25 +866,22 @@ func uniqueDeviceIDs(gpuLayers ml.GPULayersList) []ml.DeviceID {
|
|||||||
// - Calculating how much space each GPU has available for layers, based on free memory and space occupied by the graph
|
// - Calculating how much space each GPU has available for layers, based on free memory and space occupied by the graph
|
||||||
// - Assigning layers
|
// - Assigning layers
|
||||||
// - Ensuring that we don't exceed limits, such as requirements about partial offloading or system memory
|
// - Ensuring that we don't exceed limits, such as requirements about partial offloading or system memory
|
||||||
func (s *ollamaServer) createLayout(systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, error) {
|
func (s *llmServer) createLayout(systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, error) {
|
||||||
if memory == nil {
|
if memory == nil {
|
||||||
memory = &ml.BackendMemory{CPU: ml.DeviceMemory{
|
memory = &ml.BackendMemory{CPU: ml.DeviceMemory{
|
||||||
Weights: make([]uint64, s.totalLayers),
|
Weights: make([]uint64, s.totalLayers),
|
||||||
Cache: make([]uint64, s.totalLayers),
|
Cache: make([]uint64, s.totalLayers),
|
||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
gpuLayers, layers, err := s.buildLayout(systemGPUs, memory, requireFull, backoff)
|
gpuLayers, layers := s.buildLayout(systemGPUs, memory, requireFull, backoff)
|
||||||
if err != nil {
|
err := s.verifyLayout(systemInfo, memory, requireFull, gpuLayers, layers)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = s.verifyLayout(systemInfo, memory, requireFull, gpuLayers, layers)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return gpuLayers, nil
|
return gpuLayers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ollamaServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, []uint64, error) {
|
func (s *llmServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, []uint64) {
|
||||||
gpus := append(make([]ml.DeviceInfo, 0, len(systemGPUs)), systemGPUs...)
|
gpus := append(make([]ml.DeviceInfo, 0, len(systemGPUs)), systemGPUs...)
|
||||||
sort.Sort(sort.Reverse(ml.ByFreeMemory(gpus)))
|
sort.Sort(sort.Reverse(ml.ByFreeMemory(gpus)))
|
||||||
|
|
||||||
@@ -910,11 +939,11 @@ func (s *ollamaServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.Backen
|
|||||||
gpuLayers = libraryGpuLayers
|
gpuLayers = libraryGpuLayers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return gpuLayers, layers, nil
|
return gpuLayers, layers
|
||||||
}
|
}
|
||||||
|
|
||||||
// verifyLayout ensures that we don't exceed limits, such as requirements about partial offloading or system memory
|
// verifyLayout ensures that we don't exceed limits, such as requirements about partial offloading or system memory
|
||||||
func (s *ollamaServer) verifyLayout(systemInfo ml.SystemInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
|
func (s *llmServer) verifyLayout(systemInfo ml.SystemInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
|
||||||
// These sizes will only increase as we go through additional iterations and get additional information.
|
// These sizes will only increase as we go through additional iterations and get additional information.
|
||||||
cpuSize := memory.InputWeights + memory.CPU.Graph
|
cpuSize := memory.InputWeights + memory.CPU.Graph
|
||||||
var vramSize uint64
|
var vramSize uint64
|
||||||
@@ -942,11 +971,13 @@ nextLayer:
|
|||||||
|
|
||||||
if requireFull {
|
if requireFull {
|
||||||
if gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
|
if gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
|
||||||
|
slog.Info("model requires more memory than is currently available, evicting a model to make space", "loaded layers", gpuLayers.Sum())
|
||||||
return ErrLoadRequiredFull
|
return ErrLoadRequiredFull
|
||||||
}
|
}
|
||||||
|
|
||||||
if cpuSize > systemInfo.FreeMemory {
|
if cpuSize > systemInfo.FreeMemory {
|
||||||
return ErrLoadRequiredFull
|
slog.Info("model requires more system memory than is currently available, evicting a model to make space", "required", cpuSize, "free", systemInfo.FreeMemory)
|
||||||
|
return fmt.Errorf("model requires more system memory than is currently available %w", ErrLoadRequiredFull)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -976,6 +1007,13 @@ nextLayer:
|
|||||||
|
|
||||||
// assignLayers packs the maximum number of layers onto the smallest set of GPUs and comes up with a layer assignment
|
// assignLayers packs the maximum number of layers onto the smallest set of GPUs and comes up with a layer assignment
|
||||||
func assignLayers(layers []uint64, gpus []ml.DeviceInfo, requireFull bool, requestedLayers int, lastUsedGPU int) (gpuLayers ml.GPULayersList) {
|
func assignLayers(layers []uint64, gpus []ml.DeviceInfo, requireFull bool, requestedLayers int, lastUsedGPU int) (gpuLayers ml.GPULayersList) {
|
||||||
|
// If the user is manually overriding parameters, treat all GPUs equally so they split according to VRAM
|
||||||
|
if requestedLayers >= 0 || envconfig.SchedSpread() {
|
||||||
|
for i := range gpus {
|
||||||
|
gpus[i].Integrated = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If we can't fit everything then prefer offloading layers other than the output layer
|
// If we can't fit everything then prefer offloading layers other than the output layer
|
||||||
for range 2 {
|
for range 2 {
|
||||||
// requestedLayers may be -1 if nothing was requested
|
// requestedLayers may be -1 if nothing was requested
|
||||||
@@ -1008,8 +1046,10 @@ func assignLayers(layers []uint64, gpus []ml.DeviceInfo, requireFull bool, reque
|
|||||||
|
|
||||||
// findBestFit binary searches to find the smallest capacity factor that can fit
|
// findBestFit binary searches to find the smallest capacity factor that can fit
|
||||||
// the max number of layers. The capacity factor is multiplied by the free space on
|
// the max number of layers. The capacity factor is multiplied by the free space on
|
||||||
// each GPU and a small one will force even balancing.
|
// each GPU and a small one will force even balancing. Higher performance GPUs are
|
||||||
|
// used first.
|
||||||
func findBestFit(layers []uint64, gpus []ml.DeviceInfo, requestedLayers int, forceRequest bool) (gpuLayers ml.GPULayersList) {
|
func findBestFit(layers []uint64, gpus []ml.DeviceInfo, requestedLayers int, forceRequest bool) (gpuLayers ml.GPULayersList) {
|
||||||
|
for _, gl := range ml.ByPerformance(gpus) {
|
||||||
var high float32 = 1
|
var high float32 = 1
|
||||||
var low float32 = 0
|
var low float32 = 0
|
||||||
|
|
||||||
@@ -1018,15 +1058,12 @@ func findBestFit(layers []uint64, gpus []ml.DeviceInfo, requestedLayers int, for
|
|||||||
high = 1000
|
high = 1000
|
||||||
}
|
}
|
||||||
|
|
||||||
bestAssignments := greedyFit(layers, gpus, high, requestedLayers)
|
bestAssignments := greedyFit(layers, gl, high, requestedLayers)
|
||||||
maxNumGPU := bestAssignments.Sum()
|
maxNumGPU := bestAssignments.Sum()
|
||||||
if maxNumGPU == 0 {
|
|
||||||
return bestAssignments
|
|
||||||
}
|
|
||||||
|
|
||||||
for high-low > 1e-6 {
|
for high-low > 1e-6 {
|
||||||
mid := (low + high) / 2
|
mid := (low + high) / 2
|
||||||
assignments := greedyFit(layers, gpus, mid, requestedLayers)
|
assignments := greedyFit(layers, gl, mid, requestedLayers)
|
||||||
if assignments.Sum() == maxNumGPU {
|
if assignments.Sum() == maxNumGPU {
|
||||||
high = mid
|
high = mid
|
||||||
bestAssignments = assignments
|
bestAssignments = assignments
|
||||||
@@ -1034,7 +1071,13 @@ func findBestFit(layers []uint64, gpus []ml.DeviceInfo, requestedLayers int, for
|
|||||||
low = mid
|
low = mid
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return bestAssignments
|
|
||||||
|
layers = layers[:len(layers)-bestAssignments.Sum()]
|
||||||
|
requestedLayers -= bestAssignments.Sum()
|
||||||
|
gpuLayers = append(bestAssignments, gpuLayers...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return gpuLayers
|
||||||
}
|
}
|
||||||
|
|
||||||
// greedyFit assigns layers incrementally to GPUs, spilling over as each runs out of free space
|
// greedyFit assigns layers incrementally to GPUs, spilling over as each runs out of free space
|
||||||
@@ -1362,6 +1405,12 @@ type CompletionRequest struct {
|
|||||||
Grammar string // set before sending the request to the subprocess
|
Grammar string // set before sending the request to the subprocess
|
||||||
Shift bool
|
Shift bool
|
||||||
Truncate bool
|
Truncate bool
|
||||||
|
|
||||||
|
// Logprobs specifies whether to include log probabilities in the response
|
||||||
|
Logprobs bool
|
||||||
|
|
||||||
|
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
||||||
|
TopLogprobs int
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoneReason represents the reason why a completion response is done
|
// DoneReason represents the reason why a completion response is done
|
||||||
@@ -1387,6 +1436,18 @@ func (d DoneReason) String() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TokenLogprob represents log probability information for a single token alternative.
|
||||||
|
type TokenLogprob struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
Logprob float64 `json:"logprob"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logprob contains log probability information for a generated token.
|
||||||
|
type Logprob struct {
|
||||||
|
TokenLogprob
|
||||||
|
TopLogprobs []TokenLogprob `json:"top_logprobs,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
DoneReason DoneReason `json:"done_reason"`
|
DoneReason DoneReason `json:"done_reason"`
|
||||||
@@ -1395,6 +1456,9 @@ type CompletionResponse struct {
|
|||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||||
EvalCount int `json:"eval_count"`
|
EvalCount int `json:"eval_count"`
|
||||||
EvalDuration time.Duration `json:"eval_duration"`
|
EvalDuration time.Duration `json:"eval_duration"`
|
||||||
|
|
||||||
|
// Logprobs contains log probability information if requested
|
||||||
|
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||||
@@ -1531,6 +1595,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
if c.Content != "" {
|
if c.Content != "" {
|
||||||
fn(CompletionResponse{
|
fn(CompletionResponse{
|
||||||
Content: c.Content,
|
Content: c.Content,
|
||||||
|
Logprobs: c.Logprobs,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1623,68 +1688,59 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
|
|||||||
return e.Embedding, nil
|
return e.Embedding, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizeRequest struct {
|
func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TokenizeResponse struct {
|
|
||||||
Tokens []int `json:"tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
|
||||||
s.llamaModelLock.Lock()
|
s.llamaModelLock.Lock()
|
||||||
defer s.llamaModelLock.Unlock()
|
defer s.llamaModelLock.Unlock()
|
||||||
|
|
||||||
if s.llamaModel != nil {
|
if s.llamaModel == nil {
|
||||||
return s.llamaModel.Tokenize(content, false, true)
|
return nil, fmt.Errorf("no tokenizer configured")
|
||||||
}
|
}
|
||||||
if s.textProcessor != nil {
|
|
||||||
|
return s.llamaModel.Tokenize(content, false, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
tokens, err := s.textProcessor.Encode(content, false)
|
tokens, err := s.textProcessor.Encode(content, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
toks := make([]int, len(tokens))
|
toks := make([]int, len(tokens))
|
||||||
for i, t := range tokens {
|
for i, t := range tokens {
|
||||||
toks[i] = int(t)
|
toks[i] = int(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
return toks, nil
|
return toks, nil
|
||||||
}
|
|
||||||
// not reached
|
|
||||||
return nil, fmt.Errorf("no tokenizer configured")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type DetokenizeRequest struct {
|
func (s *llamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||||
Tokens []int `json:"tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DetokenizeResponse struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
|
||||||
s.llamaModelLock.Lock()
|
s.llamaModelLock.Lock()
|
||||||
defer s.llamaModelLock.Unlock()
|
defer s.llamaModelLock.Unlock()
|
||||||
|
|
||||||
if s.llamaModel != nil {
|
if s.llamaModel == nil {
|
||||||
|
return "", fmt.Errorf("no tokenizer configured")
|
||||||
|
}
|
||||||
|
|
||||||
var resp string
|
var resp string
|
||||||
for _, token := range tokens {
|
for _, token := range tokens {
|
||||||
resp += s.llamaModel.TokenToPiece(token)
|
resp += s.llamaModel.TokenToPiece(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
if s.textProcessor != nil {
|
|
||||||
|
func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||||
toks := make([]int32, len(tokens))
|
toks := make([]int32, len(tokens))
|
||||||
for i, t := range tokens {
|
for i, t := range tokens {
|
||||||
toks[i] = int32(t)
|
toks[i] = int32(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
content, err := s.textProcessor.Decode(toks)
|
content, err := s.textProcessor.Decode(toks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return content, nil
|
return content, nil
|
||||||
}
|
|
||||||
// not reached
|
|
||||||
return "", fmt.Errorf("no tokenizer configured")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Close() error {
|
func (s *llmServer) Close() error {
|
||||||
@@ -1712,31 +1768,12 @@ func (s *llmServer) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llamaServer) VRAMSize() uint64 {
|
|
||||||
return s.estimate.VRAMSize
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *llamaServer) TotalSize() uint64 {
|
|
||||||
return s.estimate.TotalSize
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *llamaServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
|
||||||
for i, gpu := range s.gpus {
|
|
||||||
if gpu.DeviceID == id {
|
|
||||||
if i < len(s.estimate.GPUSizes) {
|
|
||||||
return s.estimate.GPUSizes[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||||
slog.Debug("llamarunner free vram reporting not supported")
|
slog.Debug("llamarunner free vram reporting not supported")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ollamaServer) VRAMSize() uint64 {
|
func (s *llmServer) VRAMSize() uint64 {
|
||||||
if s.mem == nil {
|
if s.mem == nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@@ -1764,7 +1801,7 @@ func (s *ollamaServer) VRAMSize() uint64 {
|
|||||||
return mem
|
return mem
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ollamaServer) TotalSize() uint64 {
|
func (s *llmServer) TotalSize() uint64 {
|
||||||
if s.mem == nil {
|
if s.mem == nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@@ -1778,7 +1815,7 @@ func (s *ollamaServer) TotalSize() uint64 {
|
|||||||
return mem
|
return mem
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ollamaServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||||
if s.mem == nil {
|
if s.mem == nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,16 +14,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestLLMServerFitGPU(t *testing.T) {
|
func TestLLMServerFitGPU(t *testing.T) {
|
||||||
type gpu struct {
|
|
||||||
id ml.DeviceID
|
|
||||||
free int
|
|
||||||
}
|
|
||||||
|
|
||||||
minMemory := 457 * format.MebiByte
|
minMemory := 457 * format.MebiByte
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
gpus []gpu
|
gpus []ml.DeviceInfo
|
||||||
layers []int
|
layers []int
|
||||||
numGPU int
|
numGPU int
|
||||||
requireFull bool
|
requireFull bool
|
||||||
@@ -38,91 +33,91 @@ func TestLLMServerFitGPU(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Full single GPU",
|
name: "Full single GPU",
|
||||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||||
numGPU: -1,
|
numGPU: -1,
|
||||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2}}},
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2}}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Partial single GPU",
|
name: "Partial single GPU",
|
||||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||||
numGPU: -1,
|
numGPU: -1,
|
||||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Single GPU with numGPU 1",
|
name: "Single GPU with numGPU 1",
|
||||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||||
numGPU: 1,
|
numGPU: 1,
|
||||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}},
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Single GPU with numGPU 0",
|
name: "Single GPU with numGPU 0",
|
||||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||||
numGPU: 0,
|
numGPU: 0,
|
||||||
expected: ml.GPULayersList{},
|
expected: ml.GPULayersList{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Single GPU with numGPU 999",
|
name: "Single GPU with numGPU 999",
|
||||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||||
numGPU: 999,
|
numGPU: 999,
|
||||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2, 3}}},
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2, 3}}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Multi GPU fits on one",
|
name: "Multi GPU fits on one",
|
||||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||||
numGPU: -1,
|
numGPU: -1,
|
||||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1, 2}}},
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1, 2}}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Multi GPU split",
|
name: "Multi GPU split",
|
||||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||||
numGPU: -1,
|
numGPU: -1,
|
||||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Multi GPU partial",
|
name: "Multi GPU partial",
|
||||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte},
|
layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte},
|
||||||
numGPU: -1,
|
numGPU: -1,
|
||||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}},
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Multi GPU numGPU 1",
|
name: "Multi GPU numGPU 1",
|
||||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||||
numGPU: 1,
|
numGPU: 1,
|
||||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}},
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Multi GPU numGPU 2",
|
name: "Multi GPU numGPU 2",
|
||||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||||
numGPU: 2,
|
numGPU: 2,
|
||||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}},
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Multi GPU numGPU 999",
|
name: "Multi GPU numGPU 999",
|
||||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte},
|
layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte},
|
||||||
numGPU: 999,
|
numGPU: 999,
|
||||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}},
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Multi GPU different libraries",
|
name: "Multi GPU different libraries",
|
||||||
gpus: []gpu{{id: ml.DeviceID{Library: "CUDA", ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{Library: "ROCm", ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{Library: "CUDA", ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{Library: "ROCm", ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
layers: []int{128 * format.MebiByte, 128 * format.MebiByte, 50 * format.MebiByte},
|
layers: []int{128 * format.MebiByte, 128 * format.MebiByte, 50 * format.MebiByte},
|
||||||
numGPU: -1,
|
numGPU: -1,
|
||||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1", Library: "ROCm"}, Layers: []int{0, 1}}},
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1", Library: "ROCm"}, Layers: []int{0, 1}}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "requireFull",
|
name: "requireFull",
|
||||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||||
numGPU: -1,
|
numGPU: -1,
|
||||||
requireFull: true,
|
requireFull: true,
|
||||||
@@ -130,12 +125,54 @@ func TestLLMServerFitGPU(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "requireFull numGPU",
|
name: "requireFull numGPU",
|
||||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}},
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256 * format.MebiByte)}},
|
||||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||||
numGPU: 4,
|
numGPU: 4,
|
||||||
requireFull: true,
|
requireFull: true,
|
||||||
expectedErr: ErrLoadRequiredFull,
|
expectedErr: ErrLoadRequiredFull,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "iGPU",
|
||||||
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
|
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||||
|
numGPU: -1,
|
||||||
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "iGPU + dGPU",
|
||||||
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
|
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||||
|
numGPU: -1,
|
||||||
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "iGPU + dGPU fits on one",
|
||||||
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
|
layers: []int{50 * format.MebiByte, 50 * format.MebiByte},
|
||||||
|
numGPU: -1,
|
||||||
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "iGPU + dGPU partial",
|
||||||
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
|
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||||
|
numGPU: -1,
|
||||||
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "iGPU + dGPU numGPU 1",
|
||||||
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
|
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||||
|
numGPU: 1,
|
||||||
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "iGPU + dGPU numGPU 999",
|
||||||
|
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||||
|
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||||
|
numGPU: 999,
|
||||||
|
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1, 2, 3}}},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -145,12 +182,6 @@ func TestLLMServerFitGPU(t *testing.T) {
|
|||||||
systemInfo.FreeMemory = 512 * format.MebiByte
|
systemInfo.FreeMemory = 512 * format.MebiByte
|
||||||
systemInfo.FreeSwap = 256 * format.MebiByte
|
systemInfo.FreeSwap = 256 * format.MebiByte
|
||||||
|
|
||||||
gpus := make([]ml.DeviceInfo, len(tt.gpus))
|
|
||||||
for i := range tt.gpus {
|
|
||||||
gpus[i].DeviceID = tt.gpus[i].id
|
|
||||||
gpus[i].FreeMemory = uint64(tt.gpus[i].free)
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &ollamaServer{
|
s := &ollamaServer{
|
||||||
llmServer: llmServer{
|
llmServer: llmServer{
|
||||||
totalLayers: uint64(len(tt.layers)),
|
totalLayers: uint64(len(tt.layers)),
|
||||||
@@ -165,19 +196,19 @@ func TestLLMServerFitGPU(t *testing.T) {
|
|||||||
s.mem = &ml.BackendMemory{CPU: ml.DeviceMemory{
|
s.mem = &ml.BackendMemory{CPU: ml.DeviceMemory{
|
||||||
Weights: make([]uint64, s.totalLayers),
|
Weights: make([]uint64, s.totalLayers),
|
||||||
Cache: make([]uint64, s.totalLayers),
|
Cache: make([]uint64, s.totalLayers),
|
||||||
}, GPUs: make([]ml.DeviceMemory, len(gpus))}
|
}, GPUs: make([]ml.DeviceMemory, len(tt.gpus))}
|
||||||
|
|
||||||
for i := range tt.layers {
|
for i := range tt.layers {
|
||||||
s.mem.CPU.Weights[i] = uint64(tt.layers[i])
|
s.mem.CPU.Weights[i] = uint64(tt.layers[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range s.mem.GPUs {
|
for i := range s.mem.GPUs {
|
||||||
s.mem.GPUs[i].DeviceID = gpus[i].DeviceID
|
s.mem.GPUs[i].DeviceID = tt.gpus[i].DeviceID
|
||||||
s.mem.GPUs[i].Weights = make([]uint64, s.totalLayers)
|
s.mem.GPUs[i].Weights = make([]uint64, s.totalLayers)
|
||||||
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
|
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
|
||||||
}
|
}
|
||||||
|
|
||||||
gpuLayers, err := s.createLayout(systemInfo, gpus, s.mem, tt.requireFull, 0)
|
gpuLayers, err := s.createLayout(systemInfo, tt.gpus, s.mem, tt.requireFull, 0)
|
||||||
if err != tt.expectedErr {
|
if err != tt.expectedErr {
|
||||||
t.Fatalf("fitGPU returned error: %v", err)
|
t.Fatalf("fitGPU returned error: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
{
|
|
||||||
"env": {
|
|
||||||
"browser": true,
|
|
||||||
"es6": true,
|
|
||||||
"node": true
|
|
||||||
},
|
|
||||||
"extends": [
|
|
||||||
"eslint:recommended",
|
|
||||||
"plugin:@typescript-eslint/eslint-recommended",
|
|
||||||
"plugin:@typescript-eslint/recommended",
|
|
||||||
"plugin:import/recommended",
|
|
||||||
"plugin:import/electron",
|
|
||||||
"plugin:import/typescript"
|
|
||||||
],
|
|
||||||
"parser": "@typescript-eslint/parser"
|
|
||||||
}
|
|
||||||
92
macapp/.gitignore
vendored
@@ -1,92 +0,0 @@
|
|||||||
# Logs
|
|
||||||
logs
|
|
||||||
*.log
|
|
||||||
npm-debug.log*
|
|
||||||
yarn-debug.log*
|
|
||||||
yarn-error.log*
|
|
||||||
lerna-debug.log*
|
|
||||||
|
|
||||||
# Diagnostic reports (https://nodejs.org/api/report.html)
|
|
||||||
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
|
|
||||||
|
|
||||||
# Runtime data
|
|
||||||
pids
|
|
||||||
*.pid
|
|
||||||
*.seed
|
|
||||||
*.pid.lock
|
|
||||||
.DS_Store
|
|
||||||
|
|
||||||
# Directory for instrumented libs generated by jscoverage/JSCover
|
|
||||||
lib-cov
|
|
||||||
|
|
||||||
# Coverage directory used by tools like istanbul
|
|
||||||
coverage
|
|
||||||
*.lcov
|
|
||||||
|
|
||||||
# nyc test coverage
|
|
||||||
.nyc_output
|
|
||||||
|
|
||||||
# node-waf configuration
|
|
||||||
.lock-wscript
|
|
||||||
|
|
||||||
# Compiled binary addons (https://nodejs.org/api/addons.html)
|
|
||||||
build/Release
|
|
||||||
|
|
||||||
# Dependency directories
|
|
||||||
node_modules/
|
|
||||||
jspm_packages/
|
|
||||||
|
|
||||||
# TypeScript v1 declaration files
|
|
||||||
typings/
|
|
||||||
|
|
||||||
# TypeScript cache
|
|
||||||
*.tsbuildinfo
|
|
||||||
|
|
||||||
# Optional npm cache directory
|
|
||||||
.npm
|
|
||||||
|
|
||||||
# Optional eslint cache
|
|
||||||
.eslintcache
|
|
||||||
|
|
||||||
# Optional REPL history
|
|
||||||
.node_repl_history
|
|
||||||
|
|
||||||
# Output of 'npm pack'
|
|
||||||
*.tgz
|
|
||||||
|
|
||||||
# Yarn Integrity file
|
|
||||||
.yarn-integrity
|
|
||||||
|
|
||||||
# dotenv environment variables file
|
|
||||||
.env
|
|
||||||
.env.test
|
|
||||||
|
|
||||||
# parcel-bundler cache (https://parceljs.org/)
|
|
||||||
.cache
|
|
||||||
|
|
||||||
# next.js build output
|
|
||||||
.next
|
|
||||||
|
|
||||||
# nuxt.js build output
|
|
||||||
.nuxt
|
|
||||||
|
|
||||||
# vuepress build output
|
|
||||||
.vuepress/dist
|
|
||||||
|
|
||||||
# Serverless directories
|
|
||||||
.serverless/
|
|
||||||
|
|
||||||
# FuseBox cache
|
|
||||||
.fusebox/
|
|
||||||
|
|
||||||
# DynamoDB Local files
|
|
||||||
.dynamodb/
|
|
||||||
|
|
||||||
# Webpack
|
|
||||||
.webpack/
|
|
||||||
|
|
||||||
# Vite
|
|
||||||
.vite/
|
|
||||||
|
|
||||||
# Electron-Forge
|
|
||||||
out/
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
# Desktop
|
|
||||||
|
|
||||||
This app builds upon Ollama to provide a desktop experience for running models.
|
|
||||||
|
|
||||||
## Developing
|
|
||||||
|
|
||||||
First, build the `ollama` binary:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
cd ..
|
|
||||||
go build .
|
|
||||||
```
|
|
||||||
|
|
||||||
Then run the desktop app with `npm start`:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
cd macapp
|
|
||||||
npm install
|
|
||||||
npm start
|
|
||||||
```
|
|
||||||
|
|
||||||
|
Before Width: | Height: | Size: 402 B |
|
Before Width: | Height: | Size: 741 B |
|
Before Width: | Height: | Size: 440 B |
|
Before Width: | Height: | Size: 763 B |
|
Before Width: | Height: | Size: 447 B |
|
Before Width: | Height: | Size: 891 B |
|
Before Width: | Height: | Size: 443 B |
|
Before Width: | Height: | Size: 844 B |
@@ -1,79 +0,0 @@
|
|||||||
import type { ForgeConfig } from '@electron-forge/shared-types'
|
|
||||||
import { MakerSquirrel } from '@electron-forge/maker-squirrel'
|
|
||||||
import { MakerZIP } from '@electron-forge/maker-zip'
|
|
||||||
import { PublisherGithub } from '@electron-forge/publisher-github'
|
|
||||||
import { AutoUnpackNativesPlugin } from '@electron-forge/plugin-auto-unpack-natives'
|
|
||||||
import { WebpackPlugin } from '@electron-forge/plugin-webpack'
|
|
||||||
import * as path from 'path'
|
|
||||||
import * as fs from 'fs'
|
|
||||||
|
|
||||||
import { mainConfig } from './webpack.main.config'
|
|
||||||
import { rendererConfig } from './webpack.renderer.config'
|
|
||||||
|
|
||||||
const packageJson = JSON.parse(fs.readFileSync(path.resolve(__dirname, './package.json'), 'utf8'))
|
|
||||||
|
|
||||||
const config: ForgeConfig = {
|
|
||||||
packagerConfig: {
|
|
||||||
appVersion: process.env.VERSION || packageJson.version,
|
|
||||||
asar: true,
|
|
||||||
icon: './assets/icon.icns',
|
|
||||||
extraResource: [
|
|
||||||
path.join(__dirname, '../dist/darwin/ollama'),
|
|
||||||
...fs.readdirSync(path.join(__dirname, '../dist/darwin-amd64/lib/ollama')).map(f => path.join(__dirname, '../dist/darwin-amd64/lib/ollama', f)),
|
|
||||||
path.join(__dirname, './assets/iconTemplate.png'),
|
|
||||||
path.join(__dirname, './assets/iconTemplate@2x.png'),
|
|
||||||
path.join(__dirname, './assets/iconUpdateTemplate.png'),
|
|
||||||
path.join(__dirname, './assets/iconUpdateTemplate@2x.png'),
|
|
||||||
path.join(__dirname, './assets/iconDarkTemplate.png'),
|
|
||||||
path.join(__dirname, './assets/iconDarkTemplate@2x.png'),
|
|
||||||
path.join(__dirname, './assets/iconDarkUpdateTemplate.png'),
|
|
||||||
path.join(__dirname, './assets/iconDarkUpdateTemplate@2x.png'),
|
|
||||||
],
|
|
||||||
...(process.env.SIGN
|
|
||||||
? {
|
|
||||||
osxSign: {
|
|
||||||
identity: process.env.APPLE_IDENTITY,
|
|
||||||
},
|
|
||||||
osxNotarize: {
|
|
||||||
tool: 'notarytool',
|
|
||||||
appleId: process.env.APPLE_ID || '',
|
|
||||||
appleIdPassword: process.env.APPLE_PASSWORD || '',
|
|
||||||
teamId: process.env.APPLE_TEAM_ID || '',
|
|
||||||
},
|
|
||||||
}
|
|
||||||
: {}),
|
|
||||||
osxUniversal: {
|
|
||||||
x64ArchFiles: '*',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
rebuildConfig: {},
|
|
||||||
makers: [new MakerSquirrel({}), new MakerZIP({}, ['darwin'])],
|
|
||||||
hooks: {
|
|
||||||
readPackageJson: async (_, packageJson) => {
|
|
||||||
return { ...packageJson, version: process.env.VERSION || packageJson.version }
|
|
||||||
},
|
|
||||||
},
|
|
||||||
plugins: [
|
|
||||||
new AutoUnpackNativesPlugin({}),
|
|
||||||
new WebpackPlugin({
|
|
||||||
mainConfig,
|
|
||||||
devContentSecurityPolicy: `default-src * 'unsafe-eval' 'unsafe-inline'; img-src data: 'self'`,
|
|
||||||
renderer: {
|
|
||||||
config: rendererConfig,
|
|
||||||
nodeIntegration: true,
|
|
||||||
entryPoints: [
|
|
||||||
{
|
|
||||||
html: './src/index.html',
|
|
||||||
js: './src/renderer.tsx',
|
|
||||||
name: 'main_window',
|
|
||||||
preload: {
|
|
||||||
js: './src/preload.ts',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
export default config
|
|
||||||
16604
macapp/package-lock.json
generated
@@ -1,80 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "ollama",
|
|
||||||
"productName": "Ollama",
|
|
||||||
"version": "0.0.0",
|
|
||||||
"description": "ollama",
|
|
||||||
"main": ".webpack/main",
|
|
||||||
"scripts": {
|
|
||||||
"start": "electron-forge start",
|
|
||||||
"package": "electron-forge package --arch universal",
|
|
||||||
"package:sign": "SIGN=1 electron-forge package --arch universal",
|
|
||||||
"make": "electron-forge make --arch universal",
|
|
||||||
"make:sign": "SIGN=1 electron-forge make --arch universal",
|
|
||||||
"publish": "SIGN=1 electron-forge publish",
|
|
||||||
"lint": "eslint --ext .ts,.tsx ."
|
|
||||||
},
|
|
||||||
"keywords": [],
|
|
||||||
"author": {
|
|
||||||
"name": "Jeffrey Morgan",
|
|
||||||
"email": "jmorganca@gmail.com"
|
|
||||||
},
|
|
||||||
"license": "MIT",
|
|
||||||
"devDependencies": {
|
|
||||||
"@babel/core": "^7.22.5",
|
|
||||||
"@babel/preset-react": "^7.22.5",
|
|
||||||
"@electron-forge/cli": "^6.2.1",
|
|
||||||
"@electron-forge/maker-deb": "^6.2.1",
|
|
||||||
"@electron-forge/maker-rpm": "^6.2.1",
|
|
||||||
"@electron-forge/maker-squirrel": "^6.2.1",
|
|
||||||
"@electron-forge/maker-zip": "^6.2.1",
|
|
||||||
"@electron-forge/plugin-auto-unpack-natives": "^6.2.1",
|
|
||||||
"@electron-forge/plugin-webpack": "^6.2.1",
|
|
||||||
"@electron-forge/publisher-github": "^6.2.1",
|
|
||||||
"@electron/universal": "^1.4.1",
|
|
||||||
"@svgr/webpack": "^8.0.1",
|
|
||||||
"@types/chmodr": "^1.0.0",
|
|
||||||
"@types/node": "^20.4.0",
|
|
||||||
"@types/react": "^18.2.14",
|
|
||||||
"@types/react-dom": "^18.2.6",
|
|
||||||
"@types/uuid": "^9.0.2",
|
|
||||||
"@typescript-eslint/eslint-plugin": "^5.60.0",
|
|
||||||
"@typescript-eslint/parser": "^5.60.0",
|
|
||||||
"@vercel/webpack-asset-relocator-loader": "^1.7.3",
|
|
||||||
"babel-loader": "^9.1.2",
|
|
||||||
"chmodr": "^1.2.0",
|
|
||||||
"copy-webpack-plugin": "^11.0.0",
|
|
||||||
"css-loader": "^6.8.1",
|
|
||||||
"electron": "25.9.2",
|
|
||||||
"eslint": "^8.43.0",
|
|
||||||
"eslint-plugin-import": "^2.27.5",
|
|
||||||
"fork-ts-checker-webpack-plugin": "^7.3.0",
|
|
||||||
"node-loader": "^2.0.0",
|
|
||||||
"postcss": "^8.4.24",
|
|
||||||
"postcss-import": "^15.1.0",
|
|
||||||
"postcss-loader": "^7.3.3",
|
|
||||||
"postcss-preset-env": "^8.5.1",
|
|
||||||
"style-loader": "^3.3.3",
|
|
||||||
"svg-inline-loader": "^0.8.2",
|
|
||||||
"tailwindcss": "^3.3.2",
|
|
||||||
"ts-loader": "^9.4.3",
|
|
||||||
"ts-node": "^10.9.1",
|
|
||||||
"typescript": "~4.5.4",
|
|
||||||
"url-loader": "^4.1.1",
|
|
||||||
"webpack": "^5.88.0",
|
|
||||||
"webpack-cli": "^5.1.4",
|
|
||||||
"webpack-dev-server": "^4.15.1"
|
|
||||||
},
|
|
||||||
"dependencies": {
|
|
||||||
"@electron/remote": "^2.0.10",
|
|
||||||
"@heroicons/react": "^2.0.18",
|
|
||||||
"@segment/analytics-node": "^1.0.0",
|
|
||||||
"copy-to-clipboard": "^3.3.3",
|
|
||||||
"electron-squirrel-startup": "^1.0.0",
|
|
||||||
"electron-store": "^8.1.0",
|
|
||||||
"react": "^18.2.0",
|
|
||||||
"react-dom": "^18.2.0",
|
|
||||||
"uuid": "^9.0.0",
|
|
||||||
"winston": "^3.10.0",
|
|
||||||
"winston-daily-rotate-file": "^4.7.1"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
module.exports = {
|
|
||||||
plugins: {
|
|
||||||
'postcss-import': {},
|
|
||||||
tailwindcss: {},
|
|
||||||
autoprefixer: {},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
@tailwind base;
|
|
||||||
@tailwind components;
|
|
||||||
@tailwind utilities;
|
|
||||||
|
|
||||||
html,
|
|
||||||
body {
|
|
||||||
background: transparent;
|
|
||||||
}
|
|
||||||
|
|
||||||
.drag {
|
|
||||||
-webkit-app-region: drag;
|
|
||||||
}
|
|
||||||
|
|
||||||
.no-drag {
|
|
||||||
-webkit-app-region: no-drag;
|
|
||||||
}
|
|
||||||
|
|
||||||
.blink {
|
|
||||||
-webkit-animation: 1s blink step-end infinite;
|
|
||||||
-moz-animation: 1s blink step-end infinite;
|
|
||||||
-ms-animation: 1s blink step-end infinite;
|
|
||||||
-o-animation: 1s blink step-end infinite;
|
|
||||||
animation: 1s blink step-end infinite;
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes blink {
|
|
||||||
from,
|
|
||||||
to {
|
|
||||||
color: transparent;
|
|
||||||
}
|
|
||||||
50% {
|
|
||||||
color: black;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
import { useState } from 'react'
|
|
||||||
import copy from 'copy-to-clipboard'
|
|
||||||
import { CheckIcon, DocumentDuplicateIcon } from '@heroicons/react/24/outline'
|
|
||||||
import Store from 'electron-store'
|
|
||||||
import { getCurrentWindow, app } from '@electron/remote'
|
|
||||||
|
|
||||||
import { install } from './install'
|
|
||||||
import OllamaIcon from './ollama.svg'
|
|
||||||
|
|
||||||
const store = new Store()
|
|
||||||
|
|
||||||
enum Step {
|
|
||||||
WELCOME = 0,
|
|
||||||
CLI,
|
|
||||||
FINISH,
|
|
||||||
}
|
|
||||||
|
|
||||||
export default function () {
|
|
||||||
const [step, setStep] = useState<Step>(Step.WELCOME)
|
|
||||||
const [commandCopied, setCommandCopied] = useState<boolean>(false)
|
|
||||||
|
|
||||||
const command = 'ollama run llama3.2'
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className='drag'>
|
|
||||||
<div className='mx-auto flex min-h-screen w-full flex-col justify-between bg-white px-4 pt-16'>
|
|
||||||
{step === Step.WELCOME && (
|
|
||||||
<>
|
|
||||||
<div className='mx-auto text-center'>
|
|
||||||
<h1 className='mb-6 mt-4 text-2xl tracking-tight text-gray-900'>Welcome to Ollama</h1>
|
|
||||||
<p className='mx-auto w-[65%] text-sm text-gray-400'>
|
|
||||||
Let's get you up and running with your own large language models.
|
|
||||||
</p>
|
|
||||||
<button
|
|
||||||
onClick={() => setStep(Step.CLI)}
|
|
||||||
className='no-drag rounded-dm mx-auto my-8 w-[40%] rounded-md bg-black px-4 py-2 text-sm text-white hover:brightness-110'
|
|
||||||
>
|
|
||||||
Next
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
<div className='mx-auto'>
|
|
||||||
<OllamaIcon />
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
{step === Step.CLI && (
|
|
||||||
<>
|
|
||||||
<div className='mx-auto flex flex-col space-y-28 text-center'>
|
|
||||||
<h1 className='mt-4 text-2xl tracking-tight text-gray-900'>Install the command line</h1>
|
|
||||||
<pre className='mx-auto text-4xl text-gray-400'>> ollama</pre>
|
|
||||||
<div className='mx-auto'>
|
|
||||||
<button
|
|
||||||
onClick={async () => {
|
|
||||||
try {
|
|
||||||
await install()
|
|
||||||
setStep(Step.FINISH)
|
|
||||||
} catch (e) {
|
|
||||||
console.error('could not install: ', e)
|
|
||||||
} finally {
|
|
||||||
getCurrentWindow().show()
|
|
||||||
getCurrentWindow().focus()
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
className='no-drag rounded-dm mx-auto w-[60%] rounded-md bg-black px-4 py-2 text-sm text-white hover:brightness-110'
|
|
||||||
>
|
|
||||||
Install
|
|
||||||
</button>
|
|
||||||
<p className='mx-auto my-4 w-[70%] text-xs text-gray-400'>
|
|
||||||
You will be prompted for administrator access
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
{step === Step.FINISH && (
|
|
||||||
<>
|
|
||||||
<div className='mx-auto flex flex-col space-y-20 text-center'>
|
|
||||||
<h1 className='mt-4 text-2xl tracking-tight text-gray-900'>Run your first model</h1>
|
|
||||||
<div className='flex flex-col'>
|
|
||||||
<div className='group relative flex items-center'>
|
|
||||||
<pre className='language-none text-2xs w-full rounded-md bg-gray-100 px-4 py-3 text-start leading-normal'>
|
|
||||||
{command}
|
|
||||||
</pre>
|
|
||||||
<button
|
|
||||||
className={`no-drag absolute right-[5px] px-2 py-2 ${
|
|
||||||
commandCopied
|
|
||||||
? 'text-gray-900 opacity-100 hover:cursor-auto'
|
|
||||||
: 'text-gray-200 opacity-50 hover:cursor-pointer'
|
|
||||||
} hover:font-bold hover:text-gray-900 group-hover:opacity-100`}
|
|
||||||
onClick={() => {
|
|
||||||
copy(command)
|
|
||||||
setCommandCopied(true)
|
|
||||||
setTimeout(() => setCommandCopied(false), 3000)
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{commandCopied ? (
|
|
||||||
<CheckIcon className='h-4 w-4 font-bold text-gray-500' />
|
|
||||||
) : (
|
|
||||||
<DocumentDuplicateIcon className='h-4 w-4 text-gray-500' />
|
|
||||||
)}
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
<p className='mx-auto my-4 w-[70%] text-xs text-gray-400'>
|
|
||||||
Run this command in your favorite terminal.
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
<button
|
|
||||||
onClick={() => {
|
|
||||||
store.set('first-time-run', true)
|
|
||||||
window.close()
|
|
||||||
}}
|
|
||||||
className='no-drag rounded-dm mx-auto w-[60%] rounded-md bg-black px-4 py-2 text-sm text-white hover:brightness-110'
|
|
||||||
>
|
|
||||||
Finish
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
4
macapp/src/declarations.d.ts
vendored
@@ -1,4 +0,0 @@
|
|||||||
declare module '*.svg' {
|
|
||||||
const content: string
|
|
||||||
export default content
|
|
||||||
}
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8" />
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div id="app"></div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,302 +0,0 @@
|
|||||||
import { spawn, ChildProcess } from 'child_process'
|
|
||||||
import { app, autoUpdater, dialog, Tray, Menu, BrowserWindow, MenuItemConstructorOptions, nativeTheme } from 'electron'
|
|
||||||
import Store from 'electron-store'
|
|
||||||
import winston from 'winston'
|
|
||||||
import 'winston-daily-rotate-file'
|
|
||||||
import * as path from 'path'
|
|
||||||
|
|
||||||
import { v4 as uuidv4 } from 'uuid'
|
|
||||||
import { installed } from './install'
|
|
||||||
|
|
||||||
require('@electron/remote/main').initialize()
|
|
||||||
|
|
||||||
if (require('electron-squirrel-startup')) {
|
|
||||||
app.quit()
|
|
||||||
}
|
|
||||||
|
|
||||||
const store = new Store()
|
|
||||||
|
|
||||||
let welcomeWindow: BrowserWindow | null = null
|
|
||||||
|
|
||||||
declare const MAIN_WINDOW_WEBPACK_ENTRY: string
|
|
||||||
|
|
||||||
const logger = winston.createLogger({
|
|
||||||
transports: [
|
|
||||||
new winston.transports.Console(),
|
|
||||||
new winston.transports.File({
|
|
||||||
filename: path.join(app.getPath('home'), '.ollama', 'logs', 'server.log'),
|
|
||||||
maxsize: 1024 * 1024 * 20,
|
|
||||||
maxFiles: 5,
|
|
||||||
}),
|
|
||||||
],
|
|
||||||
format: winston.format.printf(info => info.message),
|
|
||||||
})
|
|
||||||
|
|
||||||
app.on('ready', () => {
|
|
||||||
const gotTheLock = app.requestSingleInstanceLock()
|
|
||||||
if (!gotTheLock) {
|
|
||||||
app.exit(0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
app.on('second-instance', () => {
|
|
||||||
if (app.hasSingleInstanceLock()) {
|
|
||||||
app.releaseSingleInstanceLock()
|
|
||||||
}
|
|
||||||
|
|
||||||
if (proc) {
|
|
||||||
proc.off('exit', restart)
|
|
||||||
proc.kill()
|
|
||||||
}
|
|
||||||
|
|
||||||
app.exit(0)
|
|
||||||
})
|
|
||||||
|
|
||||||
app.focus({ steal: true })
|
|
||||||
|
|
||||||
init()
|
|
||||||
})
|
|
||||||
|
|
||||||
function firstRunWindow() {
|
|
||||||
// Create the browser window.
|
|
||||||
welcomeWindow = new BrowserWindow({
|
|
||||||
width: 400,
|
|
||||||
height: 500,
|
|
||||||
frame: false,
|
|
||||||
fullscreenable: false,
|
|
||||||
resizable: false,
|
|
||||||
movable: true,
|
|
||||||
show: false,
|
|
||||||
webPreferences: {
|
|
||||||
nodeIntegration: true,
|
|
||||||
contextIsolation: false,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
require('@electron/remote/main').enable(welcomeWindow.webContents)
|
|
||||||
|
|
||||||
welcomeWindow.loadURL(MAIN_WINDOW_WEBPACK_ENTRY)
|
|
||||||
welcomeWindow.on('ready-to-show', () => welcomeWindow.show())
|
|
||||||
welcomeWindow.on('closed', () => {
|
|
||||||
if (process.platform === 'darwin') {
|
|
||||||
app.dock.hide()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
let tray: Tray | null = null
|
|
||||||
let updateAvailable = false
|
|
||||||
const assetPath = app.isPackaged ? process.resourcesPath : path.join(__dirname, '..', '..', 'assets')
|
|
||||||
|
|
||||||
function trayIconPath() {
|
|
||||||
return nativeTheme.shouldUseDarkColors
|
|
||||||
? updateAvailable
|
|
||||||
? path.join(assetPath, 'iconDarkUpdateTemplate.png')
|
|
||||||
: path.join(assetPath, 'iconDarkTemplate.png')
|
|
||||||
: updateAvailable
|
|
||||||
? path.join(assetPath, 'iconUpdateTemplate.png')
|
|
||||||
: path.join(assetPath, 'iconTemplate.png')
|
|
||||||
}
|
|
||||||
|
|
||||||
function updateTrayIcon() {
|
|
||||||
if (tray) {
|
|
||||||
tray.setImage(trayIconPath())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function updateTray() {
|
|
||||||
const updateItems: MenuItemConstructorOptions[] = [
|
|
||||||
{ label: 'An update is available', enabled: false },
|
|
||||||
{
|
|
||||||
label: 'Restart to update',
|
|
||||||
click: () => autoUpdater.quitAndInstall(),
|
|
||||||
},
|
|
||||||
{ type: 'separator' },
|
|
||||||
]
|
|
||||||
|
|
||||||
const menu = Menu.buildFromTemplate([
|
|
||||||
...(updateAvailable ? updateItems : []),
|
|
||||||
{ role: 'quit', label: 'Quit Ollama', accelerator: 'Command+Q' },
|
|
||||||
])
|
|
||||||
|
|
||||||
if (!tray) {
|
|
||||||
tray = new Tray(trayIconPath())
|
|
||||||
}
|
|
||||||
|
|
||||||
tray.setToolTip(updateAvailable ? 'An update is available' : 'Ollama')
|
|
||||||
tray.setContextMenu(menu)
|
|
||||||
tray.setImage(trayIconPath())
|
|
||||||
|
|
||||||
nativeTheme.off('updated', updateTrayIcon)
|
|
||||||
nativeTheme.on('updated', updateTrayIcon)
|
|
||||||
}
|
|
||||||
|
|
||||||
let proc: ChildProcess = null
|
|
||||||
|
|
||||||
function server() {
|
|
||||||
const binary = app.isPackaged
|
|
||||||
? path.join(process.resourcesPath, 'ollama')
|
|
||||||
: path.resolve(process.cwd(), '..', 'ollama')
|
|
||||||
|
|
||||||
proc = spawn(binary, ['serve'])
|
|
||||||
|
|
||||||
proc.stdout.on('data', data => {
|
|
||||||
logger.info(data.toString().trim())
|
|
||||||
})
|
|
||||||
|
|
||||||
proc.stderr.on('data', data => {
|
|
||||||
logger.error(data.toString().trim())
|
|
||||||
})
|
|
||||||
|
|
||||||
proc.on('exit', restart)
|
|
||||||
}
|
|
||||||
|
|
||||||
function restart() {
|
|
||||||
setTimeout(server, 1000)
|
|
||||||
}
|
|
||||||
|
|
||||||
app.on('before-quit', () => {
|
|
||||||
if (proc) {
|
|
||||||
proc.off('exit', restart)
|
|
||||||
proc.kill('SIGINT') // send SIGINT signal to the server, which also stops any loaded llms
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
const updateURL = `https://ollama.com/api/update?os=${process.platform}&arch=${
|
|
||||||
process.arch
|
|
||||||
}&version=${app.getVersion()}&id=${id()}`
|
|
||||||
|
|
||||||
let latest = ''
|
|
||||||
async function isNewReleaseAvailable() {
|
|
||||||
try {
|
|
||||||
const response = await fetch(updateURL)
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if (response.status === 204) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
const data = await response.json()
|
|
||||||
|
|
||||||
const url = data?.url
|
|
||||||
if (!url) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if (latest === url) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
latest = url
|
|
||||||
|
|
||||||
return true
|
|
||||||
} catch (error) {
|
|
||||||
logger.error(`update check failed - ${error}`)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function checkUpdate() {
|
|
||||||
const available = await isNewReleaseAvailable()
|
|
||||||
if (available) {
|
|
||||||
logger.info('checking for update')
|
|
||||||
autoUpdater.checkForUpdates()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function init() {
|
|
||||||
if (app.isPackaged) {
|
|
||||||
checkUpdate()
|
|
||||||
setInterval(() => {
|
|
||||||
checkUpdate()
|
|
||||||
}, 60 * 60 * 1000)
|
|
||||||
}
|
|
||||||
|
|
||||||
updateTray()
|
|
||||||
|
|
||||||
if (process.platform === 'darwin') {
|
|
||||||
if (app.isPackaged) {
|
|
||||||
if (!app.isInApplicationsFolder()) {
|
|
||||||
const chosen = dialog.showMessageBoxSync({
|
|
||||||
type: 'question',
|
|
||||||
buttons: ['Move to Applications', 'Do Not Move'],
|
|
||||||
message: 'Ollama works best when run from the Applications directory.',
|
|
||||||
defaultId: 0,
|
|
||||||
cancelId: 1,
|
|
||||||
})
|
|
||||||
|
|
||||||
if (chosen === 0) {
|
|
||||||
try {
|
|
||||||
app.moveToApplicationsFolder({
|
|
||||||
conflictHandler: conflictType => {
|
|
||||||
if (conflictType === 'existsAndRunning') {
|
|
||||||
dialog.showMessageBoxSync({
|
|
||||||
type: 'info',
|
|
||||||
message: 'Cannot move to Applications directory',
|
|
||||||
detail:
|
|
||||||
'Another version of Ollama is currently running from your Applications directory. Close it first and try again.',
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
} catch (e) {
|
|
||||||
logger.error(`[Move to Applications] Failed to move to applications folder - ${e.message}}`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
server()
|
|
||||||
|
|
||||||
if (store.get('first-time-run') && installed()) {
|
|
||||||
if (process.platform === 'darwin') {
|
|
||||||
app.dock.hide()
|
|
||||||
}
|
|
||||||
|
|
||||||
app.setLoginItemSettings({ openAtLogin: app.getLoginItemSettings().openAtLogin })
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is the first run or the CLI is no longer installed
|
|
||||||
app.setLoginItemSettings({ openAtLogin: true })
|
|
||||||
firstRunWindow()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Quit when all windows are closed, except on macOS. There, it's common
|
|
||||||
// for applications and their menu bar to stay active until the user quits
|
|
||||||
// explicitly with Cmd + Q.
|
|
||||||
app.on('window-all-closed', () => {
|
|
||||||
if (process.platform !== 'darwin') {
|
|
||||||
app.quit()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
function id(): string {
|
|
||||||
const id = store.get('id') as string
|
|
||||||
|
|
||||||
if (id) {
|
|
||||||
return id
|
|
||||||
}
|
|
||||||
|
|
||||||
const uuid = uuidv4()
|
|
||||||
store.set('id', uuid)
|
|
||||||
return uuid
|
|
||||||
}
|
|
||||||
|
|
||||||
autoUpdater.setFeedURL({ url: updateURL })
|
|
||||||
|
|
||||||
autoUpdater.on('error', e => {
|
|
||||||
logger.error(`update check failed - ${e.message}`)
|
|
||||||
console.error(`update check failed - ${e.message}`)
|
|
||||||
})
|
|
||||||
|
|
||||||
autoUpdater.on('update-downloaded', () => {
|
|
||||||
updateAvailable = true
|
|
||||||
updateTray()
|
|
||||||
})
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
import * as fs from 'fs'
|
|
||||||
import { exec as cbExec } from 'child_process'
|
|
||||||
import * as path from 'path'
|
|
||||||
import { promisify } from 'util'
|
|
||||||
|
|
||||||
const app = process && process.type === 'renderer' ? require('@electron/remote').app : require('electron').app
|
|
||||||
const ollama = app.isPackaged ? path.join(process.resourcesPath, 'ollama') : path.resolve(process.cwd(), '..', 'ollama')
|
|
||||||
const exec = promisify(cbExec)
|
|
||||||
const symlinkPath = '/usr/local/bin/ollama'
|
|
||||||
|
|
||||||
export function installed() {
|
|
||||||
return fs.existsSync(symlinkPath) && fs.readlinkSync(symlinkPath) === ollama
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function install() {
|
|
||||||
const command = `do shell script "mkdir -p ${path.dirname(
|
|
||||||
symlinkPath
|
|
||||||
)} && ln -F -s \\"${ollama}\\" \\"${symlinkPath}\\"" with administrator privileges`
|
|
||||||
|
|
||||||
await exec(`osascript -e '${command}'`)
|
|
||||||
}
|
|
||||||
|
Before Width: | Height: | Size: 17 KiB |
@@ -1,7 +0,0 @@
|
|||||||
import App from './app'
|
|
||||||
import './app.css'
|
|
||||||
import { createRoot } from 'react-dom/client'
|
|
||||||
|
|
||||||
const container = document.getElementById('app')
|
|
||||||
const root = createRoot(container)
|
|
||||||
root.render(<App />)
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
/** @type {import('tailwindcss').Config} */
|
|
||||||
module.exports = {
|
|
||||||
content: ['./src/**/*.{js,ts,jsx,tsx,mdx}'],
|
|
||||||
theme: {},
|
|
||||||
plugins: [],
|
|
||||||
}
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
{
|
|
||||||
"compilerOptions": {
|
|
||||||
"target": "ES6",
|
|
||||||
"allowJs": true,
|
|
||||||
"module": "commonjs",
|
|
||||||
"skipLibCheck": true,
|
|
||||||
"esModuleInterop": true,
|
|
||||||
"noImplicitAny": true,
|
|
||||||
"sourceMap": true,
|
|
||||||
"baseUrl": ".",
|
|
||||||
"outDir": "dist",
|
|
||||||
"moduleResolution": "node",
|
|
||||||
"resolveJsonModule": true,
|
|
||||||
"paths": {
|
|
||||||
"*": ["node_modules/*"]
|
|
||||||
},
|
|
||||||
"jsx": "react-jsx"
|
|
||||||
},
|
|
||||||
"include": ["src/**/*"]
|
|
||||||
}
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
import type { Configuration } from 'webpack'
|
|
||||||
|
|
||||||
import { rules } from './webpack.rules'
|
|
||||||
import { plugins } from './webpack.plugins'
|
|
||||||
|
|
||||||
export const mainConfig: Configuration = {
|
|
||||||
/**
|
|
||||||
* This is the main entry point for your application, it's the first file
|
|
||||||
* that runs in the main process.
|
|
||||||
*/
|
|
||||||
entry: './src/index.ts',
|
|
||||||
// Put your normal webpack config below here
|
|
||||||
module: {
|
|
||||||
rules,
|
|
||||||
},
|
|
||||||
plugins,
|
|
||||||
resolve: {
|
|
||||||
extensions: ['.js', '.ts', '.jsx', '.tsx', '.css', '.json'],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
import type IForkTsCheckerWebpackPlugin from 'fork-ts-checker-webpack-plugin'
|
|
||||||
import { DefinePlugin } from 'webpack'
|
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
|
||||||
const ForkTsCheckerWebpackPlugin: typeof IForkTsCheckerWebpackPlugin = require('fork-ts-checker-webpack-plugin')
|
|
||||||
|
|
||||||
export const plugins = [
|
|
||||||
new ForkTsCheckerWebpackPlugin({
|
|
||||||
logger: 'webpack-infrastructure',
|
|
||||||
}),
|
|
||||||
new DefinePlugin({
|
|
||||||
'process.env.TELEMETRY_WRITE_KEY': JSON.stringify(process.env.TELEMETRY_WRITE_KEY),
|
|
||||||
}),
|
|
||||||
]
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
import type { Configuration } from 'webpack'
|
|
||||||
|
|
||||||
import { rules } from './webpack.rules'
|
|
||||||
import { plugins } from './webpack.plugins'
|
|
||||||
|
|
||||||
rules.push({
|
|
||||||
test: /\.css$/,
|
|
||||||
use: [{ loader: 'style-loader' }, { loader: 'css-loader' }, { loader: 'postcss-loader' }],
|
|
||||||
})
|
|
||||||
|
|
||||||
export const rendererConfig: Configuration = {
|
|
||||||
module: {
|
|
||||||
rules,
|
|
||||||
},
|
|
||||||
plugins,
|
|
||||||
resolve: {
|
|
||||||
extensions: ['.js', '.ts', '.jsx', '.tsx', '.css'],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
import type { ModuleOptions } from 'webpack'
|
|
||||||
|
|
||||||
export const rules: Required<ModuleOptions>['rules'] = [
|
|
||||||
// Add support for native node modules
|
|
||||||
{
|
|
||||||
// We're specifying native_modules in the test because the asset relocator loader generates a
|
|
||||||
// "fake" .node file which is really a cjs file.
|
|
||||||
test: /native_modules[/\\].+\.node$/,
|
|
||||||
use: 'node-loader',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
test: /[/\\]node_modules[/\\].+\.(m?js|node)$/,
|
|
||||||
parser: { amd: false },
|
|
||||||
use: {
|
|
||||||
loader: '@vercel/webpack-asset-relocator-loader',
|
|
||||||
options: {
|
|
||||||
outputAssetBase: 'native_modules',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
test: /\.tsx?$/,
|
|
||||||
exclude: /(node_modules|\.webpack)/,
|
|
||||||
use: {
|
|
||||||
loader: 'ts-loader',
|
|
||||||
options: {
|
|
||||||
transpileOnly: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
test: /\.svg$/,
|
|
||||||
use: ['@svgr/webpack'],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
@@ -146,7 +146,6 @@ type Tensor interface {
|
|||||||
FromFloats([]float32)
|
FromFloats([]float32)
|
||||||
FromInts([]int32)
|
FromInts([]int32)
|
||||||
|
|
||||||
Neg(ctx Context) Tensor
|
|
||||||
Add(ctx Context, t2 Tensor) Tensor
|
Add(ctx Context, t2 Tensor) Tensor
|
||||||
Sub(ctx Context, t2 Tensor) Tensor
|
Sub(ctx Context, t2 Tensor) Tensor
|
||||||
Mul(ctx Context, t2 Tensor) Tensor
|
Mul(ctx Context, t2 Tensor) Tensor
|
||||||
@@ -185,7 +184,6 @@ type Tensor interface {
|
|||||||
View(ctx Context, offset int, shape ...int) Tensor
|
View(ctx Context, offset int, shape ...int) Tensor
|
||||||
Permute(ctx Context, shape ...int) Tensor
|
Permute(ctx Context, shape ...int) Tensor
|
||||||
Contiguous(ctx Context, shape ...int) Tensor
|
Contiguous(ctx Context, shape ...int) Tensor
|
||||||
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
|
|
||||||
|
|
||||||
Pad(ctx Context, shape ...int) Tensor
|
Pad(ctx Context, shape ...int) Tensor
|
||||||
|
|
||||||
@@ -198,6 +196,10 @@ type Tensor interface {
|
|||||||
Copy(ctx Context, t2 Tensor) Tensor
|
Copy(ctx Context, t2 Tensor) Tensor
|
||||||
Duplicate(ctx Context) Tensor
|
Duplicate(ctx Context) Tensor
|
||||||
|
|
||||||
|
Slice(ctx Context, dim, low, high, step int) Tensor
|
||||||
|
Chunk(ctx Context, dim int, size int) []Tensor
|
||||||
|
ChunkSections(ctx Context, dim int, sections ...int) []Tensor
|
||||||
|
|
||||||
TopK(ctx Context, k int) Tensor
|
TopK(ctx Context, k int) Tensor
|
||||||
Argsort(ctx Context) Tensor
|
Argsort(ctx Context) Tensor
|
||||||
Mean(ctx Context) Tensor
|
Mean(ctx Context) Tensor
|
||||||
@@ -205,7 +207,6 @@ type Tensor interface {
|
|||||||
Stddev(ctx Context) Tensor
|
Stddev(ctx Context) Tensor
|
||||||
Sqr(ctx Context) Tensor
|
Sqr(ctx Context) Tensor
|
||||||
Sqrt(ctx Context) Tensor
|
Sqrt(ctx Context) Tensor
|
||||||
Clamp(ctx Context, min, max float32) Tensor
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScaledDotProductAttention implements a fused attention
|
// ScaledDotProductAttention implements a fused attention
|
||||||
|
|||||||
@@ -1137,13 +1137,6 @@ func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
|
|
||||||
return &Tensor{
|
|
||||||
b: t.b,
|
|
||||||
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
@@ -1632,20 +1625,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
|
|
||||||
var tt *C.struct_ggml_tensor
|
|
||||||
switch len(strides) {
|
|
||||||
case 0:
|
|
||||||
tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
|
|
||||||
case 1:
|
|
||||||
tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
|
|
||||||
default:
|
|
||||||
panic("unsupported number of dimensions")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Tensor{b: t.b, t: tt}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
|
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
|
||||||
var kqMask *C.struct_ggml_tensor
|
var kqMask *C.struct_ggml_tensor
|
||||||
if mask != nil {
|
if mask != nil {
|
||||||
@@ -1732,9 +1711,65 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
|
// Slice returns a view of the tensor sliced along dim from low to high in step steps.
|
||||||
return &Tensor{
|
// Slice panics if the dimension is invalid or the slice parameters are out of range.
|
||||||
b: t.b,
|
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.
|
||||||
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
|
func (t *Tensor) Slice(ctx ml.Context, dim int, low, high, step int) ml.Tensor {
|
||||||
|
if dim < 0 || dim >= C.GGML_MAX_DIMS {
|
||||||
|
panic("invalid dimension")
|
||||||
|
} else if low < 0 || high > t.Dim(dim) || low >= high || step < 1 {
|
||||||
|
panic("invalid slice parameters")
|
||||||
|
}
|
||||||
|
|
||||||
|
if dim == 0 && step > 1 {
|
||||||
|
// dim=0,step>1 is a special case so handle it here first
|
||||||
|
return t.View(ctx,
|
||||||
|
low*t.Stride(0), 1,
|
||||||
|
step*t.Stride(0), (high-low+1)/step,
|
||||||
|
t.Stride(1), t.Dim(1),
|
||||||
|
// preserve dim 3 by merging it into dim 2
|
||||||
|
t.Stride(2), t.Dim(2)*t.Dim(3),
|
||||||
|
).Contiguous(ctx, (high-low+1)/step, t.Dim(1), t.Dim(2), t.Dim(3))
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []int{
|
||||||
|
low * t.Stride(dim), t.Dim(0),
|
||||||
|
t.Stride(1), t.Dim(1),
|
||||||
|
t.Stride(2), t.Dim(2),
|
||||||
|
t.Stride(3), t.Dim(3),
|
||||||
|
}
|
||||||
|
|
||||||
|
if step == 1 {
|
||||||
|
args[dim*2+1] = high - low
|
||||||
|
return t.View(ctx, args[0], args[1:]...)
|
||||||
|
} else {
|
||||||
|
args[dim*2] = step * t.Stride(dim)
|
||||||
|
args[dim*2+1] = (high - low + 1) / step
|
||||||
|
return t.View(ctx, args[0], args[1:]...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Chunk the tensor into chunk sized tensors along dim. Each sub-tensor is a view of
|
||||||
|
// the original.
|
||||||
|
func (t *Tensor) Chunk(ctx ml.Context, dim, chunk int) []ml.Tensor {
|
||||||
|
sections := make([]int, 0, t.Dim(dim)/chunk+1)
|
||||||
|
for rest := t.Dim(dim); rest > 0; rest -= chunk {
|
||||||
|
sections = append(sections, min(chunk, rest))
|
||||||
|
}
|
||||||
|
return t.ChunkSections(ctx, dim, sections...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChunkSections split the tensor into section sized tensors along dim. Each sub-tensor is a
|
||||||
|
// view of the original. The size of the dim must equal the sum of sections.
|
||||||
|
func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Tensor {
|
||||||
|
var offset int
|
||||||
|
s := make([]ml.Tensor, len(sections))
|
||||||
|
for i, section := range sections {
|
||||||
|
s[i] = t.Slice(ctx, dim, offset, offset+section, 1)
|
||||||
|
offset += section
|
||||||
|
}
|
||||||
|
if offset != t.Dim(dim) {
|
||||||
|
panic("sections do not sum to tensor dimension")
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|||||||
16
ml/backend/ggml/ggml/src/ggml-impl.h
vendored
@@ -693,6 +693,7 @@ GGML_API void ggml_dxgi_pdh_release();
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
#include <array>
|
||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -708,6 +709,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
|||||||
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
|
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return true if the edges in the graph match expectations.
|
||||||
|
inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
|
||||||
|
int start_idx,
|
||||||
|
std::initializer_list<std::array<int, 3>> edges) {
|
||||||
|
for (const auto & edge : edges) {
|
||||||
|
int dst_node = edge[0];
|
||||||
|
int src_idx = edge[1];
|
||||||
|
int src_node = edge[2];
|
||||||
|
if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// expose GGUF internals for test code
|
// expose GGUF internals for test code
|
||||||
GGML_API size_t gguf_type_size(enum gguf_type type);
|
GGML_API size_t gguf_type_size(enum gguf_type type);
|
||||||
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
||||||
|
|||||||
@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
|
|||||||
char name[256];
|
char name[256];
|
||||||
|
|
||||||
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
|
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
|
||||||
snprintf(name, 256, "%s", base);
|
snprintf(name, 256, "%s_ne02=%d", base, ne02);
|
||||||
|
|
||||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
if (res) {
|
if (res) {
|
||||||
|
|||||||
814
ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp
vendored
@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
|
|||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
layout (push_constant) uniform parameter {
|
||||||
uint ncols;
|
uint ncols;
|
||||||
|
uint nrows;
|
||||||
uint order;
|
uint order;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
|
|||||||
dst_row[idx1] = tmp;
|
dst_row[idx1] = tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
void argsort(bool needs_bounds_check) {
|
void argsort(bool needs_bounds_check, const uint row) {
|
||||||
// bitonic sort
|
// bitonic sort
|
||||||
const int col = int(gl_LocalInvocationID.x);
|
const int col = int(gl_LocalInvocationID.x);
|
||||||
const uint row = gl_WorkGroupID.y;
|
|
||||||
|
|
||||||
const uint row_offset = row * p.ncols;
|
const uint row_offset = row * p.ncols;
|
||||||
|
|
||||||
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
|
|||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
if (p.ncols == BLOCK_SIZE) {
|
if (p.ncols == BLOCK_SIZE) {
|
||||||
argsort(false);
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
argsort(false, row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
argsort(true);
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
argsort(true, row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
|||||||
#if defined(DATA_A_MXFP4)
|
#if defined(DATA_A_MXFP4)
|
||||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||||
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
|
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
|
||||||
}
|
}
|
||||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||||
vec2 v0 = dequantize(ib, iqs, a_offset);
|
vec2 v0 = dequantize(ib, iqs, a_offset);
|
||||||
@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
|||||||
|
|
||||||
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
|
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
|
||||||
const uint scales = data_a[a_offset + ib].scales[scalesi];
|
const uint scales = data_a[a_offset + ib].scales[scalesi];
|
||||||
const vec2 d = vec2(data_a[a_offset + ib].d);
|
const vec2 dm = vec2(data_a[a_offset + ib].dm);
|
||||||
|
|
||||||
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
|
||||||
}
|
}
|
||||||
vec2 get_dm(uint ib, uint a_offset) {
|
vec2 get_dm(uint ib, uint a_offset) {
|
||||||
return vec2(1, 0);
|
return vec2(1, 0);
|
||||||
@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
|||||||
const uint is = 2 * n + b; // 0..7
|
const uint is = 2 * n + b; // 0..7
|
||||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||||
|
|
||||||
const vec2 loadd = vec2(data_a[a_offset + ib].d);
|
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
|
||||||
|
|
||||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||||
@@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
|||||||
|
|
||||||
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
||||||
|
|
||||||
const vec2 loadd = vec2(data_a[a_offset + ib].d);
|
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
|
||||||
|
|
||||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
|
|||||||
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||||
{
|
{
|
||||||
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
|
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
|
||||||
const f16vec2 d = bl.block.d;
|
const f16vec2 dm = bl.block.dm;
|
||||||
const uint idx = coordInBlock[1];
|
const uint idx = coordInBlock[1];
|
||||||
|
|
||||||
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
|
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
|
||||||
@@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
|
|||||||
qs = unpack8(qs)[idx & 1];
|
qs = unpack8(qs)[idx & 1];
|
||||||
|
|
||||||
const uint scales = bl.block.scales[scalesi];
|
const uint scales = bl.block.scales[scalesi];
|
||||||
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
|
float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
|
|||||||
uint32_t qs = bl.block.qs[iqs];
|
uint32_t qs = bl.block.qs[iqs];
|
||||||
qs >>= shift;
|
qs >>= shift;
|
||||||
qs &= 0xF;
|
qs &= 0xF;
|
||||||
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
|
float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ void main() {
|
|||||||
const float d = e8m0_to_fp32(data_a[ib].e);
|
const float d = e8m0_to_fp32(data_a[ib].e);
|
||||||
|
|
||||||
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
||||||
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
|
data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
|
||||||
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
|
data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ void main() {
|
|||||||
const uint ql_idx = 32 * ip + il;
|
const uint ql_idx = 32 * ip + il;
|
||||||
const uint8_t qs = data_a[i].qs[32 * ip + il];
|
const uint8_t qs = data_a[i].qs[32 * ip + il];
|
||||||
|
|
||||||
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
|
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
|
||||||
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
|
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
|
||||||
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
|
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
|
||||||
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
|
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
|
||||||
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
|
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ void main() {
|
|||||||
const uint is = 2 * il;
|
const uint is = 2 * il;
|
||||||
const uint n = 4;
|
const uint n = 4;
|
||||||
|
|
||||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
|
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
|
||||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
|
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
|
||||||
|
|
||||||
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
|
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
|
||||||
const uint qs_idx = 32*il + n * ir;
|
const uint qs_idx = 32*il + n * ir;
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ void main() {
|
|||||||
const uint ir = tid % 16;
|
const uint ir = tid % 16;
|
||||||
const uint is = 2 * il;
|
const uint is = 2 * il;
|
||||||
|
|
||||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
|
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
|
||||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
|
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
|
||||||
|
|
||||||
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
|
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
|
||||||
const uint qs_idx = 32*il + 2 * ir;
|
const uint qs_idx = 32*il + 2 * ir;
|
||||||
|
|||||||
@@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
|||||||
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
|
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
|
||||||
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
||||||
|
|
||||||
vec2 d = vec2(data_a[ib0 + i].d);
|
const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
|
||||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
|
||||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
|
||||||
|
|
||||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
|
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
|
||||||
@@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
|||||||
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
|
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
|
||||||
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
|
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
|
||||||
}
|
}
|
||||||
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
|
temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
|||||||
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||||
vec2 d = vec2(data_a[ib0 + i].d);
|
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
|
||||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
|
||||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
|
||||||
|
|
||||||
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||||
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||||
@@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
|||||||
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
|
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
|
||||||
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
|
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
|
||||||
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
|
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
|
||||||
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
|
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
|||||||
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||||
vec2 d = vec2(data_a[ib0 + i].d);
|
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
|
||||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
|
||||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
|
||||||
|
|
||||||
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||||
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||||
@@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
|||||||
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
|
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
|
||||||
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
|
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
|
||||||
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
|
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
|
||||||
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
|
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
|
|||||||
|
|
||||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
|
||||||
shared u16vec2 row_ids[BN];
|
|
||||||
uint _ne1;
|
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
|
||||||
shared uvec4 ballots_sh[NUM_WARPS];
|
|
||||||
|
|
||||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
|
||||||
_ne1 = 0;
|
|
||||||
uint num_elements = p.nei1 * p.nei0;
|
|
||||||
uint nei0shift = findLSB(p.nei0);
|
|
||||||
|
|
||||||
uint ids[16];
|
|
||||||
uint iter = 0;
|
|
||||||
|
|
||||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
|
||||||
// prefetch up to 16 elements
|
|
||||||
if (iter == 0) {
|
|
||||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
|
||||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
|
||||||
bool in_range = i < num_elements;
|
|
||||||
uint ii1;
|
|
||||||
if (nei0_is_pow2) {
|
|
||||||
ii1 = i >> nei0shift;
|
|
||||||
} else {
|
|
||||||
ii1 = i / p.nei0;
|
|
||||||
}
|
|
||||||
uint ii0 = i - ii1 * p.nei0;
|
|
||||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
uint i = j + gl_LocalInvocationIndex;
|
|
||||||
bool in_range = i < num_elements;
|
|
||||||
uint ii1;
|
|
||||||
if (nei0_is_pow2) {
|
|
||||||
ii1 = i >> nei0shift;
|
|
||||||
} else {
|
|
||||||
ii1 = i / p.nei0;
|
|
||||||
}
|
|
||||||
uint ii0 = i - ii1 * p.nei0;
|
|
||||||
uint id = ids[iter++];
|
|
||||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
|
||||||
|
|
||||||
ballots_sh[gl_SubgroupID] = ballot;
|
|
||||||
barrier();
|
|
||||||
|
|
||||||
uint subgroup_base = 0;
|
|
||||||
uint total = 0;
|
|
||||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
|
||||||
if (k == gl_SubgroupID) {
|
|
||||||
subgroup_base = total;
|
|
||||||
}
|
|
||||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
|
|
||||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
|
||||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
|
||||||
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
|
|
||||||
}
|
|
||||||
_ne1 += total;
|
|
||||||
iter &= 15;
|
|
||||||
if (_ne1 >= (ic + 1) * BN) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
}
|
|
||||||
#endif // MUL_MAT_ID_USE_SUBGROUPS
|
|
||||||
#endif // MUL_MAT_ID
|
|
||||||
|
|
||||||
#ifdef COOPMAT
|
#ifdef COOPMAT
|
||||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "mul_mm_id_funcs.glsl"
|
||||||
#include "mul_mm_funcs.glsl"
|
#include "mul_mm_funcs.glsl"
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
|||||||
@@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||||||
const uint ib = idx / 128; // 2 values per idx
|
const uint ib = idx / 128; // 2 values per idx
|
||||||
const uint iqs = idx % 128; // 0..127
|
const uint iqs = idx % 128; // 0..127
|
||||||
|
|
||||||
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
|
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
|
||||||
const uint scalesi = iqs / 8; // 0..15
|
const uint scalesi = iqs / 8; // 0..15
|
||||||
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||||
|
|
||||||
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
|
const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
|
||||||
const uint scales = data_a[ib].scales[scalesi];
|
const uint scales = data_a[ib].scales[scalesi];
|
||||||
const vec2 d = vec2(data_a[ib].d);
|
const vec2 dm = vec2(data_a[ib].dm);
|
||||||
|
|
||||||
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
|
||||||
|
|
||||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||||
#elif defined(DATA_A_Q3_K)
|
#elif defined(DATA_A_Q3_K)
|
||||||
@@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||||||
const uint is = 2 * n + b; // 0..7
|
const uint is = 2 * n + b; // 0..7
|
||||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||||
|
|
||||||
const vec2 loadd = vec2(data_a[ib].d);
|
const vec2 loadd = vec2(data_a[ib].dm);
|
||||||
|
|
||||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||||
@@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||||||
|
|
||||||
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
||||||
|
|
||||||
const vec2 loadd = vec2(data_a[ib].d);
|
const vec2 loadd = vec2(data_a[ib].dm);
|
||||||
|
|
||||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||||
@@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
|||||||
const uint ib = idx / 8;
|
const uint ib = idx / 8;
|
||||||
const uint iqs = (idx & 0x07) * 2;
|
const uint iqs = (idx & 0x07) * 2;
|
||||||
|
|
||||||
const float d = e8m0_to_fp32(data_a[ib].e);
|
const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;
|
||||||
const uint vui = uint(data_a[ib].qs[iqs]);
|
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||||
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,70 @@
|
|||||||
|
#ifdef MUL_MAT_ID
|
||||||
|
shared u16vec2 row_ids[BN];
|
||||||
|
uint _ne1;
|
||||||
|
|
||||||
|
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||||
|
shared uvec4 ballots_sh[NUM_WARPS];
|
||||||
|
|
||||||
|
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||||
|
_ne1 = 0;
|
||||||
|
uint num_elements = p.nei1 * p.nei0;
|
||||||
|
uint nei0shift = findLSB(p.nei0);
|
||||||
|
|
||||||
|
uint ids[16];
|
||||||
|
uint iter = 0;
|
||||||
|
|
||||||
|
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||||
|
// prefetch up to 16 elements
|
||||||
|
if (iter == 0) {
|
||||||
|
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||||
|
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||||
|
bool in_range = i < num_elements;
|
||||||
|
uint ii1;
|
||||||
|
if (nei0_is_pow2) {
|
||||||
|
ii1 = i >> nei0shift;
|
||||||
|
} else {
|
||||||
|
ii1 = i / p.nei0;
|
||||||
|
}
|
||||||
|
uint ii0 = i - ii1 * p.nei0;
|
||||||
|
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
uint i = j + gl_LocalInvocationIndex;
|
||||||
|
bool in_range = i < num_elements;
|
||||||
|
uint ii1;
|
||||||
|
if (nei0_is_pow2) {
|
||||||
|
ii1 = i >> nei0shift;
|
||||||
|
} else {
|
||||||
|
ii1 = i / p.nei0;
|
||||||
|
}
|
||||||
|
uint ii0 = i - ii1 * p.nei0;
|
||||||
|
uint id = ids[iter++];
|
||||||
|
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||||
|
|
||||||
|
ballots_sh[gl_SubgroupID] = ballot;
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
uint subgroup_base = 0;
|
||||||
|
uint total = 0;
|
||||||
|
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||||
|
if (k == gl_SubgroupID) {
|
||||||
|
subgroup_base = total;
|
||||||
|
}
|
||||||
|
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||||
|
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||||
|
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
|
||||||
|
}
|
||||||
|
_ne1 += total;
|
||||||
|
iter &= 15;
|
||||||
|
if (_ne1 >= (ic + 1) * BN) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
#endif // MUL_MAT_ID_USE_SUBGROUPS
|
||||||
|
#endif // MUL_MAT_ID
|
||||||
@@ -10,10 +10,9 @@
|
|||||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef COOPMAT
|
#if defined(MUL_MAT_ID_USE_SUBGROUPS)
|
||||||
#extension GL_KHR_cooperative_matrix : enable
|
|
||||||
#extension GL_KHR_memory_scope_semantics : enable
|
|
||||||
#extension GL_KHR_shader_subgroup_basic : enable
|
#extension GL_KHR_shader_subgroup_basic : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
@@ -24,7 +23,10 @@
|
|||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
#if defined(A_TYPE_PACKED16)
|
||||||
|
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||||
|
#endif
|
||||||
#if defined(A_TYPE_PACKED32)
|
#if defined(A_TYPE_PACKED32)
|
||||||
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||||
#endif
|
#endif
|
||||||
@@ -76,40 +78,31 @@ layout (constant_id = 10) const uint WARP = 32;
|
|||||||
|
|
||||||
#define BK 32
|
#define BK 32
|
||||||
|
|
||||||
#ifdef COOPMAT
|
#define MMQ_SHMEM
|
||||||
#define SHMEM_STRIDE (BK / 4 + 4)
|
|
||||||
#else
|
|
||||||
#define SHMEM_STRIDE (BK / 4 + 1)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
|
#include "mul_mmq_shmem_types.glsl"
|
||||||
|
|
||||||
#ifndef COOPMAT
|
|
||||||
#if QUANT_AUXF == 1
|
|
||||||
shared FLOAT_TYPE buf_a_dm[BM];
|
|
||||||
#else
|
|
||||||
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
|
|
||||||
#ifndef COOPMAT
|
|
||||||
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define LOAD_VEC_A (4 * QUANT_R)
|
|
||||||
#define LOAD_VEC_B 16
|
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
shared u16vec2 row_ids[4096];
|
#define BK_STEP 1
|
||||||
#endif // MUL_MAT_ID
|
#else
|
||||||
|
#ifndef BK_STEP
|
||||||
|
#define BK_STEP 4
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Shared memory cache
|
||||||
|
shared block_a_cache buf_a[BM * BK_STEP];
|
||||||
|
shared block_b_cache buf_b[BN * BK_STEP];
|
||||||
|
// Register cache
|
||||||
|
block_a_cache cache_a[WMITER * TM];
|
||||||
|
block_b_cache cache_b;
|
||||||
|
|
||||||
|
#define LOAD_VEC_A (4 * QUANT_R_MMQ)
|
||||||
|
#define LOAD_VEC_B 16
|
||||||
|
|
||||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||||
|
|
||||||
#ifdef COOPMAT
|
#include "mul_mm_id_funcs.glsl"
|
||||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "mul_mmq_funcs.glsl"
|
#include "mul_mmq_funcs.glsl"
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
@@ -139,26 +132,12 @@ void main() {
|
|||||||
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||||
const uint WSUBM = WM / WMITER;
|
const uint WSUBM = WM / WMITER;
|
||||||
const uint WSUBN = WN / WNITER;
|
const uint WSUBN = WN / WNITER;
|
||||||
|
|
||||||
#ifdef COOPMAT
|
|
||||||
const uint warp_i = gl_SubgroupID;
|
|
||||||
|
|
||||||
const uint tiw = gl_SubgroupInvocationID;
|
|
||||||
|
|
||||||
const uint cms_per_row = WM / TM;
|
|
||||||
const uint cms_per_col = WN / TN;
|
|
||||||
|
|
||||||
const uint storestride = WARP / TM;
|
|
||||||
const uint store_r = tiw % TM;
|
|
||||||
const uint store_c = tiw / TM;
|
|
||||||
#else
|
|
||||||
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
||||||
|
|
||||||
const uint tiw = gl_LocalInvocationID.x % WARP;
|
const uint tiw = gl_LocalInvocationID.x % WARP;
|
||||||
|
|
||||||
const uint tiwr = tiw % (WSUBM / TM);
|
const uint tiwr = tiw % (WSUBM / TM);
|
||||||
const uint tiwc = tiw / (WSUBM / TM);
|
const uint tiwc = tiw / (WSUBM / TM);
|
||||||
#endif
|
|
||||||
|
|
||||||
const uint warp_r = warp_i % (BM / WM);
|
const uint warp_r = warp_i % (BM / WM);
|
||||||
const uint warp_c = warp_i / (BM / WM);
|
const uint warp_c = warp_i / (BM / WM);
|
||||||
@@ -172,17 +151,27 @@ void main() {
|
|||||||
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
|
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
uint _ne1 = 0;
|
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
if (bitCount(p.nei0) == 1) {
|
||||||
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
load_row_ids(expert_idx, true, ic);
|
||||||
|
} else {
|
||||||
|
load_row_ids(expert_idx, false, ic);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
_ne1 = 0;
|
||||||
|
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
|
||||||
|
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
|
||||||
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||||
row_ids[_ne1] = u16vec2(ii0, ii1);
|
if (_ne1 >= ic * BN) {
|
||||||
|
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
|
||||||
|
}
|
||||||
_ne1++;
|
_ne1++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
|
#endif
|
||||||
|
|
||||||
// Workgroup has no work
|
// Workgroup has no work
|
||||||
if (ic * BN >= _ne1) return;
|
if (ic * BN >= _ne1) return;
|
||||||
@@ -209,159 +198,70 @@ void main() {
|
|||||||
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
|
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef COOPMAT
|
|
||||||
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
|
||||||
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
|
||||||
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
|
|
||||||
|
|
||||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
|
|
||||||
|
|
||||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
|
||||||
|
|
||||||
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
|
||||||
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
int32_t cache_a_qs[WMITER * TM * BK / 4];
|
|
||||||
|
|
||||||
int32_t cache_b_qs[TN * BK / 4];
|
|
||||||
|
|
||||||
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
||||||
|
|
||||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||||
sums[i] = ACC_TYPE(0.0f);
|
sums[i] = ACC_TYPE(0.0f);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
#if QUANT_AUXF == 1
|
for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
|
||||||
FLOAT_TYPE cache_a_dm[WMITER * TM];
|
|
||||||
#else
|
|
||||||
FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
|
|
||||||
#endif
|
|
||||||
|
|
||||||
FLOAT_TYPE_VEC2 cache_b_ds[TN];
|
|
||||||
|
|
||||||
for (uint block = start_k; block < end_k; block += BK) {
|
|
||||||
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
|
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
|
||||||
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
|
|
||||||
const uint iqs = loadr_a;
|
|
||||||
const uint buf_ib = loadc_a + l;
|
const uint buf_ib = loadc_a + l;
|
||||||
|
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
|
||||||
|
const uint iqs = loadr_a;
|
||||||
|
|
||||||
if (iqs == 0) {
|
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
|
||||||
#if QUANT_AUXF == 1
|
block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
|
||||||
buf_a_dm[buf_ib] = get_d(ib);
|
|
||||||
#else
|
|
||||||
buf_a_dm[buf_ib] = get_dm(ib);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
#if QUANT_R == 1
|
|
||||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
|
|
||||||
#else
|
|
||||||
const i32vec2 vals = repack(ib, iqs);
|
|
||||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
|
|
||||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
|
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
|
||||||
#ifdef MUL_MAT_ID
|
|
||||||
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
|
||||||
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
|
||||||
const uint ib = idx / 8;
|
|
||||||
const uint iqs = idx & 0x7;
|
|
||||||
#else
|
|
||||||
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
|
|
||||||
const uint ib_outer = ib / 4;
|
|
||||||
const uint ib_inner = ib % 4;
|
|
||||||
|
|
||||||
const uint iqs = loadr_b;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
const uint buf_ib = loadc_b + l;
|
const uint buf_ib = loadc_b + l;
|
||||||
|
|
||||||
if (iqs == 0) {
|
#ifdef MUL_MAT_ID
|
||||||
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
const u16vec2 row_idx = row_ids[buf_ib];
|
||||||
|
const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
|
||||||
|
#else
|
||||||
|
const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
|
||||||
|
#endif
|
||||||
|
const uint iqs = loadr_b;
|
||||||
|
|
||||||
|
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
|
||||||
|
block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
|
||||||
}
|
}
|
||||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
|
||||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
|
|
||||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
|
|
||||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
|
|
||||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
pos_a_ib += 1;
|
pos_a_ib += BK_STEP;
|
||||||
pos_b_ib += 1;
|
pos_b_ib += BK_STEP;
|
||||||
|
|
||||||
#ifdef COOPMAT
|
for (uint k_step = 0; k_step < BK_STEP; k_step++) {
|
||||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
|
||||||
const uint ib_a = warp_r * WM + cm_row * TM;
|
|
||||||
// Load from shared into cache
|
|
||||||
coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
|
||||||
|
|
||||||
// TODO: only cache values that are actually needed
|
|
||||||
[[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
|
|
||||||
cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
|
||||||
const uint ib_b = warp_c * WN + cm_col * TN;
|
|
||||||
coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
|
||||||
|
|
||||||
// TODO: only cache values that are actually needed
|
|
||||||
[[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
|
|
||||||
cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
|
|
||||||
cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
|
|
||||||
|
|
||||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
|
||||||
coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
|
|
||||||
}
|
|
||||||
|
|
||||||
coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
|
||||||
sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
// Load from shared into cache
|
// Load from shared into cache
|
||||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||||
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
|
const uint reg_ib = wsir * TM + cr;
|
||||||
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
|
const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
|
||||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
|
||||||
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
|
block_a_to_registers(reg_ib, k_step * BM + buf_ib);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||||
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
|
const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
|
||||||
cache_b_ds[cc] = buf_b_ds[ib];
|
block_b_to_registers(ib);
|
||||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
|
||||||
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
|
||||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||||
const uint cache_a_idx = wsir * TM + cr;
|
const uint cache_a_idx = wsir * TM + cr;
|
||||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||||
int32_t q_sum = 0;
|
|
||||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
|
||||||
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
|
|
||||||
cache_b_qs[cc * (BK / 4) + idx_k]);
|
|
||||||
}
|
|
||||||
|
|
||||||
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
|
sums[sums_idx] += mmq_dot_product(cache_a_idx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
@@ -373,54 +273,6 @@ void main() {
|
|||||||
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef COOPMAT
|
|
||||||
#ifdef MUL_MAT_ID
|
|
||||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
|
||||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
|
||||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
|
||||||
|
|
||||||
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
|
|
||||||
const uint row_i = dc + cm_col * TN + col + store_c;
|
|
||||||
if (row_i >= _ne1) break;
|
|
||||||
|
|
||||||
const u16vec2 row_idx = row_ids[row_i];
|
|
||||||
|
|
||||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
|
|
||||||
|
|
||||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
|
||||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
|
||||||
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
|
|
||||||
|
|
||||||
if (is_aligned && is_in_bounds) {
|
|
||||||
// Full coopMat is within bounds and stride_d is aligned with 16B
|
|
||||||
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
|
|
||||||
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
|
|
||||||
} else if (is_in_bounds) {
|
|
||||||
// Full coopMat is within bounds, but stride_d is not aligned
|
|
||||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
|
||||||
|
|
||||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
|
||||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
|
||||||
}
|
|
||||||
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
|
|
||||||
// Partial coopMat is within bounds
|
|
||||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
|
||||||
|
|
||||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
|
||||||
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
|
|
||||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif // MUL_MAT_ID
|
|
||||||
#else
|
|
||||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
|
||||||
@@ -431,19 +283,21 @@ void main() {
|
|||||||
const uint row_i = dc_warp + cc;
|
const uint row_i = dc_warp + cc;
|
||||||
if (row_i >= _ne1) break;
|
if (row_i >= _ne1) break;
|
||||||
|
|
||||||
const u16vec2 row_idx = row_ids[row_i];
|
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||||
#endif // MUL_MAT_ID
|
#endif // MUL_MAT_ID
|
||||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||||
|
const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
if (dr_warp + cr < p.M) {
|
||||||
|
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
|
||||||
}
|
}
|
||||||
#endif // MUL_MAT_ID
|
#endif // MUL_MAT_ID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // COOPMAT
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,41 +6,89 @@
|
|||||||
|
|
||||||
// Each iqs value maps to a 32-bit integer
|
// Each iqs value maps to a 32-bit integer
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_0)
|
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||||
|
// 2-byte loads for Q4_0 blocks (18 bytes)
|
||||||
|
// 4-byte loads for Q4_1 blocks (20 bytes)
|
||||||
i32vec2 repack(uint ib, uint iqs) {
|
i32vec2 repack(uint ib, uint iqs) {
|
||||||
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
|
#ifdef DATA_A_Q4_0
|
||||||
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
|
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||||
data_a[ib].qs[iqs * 2 + 1]);
|
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||||
const uint32_t vui = pack32(quants);
|
const uint32_t vui = pack32(quants);
|
||||||
return i32vec2( vui & 0x0F0F0F0F,
|
return i32vec2( vui & 0x0F0F0F0F,
|
||||||
(vui >> 4) & 0x0F0F0F0F);
|
(vui >> 4) & 0x0F0F0F0F);
|
||||||
|
#else // DATA_A_Q4_1
|
||||||
|
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||||
|
return i32vec2( vui & 0x0F0F0F0F,
|
||||||
|
(vui >> 4) & 0x0F0F0F0F);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef DATA_A_Q4_0
|
||||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
|
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
|
||||||
}
|
}
|
||||||
#endif
|
#else // DATA_A_Q4_1
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_1)
|
|
||||||
i32vec2 repack(uint ib, uint iqs) {
|
|
||||||
// Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
|
|
||||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
|
||||||
return i32vec2( vui & 0x0F0F0F0F,
|
|
||||||
(vui >> 4) & 0x0F0F0F0F);
|
|
||||||
}
|
|
||||||
|
|
||||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q5_0)
|
#ifdef MMQ_SHMEM
|
||||||
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
|
#ifdef DATA_A_Q4_0
|
||||||
|
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||||
|
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||||
|
|
||||||
|
if (iqs == 0) {
|
||||||
|
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||||
|
}
|
||||||
|
#else // DATA_A_Q4_1
|
||||||
|
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
|
||||||
|
|
||||||
|
if (iqs == 0) {
|
||||||
|
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||||
|
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||||
|
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||||
|
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
|
int32_t q_sum = 0;
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||||
|
const uint32_t vui = cache_a[ib_a].qs[iqs];
|
||||||
|
const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F,
|
||||||
|
(vui >> 4) & 0x0F0F0F0F);
|
||||||
|
|
||||||
|
const int32_t qs_b0 = cache_b.qs[iqs];
|
||||||
|
const int32_t qs_b1 = cache_b.qs[iqs + 4];
|
||||||
|
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||||
|
}
|
||||||
|
#endif // MMQ_SHMEM
|
||||||
|
|
||||||
|
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||||
|
// 2-byte loads for Q5_0 blocks (22 bytes)
|
||||||
|
// 4-byte loads for Q5_1 blocks (24 bytes)
|
||||||
i32vec2 repack(uint ib, uint iqs) {
|
i32vec2 repack(uint ib, uint iqs) {
|
||||||
// Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
|
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||||
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
|
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||||
data_a[ib].qs[iqs * 2 + 1]);
|
|
||||||
const uint32_t vui = pack32(quants);
|
const uint32_t vui = pack32(quants);
|
||||||
const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
|
#ifdef DATA_A_Q5_0
|
||||||
|
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
|
||||||
|
#else // DATA_A_Q5_1
|
||||||
|
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
|
||||||
|
#endif
|
||||||
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
||||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||||
|
|
||||||
@@ -50,40 +98,457 @@ i32vec2 repack(uint ib, uint iqs) {
|
|||||||
return i32vec2(v0, v1);
|
return i32vec2(v0, v1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef DATA_A_Q5_0
|
||||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
|
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
|
||||||
}
|
}
|
||||||
#endif
|
#else // DATA_A_Q5_1
|
||||||
|
|
||||||
#if defined(DATA_A_Q5_1)
|
|
||||||
i32vec2 repack(uint ib, uint iqs) {
|
|
||||||
// Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
|
|
||||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
|
||||||
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
|
|
||||||
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
|
||||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
|
||||||
|
|
||||||
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
|
||||||
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
|
||||||
|
|
||||||
return i32vec2(v0, v1);
|
|
||||||
}
|
|
||||||
|
|
||||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef MMQ_SHMEM
|
||||||
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
|
#ifdef DATA_A_Q5_0
|
||||||
|
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||||
|
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||||
|
|
||||||
|
if (iqs == 0) {
|
||||||
|
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||||
|
buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1]));
|
||||||
|
}
|
||||||
|
#else // DATA_A_Q5_1
|
||||||
|
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
|
||||||
|
|
||||||
|
if (iqs == 0) {
|
||||||
|
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||||
|
buf_a[buf_ib].qh = data_a_packed32[ib].qh;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||||
|
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||||
|
cache_a[reg_ib].qh = buf_a[buf_ib].qh;
|
||||||
|
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||||
|
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
|
int32_t q_sum = 0;
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||||
|
const uint32_t vui = cache_a[ib_a].qs[iqs];
|
||||||
|
const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
|
||||||
|
const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
|
||||||
|
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||||
|
const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
||||||
|
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
||||||
|
|
||||||
|
const int32_t qs_b0 = cache_b.qs[iqs];
|
||||||
|
const int32_t qs_b1 = cache_b.qs[iqs + 4];
|
||||||
|
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||||
|
}
|
||||||
|
#endif // MMQ_SHMEM
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q8_0)
|
#if defined(DATA_A_Q8_0)
|
||||||
|
// 2-byte loads for Q8_0 blocks (34 bytes)
|
||||||
int32_t repack(uint ib, uint iqs) {
|
int32_t repack(uint ib, uint iqs) {
|
||||||
// Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
|
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||||
return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
|
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||||
data_a[ib].qs[iqs * 2 + 1]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||||
return ACC_TYPE(float(q_sum) * da * dsb.x);
|
return ACC_TYPE(float(q_sum) * da * dsb.x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef MMQ_SHMEM
|
||||||
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
|
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||||
|
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||||
|
|
||||||
|
if (iqs == 0) {
|
||||||
|
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||||
|
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||||
|
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||||
|
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
|
int32_t q_sum = 0;
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||||
|
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||||
|
const int32_t qs_b = cache_b.qs[iqs];
|
||||||
|
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a, qs_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||||
|
}
|
||||||
|
#endif // MMQ_SHMEM
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_MXFP4)
|
||||||
|
// 1-byte loads for mxfp4 blocks (17 bytes)
|
||||||
|
i32vec2 repack(uint ib, uint iqs) {
|
||||||
|
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
||||||
|
data_a[ib].qs[iqs * 4 + 1],
|
||||||
|
data_a[ib].qs[iqs * 4 + 2],
|
||||||
|
data_a[ib].qs[iqs * 4 + 3]));
|
||||||
|
|
||||||
|
return i32vec2( quants & 0x0F0F0F0F,
|
||||||
|
(quants >> 4) & 0x0F0F0F0F);
|
||||||
|
}
|
||||||
|
|
||||||
|
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||||
|
return ACC_TYPE(da * dsb.x * float(q_sum));
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef MMQ_SHMEM
|
||||||
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
|
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
||||||
|
data_a[ib].qs[iqs * 4 + 1],
|
||||||
|
data_a[ib].qs[iqs * 4 + 2],
|
||||||
|
data_a[ib].qs[iqs * 4 + 3]));
|
||||||
|
|
||||||
|
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
|
||||||
|
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
|
||||||
|
|
||||||
|
buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
|
||||||
|
buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
|
||||||
|
|
||||||
|
if (iqs == 0) {
|
||||||
|
buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||||
|
cache_a[reg_ib].d = buf_a[buf_ib].d;
|
||||||
|
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||||
|
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
|
int32_t q_sum = 0;
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||||
|
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||||
|
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
|
||||||
|
}
|
||||||
|
#endif // MMQ_SHMEM
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
|
||||||
|
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
|
||||||
|
#if defined(DATA_A_Q2_K)
|
||||||
|
// 4-byte loads for Q2_K blocks (84 bytes)
|
||||||
|
int32_t repack(uint ib, uint iqs) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||||
|
|
||||||
|
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||||
|
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||||
|
|
||||||
|
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t get_scale(uint ib, uint iqs) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||||
|
|
||||||
|
return data_a[ib_k].scales[iqs_k / 4];
|
||||||
|
}
|
||||||
|
|
||||||
|
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||||
|
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef MMQ_SHMEM
|
||||||
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
||||||
|
|
||||||
|
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||||
|
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||||
|
|
||||||
|
// Repack 4x4 quants into one int
|
||||||
|
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
|
||||||
|
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
|
||||||
|
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
|
||||||
|
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
|
||||||
|
|
||||||
|
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
|
||||||
|
|
||||||
|
if (iqs == 0) {
|
||||||
|
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||||
|
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||||
|
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||||
|
cache_a[reg_ib].scales = buf_a[buf_ib].scales;
|
||||||
|
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
|
||||||
|
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
|
int32_t sum_d = 0;
|
||||||
|
int32_t sum_m = 0;
|
||||||
|
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||||
|
const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
|
||||||
|
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
|
||||||
|
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
|
||||||
|
|
||||||
|
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
|
||||||
|
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||||
|
}
|
||||||
|
#endif // MMQ_SHMEM
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q3_K)
|
||||||
|
// 2-byte loads for Q3_K blocks (110 bytes)
|
||||||
|
#ifdef MMQ_SHMEM
|
||||||
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
const uint hm_idx = iqs * QUANT_R_MMQ;
|
||||||
|
const uint iqs_k = (ib % 8) * 8 + hm_idx;
|
||||||
|
|
||||||
|
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||||
|
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||||
|
const uint hm_shift = iqs_k / 8;
|
||||||
|
|
||||||
|
// Repack 2x4 quants into one int
|
||||||
|
// Add the 3rd bit instead of subtracting it to allow packing the quants
|
||||||
|
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||||
|
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||||
|
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||||
|
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||||
|
buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
|
||||||
|
(pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
|
||||||
|
|
||||||
|
if (iqs == 0) {
|
||||||
|
const uint is = iqs_k / 4;
|
||||||
|
const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
|
||||||
|
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
|
||||||
|
|
||||||
|
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||||
|
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
|
||||||
|
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||||
|
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
|
float result = 0.0;
|
||||||
|
int32_t q_sum = 0;
|
||||||
|
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||||
|
// Subtract 4 from the quants to correct the 3rd bit offset
|
||||||
|
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
|
||||||
|
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||||
|
}
|
||||||
|
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
|
||||||
|
q_sum = 0;
|
||||||
|
|
||||||
|
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
|
||||||
|
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
|
||||||
|
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||||
|
}
|
||||||
|
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
||||||
|
|
||||||
|
return ACC_TYPE(cache_b.ds.x * result);
|
||||||
|
}
|
||||||
|
#endif // MMQ_SHMEM
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
|
||||||
|
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
|
||||||
|
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||||
|
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef MMQ_SHMEM
|
||||||
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
||||||
|
|
||||||
|
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
|
||||||
|
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
|
||||||
|
|
||||||
|
// Repack 2x4 quants into one int
|
||||||
|
#if defined(DATA_A_Q4_K)
|
||||||
|
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
|
||||||
|
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
|
||||||
|
|
||||||
|
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
|
||||||
|
#else // defined(DATA_A_Q5_K)
|
||||||
|
const uint qh_idx = iqs * QUANT_R_MMQ;
|
||||||
|
const uint qh_shift = iqs_k / 8;
|
||||||
|
|
||||||
|
buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
|
||||||
|
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
if (iqs == 0) {
|
||||||
|
// Scale index
|
||||||
|
const uint is = iqs_k / 8;
|
||||||
|
u8vec2 scale_dm;
|
||||||
|
if (is < 4) {
|
||||||
|
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
|
||||||
|
} else {
|
||||||
|
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
|
||||||
|
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||||
|
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||||
|
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {
|
||||||
|
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
|
int32_t q_sum = 0;
|
||||||
|
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||||
|
#if defined(DATA_A_Q4_K)
|
||||||
|
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
|
||||||
|
#else // defined(DATA_A_Q5_K)
|
||||||
|
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||||
|
#endif
|
||||||
|
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||||
|
}
|
||||||
|
#endif // MMQ_SHMEM
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef MMQ_SHMEM
|
||||||
|
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
|
const uint ib_outer = ib / 4;
|
||||||
|
const uint ib_inner = ib % 4;
|
||||||
|
|
||||||
|
if (iqs == 0) {
|
||||||
|
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
||||||
|
}
|
||||||
|
|
||||||
|
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||||
|
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
|
||||||
|
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
|
||||||
|
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
|
||||||
|
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
void block_b_to_registers(const uint ib) {
|
||||||
|
cache_b.ds = buf_b[ib].ds;
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
|
||||||
|
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q6_K)
|
||||||
|
// 2-byte loads for Q6_K blocks (210 bytes)
|
||||||
|
#ifdef MMQ_SHMEM
|
||||||
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||||
|
|
||||||
|
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
|
||||||
|
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
|
||||||
|
|
||||||
|
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
|
||||||
|
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
|
||||||
|
|
||||||
|
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||||
|
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||||
|
buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
|
||||||
|
|
||||||
|
if (iqs == 0) {
|
||||||
|
const uint is = iqs_k / 4;
|
||||||
|
const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
|
||||||
|
|
||||||
|
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||||
|
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
|
||||||
|
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||||
|
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
|
float result = 0.0;
|
||||||
|
int32_t q_sum = 0;
|
||||||
|
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||||
|
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||||
|
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||||
|
}
|
||||||
|
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
|
||||||
|
q_sum = 0;
|
||||||
|
|
||||||
|
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
|
||||||
|
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||||
|
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||||
|
}
|
||||||
|
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
||||||
|
|
||||||
|
return ACC_TYPE(cache_b.ds.x * result);
|
||||||
|
}
|
||||||
|
#endif // MMQ_SHMEM
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||||
@@ -103,3 +568,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
|||||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q2_K)
|
||||||
|
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|||||||
@@ -0,0 +1,78 @@
|
|||||||
|
#if defined(DATA_A_Q4_0)
|
||||||
|
#define QUANT_R_MMQ 2
|
||||||
|
struct block_a_cache {
|
||||||
|
uint32_t qs[16/4];
|
||||||
|
FLOAT_TYPE dm;
|
||||||
|
};
|
||||||
|
#elif defined(DATA_A_Q4_1)
|
||||||
|
#define QUANT_R_MMQ 2
|
||||||
|
struct block_a_cache {
|
||||||
|
uint32_t qs[16/4];
|
||||||
|
FLOAT_TYPE_VEC2 dm;
|
||||||
|
};
|
||||||
|
#elif defined(DATA_A_Q5_0)
|
||||||
|
#define QUANT_R_MMQ 2
|
||||||
|
struct block_a_cache {
|
||||||
|
uint32_t qs[16/4];
|
||||||
|
uint32_t qh;
|
||||||
|
FLOAT_TYPE dm;
|
||||||
|
};
|
||||||
|
#elif defined(DATA_A_Q5_1)
|
||||||
|
#define QUANT_R_MMQ 2
|
||||||
|
struct block_a_cache {
|
||||||
|
uint32_t qs[16/4];
|
||||||
|
uint32_t qh;
|
||||||
|
FLOAT_TYPE_VEC2 dm;
|
||||||
|
};
|
||||||
|
#elif defined(DATA_A_Q8_0)
|
||||||
|
#define QUANT_R_MMQ 1
|
||||||
|
// AMD likes 4, Intel likes 1 and Nvidia likes 2
|
||||||
|
// #define BK_STEP 1
|
||||||
|
struct block_a_cache {
|
||||||
|
int32_t qs[32/4];
|
||||||
|
FLOAT_TYPE dm;
|
||||||
|
};
|
||||||
|
#elif defined(DATA_A_MXFP4)
|
||||||
|
#define QUANT_R_MMQ 2
|
||||||
|
struct block_a_cache {
|
||||||
|
int32_t qs[8];
|
||||||
|
FLOAT_TYPE d;
|
||||||
|
};
|
||||||
|
#elif defined(DATA_A_Q2_K)
|
||||||
|
#define QUANT_R_MMQ 4
|
||||||
|
struct block_a_cache {
|
||||||
|
uint32_t qs[2];
|
||||||
|
u8vec2 scales;
|
||||||
|
FLOAT_TYPE_VEC2 dm;
|
||||||
|
};
|
||||||
|
#elif defined(DATA_A_Q3_K)
|
||||||
|
#define QUANT_R_MMQ 2
|
||||||
|
struct block_a_cache {
|
||||||
|
uint32_t qs[4];
|
||||||
|
FLOAT_TYPE_VEC2 d_scales;
|
||||||
|
};
|
||||||
|
#elif defined(DATA_A_Q4_K)
|
||||||
|
#define QUANT_R_MMQ 2
|
||||||
|
struct block_a_cache {
|
||||||
|
uint32_t qs[4];
|
||||||
|
FLOAT_TYPE_VEC2 dm;
|
||||||
|
};
|
||||||
|
#elif defined(DATA_A_Q5_K)
|
||||||
|
#define QUANT_R_MMQ 1
|
||||||
|
struct block_a_cache {
|
||||||
|
int32_t qs[8];
|
||||||
|
FLOAT_TYPE_VEC2 dm;
|
||||||
|
};
|
||||||
|
#elif defined(DATA_A_Q6_K)
|
||||||
|
#define QUANT_R_MMQ 1
|
||||||
|
struct block_a_cache {
|
||||||
|
int32_t qs[8];
|
||||||
|
FLOAT_TYPE_VEC2 d_scales;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct block_b_cache
|
||||||
|
{
|
||||||
|
int32_t qs[8];
|
||||||
|
FLOAT_TYPE_VEC2 ds;
|
||||||
|
};
|
||||||
@@ -10,6 +10,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
|||||||
layout (binding = 1) readonly buffer Y {int data_pos[];};
|
layout (binding = 1) readonly buffer Y {int data_pos[];};
|
||||||
layout (binding = 2) readonly buffer Z {float data_ff[];};
|
layout (binding = 2) readonly buffer Z {float data_ff[];};
|
||||||
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
|
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
layout (push_constant) uniform parameter {
|
||||||
uint ncols;
|
uint ncols;
|
||||||
@@ -27,6 +28,7 @@ layout (push_constant) uniform parameter {
|
|||||||
uint s2;
|
uint s2;
|
||||||
int sections[4];
|
int sections[4];
|
||||||
uint is_back;
|
uint is_back;
|
||||||
|
uint set_rows_stride;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
||||||
|
|||||||
@@ -16,12 +16,19 @@ void main() {
|
|||||||
const uint row_x = row_dst % ne1;
|
const uint row_x = row_dst % ne1;
|
||||||
const uint channel_x = row_dst / ne1;
|
const uint channel_x = row_dst / ne1;
|
||||||
|
|
||||||
const uint idst = row_dst*ne0 + i0/2;
|
uint idst = row_dst*ne0 + i0/2;
|
||||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||||
|
|
||||||
|
// Fusion optimization: ROPE + VIEW + SET_ROWS..
|
||||||
|
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
|
||||||
|
if (p.set_rows_stride != 0) {
|
||||||
|
idst = row_x*ne0 + i0/2;
|
||||||
|
idst += data_i[channel_x].x * p.set_rows_stride;
|
||||||
|
}
|
||||||
|
|
||||||
if (i0 >= p.n_dims) {
|
if (i0 >= p.n_dims) {
|
||||||
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]);
|
||||||
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]);
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,12 +16,19 @@ void main() {
|
|||||||
const uint row_x = row_dst % ne1;
|
const uint row_x = row_dst % ne1;
|
||||||
const uint channel_x = row_dst / ne1;
|
const uint channel_x = row_dst / ne1;
|
||||||
|
|
||||||
const uint idst = row_dst*ne0 + i0;
|
uint idst = row_dst*ne0 + i0;
|
||||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
|
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
|
||||||
|
|
||||||
|
// Fusion optimization: ROPE + VIEW + SET_ROWS..
|
||||||
|
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
|
||||||
|
if (p.set_rows_stride != 0) {
|
||||||
|
idst = row_x*ne0 + i0;
|
||||||
|
idst += data_i[channel_x].x * p.set_rows_stride;
|
||||||
|
}
|
||||||
|
|
||||||
if (i0 >= p.n_dims) {
|
if (i0 >= p.n_dims) {
|
||||||
data_d[idst + 0] = data_a[ix + 0];
|
data_d[idst + 0] = D_TYPE(data_a[ix + 0]);
|
||||||
data_d[idst + 1] = data_a[ix + 1];
|
data_d[idst + 1] = D_TYPE(data_a[ix + 1]);
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||