Merge branch 'ollama:main' into main

This commit is contained in:
likelovewant
2025-11-14 16:46:05 +08:00
committed by GitHub
126 changed files with 10858 additions and 22108 deletions

View File

@@ -104,6 +104,13 @@ jobs:
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2'
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
runner_dir: 'rocm'
- os: windows
arch: amd64
preset: Vulkan
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
flags: ''
runner_dir: 'vulkan'
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
env:
@@ -113,13 +120,14 @@ jobs:
run: |
choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ')
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
id: cache-install
uses: actions/cache/restore@v4
with:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- if: startsWith(matrix.preset, 'CUDA ')
name: Install CUDA ${{ matrix.cuda-version }}
@@ -149,6 +157,18 @@ jobs:
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'Vulkan'
name: Install Vulkan ${{ matrix.rocm-version }}
run: |
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
}
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
- if: matrix.preset == 'CPU'
run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
@@ -159,6 +179,7 @@ jobs:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
@@ -171,7 +192,7 @@ jobs:
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
env:
CMAKE_GENERATOR: Ninja
@@ -312,13 +333,13 @@ jobs:
include:
- os: linux
arch: amd64
target: archive_novulkan
target: archive
- os: linux
arch: amd64
target: rocm
- os: linux
arch: arm64
target: archive_novulkan
target: archive
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: setup-environment
@@ -374,14 +395,12 @@ jobs:
include:
- os: linux
arch: arm64
target: novulkan
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
- os: linux
arch: amd64
target: novulkan
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
@@ -394,14 +413,6 @@ jobs:
CGO_CXXFLAGS
GOFLAGS
FLAVOR=rocm
- os: linux
arch: amd64
suffix: '-vulkan'
target: default
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: setup-environment
@@ -419,7 +430,6 @@ jobs:
with:
context: .
platforms: ${{ matrix.os }}/${{ matrix.arch }}
target: ${{ matrix.preset }}
build-args: ${{ matrix.build-args }}
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest

View File

@@ -172,6 +172,7 @@ jobs:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4

View File

@@ -159,32 +159,7 @@ ARG VULKANVERSION
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
# Temporary opt-out stages for Vulkan
FROM --platform=linux/amd64 scratch AS amd64_novulkan
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
FROM arm64 AS arm64_novulkan
FROM ${FLAVOR}_novulkan AS archive_novulkan
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04 AS novulkan
RUN apt-get update \
&& apt-get install -y ca-certificates \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive_novulkan /bin /usr/bin
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
COPY --from=archive_novulkan /lib/ollama /usr/lib/ollama
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENV NVIDIA_VISIBLE_DEVICES=all
ENV OLLAMA_HOST=0.0.0.0:11434
EXPOSE 11434
ENTRYPOINT ["/bin/ollama"]
CMD ["serve"]
FROM ubuntu:24.04 AS default
FROM ubuntu:24.04
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get clean \

View File

@@ -321,6 +321,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LibreChat](https://github.com/danny-avila/LibreChat)
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [AI-UI](https://github.com/bajahaw/ai-ui)
- [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
@@ -387,7 +388,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VS Code extension for multi-file/whole-repo coding
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
@@ -419,7 +420,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
- [Ollama Chat WebUI for Docker ](https://github.com/oslook/ollama-webui) (Support for local docker deployment, lightweight ollama webui)
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VSCode extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VS Code extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
- [MinimalNextOllamaChat](https://github.com/anilkay/MinimalNextOllamaChat) (Minimal Web UI for Chat and Model Control)
- [Chipper](https://github.com/TilmanGriesel/chipper) AI interface for tinkerers (Ollama, Haystack RAG, Python)
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
@@ -662,5 +663,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
## Security
### Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

View File

@@ -117,6 +117,14 @@ type GenerateRequest struct {
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
// Logprobs specifies whether to return log probabilities of the output tokens.
Logprobs bool `json:"logprobs,omitempty"`
// TopLogprobs is the number of most likely tokens to return at each token position,
// each with an associated log probability. Only applies when Logprobs is true.
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
TopLogprobs int `json:"top_logprobs,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -159,6 +167,14 @@ type ChatRequest struct {
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
// Logprobs specifies whether to return log probabilities of the output tokens.
Logprobs bool `json:"logprobs,omitempty"`
// TopLogprobs is the number of most likely tokens to return at each token position,
// each with an associated log probability. Only applies when Logprobs is true.
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
TopLogprobs int `json:"top_logprobs,omitempty"`
}
type Tools []Tool
@@ -343,6 +359,27 @@ func (t *ToolFunction) String() string {
return string(bts)
}
// TokenLogprob represents log probability information for a single token alternative.
type TokenLogprob struct {
// Token is the text representation of the token.
Token string `json:"token"`
// Logprob is the log probability of this token.
Logprob float64 `json:"logprob"`
// Bytes contains the raw byte representation of the token
Bytes []int `json:"bytes,omitempty"`
}
// Logprob contains log probability information for a generated token.
type Logprob struct {
TokenLogprob
// TopLogprobs contains the most likely tokens and their log probabilities
// at this position, if requested via TopLogprobs parameter.
TopLogprobs []TokenLogprob `json:"top_logprobs,omitempty"`
}
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {
@@ -369,6 +406,10 @@ type ChatResponse struct {
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
// Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"`
Metrics
}
@@ -677,6 +718,10 @@ type GenerateResponse struct {
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
// Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"`
}
// ModelDetails provides details about a model.

View File

@@ -48,16 +48,6 @@ The `-dev` flag enables:
- CORS headers for cross-origin requests
- Hot-reload support for UI development
#### Run Storybook
Inside the `ui/app` directory, run:
```bash
npm run storybook
```
For now we're writing stories as siblings of the component they're testing. So for example, `src/components/Message.stories.tsx` is the story for `src/components/Message.tsx`.
## Build

File diff suppressed because it is too large Load Diff

View File

@@ -34,6 +34,7 @@
"rehype-raw": "^7.0.0",
"rehype-sanitize": "^6.0.0",
"remark-math": "^6.0.0",
"streamdown": "^1.4.0",
"unist-builder": "^4.0.0",
"unist-util-parents": "^3.0.0"
},

View File

@@ -205,6 +205,13 @@ export async function* sendMessage(
data: uint8ArrayToBase64(att.data),
}));
// Only send think parameter when actually requesting thinking
// Don't send false as it causes issues with some providers
const shouldSendThink =
think !== undefined &&
((typeof think === "boolean" && think) ||
(typeof think === "string" && think !== ""));
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
method: "POST",
headers: {
@@ -222,7 +229,7 @@ export async function* sendMessage(
web_search: webSearch ?? false,
file_tools: fileTools ?? false,
...(forceUpdate !== undefined ? { forceUpdate } : {}),
...(think !== undefined ? { think } : {}),
...(shouldSendThink ? { think } : {}),
}),
),
signal,

File diff suppressed because one or more lines are too long

View File

@@ -1,522 +0,0 @@
import { expect, test, suite } from "vitest";
import { processStreamingMarkdown } from "@/utils/processStreamingMarkdown";
suite("common llm outputs that cause issues", () => {
test("prefix of bolded list item shouldn't make a horizontal line", () => {
// we're going to go in order of incrementally adding characters. This
// happens really commonly with LLMs that like to make lists like so:
//
// * **point 1**: explanatory text
// * **point 2**: more explanatory text
//
// Partial rendering of `*` (A), followed by `* *` (B), followed by `* **`
// (C) is a total mess. (A) renders as a single bullet point in an
// otherwise empty list, (B) renders as two nested lists (and therefore
// two bullet points, styled differently by default in html), and (C)
// renders as a horizontal line because in markdown apparently `***` or `*
// * *` horizontal rules don't have as strict whitespace rules as I
// expected them to
// these are alone (i.e., they would be the first list item)
expect(processStreamingMarkdown("*")).toBe("");
expect(processStreamingMarkdown("* *")).toBe("");
expect(processStreamingMarkdown("* **")).toBe("");
// expect(processStreamingMarkdown("* **b")).toBe("* **b**");
// with a list item before them
expect(
processStreamingMarkdown(
// prettier-ignore
[
"* abc",
"*"
].join("\n"),
),
).toBe("* abc");
expect(
processStreamingMarkdown(
// prettier-ignore
[
"* abc",
"* *"
].join("\n"),
),
).toBe("* abc");
expect(
processStreamingMarkdown(
// prettier-ignore
[
"* abc",
"* **"
].join("\n"),
),
).toBe("* abc");
});
test("bolded list items with text should be rendered properly", () => {
expect(processStreamingMarkdown("* **abc**")).toBe("* **abc**");
});
test("partially bolded list items should be autoclosed", () => {
expect(processStreamingMarkdown("* **abc")).toBe("* **abc**");
});
suite(
"partially bolded list items should be autoclosed, even if the last node isn't a text node",
() => {
test("inline code", () => {
expect(
processStreamingMarkdown("* **Asynchronous Function `async`*"),
).toBe("* **Asynchronous Function `async`**");
});
},
);
});
suite("autoclosing bold", () => {
suite("endings with no asterisks", () => {
test("should autoclose bold", () => {
expect(processStreamingMarkdown("**abc")).toBe("**abc**");
expect(processStreamingMarkdown("abc **abc")).toBe("abc **abc**");
});
suite("should autoclose, even if the last node isn't a text node", () => {
test("inline code", () => {
expect(
processStreamingMarkdown("* **Asynchronous Function `async`"),
).toBe("* **Asynchronous Function `async`**");
});
test("opening ** is at the end of the text", () => {
expect(processStreamingMarkdown("abc **`def` jhk [lmn](opq)")).toBe(
"abc **`def` jhk [lmn](opq)**",
);
});
test("if there's a space after the **, it should NOT be autoclosed", () => {
expect(processStreamingMarkdown("abc ** `def` jhk [lmn](opq)")).toBe(
"abc \\*\\* `def` jhk [lmn](opq)",
);
});
});
test("should autoclose bold, even if the last node isn't a text node", () => {
expect(
processStreamingMarkdown("* **Asynchronous Function ( `async`"),
).toBe("* **Asynchronous Function ( `async`**");
});
test("whitespace fakeouts should not be modified", () => {
expect(processStreamingMarkdown("** abc")).toBe("\\*\\* abc");
});
// TODO(drifkin): arguably this should just be removed entirely, but empty
// isn't so bad
test("should handle empty bolded items", () => {
expect(processStreamingMarkdown("**")).toBe("");
});
});
suite("partially closed bolded items", () => {
test("simple partial", () => {
expect(processStreamingMarkdown("**abc*")).toBe("**abc**");
});
test("partial with non-text node at end", () => {
expect(processStreamingMarkdown("**abc`def`*")).toBe("**abc`def`**");
});
test("partial with multiply nested ending nodes", () => {
expect(processStreamingMarkdown("**abc[abc](`def`)*")).toBe(
"**abc[abc](`def`)**",
);
});
test("normal emphasis should not be affected", () => {
expect(processStreamingMarkdown("*abc*")).toBe("*abc*");
});
test("normal emphasis with nested code should not be affected", () => {
expect(processStreamingMarkdown("*`abc`*")).toBe("*`abc`*");
});
});
test.skip("shouldn't autoclose immediately if there's a space before the closing *", () => {
expect(processStreamingMarkdown("**abc *")).toBe("**abc**");
});
// skipping for now because this requires partial link completion as well
suite.skip("nested blocks that each need autoclosing", () => {
test("emph nested in link nested in strong nested in list item", () => {
expect(processStreamingMarkdown("* **[abc **def")).toBe(
"* **[abc **def**]()**",
);
});
test("* **[ab *`def`", () => {
expect(processStreamingMarkdown("* **[ab *`def`")).toBe(
"* **[ab *`def`*]()**",
);
});
});
});
suite("numbered list items", () => {
test("should remove trailing numbers", () => {
expect(processStreamingMarkdown("1. First\n2")).toBe("1. First");
});
test("should remove trailing numbers with breaks before", () => {
expect(processStreamingMarkdown("1. First \n2")).toBe("1. First");
});
test("should remove trailing numbers that form a new paragraph", () => {
expect(processStreamingMarkdown("1. First\n\n2")).toBe("1. First");
});
test("but should leave list items separated by two newlines", () => {
expect(processStreamingMarkdown("1. First\n\n2. S")).toBe(
"1. First\n\n2. S",
);
});
});
// TODO(drifkin):slop tests ahead, some are decent, but need to manually go
// through them as I implement
/*
describe("StreamingMarkdownContent - processStreamingMarkdown", () => {
describe("Ambiguous endings removal", () => {
it("should remove list markers at the end", () => {
expect(processStreamingMarkdown("Some text\n* ")).toBe("Some text");
expect(processStreamingMarkdown("Some text\n*")).toBe("Some text");
expect(processStreamingMarkdown("* Item 1\n- ")).toBe("* Item 1");
expect(processStreamingMarkdown("* Item 1\n-")).toBe("* Item 1");
expect(processStreamingMarkdown("Text\n+ ")).toBe("Text");
expect(processStreamingMarkdown("Text\n+")).toBe("Text");
expect(processStreamingMarkdown("1. First\n2. ")).toBe("1. First");
});
it("should remove heading markers at the end", () => {
expect(processStreamingMarkdown("Some text\n# ")).toBe("Some text");
expect(processStreamingMarkdown("Some text\n#")).toBe("Some text\n#"); // # without space is not removed
expect(processStreamingMarkdown("# Title\n## ")).toBe("# Title");
expect(processStreamingMarkdown("# Title\n##")).toBe("# Title\n##"); // ## without space is not removed
});
it("should remove ambiguous bold markers at the end", () => {
expect(processStreamingMarkdown("Text **")).toBe("Text ");
expect(processStreamingMarkdown("Some text\n**")).toBe("Some text");
});
it("should remove code block markers at the end", () => {
expect(processStreamingMarkdown("Text\n```")).toBe("Text");
expect(processStreamingMarkdown("```")).toBe("");
});
it("should remove single backtick at the end", () => {
expect(processStreamingMarkdown("Text `")).toBe("Text ");
expect(processStreamingMarkdown("`")).toBe("");
});
it("should remove single asterisk at the end", () => {
expect(processStreamingMarkdown("Text *")).toBe("Text ");
expect(processStreamingMarkdown("*")).toBe("");
});
it("should handle empty content", () => {
expect(processStreamingMarkdown("")).toBe("");
});
it("should handle single line removals correctly", () => {
expect(processStreamingMarkdown("* ")).toBe("");
expect(processStreamingMarkdown("# ")).toBe("");
expect(processStreamingMarkdown("**")).toBe("");
expect(processStreamingMarkdown("`")).toBe("");
});
it("shouldn't have this regexp capture group bug", () => {
expect(
processStreamingMarkdown("Here's a shopping list:\n*"),
).not.toContain("0*");
expect(processStreamingMarkdown("Here's a shopping list:\n*")).toBe(
"Here's a shopping list:",
);
});
});
describe("List markers", () => {
it("should preserve complete list items", () => {
expect(processStreamingMarkdown("* Complete item")).toBe(
"* Complete item",
);
expect(processStreamingMarkdown("- Another item")).toBe("- Another item");
expect(processStreamingMarkdown("+ Plus item")).toBe("+ Plus item");
expect(processStreamingMarkdown("1. Numbered item")).toBe(
"1. Numbered item",
);
});
it("should handle indented list markers", () => {
expect(processStreamingMarkdown(" * ")).toBe(" ");
expect(processStreamingMarkdown(" - ")).toBe(" ");
expect(processStreamingMarkdown("\t+ ")).toBe("\t");
});
});
describe("Heading markers", () => {
it("should preserve complete headings", () => {
expect(processStreamingMarkdown("# Complete Heading")).toBe(
"# Complete Heading",
);
expect(processStreamingMarkdown("## Subheading")).toBe("## Subheading");
expect(processStreamingMarkdown("### H3 Title")).toBe("### H3 Title");
});
it("should not affect # in other contexts", () => {
expect(processStreamingMarkdown("C# programming")).toBe("C# programming");
expect(processStreamingMarkdown("Issue #123")).toBe("Issue #123");
});
});
describe("Bold text", () => {
it("should close incomplete bold text", () => {
expect(processStreamingMarkdown("This is **bold text")).toBe(
"This is **bold text**",
);
expect(processStreamingMarkdown("Start **bold and more")).toBe(
"Start **bold and more**",
);
expect(processStreamingMarkdown("**just bold")).toBe("**just bold**");
});
it("should not affect complete bold text", () => {
expect(processStreamingMarkdown("**complete bold**")).toBe(
"**complete bold**",
);
expect(processStreamingMarkdown("Text **bold** more")).toBe(
"Text **bold** more",
);
});
it("should handle nested bold correctly", () => {
expect(processStreamingMarkdown("**bold** and **another")).toBe(
"**bold** and **another**",
);
});
});
describe("Italic text", () => {
it("should close incomplete italic text", () => {
expect(processStreamingMarkdown("This is *italic text")).toBe(
"This is *italic text*",
);
expect(processStreamingMarkdown("Start *italic and more")).toBe(
"Start *italic and more*",
);
});
it("should differentiate between list markers and italic", () => {
expect(processStreamingMarkdown("* Item\n* ")).toBe("* Item");
expect(processStreamingMarkdown("Some *italic text")).toBe(
"Some *italic text*",
);
expect(processStreamingMarkdown("*just italic")).toBe("*just italic*");
});
it("should not affect complete italic text", () => {
expect(processStreamingMarkdown("*complete italic*")).toBe(
"*complete italic*",
);
expect(processStreamingMarkdown("Text *italic* more")).toBe(
"Text *italic* more",
);
});
});
describe("Code blocks", () => {
it("should close incomplete code blocks", () => {
expect(processStreamingMarkdown("```javascript\nconst x = 42;")).toBe(
"```javascript\nconst x = 42;\n```",
);
expect(processStreamingMarkdown("```\ncode here")).toBe(
"```\ncode here\n```",
);
});
it("should not affect complete code blocks", () => {
expect(processStreamingMarkdown("```\ncode\n```")).toBe("```\ncode\n```");
expect(processStreamingMarkdown("```js\nconst x = 1;\n```")).toBe(
"```js\nconst x = 1;\n```",
);
});
it("should handle nested code blocks correctly", () => {
expect(processStreamingMarkdown("```\ncode\n```\n```python")).toBe(
"```\ncode\n```\n```python\n```",
);
});
it("should not process markdown inside code blocks", () => {
expect(processStreamingMarkdown("```\n* not a list\n**not bold**")).toBe(
"```\n* not a list\n**not bold**\n```",
);
});
});
describe("Inline code", () => {
it("should close incomplete inline code", () => {
expect(processStreamingMarkdown("This is `inline code")).toBe(
"This is `inline code`",
);
expect(processStreamingMarkdown("Use `console.log")).toBe(
"Use `console.log`",
);
});
it("should not affect complete inline code", () => {
expect(processStreamingMarkdown("`complete code`")).toBe(
"`complete code`",
);
expect(processStreamingMarkdown("Use `code` here")).toBe(
"Use `code` here",
);
});
it("should handle multiple inline codes correctly", () => {
expect(processStreamingMarkdown("`code` and `more")).toBe(
"`code` and `more`",
);
});
it("should not confuse inline code with code blocks", () => {
expect(processStreamingMarkdown("```\nblock\n```\n`inline")).toBe(
"```\nblock\n```\n`inline`",
);
});
});
describe("Complex streaming scenarios", () => {
it("should handle progressive streaming of a heading", () => {
const steps = [
{ input: "#", expected: "#" }, // # alone is not removed (needs space)
{ input: "# ", expected: "" },
{ input: "# H", expected: "# H" },
{ input: "# Hello", expected: "# Hello" },
];
steps.forEach(({ input, expected }) => {
expect(processStreamingMarkdown(input)).toBe(expected);
});
});
it("should handle progressive streaming of bold text", () => {
const steps = [
{ input: "*", expected: "" },
{ input: "**", expected: "" },
{ input: "**b", expected: "**b**" },
{ input: "**bold", expected: "**bold**" },
{ input: "**bold**", expected: "**bold**" },
];
steps.forEach(({ input, expected }) => {
expect(processStreamingMarkdown(input)).toBe(expected);
});
});
it("should handle multiline content with various patterns", () => {
const multiline = `# Title
This is a paragraph with **bold text** and *italic text*.
* Item 1
* Item 2
* `;
const expected = `# Title
This is a paragraph with **bold text** and *italic text*.
* Item 1
* Item 2`;
expect(processStreamingMarkdown(multiline)).toBe(expected);
});
it("should only fix the last line", () => {
expect(processStreamingMarkdown("# Complete\n# Another\n# ")).toBe(
"# Complete\n# Another",
);
expect(processStreamingMarkdown("* Item 1\n* Item 2\n* ")).toBe(
"* Item 1\n* Item 2",
);
});
it("should handle mixed content correctly", () => {
const input = `# Header
This has **bold** text and *italic* text.
\`\`\`js
const x = 42;
\`\`\`
Now some \`inline code\` and **unclosed bold`;
const expected = `# Header
This has **bold** text and *italic* text.
\`\`\`js
const x = 42;
\`\`\`
Now some \`inline code\` and **unclosed bold**`;
expect(processStreamingMarkdown(input)).toBe(expected);
});
});
describe("Edge cases with escaping", () => {
it("should handle escaped asterisks (future enhancement)", () => {
// Note: Current implementation doesn't handle escaping
// This is a known limitation - escaped characters still trigger closing
expect(processStreamingMarkdown("Text \\*not italic")).toBe(
"Text \\*not italic*",
);
});
it("should handle escaped backticks (future enhancement)", () => {
// Note: Current implementation doesn't handle escaping
// This is a known limitation - escaped characters still trigger closing
expect(processStreamingMarkdown("Text \\`not code")).toBe(
"Text \\`not code`",
);
});
});
describe("Code block edge cases", () => {
it("should handle triple backticks in the middle of lines", () => {
expect(processStreamingMarkdown("Text ``` in middle")).toBe(
"Text ``` in middle\n```",
);
expect(processStreamingMarkdown("```\nText ``` in code\nmore")).toBe(
"```\nText ``` in code\nmore\n```",
);
});
it("should properly close code blocks with language specifiers", () => {
expect(processStreamingMarkdown("```typescript")).toBe(
"```typescript\n```",
);
expect(processStreamingMarkdown("```typescript\nconst x = 1")).toBe(
"```typescript\nconst x = 1\n```",
);
});
it("should remove a completely empty partial code block", () => {
expect(processStreamingMarkdown("```\n")).toBe("");
});
});
});
*/

View File

@@ -1,66 +1,123 @@
import React from "react";
import Markdown from "react-markdown";
import remarkGfm from "remark-gfm";
import remarkMath from "remark-math";
import rehypeRaw from "rehype-raw";
import rehypeSanitize, { defaultSchema } from "rehype-sanitize";
import rehypePrismPlus from "rehype-prism-plus";
import rehypeKatex from "rehype-katex";
import remarkStreamingMarkdown, {
type LastNodeInfo,
} from "@/utils/remarkStreamingMarkdown";
import type { PluggableList } from "unified";
import { Streamdown, defaultRemarkPlugins } from "streamdown";
import remarkCitationParser from "@/utils/remarkCitationParser";
import CopyButton from "./CopyButton";
import type { BundledLanguage } from "shiki";
import { highlighter } from "@/lib/highlighter";
interface StreamingMarkdownContentProps {
content: string;
isStreaming?: boolean;
size?: "sm" | "md" | "lg";
onLastNode?: (info: LastNodeInfo) => void;
browserToolResult?: any; // TODO: proper type
}
// Helper to extract text from React nodes
const extractText = (node: React.ReactNode): string => {
if (typeof node === "string") return node;
if (typeof node === "number") return String(node);
if (!node) return "";
if (React.isValidElement(node)) {
const props = node.props as any;
if (props?.children) {
return extractText(props.children as React.ReactNode);
}
}
if (Array.isArray(node)) {
return node.map(extractText).join("");
}
return "";
};
const CodeBlock = React.memo(
({ children, className, ...props }: React.HTMLAttributes<HTMLPreElement>) => {
const extractText = React.useCallback((node: React.ReactNode): string => {
if (typeof node === "string") return node;
if (typeof node === "number") return String(node);
if (!node) return "";
({ children }: React.HTMLAttributes<HTMLPreElement>) => {
// Extract code and language from children
const codeElement = children as React.ReactElement<{
className?: string;
children: React.ReactNode;
}>;
const language =
codeElement.props.className?.replace(/language-/, "") || "";
const codeText = extractText(codeElement.props.children);
if (React.isValidElement(node)) {
if (
node.props &&
typeof node.props === "object" &&
"children" in node.props
) {
return extractText(node.props.children as React.ReactNode);
}
// Synchronously highlight code using the pre-loaded highlighter
const tokens = React.useMemo(() => {
if (!highlighter) return null;
try {
return {
light: highlighter.codeToTokensBase(codeText, {
lang: language as BundledLanguage,
theme: "one-light" as any,
}),
dark: highlighter.codeToTokensBase(codeText, {
lang: language as BundledLanguage,
theme: "one-dark" as any,
}),
};
} catch (error) {
console.error("Failed to highlight code:", error);
return null;
}
if (Array.isArray(node)) {
return node.map(extractText).join("");
}
return "";
}, []);
const language = className?.replace(/language-/, "") || "";
}, [codeText, language]);
return (
<div className="relative bg-neutral-100 dark:bg-neutral-800 rounded-2xl overflow-hidden my-6">
<div className="flex justify-between select-none">
<div className="text-[13px] text-neutral-500 dark:text-neutral-400 font-mono px-4 py-2">
{language}
</div>
<div className="flex select-none">
{language && (
<div className="text-[13px] text-neutral-500 dark:text-neutral-400 font-mono px-4 py-2">
{language}
</div>
)}
<CopyButton
content={extractText(children)}
content={codeText}
showLabels={true}
className="copy-button text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800"
className="copy-button text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800 ml-auto"
/>
</div>
<pre className={className} {...props}>
{children}
{/* Light mode */}
<pre className="dark:hidden m-0 bg-neutral-100 text-sm overflow-x-auto p-4">
<code className="font-mono text-sm">
{tokens?.light
? tokens.light.map((line: any, i: number) => (
<React.Fragment key={i}>
{line.map((token: any, j: number) => (
<span
key={j}
style={{
color: token.color,
}}
>
{token.content}
</span>
))}
{i < tokens.light.length - 1 && "\n"}
</React.Fragment>
))
: codeText}
</code>
</pre>
{/* Dark mode */}
<pre className="hidden dark:block m-0 bg-neutral-800 text-sm overflow-x-auto p-4">
<code className="font-mono text-sm">
{tokens?.dark
? tokens.dark.map((line: any, i: number) => (
<React.Fragment key={i}>
{line.map((token: any, j: number) => (
<span
key={j}
style={{
color: token.color,
}}
>
{token.content}
</span>
))}
{i < tokens.dark.length - 1 && "\n"}
</React.Fragment>
))
: codeText}
</code>
</pre>
</div>
);
@@ -68,65 +125,19 @@ const CodeBlock = React.memo(
);
const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
React.memo(
({ content, isStreaming = false, size, onLastNode, browserToolResult }) => {
// Build the remark plugins array
const remarkPlugins = React.useMemo(() => {
const plugins: PluggableList = [
remarkGfm,
[remarkMath, { singleDollarTextMath: false }],
remarkCitationParser,
];
React.memo(({ content, isStreaming = false, size, browserToolResult }) => {
// Build the remark plugins array - keep default GFM and Math, add citations
const remarkPlugins = React.useMemo(() => {
return [
defaultRemarkPlugins.gfm,
defaultRemarkPlugins.math,
remarkCitationParser,
];
}, []);
// Add streaming plugin when in streaming mode
if (isStreaming) {
plugins.push([remarkStreamingMarkdown, { debug: true, onLastNode }]);
}
return plugins;
}, [isStreaming, onLastNode]);
// Create a custom sanitization schema that allows math elements
const sanitizeSchema = React.useMemo(() => {
return {
...defaultSchema,
attributes: {
...defaultSchema.attributes,
span: [
...(defaultSchema.attributes?.span || []),
["className", /^katex/],
],
div: [
...(defaultSchema.attributes?.div || []),
["className", /^katex/],
],
"ol-citation": ["cursor", "start", "end"],
},
tagNames: [
...(defaultSchema.tagNames || []),
"math",
"mrow",
"mi",
"mo",
"mn",
"msup",
"msub",
"mfrac",
"mover",
"munder",
"msqrt",
"mroot",
"merror",
"mspace",
"mpadded",
"ol-citation",
],
};
}, []);
return (
<div
className={`
return (
<div
className={`
max-w-full
${size === "sm" ? "prose-sm" : size === "lg" ? "prose-lg" : ""}
prose
@@ -144,7 +155,27 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
prose-pre:my-0
prose-pre:max-w-full
prose-pre:pt-1
[&_code:not(pre_code)]:text-neutral-700
[&_table]:border-collapse
[&_table]:w-full
[&_table]:border
[&_table]:border-neutral-200
[&_table]:rounded-lg
[&_table]:overflow-hidden
[&_th]:px-3
[&_th]:py-2
[&_th]:text-left
[&_th]:font-semibold
[&_th]:border-b
[&_th]:border-r
[&_th]:border-neutral-200
[&_th:last-child]:border-r-0
[&_td]:px-3
[&_td]:py-2
[&_td]:border-r
[&_td]:border-neutral-200
[&_td:last-child]:border-r-0
[&_tbody_tr:not(:last-child)_td]:border-b
[&_code:not(pre_code)]:text-neutral-700
[&_code:not(pre_code)]:bg-neutral-100
[&_code:not(pre_code)]:font-normal
[&_code:not(pre_code)]:px-1.5
@@ -160,6 +191,10 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
dark:prose-strong:text-neutral-200
dark:prose-pre:text-neutral-200
dark:prose:pre:text-neutral-200
dark:[&_table]:border-neutral-700
dark:[&_thead]:bg-neutral-800
dark:[&_th]:border-neutral-700
dark:[&_td]:border-neutral-700
dark:[&_code:not(pre_code)]:text-neutral-200
dark:[&_code:not(pre_code)]:bg-neutral-800
dark:[&_code:not(pre_code)]:font-normal
@@ -167,104 +202,86 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
dark:prose-li:marker:text-neutral-300
break-words
`}
>
<StreamingMarkdownErrorBoundary
content={content}
isStreaming={isStreaming}
>
<StreamingMarkdownErrorBoundary
content={content}
isStreaming={isStreaming}
>
<Markdown
remarkPlugins={remarkPlugins}
rehypePlugins={
[
[rehypeRaw, { allowDangerousHtml: true }],
[rehypeSanitize, sanitizeSchema],
[rehypePrismPlus, { ignoreMissing: true }],
[
rehypeKatex,
{
errorColor: "#000000", // Black instead of red for errors
strict: false, // Be more lenient with parsing
throwOnError: false,
},
],
] as PluggableList
}
components={{
pre: CodeBlock,
table: ({
children,
...props
}: React.HTMLAttributes<HTMLTableElement>) => (
<div className="overflow-x-auto max-w-full">
<table {...props}>{children}</table>
</div>
),
// @ts-expect-error: custom type
"ol-citation": ({
cursor,
// start,
// end,
}: {
cursor: number;
start: number;
end: number;
}) => {
// Check if we have a page_stack and if the cursor is valid
const pageStack = browserToolResult?.page_stack;
const hasValidPage = pageStack && cursor < pageStack.length;
const pageUrl = hasValidPage ? pageStack[cursor] : null;
<Streamdown
parseIncompleteMarkdown={isStreaming}
isAnimating={isStreaming}
remarkPlugins={remarkPlugins}
controls={false}
components={{
pre: CodeBlock,
table: ({
children,
...props
}: React.HTMLAttributes<HTMLTableElement>) => (
<div className="overflow-x-auto max-w-full">
<table
{...props}
className="border-collapse w-full border border-neutral-200 dark:border-neutral-700 rounded-lg overflow-hidden"
>
{children}
</table>
</div>
),
// @ts-expect-error: custom citation type
"ol-citation": ({
cursor,
}: {
cursor: number;
start: number;
end: number;
}) => {
const pageStack = browserToolResult?.page_stack;
const hasValidPage = pageStack && cursor < pageStack.length;
const pageUrl = hasValidPage ? pageStack[cursor] : null;
// Extract a readable title from the URL if possible
const getPageTitle = (url: string) => {
if (url.startsWith("search_results_")) {
const searchTerm = url.substring(
"search_results_".length,
);
return `Search: ${searchTerm}`;
}
// For regular URLs, try to extract domain or use full URL
try {
const urlObj = new URL(url);
return urlObj.hostname;
} catch {
// If not a valid URL, return as is
return url;
}
};
const citationElement = (
<span className="text-xs text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800 rounded-full px-2 py-1 ml-1">
[{cursor}]
</span>
);
// If we have a valid page URL, wrap in a link
if (pageUrl && pageUrl.startsWith("http")) {
return (
<a
href={pageUrl}
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center hover:opacity-80 transition-opacity no-underline"
title={getPageTitle(pageUrl)}
>
{citationElement}
</a>
);
const getPageTitle = (url: string) => {
if (url.startsWith("search_results_")) {
const searchTerm = url.substring("search_results_".length);
return `Search: ${searchTerm}`;
}
try {
const urlObj = new URL(url);
return urlObj.hostname;
} catch {
return url;
}
};
// Otherwise, just return the citation without a link
return citationElement;
},
}}
>
{content}
</Markdown>
</StreamingMarkdownErrorBoundary>
</div>
);
},
);
const citationElement = (
<span className="text-xs text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800 rounded-full px-2 py-1 ml-1">
[{cursor}]
</span>
);
if (pageUrl && pageUrl.startsWith("http")) {
return (
<a
href={pageUrl}
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center hover:opacity-80 transition-opacity no-underline"
title={getPageTitle(pageUrl)}
>
{citationElement}
</a>
);
}
return citationElement;
},
}}
>
{content}
</Streamdown>
</StreamingMarkdownErrorBoundary>
</div>
);
});
interface StreamingMarkdownErrorBoundaryProps {
content: string;

View File

@@ -73,8 +73,9 @@ export default function Thinking({
// Calculate max height for smooth animations
const getMaxHeight = () => {
if (isCollapsed) {
return finishedThinking ? "0px" : "12rem"; // 8rem = 128px (same as max-h-32)
return finishedThinking ? "0px" : "12rem";
}
// When expanded, use the content height or grow naturally
return contentHeight ? `${contentHeight}px` : "none";
};
@@ -131,10 +132,11 @@ export default function Thinking({
</div>
<div
ref={wrapperRef}
className={`text-xs text-neutral-500 dark:text-neutral-500 rounded-md overflow-hidden
transition-[max-height,opacity] duration-300 ease-in-out relative ml-6 mt-2`}
className={`text-xs text-neutral-500 dark:text-neutral-500 rounded-md
transition-[max-height,opacity] duration-300 ease-in-out relative ml-6 mt-2
${isCollapsed ? "overflow-hidden" : "overflow-y-auto"}`}
style={{
maxHeight: getMaxHeight(),
maxHeight: isCollapsed ? getMaxHeight() : undefined,
opacity: isCollapsed && finishedThinking ? 0 : 1,
}}
>

View File

@@ -16,793 +16,6 @@
--text-color: #ffffff;
}
}
@media (prefers-color-scheme: light) {
.prose {
/**
* One Light theme for prism.js
* Based on Atom's One Light theme: https://github.com/atom/atom/tree/master/packages/one-light-syntax
*/
/**
* One Light colours (accurate as of commit eb064bf on 19 Feb 2021)
* From colors.less
* --mono-1: hsl(230, 8%, 24%);
* --mono-2: hsl(230, 6%, 44%);
* --mono-3: hsl(230, 4%, 64%)
* --hue-1: hsl(198, 99%, 37%);
* --hue-2: hsl(221, 87%, 60%);
* --hue-3: hsl(301, 63%, 40%);
* --hue-4: hsl(119, 34%, 47%);
* --hue-5: hsl(5, 74%, 59%);
* --hue-5-2: hsl(344, 84%, 43%);
* --hue-6: hsl(35, 99%, 36%);
* --hue-6-2: hsl(35, 99%, 40%);
* --syntax-fg: hsl(230, 8%, 24%);
* --syntax-bg: hsl(230, 1%, 98%);
* --syntax-gutter: hsl(230, 1%, 62%);
* --syntax-guide: hsla(230, 8%, 24%, 0.2);
* --syntax-accent: hsl(230, 100%, 66%);
* From syntax-variables.less
* --syntax-selection-color: hsl(230, 1%, 90%);
* --syntax-gutter-background-color-selected: hsl(230, 1%, 90%);
* --syntax-cursor-line: hsla(230, 8%, 24%, 0.05);
*/
.token.comment,
.token.prolog,
.token.cdata {
color: hsl(230, 4%, 64%);
}
.token.doctype,
.token.punctuation,
.token.entity {
color: hsl(230, 8%, 24%);
}
.token.attr-name,
.token.class-name,
.token.boolean,
.token.constant,
.token.number,
.token.atrule {
color: hsl(35, 99%, 36%);
}
.token.keyword {
color: hsl(301, 63%, 40%);
}
.token.property,
.token.tag,
.token.symbol,
.token.deleted,
.token.important {
color: hsl(5, 74%, 59%);
}
.token.selector,
.token.string,
.token.char,
.token.builtin,
.token.inserted,
.token.regex,
.token.attr-value,
.token.attr-value > .token.punctuation {
color: hsl(119, 34%, 47%);
}
.token.variable,
.token.operator,
.token.function {
color: hsl(221, 87%, 60%);
}
.token.url {
color: hsl(198, 99%, 37%);
}
/* HTML overrides */
.token.attr-value > .token.punctuation.attr-equals,
.token.special-attr > .token.attr-value > .token.value.css {
color: hsl(230, 8%, 24%);
}
/* CSS overrides */
.language-css .token.selector {
color: hsl(5, 74%, 59%);
}
.language-css .token.property {
color: hsl(230, 8%, 24%);
}
.language-css .token.function,
.language-css .token.url > .token.function {
color: hsl(198, 99%, 37%);
}
.language-css .token.url > .token.string.url {
color: hsl(119, 34%, 47%);
}
.language-css .token.important,
.language-css .token.atrule .token.rule {
color: hsl(301, 63%, 40%);
}
/* JS overrides */
.language-javascript .token.operator {
color: hsl(301, 63%, 40%);
}
.language-javascript
.token.template-string
> .token.interpolation
> .token.interpolation-punctuation.punctuation {
color: hsl(344, 84%, 43%);
}
/* JSON overrides */
.language-json .token.operator {
color: hsl(230, 8%, 24%);
}
.language-json .token.null.keyword {
color: hsl(35, 99%, 36%);
}
/* MD overrides */
.language-markdown .token.url,
.language-markdown .token.url > .token.operator,
.language-markdown .token.url-reference.url > .token.string {
color: hsl(230, 8%, 24%);
}
.language-markdown .token.url > .token.content {
color: hsl(221, 87%, 60%);
}
.language-markdown .token.url > .token.url,
.language-markdown .token.url-reference.url {
color: hsl(198, 99%, 37%);
}
.language-markdown .token.blockquote.punctuation,
.language-markdown .token.hr.punctuation {
color: hsl(230, 4%, 64%);
font-style: italic;
}
.language-markdown .token.code-snippet {
color: hsl(119, 34%, 47%);
}
.language-markdown .token.bold .token.content {
color: hsl(35, 99%, 36%);
}
.language-markdown .token.italic .token.content {
color: hsl(301, 63%, 40%);
}
.language-markdown .token.strike .token.content,
.language-markdown .token.strike .token.punctuation,
.language-markdown .token.list.punctuation,
.language-markdown .token.title.important > .token.punctuation {
color: hsl(5, 74%, 59%);
}
/* General */
.token.bold {
font-weight: bold;
}
.token.comment,
.token.italic {
font-style: italic;
}
.token.entity {
cursor: help;
}
.token.namespace {
opacity: 0.8;
}
/* Plugin overrides */
/* Selectors should have higher specificity than those in the plugins' default stylesheets */
/* Show Invisibles plugin overrides */
.token.token.tab:not(:empty):before,
.token.token.cr:before,
.token.token.lf:before,
.token.token.space:before {
color: hsla(230, 8%, 24%, 0.2);
}
/* Toolbar plugin overrides */
/* Space out all buttons and move them away from the right edge of the code block */
div.code-toolbar > .toolbar.toolbar > .toolbar-item {
margin-right: 0.4em;
}
/* Styling the buttons */
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span {
background: hsl(230, 1%, 90%);
color: hsl(230, 6%, 44%);
padding: 0.1em 0.4em;
border-radius: 0.3em;
}
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:focus,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:focus,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:focus {
background: hsl(230, 1%, 78%); /* custom: darken(--syntax-bg, 20%) */
color: hsl(230, 8%, 24%);
}
/* Line Highlight plugin overrides */
/* The highlighted line itself */
.line-highlight.line-highlight {
background: hsla(230, 8%, 24%, 0.05);
}
/* Default line numbers in Line Highlight plugin */
.line-highlight.line-highlight:before,
.line-highlight.line-highlight[data-end]:after {
background: hsl(230, 1%, 90%);
color: hsl(230, 8%, 24%);
padding: 0.1em 0.6em;
border-radius: 0.3em;
box-shadow: 0 2px 0 0 rgba(0, 0, 0, 0.2); /* same as Toolbar plugin default */
}
/* Hovering over a linkable line number (in the gutter area) */
/* Requires Line Numbers plugin as well */
pre[id].linkable-line-numbers.linkable-line-numbers
span.line-numbers-rows
> span:hover:before {
background-color: hsla(230, 8%, 24%, 0.05);
}
/* Line Numbers and Command Line plugins overrides */
/* Line separating gutter from coding area */
.line-numbers.line-numbers .line-numbers-rows,
.command-line .command-line-prompt {
border-right-color: hsla(230, 8%, 24%, 0.2);
}
/* Stuff in the gutter */
.line-numbers .line-numbers-rows > span:before,
.command-line .command-line-prompt > span:before {
color: hsl(230, 1%, 62%);
}
/* Match Braces plugin overrides */
/* Note: Outline colour is inherited from the braces */
.rainbow-braces .token.token.punctuation.brace-level-1,
.rainbow-braces .token.token.punctuation.brace-level-5,
.rainbow-braces .token.token.punctuation.brace-level-9 {
color: hsl(5, 74%, 59%);
}
.rainbow-braces .token.token.punctuation.brace-level-2,
.rainbow-braces .token.token.punctuation.brace-level-6,
.rainbow-braces .token.token.punctuation.brace-level-10 {
color: hsl(119, 34%, 47%);
}
.rainbow-braces .token.token.punctuation.brace-level-3,
.rainbow-braces .token.token.punctuation.brace-level-7,
.rainbow-braces .token.token.punctuation.brace-level-11 {
color: hsl(221, 87%, 60%);
}
.rainbow-braces .token.token.punctuation.brace-level-4,
.rainbow-braces .token.token.punctuation.brace-level-8,
.rainbow-braces .token.token.punctuation.brace-level-12 {
color: hsl(301, 63%, 40%);
}
/* Diff Highlight plugin overrides */
/* Taken from https://github.com/atom/github/blob/master/styles/variables.less */
pre.diff-highlight > code .token.token.deleted:not(.prefix),
pre > code.diff-highlight .token.token.deleted:not(.prefix) {
background-color: hsla(353, 100%, 66%, 0.15);
}
pre.diff-highlight > code .token.token.deleted:not(.prefix)::-moz-selection,
pre.diff-highlight
> code
.token.token.deleted:not(.prefix)
*::-moz-selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix)::-moz-selection,
pre
> code.diff-highlight
.token.token.deleted:not(.prefix)
*::-moz-selection {
background-color: hsla(353, 95%, 66%, 0.25);
}
pre.diff-highlight > code .token.token.deleted:not(.prefix)::selection,
pre.diff-highlight > code .token.token.deleted:not(.prefix) *::selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix)::selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix) *::selection {
background-color: hsla(353, 95%, 66%, 0.25);
}
pre.diff-highlight > code .token.token.inserted:not(.prefix),
pre > code.diff-highlight .token.token.inserted:not(.prefix) {
background-color: hsla(137, 100%, 55%, 0.15);
}
pre.diff-highlight
> code
.token.token.inserted:not(.prefix)::-moz-selection,
pre.diff-highlight
> code
.token.token.inserted:not(.prefix)
*::-moz-selection,
pre
> code.diff-highlight
.token.token.inserted:not(.prefix)::-moz-selection,
pre
> code.diff-highlight
.token.token.inserted:not(.prefix)
*::-moz-selection {
background-color: hsla(135, 73%, 55%, 0.25);
}
pre.diff-highlight > code .token.token.inserted:not(.prefix)::selection,
pre.diff-highlight > code .token.token.inserted:not(.prefix) *::selection,
pre > code.diff-highlight .token.token.inserted:not(.prefix)::selection,
pre > code.diff-highlight .token.token.inserted:not(.prefix) *::selection {
background-color: hsla(135, 73%, 55%, 0.25);
}
/* Previewers plugin overrides */
/* Based on https://github.com/atom-community/atom-ide-datatip/blob/master/styles/atom-ide-datatips.less and https://github.com/atom/atom/blob/master/packages/one-light-ui */
/* Border around popup */
.prism-previewer.prism-previewer:before,
.prism-previewer-gradient.prism-previewer-gradient div {
border-color: hsl(0, 0, 95%);
}
/* Angle and time should remain as circles and are hence not included */
.prism-previewer-color.prism-previewer-color:before,
.prism-previewer-gradient.prism-previewer-gradient div,
.prism-previewer-easing.prism-previewer-easing:before {
border-radius: 0.3em;
}
/* Triangles pointing to the code */
.prism-previewer.prism-previewer:after {
border-top-color: hsl(0, 0, 95%);
}
.prism-previewer-flipped.prism-previewer-flipped.after {
border-bottom-color: hsl(0, 0, 95%);
}
/* Background colour within the popup */
.prism-previewer-angle.prism-previewer-angle:before,
.prism-previewer-time.prism-previewer-time:before,
.prism-previewer-easing.prism-previewer-easing {
background: hsl(0, 0%, 100%);
}
/* For angle, this is the positive area (eg. 90deg will display one quadrant in this colour) */
/* For time, this is the alternate colour */
.prism-previewer-angle.prism-previewer-angle circle,
.prism-previewer-time.prism-previewer-time circle {
stroke: hsl(230, 8%, 24%);
stroke-opacity: 1;
}
/* Stroke colours of the handle, direction point, and vector itself */
.prism-previewer-easing.prism-previewer-easing circle,
.prism-previewer-easing.prism-previewer-easing path,
.prism-previewer-easing.prism-previewer-easing line {
stroke: hsl(230, 8%, 24%);
}
/* Fill colour of the handle */
.prism-previewer-easing.prism-previewer-easing circle {
fill: transparent;
}
}
}
@media (prefers-color-scheme: dark) {
.prose {
.token.comment,
.token.prolog,
.token.cdata {
color: hsl(220, 10%, 40%);
}
.token.doctype,
.token.punctuation,
.token.entity {
color: hsl(220, 14%, 71%);
}
.token.attr-name,
.token.class-name,
.token.boolean,
.token.constant,
.token.number,
.token.atrule {
color: hsl(29, 54%, 61%);
}
.token.keyword {
color: hsl(286, 60%, 67%);
}
.token.property,
.token.tag,
.token.symbol,
.token.deleted,
.token.important {
color: hsl(355, 65%, 65%);
}
.token.selector,
.token.string,
.token.char,
.token.builtin,
.token.inserted,
.token.regex,
.token.attr-value,
.token.attr-value > .token.punctuation {
color: hsl(95, 38%, 62%);
}
.token.variable,
.token.operator,
.token.function {
color: hsl(207, 82%, 66%);
}
.token.url {
color: hsl(187, 47%, 55%);
}
/* HTML overrides */
.token.attr-value > .token.punctuation.attr-equals,
.token.special-attr > .token.attr-value > .token.value.css {
color: hsl(220, 14%, 71%);
}
/* CSS overrides */
.language-css .token.selector {
color: hsl(355, 65%, 65%);
}
.language-css .token.property {
color: hsl(220, 14%, 71%);
}
.language-css .token.function,
.language-css .token.url > .token.function {
color: hsl(187, 47%, 55%);
}
.language-css .token.url > .token.string.url {
color: hsl(95, 38%, 62%);
}
.language-css .token.important,
.language-css .token.atrule .token.rule {
color: hsl(286, 60%, 67%);
}
/* JS overrides */
.language-javascript .token.operator {
color: hsl(286, 60%, 67%);
}
.language-javascript
.token.template-string
> .token.interpolation
> .token.interpolation-punctuation.punctuation {
color: hsl(5, 48%, 51%);
}
/* JSON overrides */
.language-json .token.operator {
color: hsl(220, 14%, 71%);
}
.language-json .token.null.keyword {
color: hsl(29, 54%, 61%);
}
/* MD overrides */
.language-markdown .token.url,
.language-markdown .token.url > .token.operator,
.language-markdown .token.url-reference.url > .token.string {
color: hsl(220, 14%, 71%);
}
.language-markdown .token.url > .token.content {
color: hsl(207, 82%, 66%);
}
.language-markdown .token.url > .token.url,
.language-markdown .token.url-reference.url {
color: hsl(187, 47%, 55%);
}
.language-markdown .token.blockquote.punctuation,
.language-markdown .token.hr.punctuation {
color: hsl(220, 10%, 40%);
font-style: italic;
}
.language-markdown .token.code-snippet {
color: hsl(95, 38%, 62%);
}
.language-markdown .token.bold .token.content {
color: hsl(29, 54%, 61%);
}
.language-markdown .token.italic .token.content {
color: hsl(286, 60%, 67%);
}
.language-markdown .token.strike .token.content,
.language-markdown .token.strike .token.punctuation,
.language-markdown .token.list.punctuation,
.language-markdown .token.title.important > .token.punctuation {
color: hsl(355, 65%, 65%);
}
/* General */
.token.bold {
font-weight: bold;
}
.token.comment,
.token.italic {
font-style: italic;
}
.token.entity {
cursor: help;
}
.token.namespace {
opacity: 0.8;
}
/* Plugin overrides */
/* Selectors should have higher specificity than those in the plugins' default stylesheets */
/* Show Invisibles plugin overrides */
.token.token.tab:not(:empty):before,
.token.token.cr:before,
.token.token.lf:before,
.token.token.space:before {
color: hsla(220, 14%, 71%, 0.15);
text-shadow: none;
}
/* Toolbar plugin overrides */
/* Space out all buttons and move them away from the right edge of the code block */
div.code-toolbar > .toolbar.toolbar > .toolbar-item {
margin-right: 0.4em;
}
/* Styling the buttons */
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span {
background: hsl(220, 13%, 26%);
color: hsl(220, 9%, 55%);
padding: 0.1em 0.4em;
border-radius: 0.3em;
}
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:focus,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:focus,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:focus {
background: hsl(220, 13%, 28%);
color: hsl(220, 14%, 71%);
}
/* Line Highlight plugin overrides */
/* The highlighted line itself */
.line-highlight.line-highlight {
background: hsla(220, 100%, 80%, 0.04);
}
/* Default line numbers in Line Highlight plugin */
.line-highlight.line-highlight:before,
.line-highlight.line-highlight[data-end]:after {
background: hsl(220, 13%, 26%);
color: hsl(220, 14%, 71%);
padding: 0.1em 0.6em;
border-radius: 0.3em;
box-shadow: 0 2px 0 0 rgba(0, 0, 0, 0.2); /* same as Toolbar plugin default */
}
/* Hovering over a linkable line number (in the gutter area) */
/* Requires Line Numbers plugin as well */
pre[id].linkable-line-numbers.linkable-line-numbers
span.line-numbers-rows
> span:hover:before {
background-color: hsla(220, 100%, 80%, 0.04);
}
/* Line Numbers and Command Line plugins overrides */
/* Line separating gutter from coding area */
.line-numbers.line-numbers .line-numbers-rows,
.command-line .command-line-prompt {
border-right-color: hsla(220, 14%, 71%, 0.15);
}
/* Stuff in the gutter */
.line-numbers .line-numbers-rows > span:before,
.command-line .command-line-prompt > span:before {
color: hsl(220, 14%, 45%);
}
/* Match Braces plugin overrides */
/* Note: Outline colour is inherited from the braces */
.rainbow-braces .token.token.punctuation.brace-level-1,
.rainbow-braces .token.token.punctuation.brace-level-5,
.rainbow-braces .token.token.punctuation.brace-level-9 {
color: hsl(355, 65%, 65%);
}
.rainbow-braces .token.token.punctuation.brace-level-2,
.rainbow-braces .token.token.punctuation.brace-level-6,
.rainbow-braces .token.token.punctuation.brace-level-10 {
color: hsl(95, 38%, 62%);
}
.rainbow-braces .token.token.punctuation.brace-level-3,
.rainbow-braces .token.token.punctuation.brace-level-7,
.rainbow-braces .token.token.punctuation.brace-level-11 {
color: hsl(207, 82%, 66%);
}
.rainbow-braces .token.token.punctuation.brace-level-4,
.rainbow-braces .token.token.punctuation.brace-level-8,
.rainbow-braces .token.token.punctuation.brace-level-12 {
color: hsl(286, 60%, 67%);
}
/* Diff Highlight plugin overrides */
/* Taken from https://github.com/atom/github/blob/master/styles/variables.less */
pre.diff-highlight > code .token.token.deleted:not(.prefix),
pre > code.diff-highlight .token.token.deleted:not(.prefix) {
background-color: hsla(353, 100%, 66%, 0.15);
}
pre.diff-highlight > code .token.token.deleted:not(.prefix)::-moz-selection,
pre.diff-highlight
> code
.token.token.deleted:not(.prefix)
*::-moz-selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix)::-moz-selection,
pre
> code.diff-highlight
.token.token.deleted:not(.prefix)
*::-moz-selection {
background-color: hsla(353, 95%, 66%, 0.25);
}
pre.diff-highlight > code .token.token.deleted:not(.prefix)::selection,
pre.diff-highlight > code .token.token.deleted:not(.prefix) *::selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix)::selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix) *::selection {
background-color: hsla(353, 95%, 66%, 0.25);
}
pre.diff-highlight > code .token.token.inserted:not(.prefix),
pre > code.diff-highlight .token.token.inserted:not(.prefix) {
background-color: hsla(137, 100%, 55%, 0.15);
}
pre.diff-highlight
> code
.token.token.inserted:not(.prefix)::-moz-selection,
pre.diff-highlight
> code
.token.token.inserted:not(.prefix)
*::-moz-selection,
pre
> code.diff-highlight
.token.token.inserted:not(.prefix)::-moz-selection,
pre
> code.diff-highlight
.token.token.inserted:not(.prefix)
*::-moz-selection {
background-color: hsla(135, 73%, 55%, 0.25);
}
pre.diff-highlight > code .token.token.inserted:not(.prefix)::selection,
pre.diff-highlight > code .token.token.inserted:not(.prefix) *::selection,
pre > code.diff-highlight .token.token.inserted:not(.prefix)::selection,
pre > code.diff-highlight .token.token.inserted:not(.prefix) *::selection {
background-color: hsla(135, 73%, 55%, 0.25);
}
/* Previewers plugin overrides */
/* Based on https://github.com/atom-community/atom-ide-datatip/blob/master/styles/atom-ide-datatips.less and https://github.com/atom/atom/blob/master/packages/one-dark-ui */
/* Border around popup */
.prism-previewer.prism-previewer:before,
.prism-previewer-gradient.prism-previewer-gradient div {
border-color: hsl(224, 13%, 17%);
}
/* Angle and time should remain as circles and are hence not included */
.prism-previewer-color.prism-previewer-color:before,
.prism-previewer-gradient.prism-previewer-gradient div,
.prism-previewer-easing.prism-previewer-easing:before {
border-radius: 0.3em;
}
/* Triangles pointing to the code */
.prism-previewer.prism-previewer:after {
border-top-color: hsl(224, 13%, 17%);
}
.prism-previewer-flipped.prism-previewer-flipped.after {
border-bottom-color: hsl(224, 13%, 17%);
}
/* Background colour within the popup */
.prism-previewer-angle.prism-previewer-angle:before,
.prism-previewer-time.prism-previewer-time:before,
.prism-previewer-easing.prism-previewer-easing {
background: hsl(219, 13%, 22%);
}
/* For angle, this is the positive area (eg. 90deg will display one quadrant in this colour) */
/* For time, this is the alternate colour */
.prism-previewer-angle.prism-previewer-angle circle,
.prism-previewer-time.prism-previewer-time circle {
stroke: hsl(220, 14%, 71%);
stroke-opacity: 1;
}
/* Stroke colours of the handle, direction point, and vector itself */
.prism-previewer-easing.prism-previewer-easing circle,
.prism-previewer-easing.prism-previewer-easing path,
.prism-previewer-easing.prism-previewer-easing line {
stroke: hsl(220, 14%, 71%);
}
/* Fill colour of the handle */
.prism-previewer-easing.prism-previewer-easing circle {
fill: transparent;
}
}
}
.prose pre {
contain: layout style;
}
/* Or more aggressively */
.prose pre code {
contain: layout style paint;
}
/* messaging-style typing indicator animation */
@keyframes typing {

View File

@@ -0,0 +1,156 @@
import { createHighlighter } from "shiki";
import type { ThemeRegistration } from "shiki";
const oneLightTheme: ThemeRegistration = {
name: "one-light",
type: "light",
colors: {
"editor.background": "#fafafa",
"editor.foreground": "#383a42",
},
tokenColors: [
{
scope: ["comment", "punctuation.definition.comment"],
settings: { foreground: "#a0a1a7" },
},
{
scope: ["keyword", "storage.type", "storage.modifier"],
settings: { foreground: "#a626a4" },
},
{ scope: ["string", "string.quoted"], settings: { foreground: "#50a14f" } },
{
scope: ["function", "entity.name.function", "support.function"],
settings: { foreground: "#4078f2" },
},
{
scope: [
"constant.numeric",
"constant.language",
"constant.character",
"number",
],
settings: { foreground: "#c18401" },
},
{
scope: ["variable", "support.variable"],
settings: { foreground: "#e45649" },
},
{
scope: ["entity.name.tag", "entity.name.type", "entity.name.class"],
settings: { foreground: "#e45649" },
},
{
scope: ["entity.other.attribute-name"],
settings: { foreground: "#c18401" },
},
{
scope: ["keyword.operator", "operator"],
settings: { foreground: "#a626a4" },
},
{ scope: ["punctuation"], settings: { foreground: "#383a42" } },
{
scope: ["markup.heading"],
settings: { foreground: "#e45649", fontStyle: "bold" },
},
{
scope: ["markup.bold"],
settings: { foreground: "#c18401", fontStyle: "bold" },
},
{
scope: ["markup.italic"],
settings: { foreground: "#a626a4", fontStyle: "italic" },
},
],
};
const oneDarkTheme: ThemeRegistration = {
name: "one-dark",
type: "dark",
colors: {
"editor.background": "#282c34",
"editor.foreground": "#abb2bf",
},
tokenColors: [
{
scope: ["comment", "punctuation.definition.comment"],
settings: { foreground: "#5c6370" },
},
{
scope: ["keyword", "storage.type", "storage.modifier"],
settings: { foreground: "#c678dd" },
},
{ scope: ["string", "string.quoted"], settings: { foreground: "#98c379" } },
{
scope: ["function", "entity.name.function", "support.function"],
settings: { foreground: "#61afef" },
},
{
scope: [
"constant.numeric",
"constant.language",
"constant.character",
"number",
],
settings: { foreground: "#d19a66" },
},
{
scope: ["variable", "support.variable"],
settings: { foreground: "#e06c75" },
},
{
scope: ["entity.name.tag", "entity.name.type", "entity.name.class"],
settings: { foreground: "#e06c75" },
},
{
scope: ["entity.other.attribute-name"],
settings: { foreground: "#d19a66" },
},
{
scope: ["keyword.operator", "operator"],
settings: { foreground: "#c678dd" },
},
{ scope: ["punctuation"], settings: { foreground: "#abb2bf" } },
{
scope: ["markup.heading"],
settings: { foreground: "#e06c75", fontStyle: "bold" },
},
{
scope: ["markup.bold"],
settings: { foreground: "#d19a66", fontStyle: "bold" },
},
{
scope: ["markup.italic"],
settings: { foreground: "#c678dd", fontStyle: "italic" },
},
],
};
export let highlighter: Awaited<ReturnType<typeof createHighlighter>> | null =
null;
export const highlighterPromise = createHighlighter({
themes: [oneLightTheme, oneDarkTheme],
langs: [
"javascript",
"typescript",
"python",
"bash",
"shell",
"json",
"html",
"css",
"tsx",
"jsx",
"go",
"rust",
"java",
"c",
"cpp",
"sql",
"yaml",
"markdown",
],
}).then((h) => {
highlighter = h;
return h;
});

View File

@@ -1,24 +0,0 @@
import { remark } from "remark";
import remarkStringify from "remark-stringify";
import remarkStreamingMarkdown from "./remarkStreamingMarkdown";
/**
* Process markdown content for streaming display using the remark plugin.
* This is primarily used for testing the remark plugin with string inputs/outputs.
*/
export function processStreamingMarkdown(content: string): string {
if (!content) return content;
const result = remark()
.use(remarkStreamingMarkdown, { debug: false })
.use(remarkStringify)
.processSync(content);
// remove trailing newline to keep tests cleaner
let output = result.toString();
if (output.endsWith("\n")) {
output = output.slice(0, -1);
}
return output;
}

View File

@@ -1,447 +0,0 @@
import { parents, type Proxy } from "unist-util-parents";
import type { Plugin } from "unified";
import type {
Emphasis,
Node,
Parent,
Root,
RootContent,
Text,
Strong,
PhrasingContent,
Paragraph,
} from "mdast";
import { u } from "unist-builder";
declare module "unist" {
interface Node {
/** Added by `unist-util-parents` (or your own walk). */
parent?: Proxy & Parent;
}
}
// interface SimpleTextRule {
// pattern: RegExp;
// transform: (matches: RegExpExecArray[], lastNode: Proxy) => void;
// }
// const simpleTextRules: SimpleTextRule[] = [
// // TODO(drifkin): generalize this for `__`/`_`/`~~`/`~` etc.
// {
// pattern: /(\*\*)(?=\S|$)/g,
// transform: (matchesIterator, lastNode) => {
// const textNode = lastNode.node as Text;
// const matches = [...matchesIterator];
// const lastMatch = matches[matches.length - 1];
// const origValue = textNode.value;
// const start = lastMatch.index;
// const sep = lastMatch[1];
// const before = origValue.slice(0, start);
// const after = origValue.slice(start + sep.length);
// if (lastNode.parent) {
// const index = (lastNode.parent.node as Parent).children.indexOf(
// lastNode.node as RootContent,
// );
// const shouldRemove = before.length === 0;
// if (!shouldRemove) {
// textNode.value = before;
// }
// const newNode = u("strong", {
// children: [u("text", { value: after })],
// });
// (lastNode.parent.node as Parent).children.splice(
// index + (shouldRemove ? 0 : 1),
// shouldRemove ? 1 : 0,
// newNode,
// );
// }
// },
// },
// ];
interface Options {
debug?: boolean;
onLastNode?: (info: LastNodeInfo) => void;
}
export interface LastNodeInfo {
path: string[];
type: string;
value?: string;
lastChars?: string;
fullNode: Node;
}
/**
* Removes `child` from `parent` in-place.
* @returns `true` if the child was found and removed; `false` otherwise.
*/
export function removeChildFromParent(
child: RootContent,
parent: Node,
): boolean {
if (!isParent(parent)) return false; // parent isnt a Parent → nothing to do
const idx = parent.children.indexOf(child);
if (idx < 0) return false; // not a child → nothing to remove
parent.children.splice(idx, 1);
return true; // removal successful
}
/** Narrow a generic `Node` to a `Parent` (i.e. one that really has children). */
function isParent(node: Node): node is Parent {
// A `Parent` always has a `children` array; make sure it's an array first.
return Array.isArray((node as Partial<Parent>).children);
}
/**
* Follow “last-child” pointers until you reach a leaf.
* Returns the right-most, deepest node in source order.
*/
export function findRightmostDeepestNode(root: Node): Node {
let current: Node = root;
// While the current node *is* a Parent and has at least one child…
while (isParent(current) && current.children.length > 0) {
const lastIndex = current.children.length - 1;
current = current.children[lastIndex];
}
return current; // Leaf: no further children
}
const remarkStreamingMarkdown: Plugin<[Options?], Root> = () => {
return (tree) => {
const treeWithParents = parents(tree);
const lastNode = findRightmostDeepestNode(treeWithParents) as Proxy;
const parentNode = lastNode.parent;
const grandparentNode = parentNode?.parent;
let ruleMatched = false;
// handling `* *` -> ``
//
// if the last node is part of a <list item (otherwise empty)> ->
// <list (otherwise empty)> -> <list item (last node, empty)>, then we need to
// remove everything up to and including the first list item. This happens
// when we have `* *`, which can become a bolded list item OR a horizontal
// line
if (
lastNode.type === "listItem" &&
parentNode &&
grandparentNode &&
parentNode.type === "list" &&
grandparentNode.type === "listItem" &&
parentNode.children.length === 1 &&
grandparentNode.children.length === 1
) {
ruleMatched = true;
if (grandparentNode.parent) {
removeChildFromParent(
grandparentNode.node as RootContent,
grandparentNode.parent.node,
);
}
// Handle `*` -> ``:
//
// if the last node is just an empty list item, we need to remove it
// because it could become something else (e.g., a horizontal line)
} else if (
lastNode.type === "listItem" &&
parentNode &&
parentNode.type === "list"
) {
ruleMatched = true;
removeChildFromParent(lastNode.node as RootContent, parentNode.node);
} else if (lastNode.type === "thematicBreak") {
ruleMatched = true;
const parent = lastNode.parent;
if (parent) {
removeChildFromParent(lastNode.node as RootContent, parent.node);
}
} else if (lastNode.type === "text") {
const textNode = lastNode.node as Text;
if (textNode.value.endsWith("**")) {
ruleMatched = true;
textNode.value = textNode.value.slice(0, -2);
// if there's a newline then a number, this is very very likely a
// numbered list item. Let's just hide it until the period comes (or
// other text disambiguates it)
} else {
const match = textNode.value.match(/^([0-9]+)$/m);
if (match) {
const number = match[1];
textNode.value = textNode.value.slice(0, -number.length - 1);
ruleMatched = true;
// if the text node is now empty, then we might want to remove other
// elements, like a now-empty containing paragraph, or a break that
// might disappear once more tokens come in
if (textNode.value.length === 0) {
if (
lastNode.parent?.type === "paragraph" &&
lastNode.parent.children.length === 1
) {
// remove the whole paragraph if it's now empty (otherwise it'll
// cause an extra newline that might not last)
removeChildFromParent(
lastNode.parent.node as Paragraph,
lastNode.parent.parent?.node as Node,
);
} else {
const prev = prevSibling(lastNode);
if (prev?.type === "break") {
removeChildFromParent(
prev.node as RootContent,
lastNode.parent?.node as Node,
);
removeChildFromParent(
lastNode.node as RootContent,
lastNode.parent?.node as Node,
);
}
}
}
}
}
}
if (ruleMatched) {
return tree;
}
// we need to
// a case like
// - *def `abc` [abc **def**](abc)*
// is pretty tricky, because if we land just after def, then we actually
// have two separate tags to process at two different parents. Maybe we
// need to keep iterating up until we find a paragraph, but process each
// parent on the way up. Hmm, well actually after `def` we won't even be a proper link yet
// TODO(drifkin): it's really if the last node's parent is a paragraph, for which the following is a sub-cas where the lastNode is a text node.
// And instead of just processing simple text rules, they need to operate on the whole paragraph
// like `**[abc](def)` needs to become `**[abc](def)**`
// if we're just text at the end, then we should remove some ambiguous characters
if (lastNode.parent) {
const didChange = processParent(lastNode.parent as Parent & Proxy);
if (didChange) {
// TODO(drifkin): need to fix up the tree, but not sure lastNode will still exist? Check all the transforms to see if it's safe to find the last node again
//
// need to regen the tree w/ parents since reparenting could've happened
// treeWithParents = parents(tree);
}
}
const grandparent = lastNode.parent?.parent;
// TODO(drifkin): let's go arbitrarily high up the tree, but limiting it
// to 2 levels for now until I think more about the stop condition
if (grandparent) {
processParent(grandparent as Parent & Proxy);
}
// console.log("ruleMatched", ruleMatched);
// } else if (lastNode.parent?.type === "paragraph") {
// console.log("!!! paragraph");
// console.log("lastNode.parent", lastNode.parent);
// // Handle `**abc*` -> `**abc**`:
// // We detect this when the last child is an emphasis node, and it's preceded by a text node that ends with `*`
// const paragraph = lastNode.parent as Proxy & Paragraph;
// if (paragraph.children.length >= 2) {
// const lastChild = paragraph.children[paragraph.children.length - 1];
// if (lastChild.type === "emphasis") {
// const sibling = paragraph.children[paragraph.children.length - 2];
// if (sibling.type === "text") {
// const siblingText = sibling as Text & Proxy;
// if (siblingText.value.endsWith("*")) {
// ruleMatched = true;
// const textNode = (lastNode as Proxy).node as Text;
// textNode.value = textNode.value.slice(0, -1);
// paragraph.node.type = "strong";
// }
// }
// }
// }
// } else if (lastNode.type === "text") {
// // Handle `**abc*` -> `**abc**`:
// //
// // this gets parsed as a text node ending in `*` followed by an emphasis
// // node. So if we're in text, we need to check if our parent is emphasis,
// // and then get our parent's sibling before it and check if it ends with
// // `*`
// const parent = lastNode.parent;
// if (parent && parent.type === "emphasis") {
// const grandparent = parent.parent;
// if (grandparent) {
// const index = (grandparent.node as Parent).children.indexOf(
// parent.node as RootContent,
// );
// if (index > 0) {
// const prevNode = grandparent.children[index - 1];
// if (
// prevNode.type === "text" &&
// (prevNode as Text).value.endsWith("*")
// ) {
// ruleMatched = true;
// const textNode = (prevNode as Proxy).node as Text;
// textNode.value = textNode.value.slice(0, -1);
// parent.node.type = "strong";
// }
// }
// }
// }
// if (!ruleMatched) {
// // if the last node is just text, then we process it in order to fix up certain unclosed items
// // e.g., `**abc` -> `**abc**`
// const textNode = lastNode.node as Text;
// for (const rule of simpleTextRules) {
// const matchesIterator = textNode.value.matchAll(rule.pattern);
// const matches = [...matchesIterator];
// if (matches.length > 0) {
// rule.transform(matches, lastNode);
// ruleMatched = true;
// break;
// }
// }
// }
// } else if (!ruleMatched) {
// // console.log("no rule matched", lastNode);
// }
return tree;
};
};
function processParent(parent: Parent & Proxy): boolean {
if (parent.type === "emphasis") {
// Handle `**abc*` -> `**abc**`:
// We detect this when we end with an emphasis node, and it's preceded by
// a text node that ends with `*`
// TODO(drifkin): the last node can be more deeply nested (e.g., a code
// literal in a link), so we probably need to walk up the tree until we
// find an emphasis node or a block? For now we'll just go up one layer to
// catch the most common cases
const emphasisNode = parent as Emphasis & Proxy;
const grandparent = emphasisNode.parent;
if (grandparent) {
const indexOfEmphasisNode = (grandparent.node as Parent).children.indexOf(
emphasisNode.node as RootContent,
);
if (indexOfEmphasisNode >= 0) {
const nodeBefore = grandparent.children[indexOfEmphasisNode - 1] as
| (Node & Proxy)
| undefined;
if (nodeBefore?.type === "text") {
const textNode = nodeBefore.node as Text;
if (textNode.value.endsWith("*")) {
const strBefore = textNode.value.slice(0, -1);
textNode.value = strBefore;
const strongNode = u("strong", {
children: emphasisNode.children,
});
(grandparent.node as Parent).children.splice(
indexOfEmphasisNode,
1,
strongNode,
);
return true;
}
}
}
}
}
// Let's check if we have any bold items to close
for (let i = parent.children.length - 1; i >= 0; i--) {
const child = parent.children[i];
if (child.type === "text") {
const textNode = child as Text & Proxy;
const sep = "**";
const index = textNode.value.lastIndexOf(sep);
if (index >= 0) {
let isValidOpening = false;
if (index + sep.length < textNode.value.length) {
const charAfter = textNode.value[index + sep.length];
if (!isWhitespace(charAfter)) {
isValidOpening = true;
}
} else {
if (i < parent.children.length - 1) {
// TODO(drifkin): I'm not sure that this check is strict enough.
// We're trying to detect cases like `**[abc]()` where the char
// after the opening ** is indeed a non-whitespace character. We're
// using the heuristic that there's another item after the current
// one, but I'm not sure if that is good enough. In a well
// constructed tree, there aren't two text nodes in a row, so this
// _seems_ good, but I should think through it more
isValidOpening = true;
}
}
if (isValidOpening) {
// TODO(drifkin): close the bold
const strBefore = textNode.value.slice(0, index);
const strAfter = textNode.value.slice(index + sep.length);
(textNode.node as Text).value = strBefore;
// TODO(drifkin): the node above could be empty in which case we probably want to delete it
const children: PhrasingContent[] = [
...(strAfter.length > 0 ? [u("text", { value: strAfter })] : []),
];
const strongNode: Strong = u("strong", {
children,
});
const nodesAfter = (parent.node as Parent).children.splice(
i + 1,
parent.children.length - i - 1,
strongNode,
);
// TODO(drifkin): this cast seems iffy, should see if we can cast the
// parent instead, which would also help us check some of our
// assumptions
strongNode.children.push(...(nodesAfter as PhrasingContent[]));
return true;
}
}
}
}
return false;
}
function prevSibling(node: Node & Proxy): (Node & Proxy) | null {
const parent = node.parent;
if (parent) {
const index = parent.children.indexOf(node);
return parent.children[index - 1] as Node & Proxy;
}
return null;
}
function isWhitespace(str: string) {
return str.trim() === "";
}
// function debugPrintTreeNoPos(tree: Node) {
// console.log(
// JSON.stringify(
// tree,
// (key, value) => {
// if (key === "position") {
// return undefined;
// }
// return value;
// },
// 2,
// ),
// );
// }
export default remarkStreamingMarkdown;

View File

@@ -1794,13 +1794,14 @@ func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, ava
var thinkValue *api.ThinkValue
if think != nil {
// Only set Think if it's actually requesting thinking
if boolValue, ok := think.(bool); ok {
thinkValue = &api.ThinkValue{
Value: boolValue,
if boolValue {
thinkValue = &api.ThinkValue{Value: boolValue}
}
} else if stringValue, ok := think.(string); ok {
thinkValue = &api.ThinkValue{
Value: stringValue,
if stringValue != "" && stringValue != "none" {
thinkValue = &api.ThinkValue{Value: stringValue}
}
}
}

View File

@@ -110,9 +110,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape()
if !strings.HasSuffix(name, ".weight") {
name = name + ".weight"
}
if strings.Contains(name, "ffn_down_exps") {
out = append(out, &ggml.Tensor{
Name: name + ".weight",
Name: name,
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
WriterTo: mxfp4,
@@ -121,12 +124,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
Name: strings.Replace(name, "gate_up", "gate", 1),
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
}, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
Name: strings.Replace(name, "gate_up", "up", 1),
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),

View File

@@ -2,10 +2,12 @@ package convert
import (
"cmp"
"errors"
"io"
"iter"
"path"
"slices"
"strconv"
"strings"
"github.com/pdevine/tensor"
@@ -94,6 +96,26 @@ func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []
return matched
})
slices.SortStableFunc(matched, func(a, b Tensor) int {
x := strings.Split(a.Name(), ".")
y := strings.Split(b.Name(), ".")
if len(x) != len(y) {
return cmp.Compare(len(x), len(y))
}
vals := make([]int, len(x))
for i := range x {
vals[i] = strings.Compare(x[i], y[i])
m, err := strconv.ParseInt(x[i], 0, 0)
n, err2 := strconv.ParseInt(y[i], 0, 0)
if errors.Join(err, err2) == nil {
vals[i] = cmp.Compare(m, n)
}
}
return cmp.Or(vals...)
})
if len(matched) > 0 {
out = append(out, &ggml.Tensor{
Name: merges[i].name,

View File

@@ -3,8 +3,10 @@ package convert
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"iter"
"math/rand/v2"
"slices"
"strings"
"testing"
@@ -951,3 +953,45 @@ func TestMerge(t *testing.T) {
}
})
}
func TestMergeOrder(t *testing.T) {
for range 8 {
t.Run("", func(t *testing.T) {
tensors := make([]Tensor, 16)
for i := range tensors {
tensors[i] = &fakeTensor{
name: fmt.Sprintf("layer.%d.weight", i),
shape: []uint64{1},
data: []float32{float32(i)},
}
}
rand.Shuffle(len(tensors), func(i, j int) {
tensors[i], tensors[j] = tensors[j], tensors[i]
})
matched, unmatched := mergeTensors(tensors, merge{"layer.*.weight", "layer.weight"})
if len(unmatched) != 0 {
t.Error("expected no remaining tensors, got", len(unmatched))
}
if len(matched) != 1 {
t.Error("expected 1 merged tensor, got", len(matched))
}
var b bytes.Buffer
if _, err := matched[0].WriteTo(&b); err != nil {
t.Fatal(err)
}
var f32s [16]float32
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.IsSorted(f32s[:]) {
t.Errorf("merged tensor data is not in order: %+v", f32s)
}
})
}
}

View File

@@ -94,6 +94,9 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
continue
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
continue
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
continue
}
dirs = []string{ml.LibOllamaPath, dir}
} else {

View File

@@ -13,9 +13,23 @@ Embeddings turn text into numeric vectors you can store in a vector database, se
## Generate embeddings
Use `/api/embed` with a single string.
<Tabs>
<Tab title="CLI">
Generate embeddings directly from the command line:
```shell
ollama run embeddinggemma "Hello world"
```
You can also pipe text to generate embeddings:
```shell
echo "Hello world" | ollama run embeddinggemma
```
Output is a JSON array.
</Tab>
<Tab title="cURL">
```shell
curl -X POST http://localhost:11434/api/embed \

View File

@@ -68,6 +68,15 @@ To run Ollama using Docker with AMD GPUs, use the `rocm` tag and the following c
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama:rocm
```
## Vulkan Support
Vulkan is bundled into the `ollama/ollama` image.
```shell
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 -e OLLAMA_VULKAN=1 --name ollama ollama/ollama
```
## Run model locally
Now you can run a model:
@@ -79,3 +88,4 @@ docker exec -it ollama ollama run llama3.2
## Try different models
More models can be found on the [Ollama library](https://ollama.com/library).

View File

@@ -63,6 +63,10 @@
{
"source": "/api/openai",
"destination": "/api/openai-compatibility"
},
{
"source": "/api",
"destination": "/api/introduction"
}
],
"navigation": {
@@ -130,7 +134,7 @@
{
"group": "API Reference",
"pages": [
"/api/index",
"/api/introduction",
"/api/authentication",
"/api/streaming",
"/api/usage",

View File

@@ -223,7 +223,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
## How can I use Ollama in Visual Studio Code?
There is already a large collection of plugins available for VSCode as well as other editors that leverage Ollama. See the list of [extensions & plugins](https://github.com/ollama/ollama#extensions--plugins) at the bottom of the main repository readme.
There is already a large collection of plugins available for VS Code as well as other editors that leverage Ollama. See the list of [extensions & plugins](https://github.com/ollama/ollama#extensions--plugins) at the bottom of the main repository readme.
## How do I use Ollama with GPU acceleration in Docker?

View File

@@ -52,7 +52,11 @@ sudo modprobe nvidia_uvm`
## AMD Radeon
Ollama supports the following AMD GPUs:
Ollama supports the following AMD GPUs via the ROCm library:
> [!NOTE]
> Additional AMD GPU support is provided by the Vulkan Library - see below.
### Linux Support
@@ -121,6 +125,42 @@ In some Linux distributions, SELinux can prevent containers from
accessing the AMD GPU devices. On the host system you can run
`sudo setsebool container_use_devices=1` to allow containers to use devices.
### Metal (Apple GPUs)
## Metal (Apple GPUs)
Ollama supports GPU acceleration on Apple devices via the Metal API.
## Vulkan GPU Support
> [!NOTE]
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
Additional GPU support on Windows and Linux is provided via
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
bundled with Vulkan support and require no additional setup steps. Most Linux
distributions require installing additional components, and you may have
multiple options for Vulkan drivers between Mesa and GPU Vendor specific packages
- Linux Intel GPU Instructions - https://dgpu-docs.intel.com/driver/client/overview.html
- Linux AMD GPU Instructions - https://amdgpu-install.readthedocs.io/en/latest/install-script.html#specifying-a-vulkan-implementation
For AMD GPUs on some Linux distributions, you may need to add the `ollama` user to the `render` group.
The Ollama scheduler leverages available VRAM data reported by the GPU libraries to
make optimal scheduling decisions. Vulkan requires additional capabilities or
running as root to expose this available VRAM data. If neither root access or this
capability are granted, Ollama will use approximate sizes of the models
to make best effort scheduling decisions.
```bash
sudo setcap cap_perfmon+ep /usr/local/bin/ollama
```
### GPU Selection
To select specific Vulkan GPU(s), you can set the environment variable
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
by setting `GGML_VK_VISIBLE_DEVICES=-1`

View File

@@ -4,7 +4,7 @@ title: VS Code
## Install
Install [VSCode](https://code.visualstudio.com/download).
Install [VS Code](https://code.visualstudio.com/download).
## Usage with Ollama
@@ -12,7 +12,7 @@ Install [VSCode](https://code.visualstudio.com/download).
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/vscode-sidebar.png"
alt="VSCode chat Sidebar"
alt="VS Code chat Sidebar"
width="75%"
/>
</div>
@@ -20,7 +20,7 @@ Install [VSCode](https://code.visualstudio.com/download).
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/vscode-models.png"
alt="VSCode model picker"
alt="VS Code model picker"
width="75%"
/>
</div>
@@ -28,7 +28,7 @@ Install [VSCode](https://code.visualstudio.com/download).
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/vscode-model-options.png"
alt="VSCode model options dropdown"
alt="VS Code model options dropdown"
width="75%"
/>
</div>

View File

@@ -2,12 +2,15 @@ openapi: 3.1.0
info:
title: Ollama API
version: 0.1.0
license:
name: MIT
url: https://opensource.org/licenses/MIT
description: |
OpenAPI specification for the Ollama HTTP API
servers:
- url: http://localhost:11434
description: Local Ollama instance
description: Ollama
security: []
components:
securitySchemes:
bearerAuth:
@@ -93,8 +96,11 @@ components:
type: boolean
default: true
think:
type: boolean
description: When true, returns separate thinking output in addition to content
oneOf:
- type: boolean
- type: string
enum: [high, medium, low]
description: When true, returns separate thinking output in addition to content. Can be a boolean (true/false) or a string ("high", "medium", "low") for supported models.
raw:
type: boolean
description: When true, returns the raw response from the model without any prompt templating
@@ -271,8 +277,11 @@ components:
type: boolean
default: true
think:
type: boolean
description: When true, returns separate thinking output in addition to content
oneOf:
- type: boolean
- type: string
enum: [high, medium, low]
description: When true, returns separate thinking output in addition to content. Can be a boolean (true/false) or a string ("high", "medium", "low") for supported models.
keep_alive:
oneOf:
- type: string
@@ -310,7 +319,6 @@ components:
type: array
items:
type: string
nullable: true
description: Optional base64-encoded images in the response
done:
type: boolean
@@ -367,7 +375,6 @@ components:
type: array
items:
type: string
nullable: true
description: Partial base64-encoded images, when present
done:
type: boolean
@@ -543,6 +550,9 @@ components:
license:
type: string
description: The license of the model
modified_at:
type: string
description: Last modified timestamp in ISO 8601 format
details:
type: object
description: High-level model details
@@ -622,6 +632,9 @@ components:
size_vram:
type: integer
description: VRAM usage in bytes
context_length:
type: integer
description: Context length for the running model
PsResponse:
type: object
properties:
@@ -1275,6 +1288,9 @@ paths:
example:
source: gemma3
destination: gemma3-backup
responses:
"200":
description: Model successfully copied
/api/pull:
post:
summary: Pull a model
@@ -1382,16 +1398,7 @@ paths:
model: gemma3
responses:
"200":
description: Deletion status updates.
content:
application/json:
schema:
$ref: "#/components/schemas/StatusResponse"
example:
status: "success"
application/x-ndjson:
schema:
$ref: "#/components/schemas/StatusEvent"
description: Model successfully deleted
/api/version:
get:
summary: Get version

View File

@@ -196,8 +196,6 @@ var (
NoPrune = Bool("OLLAMA_NOPRUNE")
// SchedSpread allows scheduling models across all GPUs.
SchedSpread = Bool("OLLAMA_SCHED_SPREAD")
// IntelGPU enables experimental Intel GPU detection.
IntelGPU = Bool("OLLAMA_INTEL_GPU")
// MultiUserCache optimizes prompt caching for multi-user scenarios
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
// Enable the new Ollama engine
@@ -206,6 +204,8 @@ var (
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
// Auth enables authentication between the Ollama client and server
UseAuth = Bool("OLLAMA_AUTH")
// Enable Vulkan backend
EnableVulkan = Bool("OLLAMA_VULKAN")
)
func String(s string) func() string {
@@ -314,7 +314,7 @@ func AsMap() map[string]EnvVar {
ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which Vulkan devices are visible by numeric ID"}
ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"}
ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"}
ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"}
ret["OLLAMA_VULKAN"] = EnvVar{"OLLAMA_VULKAN", EnableVulkan(), "Enable experimental Vulkan support"}
}
return ret

View File

@@ -797,73 +797,6 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
return
}
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
if llm.KV().Uint("vision.block_count") == 0 {
return
}
for name, layer := range llm.Tensors().GroupLayers() {
if name == "v" || strings.HasPrefix(name, "v.") {
for _, tensor := range layer {
weights += tensor.Size()
}
}
}
imageSize := uint64(llm.KV().Uint("vision.image_size"))
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
if patchSize == 0 {
slog.Warn("unknown patch size for vision model")
return
}
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
switch llm.KV().Architecture() {
case "mllama":
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
graphSize = 4 * (8 +
imageSize*imageSize*numChannels*maxNumTiles +
embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3", "mistral3":
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
case "qwen25vl":
maxPixels := uint64(llm.KV().Uint("vision.max_pixels", 28*28*1280))
numPatches := maxPixels / (patchSize * patchSize)
graphSize = 4 * (maxPixels*numChannels + // Original image storage
// Normalized pixels
maxPixels*numChannels +
// Patches storage (numPatches * channels * patchSize^2)
numPatches*numChannels*patchSize*patchSize +
// Self-attention calculations
numPatches*numPatches*headCount +
// Additional buffer for processing
embeddingLength*numPatches)
case "llama4":
// vision graph is computed independently in the same schedule
// and is negligible compared to the worst case text graph
}
return weights, graphSize
}
// SupportsKVCacheType checks if the requested cache type is supported
func (f GGML) SupportsKVCacheType(cacheType string) bool {
if cacheType == "" || cacheType == "f16" {

View File

@@ -14,6 +14,23 @@ import (
"github.com/ollama/ollama/api"
)
func assertBytesMatchToken(t *testing.T, label, token string, ints []int) {
t.Helper()
raw := []byte(token)
if len(ints) != len(raw) {
t.Errorf("%s expected %d bytes for token %q, got %d (%v)", label, len(raw), token, len(ints), ints)
return
}
for i, b := range raw {
if ints[i] != int(b) {
t.Errorf("%s byte[%d] mismatch for token %q: got %d want %d", label, i, token, ints[i], int(b))
return
}
}
}
func TestAPIGenerate(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 30 * time.Second
@@ -381,3 +398,182 @@ func TestAPIShowModel(t *testing.T) {
t.Errorf("%s missing modified_at: %#v", modelName, resp)
}
}
func TestAPIGenerateLogprobs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, smol); err != nil {
t.Fatalf("pull failed %s", err)
}
enableLogprobs := true
noStream := false
tests := []struct {
name string
logprobs *bool
topLogprobs int
expectCount int
}{
{
name: "no_logprobs",
logprobs: nil,
topLogprobs: 0,
expectCount: 0,
},
{
name: "logprobs_only",
logprobs: &enableLogprobs,
topLogprobs: 0,
expectCount: 1,
},
{
name: "logprobs_with_top_5",
logprobs: &enableLogprobs,
topLogprobs: 5,
expectCount: 1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := api.GenerateRequest{
Model: smol,
Prompt: "Why is the sky blue?",
Stream: &noStream,
Logprobs: test.logprobs != nil && *test.logprobs,
TopLogprobs: test.topLogprobs,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_predict": 10,
},
}
var response api.GenerateResponse
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
if resp.Done {
response = resp
}
return nil
})
if err != nil {
t.Fatalf("generate failed: %s", err)
}
// Check logprobs based on expectation
if test.expectCount == 0 {
if len(response.Logprobs) > 0 {
t.Errorf("expected no logprobs but got %d", len(response.Logprobs))
}
} else {
if len(response.Logprobs) == 0 {
t.Errorf("expected logprobs but got none")
}
// Validate each logprob entry
for i, lp := range response.Logprobs {
if lp.Token == "" {
t.Errorf("logprob[%d] has empty token", i)
}
if lp.Logprob > 0 {
t.Errorf("logprob[%d] has positive logprob %f (should be <= 0)", i, lp.Logprob)
}
assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d]", i), lp.Token, lp.Bytes)
// Check top_logprobs if requested
if test.topLogprobs > 0 {
if len(lp.TopLogprobs) == 0 {
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
}
if len(lp.TopLogprobs) > test.topLogprobs {
t.Errorf("logprob[%d] has %d top_logprobs, expected max %d", i, len(lp.TopLogprobs), test.topLogprobs)
}
// Verify top_logprobs are sorted by probability (descending)
for j := 1; j < len(lp.TopLogprobs); j++ {
if lp.TopLogprobs[j-1].Logprob < lp.TopLogprobs[j].Logprob {
t.Errorf("logprob[%d].top_logprobs not sorted: %f < %f", i, lp.TopLogprobs[j-1].Logprob, lp.TopLogprobs[j].Logprob)
}
}
for j, top := range lp.TopLogprobs {
assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d].top[%d]", i, j), top.Token, top.Bytes)
}
} else if len(lp.TopLogprobs) > 0 {
t.Errorf("logprob[%d] has top_logprobs but none were requested", i)
}
}
}
})
}
}
func TestAPIChatLogprobs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, smol); err != nil {
t.Fatalf("pull failed %s", err)
}
enableLogprobs := true
noStream := false
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{Role: "user", Content: "Say hello in one word"},
},
Stream: &noStream,
Logprobs: enableLogprobs,
TopLogprobs: 3,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_predict": 5,
},
}
var response api.ChatResponse
err := client.Chat(ctx, &req, func(resp api.ChatResponse) error {
if resp.Done {
response = resp
}
return nil
})
if err != nil {
t.Fatalf("chat failed: %s", err)
}
if len(response.Logprobs) == 0 {
t.Fatal("expected logprobs in response but got none")
}
t.Logf("received %d logprobs for chat response", len(response.Logprobs))
for i, lp := range response.Logprobs {
if lp.Token == "" {
t.Errorf("logprob[%d] has empty token", i)
}
if lp.Logprob > 0 {
t.Errorf("logprob[%d] has positive logprob %f", i, lp.Logprob)
}
assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d]", i), lp.Token, lp.Bytes)
if len(lp.TopLogprobs) == 0 {
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
}
if len(lp.TopLogprobs) > 3 {
t.Errorf("logprob[%d] has %d top_logprobs, expected max 3", i, len(lp.TopLogprobs))
}
for j, top := range lp.TopLogprobs {
assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d].top[%d]", i, j), top.Token, top.Bytes)
}
}
}

View File

@@ -63,8 +63,13 @@ func BackendInit() {
C.llama_backend_init()
}
func EnumerateGPUs() []ml.DeviceID {
var ids []ml.DeviceID
type Devices struct {
ml.DeviceID
LlamaID uint64
}
func EnumerateGPUs() []Devices {
var ids []Devices
for i := range C.ggml_backend_dev_count() {
device := C.ggml_backend_dev_get(i)
@@ -74,9 +79,12 @@ func EnumerateGPUs() []ml.DeviceID {
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
var props C.struct_ggml_backend_dev_props
C.ggml_backend_dev_get_props(device, &props)
ids = append(ids, ml.DeviceID{
ID: C.GoString(props.id),
Library: C.GoString(props.library),
ids = append(ids, Devices{
DeviceID: ml.DeviceID{
ID: C.GoString(props.id),
Library: C.GoString(props.library),
},
LlamaID: uint64(i),
})
}
}
@@ -217,7 +225,21 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
return embeddings
}
// GetLogitsIth gets the logits for the ith token
func (c *Context) GetLogitsIth(i int) []float32 {
logits := unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int32_t(i)))
if logits == nil {
return nil
}
vocabSize := c.Model().NumVocab()
result := make([]float32, vocabSize)
_ = copy(result, unsafe.Slice((*float32)(logits), vocabSize))
return result
}
type ModelParams struct {
Devices []uint64
NumGpuLayers int
MainGpu int
UseMmap bool
@@ -241,6 +263,21 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
cparams.use_mmap = C.bool(params.UseMmap)
cparams.vocab_only = C.bool(params.VocabOnly)
var devices []C.ggml_backend_dev_t
for _, llamaID := range params.Devices {
devices = append(devices, C.ggml_backend_dev_get(C.size_t(llamaID)))
}
if len(devices) > 0 {
devices = append(devices, C.ggml_backend_dev_t(C.NULL))
devicesData := &devices[0]
var devicesPin runtime.Pinner
devicesPin.Pin(devicesData)
defer devicesPin.Unpin()
cparams.devices = devicesData
}
if len(params.TensorSplit) > 0 {
tensorSplitData := &params.TensorSplit[0]

View File

@@ -0,0 +1,32 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Wed, 29 Oct 2025 03:53:04 -0500
Subject: [PATCH] vulkan: Call ggml_vk_buffer_write_2d from ggml_vk_buffer_copy
(#16793)
This lets the copy to the destination device use the host-visible
vidmem optimization.
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 221e29509..18b7cbccf 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -5654,14 +5654,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
// Copy device to device
ggml_vk_ensure_sync_staging_buffer(src->device, size);
- ggml_vk_ensure_sync_staging_buffer(dst->device, size);
// Copy to src staging buffer
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
- // memcpy to dst staging buffer
- memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size);
// Copy to dst buffer
- ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size);
+ ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,657 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Wed, 29 Oct 2025 08:44:29 -0500
Subject: [PATCH] vulkan: Update topk_moe fusion to handle gpt's late softmax
(#16656)
* vulkan: Update topk_moe fusion to handle gpt's late softmax
Based on #16649.
* Add ggml_check_edges
* Add sync logging to show fusion effects
* handle clamp added in #16655
* Update ggml/src/ggml-impl.h
Co-authored-by: Diego Devesa <slarengh@gmail.com>
---
ggml/src/ggml-impl.h | 16 +
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 304 +++++++++++-------
.../ggml-vulkan/vulkan-shaders/topk_moe.comp | 90 ++++--
3 files changed, 272 insertions(+), 138 deletions(-)
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index 639d551a2..e5c446d1d 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -693,6 +693,7 @@ GGML_API void ggml_dxgi_pdh_release();
#endif
#ifdef __cplusplus
+#include <array>
#include <initializer_list>
#include <vector>
@@ -708,6 +709,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
}
+// Return true if the edges in the graph match expectations.
+inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
+ int start_idx,
+ std::initializer_list<std::array<int, 3>> edges) {
+ for (const auto & edge : edges) {
+ int dst_node = edge[0];
+ int src_idx = edge[1];
+ int src_node = edge[2];
+ if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
+ return false;
+ }
+ }
+ return true;
+}
+
// expose GGUF internals for test code
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 53b57c179..b2855b078 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -387,12 +387,76 @@ static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t num_topk_moe_pipelines = 10;
-static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
- GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
- GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
-static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
- GGML_OP_VIEW, GGML_OP_GET_ROWS };
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
+ GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+ GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
+ GGML_OP_RESHAPE };
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
+ GGML_OP_VIEW, GGML_OP_GET_ROWS };
+static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
+ GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
+
+//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
+//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
+//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
+//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ]
+//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ]
+//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ]
+//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ]
+//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
+//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
+//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ]
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
+ { 1, 0, 0 }, // reshape->src[0] == softmax
+ { 2, 0, 0 }, // argsort->src[0] == softmax
+ { 3, 0, 2 }, // view->src[0] == argsort
+ { 4, 0, 1 }, // get_rows->src[0] == reshape
+ { 4, 1, 3 }, // get_rows->src[1] == view
+ { 5, 0, 4 }, // reshape->src[0] == get_rows
+ { 6, 0, 5 }, // sum_rows->src[0] == reshape
+ { 7, 0, 6 }, // clamp->src[0] == sum_rows
+ { 8, 0, 5 }, // div->src[0] == reshape
+ { 8, 1, 7 }, // div->src[1] == clamp
+ { 9, 0, 8 }, // reshape->src[0] == div
+};
+
+// same as early_softmax_norm but ending after the get_rows
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
+ { 1, 0, 0 }, // reshape->src[0] == softmax
+ { 2, 0, 0 }, // argsort->src[0] == softmax
+ { 3, 0, 2 }, // view->src[0] == argsort
+ { 4, 0, 1 }, // get_rows->src[0] == reshape
+ { 4, 1, 3 }, // get_rows->src[1] == view
+};
+//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
+//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
+//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
+//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
+//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
+//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
+ { 1, 0, 0 }, // view->src[0] == argsort
+ { 2, 1, 1 }, // get_rows->src[1] == view
+ { 3, 0, 2 }, // reshape->src[0] == get_rows
+ { 4, 0, 3 }, // soft_max->src[0] == reshape
+ { 5, 0, 4 }, // reshape->src[0] == soft_max
+};
+
+enum topk_moe_mode {
+ TOPK_MOE_EARLY_SOFTMAX,
+ TOPK_MOE_EARLY_SOFTMAX_NORM,
+ TOPK_MOE_LATE_SOFTMAX,
+ TOPK_MOE_COUNT,
+};
+
+static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
+ topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
+ num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
+ TOPK_MOE_LATE_SOFTMAX;
+ return mode;
+}
struct vk_device_struct {
std::recursive_mutex mutex;
@@ -607,8 +671,7 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_split_k_reduce;
- // [2] is {!norm, norm}
- vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
+ vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
std::vector<vk_pipeline_ref> all_pipelines;
@@ -956,6 +1019,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_topk_moe_push_constants {
uint32_t n_rows;
uint32_t n_expert_used;
+ float clamp_min;
+ float clamp_max;
};
struct vk_op_add_id_push_constants {
@@ -3806,8 +3871,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
}
for (auto &c : compiles) {
@@ -8085,8 +8151,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
- return ctx->device->pipeline_topk_moe[idx][with_norm];
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+ return ctx->device->pipeline_topk_moe[idx][mode];
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@@ -8141,6 +8207,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
}
case GGML_OP_ARGSORT:
+ if (ctx->num_additional_fused_ops) {
+ uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
+ GGML_ASSERT(idx < num_topk_moe_pipelines);
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+ return ctx->device->pipeline_topk_moe[idx][mode];
+ }
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
return ctx->device->pipeline_argsort_f32[idx];
@@ -9676,10 +9749,12 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
- ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
- ggml_tensor * ids = cgraph->nodes[node_idx + 3];
+ ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
+ (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
+ cgraph->nodes[node_idx + 5];
+ ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
@@ -9738,9 +9813,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
GGML_ASSERT(d_ids != nullptr);
}
- vk_op_topk_moe_push_constants pc;
+ vk_op_topk_moe_push_constants pc {};
pc.n_rows = n_rows;
pc.n_expert_used = n_expert_used;
+ if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
+ ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
+ pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
+ pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
+ }
GGML_ASSERT(n_expert_used <= n_experts);
@@ -11335,7 +11415,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
}
+
+#define ENABLE_SYNC_LOGGING 0
+
if (need_sync) {
+#if ENABLE_SYNC_LOGGING
+ std::cerr << "sync" << std::endl;
+#endif
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx);
@@ -11353,6 +11439,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
}
+#if ENABLE_SYNC_LOGGING
+ if (!dryrun) {
+ for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
+ auto *n = cgraph->nodes[node_idx + i];
+ std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
+ if (n->op == GGML_OP_GLU) {
+ std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
+ }
+ std::cerr << std::endl;
+ }
+ }
+#endif
switch (node->op) {
case GGML_OP_REPEAT:
@@ -11531,7 +11629,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_ARGSORT:
- ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
+ if (ctx->num_additional_fused_ops) {
+ ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
+ } else {
+ ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
+ }
break;
case GGML_OP_SUM:
@@ -12329,30 +12431,27 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
}
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
- int node_idx, bool with_norm) {
+ int node_idx, topk_moe_mode mode) {
- if (with_norm) {
- if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
- return false;
- }
- for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
- if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
- return false;
- }
- }
- } else {
- if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
- return false;
- }
- for (size_t i = 0; i < topk_moe.size(); ++i) {
- if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
- return false;
- }
- }
- }
+ const ggml_tensor * softmax;
+ const ggml_tensor * weights;
- const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
- const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
+ switch (mode) {
+ case TOPK_MOE_EARLY_SOFTMAX_NORM:
+ softmax = cgraph->nodes[node_idx + 0];
+ weights = cgraph->nodes[node_idx + 9];
+ break;
+ case TOPK_MOE_EARLY_SOFTMAX:
+ softmax = cgraph->nodes[node_idx + 0];
+ weights = cgraph->nodes[node_idx + 4];
+ break;
+ case TOPK_MOE_LATE_SOFTMAX:
+ softmax = cgraph->nodes[node_idx + 4];
+ weights = cgraph->nodes[node_idx + 5];
+ break;
+ default:
+ return false;
+ }
const float * op_params = (const float *)softmax->op_params;
@@ -12378,60 +12477,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
return false;
}
- // Check that the nodes don't have any unexpected uses
- const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
- const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
- const ggml_tensor * view = cgraph->nodes[node_idx + 3];
- const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
- const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
- const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
- const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
- const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
-
- // softmax is used by reshape and argsort
- if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
- reshape1->src[0] != softmax ||
- argsort->src[0] != softmax) {
- return false;
- }
- // reshape is used by get_rows
- if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
- get_rows->src[0] != reshape1) {
- return false;
- }
- // argsort is used by view
- if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
- view->src[0] != argsort) {
- return false;
- }
- // view is written (via argsort), we can skip checking it
-
- if (with_norm) {
- // get_rows is used by reshape
- if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
- reshape5->src[0] != get_rows) {
- return false;
- }
-
- // reshape is used by sum_rows and div
- if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
- sum_rows->src[0] != reshape5 ||
- div->src[0] != reshape5) {
- return false;
- }
-
- // sum_rows is used by div
- if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
- div->src[1] != sum_rows) {
- return false;
- }
-
- // div/reshape are written
- if (reshape8->src[0] != div) {
- return false;
- }
- }
-
if (!ctx->device->subgroup_arithmetic ||
!ctx->device->subgroup_shuffle ||
!ctx->device->subgroup_require_full_support ||
@@ -12517,10 +12562,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
- ctx->num_additional_fused_ops = topk_moe.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
}
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -12618,10 +12671,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
- ctx->num_additional_fused_ops = topk_moe.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
+ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
}
}
@@ -12754,25 +12815,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
while (first_unused < graph->n_nodes) {
std::vector<int> current_set;
- // Avoid reordering topk_moe_norm
- if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
- bool is_topk_moe_norm = true;
- for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
- if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
- is_topk_moe_norm = false;
+ // Check for fusion patterns and avoid reordering them
+ auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
+ if (start + (int)pattern.size() <= graph->n_nodes) {
+ bool is_pattern = true;
+ for (size_t j = 0; j < pattern.size(); ++j) {
+ if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
+ is_pattern = false;
+ }
}
+ return is_pattern;
}
- if (is_topk_moe_norm) {
- for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
+ return false;
+ };
+
+ auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
+ if (match_pattern(pattern, first_unused)) {
+ for (size_t j = 0; j < pattern.size(); ++j) {
new_order.push_back(graph->nodes[first_unused + j]);
used[first_unused + j] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
first_unused++;
}
- continue;
+ return true;
}
+ return false;
+ };
+
+ if (keep_pattern(topk_moe_early_softmax_norm)) {
+ continue;
+ }
+ if (keep_pattern(topk_moe_early_softmax)) {
+ continue;
}
+ if (keep_pattern(topk_moe_late_softmax)) {
+ continue;
+ }
+
// First, grab the next unused node.
current_set.push_back(first_unused);
@@ -12790,6 +12870,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
if (is_empty(graph->nodes[j])) {
continue;
}
+ // Don't pull forward nodes from fusion patterns
+ if (match_pattern(topk_moe_early_softmax_norm, j) ||
+ match_pattern(topk_moe_early_softmax, j) ||
+ match_pattern(topk_moe_late_softmax, j)) {
+ continue;
+ }
bool ok = true;
for (int c = first_unused; c < j; ++c) {
if (!used[c] &&
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
index 9e56d5f8a..bc1c278bf 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
{
uint n_rows;
uint n_expert_used;
+ float clamp_min;
+ float clamp_max;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 2) const bool with_norm = true;
+layout(constant_id = 3) const bool late_softmax = false;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
@@ -25,53 +28,72 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
-void main() {
- const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
- if (row >= n_rows) {
- return;
- }
+const float INFINITY = 1.0 / 0.0;
- const uint logits_offset = n_experts * row;
- const uint weights_offset = n_expert_used * row;
- const uint ids_offset = n_experts * row;
-
- float logits_r[experts_per_thread];
-
- const float INFINITY = 1.0 / 0.0;
+// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
+void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
+ float max_val = -INFINITY;
[[unroll]]
- for (uint i = 0; i < n_experts; i += WARP_SIZE) {
- const uint expert = i + gl_LocalInvocationID.x;
- logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
+ for (int i = 0; i < experts_per_thread; i++) {
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ max_val = max(max_val, vals[i]);
+ }
}
- float max_val = logits_r[0];
+ max_val = subgroupMax(max_val);
+
+ float sum = 0.f;
[[unroll]]
- for (int i = 1; i < experts_per_thread; i++) {
- const float val = logits_r[i];
- max_val = max(val, max_val);
+ for (int i = 0; i < experts_per_thread; i++) {
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ const float val = exp(vals[i] - max_val);
+ vals[i] = val;
+ sum += val;
+ } else {
+ vals[i] = 0.f;
+ }
}
- max_val = subgroupMax(max_val);
+ sum = subgroupAdd(sum);
- float wt[experts_per_thread];
- float tmp = 0.f;
+ const float inv_sum = 1.0f / sum;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
- const float val = logits_r[i];
- wt[i] = exp(val - max_val);
- tmp += wt[i];
+ const uint idx = lane + i * WARP_SIZE;
+ const bool is_active = !use_limit || (idx < limit);
+ if (is_active) {
+ vals[i] *= inv_sum;
+ }
}
+}
- tmp = subgroupAdd(tmp);
+void main() {
+ const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
+ if (row >= n_rows) {
+ return;
+ }
- const float inv_sum = 1.0f / tmp;
+ const uint logits_offset = n_experts * row;
+ const uint weights_offset = n_expert_used * row;
+ const uint ids_offset = n_experts * row;
+
+ float wt[experts_per_thread];
[[unroll]]
- for (int i = 0; i < experts_per_thread; i++) {
- wt[i] = wt[i] * inv_sum;
+ for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+ const uint expert = i + gl_LocalInvocationID.x;
+ wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
+ }
+
+ if (!late_softmax) {
+ softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
}
// at this point, each thread holds a portion of softmax,
@@ -82,6 +104,11 @@ void main() {
float output_weights[experts_per_thread];
+ [[unroll]]
+ for (int i = 0; i < experts_per_thread; i++) {
+ output_weights[i] = 0.f;
+ }
+
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
uint max_expert = gl_LocalInvocationID.x;
@@ -121,6 +148,7 @@ void main() {
if (with_norm) {
wt_sum = subgroupAdd(wt_sum);
+ wt_sum = clamp(wt_sum, clamp_min, clamp_max);
const float inv_sum = 1.0f / wt_sum;
[[unroll]]
@@ -129,6 +157,10 @@ void main() {
}
}
+ if (late_softmax) {
+ softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
+ }
+
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,85 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Thu, 30 Oct 2025 01:27:41 -0500
Subject: [PATCH] vulkan: Handle argsort with a large number of rows (#16851)
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++++
ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp | 16 ++++++++++++----
2 files changed, 16 insertions(+), 4 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index aaf4334b5..3604ceb04 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -1084,6 +1084,7 @@ struct vk_op_soft_max_push_constants {
struct vk_op_argsort_push_constants {
uint32_t ncols;
+ uint32_t nrows;
int32_t order;
};
@@ -8710,6 +8711,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
break;
case GGML_OP_ARGSORT:
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
break;
case GGML_OP_IM2COL:
{
@@ -9952,9 +9954,11 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
int32_t * op_params = (int32_t *)dst->op_params;
uint32_t ncols = src0->ne[0];
+ uint32_t nrows = ggml_nrows(src0);
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
ncols,
+ nrows,
op_params[0],
}, dryrun);
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
index c81b84452..c4e68bc02 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
+ uint nrows;
uint order;
} p;
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
dst_row[idx1] = tmp;
}
-void argsort(bool needs_bounds_check) {
+void argsort(bool needs_bounds_check, const uint row) {
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
- const uint row = gl_WorkGroupID.y;
const uint row_offset = row * p.ncols;
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
void main() {
if (p.ncols == BLOCK_SIZE) {
- argsort(false);
+ uint row = gl_WorkGroupID.y;
+ while (row < p.nrows) {
+ argsort(false, row);
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
} else {
- argsort(true);
+ uint row = gl_WorkGroupID.y;
+ while (row < p.nrows) {
+ argsort(true, row);
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
}
}

View File

@@ -0,0 +1,77 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Ruben Ortlam <picard12@live.de>
Date: Fri, 31 Oct 2025 08:14:49 +0100
Subject: [PATCH] vulkan: fix shmem overrun in mmq id shader (#16873)
* vulkan: fix shmem overrun in mmq id shader
* metal : fix mul_mm_id
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
---
ggml/src/ggml-metal/ggml-metal-device.cpp | 2 +-
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp | 4 ++++
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl | 2 +-
tests/test-backend-ops.cpp | 3 +++
4 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
index 758116342..c78082ac3 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
char name[256];
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
- snprintf(name, 256, "%s", base);
+ snprintf(name, 256, "%s_ne02=%d", base, ne02);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
index 8b238ac4b..d955b4fc7 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
@@ -82,9 +82,13 @@ layout (constant_id = 10) const uint WARP = 32;
#include "mul_mmq_shmem_types.glsl"
+#ifdef MUL_MAT_ID
+#define BK_STEP 1
+#else
#ifndef BK_STEP
#define BK_STEP 4
#endif
+#endif
// Shared memory cache
shared block_a_cache buf_a[BM * BK_STEP];
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
index 72fec4404..1c0f5306f 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
@@ -27,7 +27,7 @@ struct block_a_cache {
#elif defined(DATA_A_Q8_0)
#define QUANT_R_MMQ 1
// AMD likes 4, Intel likes 1 and Nvidia likes 2
-#define BK_STEP 1
+// #define BK_STEP 1
struct block_a_cache {
int32_t qs[32/4];
FLOAT_TYPE dm;
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 657b6cc2f..1f8dda383 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -6722,6 +6722,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
+ // gpt-oss issue with Vulkan mmq_id
+ test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
+
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {4, 8}) {

View File

@@ -0,0 +1,80 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Masato Nakasaka <masato.nakasaka@intel.com>
Date: Fri, 31 Oct 2025 16:18:59 +0900
Subject: [PATCH] vulkan: Fix crash when FP16 mul_mat accumulation is not
supported (#16796)
* Experimenting crash fix
* added assert for aborting and fixed comment
* changed to check if a pipeline is empty or not
* Moved function in class definition
* replaced with is_empty
* Modified is_empty to check only unaligned pipelines
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 +++++++++++++-------
1 file changed, 13 insertions(+), 7 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 3604ceb04..80185d9f0 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -146,8 +146,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
struct vk_matmul_pipeline_struct {
vk_pipeline l, m, s;
vk_pipeline a_l, a_m, a_s;
+ // Returns true when all unaligned pipelines are null.
+ // We only check for unaligned variants since one of the unaligned pipelines must exist
+ // while aligned pipelines are optional
+ bool is_empty() const {
+ return l == nullptr && m == nullptr && s == nullptr;
+ }
};
-
typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
struct vk_matmul_pipeline2 {
@@ -5080,7 +5085,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+ if (pipelines->is_empty()) {
return nullptr;
}
@@ -5229,7 +5234,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+ if (pipelines->is_empty()) {
return nullptr;
}
@@ -5264,16 +5269,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
return nullptr;
}
+ vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
// XXX TODO 'prec' is not actually allowed in mul_mat_id.
bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
- bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr;
- bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr;
+ bool support_fp16acc = !mmp.f16acc->is_empty();
+ bool support_fp32acc = !mmp.f32acc->is_empty();
if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc;
+ return mmp.f16acc;
} else {
GGML_ASSERT(support_fp32acc);
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
+ return mmp.f32acc;
}
}

View File

@@ -1,516 +0,0 @@
package llm
import (
"fmt"
"log/slog"
"os"
"slices"
"sort"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
)
// pickBestFullFitByLibrary will try to find the optimal placement of the model in the available GPUs where the model fully fits
// The list of GPUs returned will always be the same brand (library)
// If the model can not be fit fully within the available GPU(s) nil is returned
func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []string, adapters []string, opts api.Options, gpus []ml.DeviceInfo, numParallel int) []ml.DeviceInfo {
for _, gl := range ml.ByLibrary(gpus) {
sgl := append(make([]ml.DeviceInfo, 0, len(gl)), gl...)
// TODO - potentially sort by performance capability, existing models loaded, etc.
// TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them
// Note: at present, this will favor most current available VRAM descending and ignoring faster GPU speed in mixed setups
sort.Sort(sort.Reverse(ml.ByFreeMemory(sgl)))
if !envconfig.SchedSpread() {
// Try to pack into as few GPUs as possible, starting from 1 GPU
for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ {
gpuSubset := sgl[:numGPUs]
ok, estimatedVRAM := predictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel)
if ok {
slog.Info("new model will fit in available VRAM across minimum required GPUs, loading",
"model", modelPath,
"library", sgl[0].Library,
"parallel", numParallel,
"required", format.HumanBytes2(estimatedVRAM),
"gpus", numGPUs)
return gpuSubset
}
}
} else {
// TODO future refinements
// - if multiple Libraries, see if any single GPU in any Library will fit
// - try subsets of GPUs instead of just falling back to 1 or all in a family
// Now try all the GPUS (OLLAMA_SCHED_SPREAD is set)
if ok, estimatedVRAM := predictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok {
slog.Info("new model will fit in available VRAM, loading",
"model", modelPath,
"library", sgl[0].Library,
"parallel", numParallel,
"required", format.HumanBytes2(estimatedVRAM),
"gpus", len(sgl))
return sgl
}
}
}
return nil
}
// If multiple Libraries are detected, pick the Library which loads the most layers for the model
func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []string, opts api.Options, gpus []ml.DeviceInfo, numParallel int) []ml.DeviceInfo {
byLibrary := ml.ByLibrary(gpus)
if len(byLibrary) <= 1 {
return gpus
}
var bestEstimate uint64
var bestFit int
for i, gl := range byLibrary {
_, estimatedVRAM := predictServerFit(gl, f, adapters, projectors, opts, numParallel)
if estimatedVRAM > bestEstimate {
bestEstimate = estimatedVRAM
bestFit = i
}
}
return byLibrary[bestFit]
}
// This algorithm looks for a complete fit to determine if we need to unload other models
func predictServerFit(allGpus []ml.DeviceInfo, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
// Split up the GPUs by type and try them
var estimatedVRAM uint64
for _, gpus := range ml.ByLibrary(allGpus) {
var layerCount int
estimate := estimateGPULayers(gpus, f, projectors, opts, numParallel)
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
if opts.NumGPU < 0 {
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
return true, estimatedVRAM
}
} else {
if layerCount > 0 && layerCount >= opts.NumGPU {
return true, estimatedVRAM
}
}
}
return false, estimatedVRAM
}
func verifyCPUFit(f *ggml.GGML, modelPath string, projectors []string, adapters []string, opts api.Options, systemInfo ml.SystemInfo, numParallel int) bool {
estimate := estimateGPULayers(nil, f, projectors, opts, numParallel)
if estimate.TotalSize > systemInfo.FreeMemory {
return false
}
slog.Info("new model will fit in available system memory for CPU inference, loading",
"model", modelPath,
"parallel", numParallel,
"required", format.HumanBytes2(estimate.TotalSize),
)
return true
}
type MemoryEstimate struct {
// How many layers we predict we can load
Layers int
// The size of the graph which occupies the main GPU
Graph uint64
// How much VRAM will be allocated given the number of layers we predict
VRAMSize uint64
// The total size of the model if loaded into VRAM. If all layers are loaded, VRAMSize == TotalSize
TotalSize uint64
// For multi-GPU scenarios, this provides the tensor split parameter
TensorSplit []int
// For multi-GPU scenarios, this is the size in bytes per GPU
GPUSizes []uint64
// internal fields for logging purposes
inferenceLibrary string
layersRequested int
layersModel int
availableList []string
kv uint64
allocationsList []string
memoryWeights uint64
memoryLayerOutput uint64
graphFullOffload uint64
graphPartialOffload uint64
projectorWeights, projectorGraph uint64
}
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
// The GPUs provided must all be the same Library
func estimateGPULayers(gpus []ml.DeviceInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
// Graph size for a partial offload, applies to all GPUs
var graphPartialOffload uint64
// Graph size when all layers are offloaded, applies to all GPUs
var graphFullOffload uint64
// Final graph offload once we know full or partial
var graphOffload uint64
// Projectors loaded into GPU0 only
var llamaEngineProjectorWeights uint64
// Projectors loaded with output layer
var ollamaEngineProjectorWeights uint64
var ollamaEngineProjectorGraph uint64
// Conditional output size on GPU 0
var memoryLayerOutput uint64
// The sizes of a layer
var layerSize uint64
// The sum of all the layer sizes (just for logging)
var memoryWeights uint64
// True if all the layers are loaded
var fullyLoaded bool
// Overflow that didn't fit into the GPU
var overflow uint64
overhead := envconfig.GpuOverhead()
availableList := make([]string, len(gpus))
libraries := []string{}
for i, gpu := range gpus {
availableList[i] = format.HumanBytes2(gpu.FreeMemory)
if !slices.Contains(libraries, gpu.Library) {
libraries = append(libraries, gpu.Library)
}
}
if len(libraries) == 0 {
libraries = []string{"cpu"}
}
slog.Debug("evaluating", "library", strings.Join(libraries, ","), "gpu_count", len(gpus), "available", availableList)
for _, projector := range projectors {
llamaEngineProjectorWeights += projectorMemoryRequirements(projector)
}
if llamaEngineProjectorWeights == 0 {
ollamaEngineProjectorWeights, ollamaEngineProjectorGraph = f.VisionGraphSize()
}
layers := f.Tensors().GroupLayers()
// add one layer worth of memory as a buffer
if blk0, ok := layers["blk.0"]; ok {
layerSize = blk0.Size()
} else {
slog.Warn("model missing blk.0 layer size")
}
useFlashAttention := envconfig.FlashAttention(f.FlashAttention()) &&
ml.FlashAttentionSupported(gpus) &&
f.SupportsFlashAttention()
var kvct string
if useFlashAttention {
requested := strings.ToLower(envconfig.KvCacheType())
if f.SupportsKVCacheType(requested) {
kvct = requested
}
}
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention)
if len(kv) > 0 {
layerSize += kv[0]
}
var kvTotal uint64
for _, kvLayer := range kv {
kvTotal += kvLayer
}
if graphPartialOffload == 0 {
headsKV := f.KV().HeadCountKVMin()
if headsKV == 0 {
headsKV = 1
}
gqa := f.KV().HeadCountMax() / headsKV
graphPartialOffload = gqa * kvTotal / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
// on metal there's no partial offload overhead
if len(gpus) > 0 && gpus[0].Library == "Metal" {
graphPartialOffload = graphFullOffload
} else if len(gpus) > 1 {
// multigpu should always use the partial graph size
graphFullOffload = graphPartialOffload
}
// Output layer handled at the end if we have space
if layer, ok := layers["output_norm"]; ok {
memoryLayerOutput += layer.Size()
}
if layer, ok := layers["output"]; ok {
memoryLayerOutput += layer.Size()
} else if layer, ok := layers["token_embd"]; ok {
memoryLayerOutput += layer.Size()
}
gpuZeroOverhead := llamaEngineProjectorWeights
// Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer
var layerCount int
tensorSplit := make([]int, len(gpus))
gpuAllocations := make([]uint64, len(gpus))
type gs struct {
i int
g *ml.DeviceInfo
}
gpusWithSpace := []gs{}
for i := range gpus {
var gzo uint64
if len(gpusWithSpace) == 0 {
gzo = gpuZeroOverhead
}
// Only include GPUs that can fit the graph, gpu minimum, the layer buffer and at least more layer
if gpus[i].FreeMemory < overhead+gzo+max(graphPartialOffload, graphFullOffload)+gpus[i].MinimumMemory()+2*layerSize {
slog.Debug("gpu has too little memory to allocate any layers",
"id", gpus[i].ID,
"library", gpus[i].Library,
"compute", gpus[i].Compute(),
"driver", fmt.Sprintf("%d.%d", gpus[i].DriverMajor, gpus[i].DriverMinor),
"name", gpus[i].Name,
"total", format.HumanBytes2(gpus[i].TotalMemory),
"available", format.HumanBytes2(gpus[i].FreeMemory),
"minimum_memory", gpus[i].MinimumMemory,
"layer_size", format.HumanBytes2(layerSize),
"gpu_zer_overhead", format.HumanBytes2(gzo),
"partial_offload", format.HumanBytes2(graphPartialOffload),
"full_offload", format.HumanBytes2(graphFullOffload),
)
continue
}
gpusWithSpace = append(gpusWithSpace, gs{i, &gpus[i]})
gpuAllocations[i] += gpus[i].MinimumMemory() + layerSize // We hold off on graph until we know partial vs. full
}
var gpuZeroID int
if len(gpusWithSpace) > 0 {
gpuZeroID = gpusWithSpace[0].i
gpuAllocations[gpuZeroID] += gpuZeroOverhead
} else {
overflow += gpuZeroOverhead
}
// For all the layers, find where they can fit on the GPU(s)
for i := int(f.KV().BlockCount()) - 1; i >= 0; i-- {
// Some models have inconsistent layer sizes
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
layerSize = blk.Size()
layerSize += kv[i]
memoryWeights += blk.Size()
}
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
// Stop allocating on GPU(s) once we hit the users target NumGPU
overflow += layerSize
continue
}
// distribute the layers across the GPU(s) that have space
for j := len(gpusWithSpace); j > 0; j-- {
g := gpusWithSpace[i%j]
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
if g.g.FreeMemory > overhead+used+layerSize {
gpuAllocations[g.i] += layerSize
tensorSplit[g.i]++
layerCount++
break
} else {
gpusWithSpace = append(gpusWithSpace[:i%j], gpusWithSpace[i%j+1:]...)
}
}
if len(gpusWithSpace) == 0 {
overflow += layerSize
}
}
if layerCount >= int(f.KV().BlockCount()) {
fullyLoaded = true
}
// Determine if we need to consider output then find where it fits
memoryLastLayer := memoryLayerOutput + ollamaEngineProjectorWeights + ollamaEngineProjectorGraph
if memoryLastLayer > 0 {
if opts.NumGPU < 0 || layerCount < opts.NumGPU {
for j := len(gpusWithSpace); j > 0; j-- {
g := gpusWithSpace[layerCount%j]
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
if g.g.FreeMemory > overhead+used+memoryLastLayer {
gpuAllocations[g.i] += memoryLastLayer
tensorSplit[g.i]++
layerCount++
break
}
}
}
if layerCount < int(f.KV().BlockCount())+1 {
fullyLoaded = false
overflow += memoryLastLayer
}
}
// Add the applicable (full or partial) graph allocations
for i := range gpus {
if tensorSplit[i] <= 0 {
continue
}
if fullyLoaded {
gpuAllocations[i] += graphFullOffload
} else {
gpuAllocations[i] += graphPartialOffload
}
}
if fullyLoaded {
graphOffload = graphFullOffload
} else {
graphOffload = graphPartialOffload
}
// Summaries for the log
var memoryRequiredPartial, memoryRequiredTotal uint64
for i := range gpuAllocations {
memoryRequiredPartial += gpuAllocations[i]
}
memoryRequiredTotal = memoryRequiredPartial + overflow
allocationsList := []string{}
for _, a := range gpuAllocations {
allocationsList = append(allocationsList, format.HumanBytes2(a))
}
estimate := MemoryEstimate{
TotalSize: memoryRequiredTotal,
Layers: 0,
Graph: 0,
VRAMSize: 0,
GPUSizes: []uint64{},
inferenceLibrary: strings.Join(libraries, ","),
layersRequested: opts.NumGPU,
layersModel: int(f.KV().BlockCount()) + 1,
availableList: availableList,
kv: kvTotal,
allocationsList: allocationsList,
memoryWeights: memoryWeights,
memoryLayerOutput: memoryLayerOutput,
graphFullOffload: graphFullOffload,
graphPartialOffload: graphPartialOffload,
projectorWeights: llamaEngineProjectorWeights + ollamaEngineProjectorWeights,
projectorGraph: ollamaEngineProjectorGraph,
}
if len(gpus) == 0 {
return estimate
}
if layerCount == 0 {
slog.Debug("insufficient VRAM to load any model layers")
return estimate
}
estimate.Layers = layerCount
estimate.Graph = graphOffload
estimate.VRAMSize = memoryRequiredPartial
estimate.TotalSize = memoryRequiredTotal
estimate.TensorSplit = tensorSplit
estimate.GPUSizes = gpuAllocations
return estimate
}
func (m MemoryEstimate) LogValue() slog.Value {
attrs := []slog.Attr{
slog.String("library", m.inferenceLibrary),
slog.Group(
"layers",
// requested number of layers to offload
"requested", m.layersRequested,
// The number of layers the model has (including output)
"model", m.layersModel,
// estimated number of layers that can be offloaded
"offload", m.Layers,
// multi-gpu split for tensors
"split", m.TensorSplit,
),
slog.Group(
"memory",
// memory available by GPU for offloading
"available", m.availableList,
"gpu_overhead", format.HumanBytes2(envconfig.GpuOverhead()),
slog.Group(
"required",
// memory required for full offloading
"full", format.HumanBytes2(m.TotalSize),
// memory required to offload layers.estimate layers
"partial", format.HumanBytes2(m.VRAMSize),
// memory of KV cache
"kv", format.HumanBytes2(m.kv),
// Allocations across the GPUs
"allocations", m.allocationsList,
),
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(m.memoryWeights+m.memoryLayerOutput),
// memory of repeating layers
"repeating", format.HumanBytes2(m.memoryWeights),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
),
slog.Group(
"graph",
// memory of graph when fully offloaded
"full", format.HumanBytes2(m.graphFullOffload),
// memory of graph when not fully offloaded
"partial", format.HumanBytes2(m.graphPartialOffload),
),
),
}
if m.projectorWeights > 0 {
attrs = append(attrs, slog.Group(
"projector",
"weights", format.HumanBytes2(m.projectorWeights),
"graph", format.HumanBytes2(m.projectorGraph),
))
}
return slog.GroupValue(attrs...)
}
func projectorMemoryRequirements(filename string) (weights uint64) {
file, err := os.Open(filename)
if err != nil {
return 0
}
defer file.Close()
ggml, err := ggml.Decode(file, 1024)
if err != nil {
return 0
}
for _, layer := range ggml.Tensors().GroupLayers() {
weights += layer.Size()
}
return weights
}

View File

@@ -1,130 +0,0 @@
package llm
import (
"bytes"
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
)
func TestEstimateGPULayers(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", "1")
t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16
t.Setenv("OLLAMA_CONTEXT_LENGTH", "2048")
modelName := "dummy"
f, err := os.CreateTemp(t.TempDir(), modelName)
require.NoError(t, err)
defer f.Close()
inputLayerCount := 5
tensors := []*ggml.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
}
assert.Len(t, tensors, inputLayerCount+1)
err = ggml.WriteGGUF(f, ggml.KV{
"general.architecture": "llama",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
"llama.block_count": uint32(inputLayerCount),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(32),
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, tensors)
require.NoError(t, err)
ggml, err := LoadModel(f.Name(), 0)
if err != nil {
t.Fatal(err)
}
// Simple CPU scenario
gpus := []ml.DeviceInfo{}
projectors := []string{}
opts := api.DefaultOptions()
t.Run("cpu", func(t *testing.T) {
estimate := estimateGPULayers(gpus, ggml, projectors, opts, 1)
assert.Equal(t, 0, estimate.Layers)
assert.Equal(t, uint64(0), estimate.Graph)
})
// derived from the dummy ggml file above
graphPartialOffload := uint64(202377216)
graphFullOffload := uint64(171968512)
layerSize := uint64(33554436)
projectorSize := uint64(0)
memoryLayerOutput := uint64(4)
// Dual CUDA scenario with asymmetry
gpuMinimumMemory := uint64(457 * format.MebiByte)
gpus = []ml.DeviceInfo{
{
DeviceID: ml.DeviceID{
Library: "CUDA",
},
},
{
DeviceID: ml.DeviceID{
Library: "CUDA",
},
},
}
// Nested array: GPU0 layer space, GPU1 layer space, expected gpu0, expected gpu1
for i, s := range []struct {
layer0, layer1 uint64
expect0, expect1 int
}{
{1, 1, 1, 1},
{2, 1, 2, 1},
{2, 2, 2, 2},
{1, 2, 1, 2},
{3, 3, 3, 3},
{4, 4, 3, 3},
{6, 6, 3, 3},
{0, 3, 0, 3},
} {
t.Run(fmt.Sprintf("%v", s), func(t *testing.T) {
gpus[0].FreeMemory = 0
gpus[1].FreeMemory = 0
gpus[0].FreeMemory += projectorSize
if s.layer0 > 0 {
gpus[0].FreeMemory += memoryLayerOutput
} else {
gpus[1].FreeMemory += memoryLayerOutput
}
gpus[0].FreeMemory += gpuMinimumMemory + layerSize + s.layer0*layerSize + 1
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
estimate := estimateGPULayers(gpus, ggml, projectors, opts, 1)
assert.Equal(t, s.expect0+s.expect1, estimate.Layers, "scenario %d: %v", i, s)
assert.Equal(t, []int{s.expect0, s.expect1}, estimate.TensorSplit, "scenario %d: %v", i, s)
var layerSums uint64
for _, b := range estimate.GPUSizes {
layerSums += b
}
if estimate.Layers < inputLayerCount+1 {
assert.Less(t, estimate.VRAMSize, estimate.TotalSize, "scenario %d: %v %+v", i, s, estimate)
assert.Equal(t, estimate.VRAMSize, layerSums, "scenario %d: %v %+v", i, s, estimate)
} else {
assert.Equal(t, estimate.VRAMSize, estimate.TotalSize, "scenario %d: %v %+v", i, s, estimate)
assert.Equal(t, estimate.TotalSize, layerSums, "scenario %d: %v %+v", i, s, estimate)
}
})
}
}

View File

@@ -84,25 +84,21 @@ type LlamaServer interface {
// llmServer is an instance of a runner hosting a single model
type llmServer struct {
port int
cmd *exec.Cmd
done chan error // Channel to signal when the process exits
status *StatusWriter
options api.Options
numParallel int
modelPath string
port int
cmd *exec.Cmd
done chan error // Channel to signal when the process exits
status *StatusWriter
options api.Options
modelPath string
loadRequest LoadRequest // Parameters used to initialize the runner
loadRequest LoadRequest // Parameters used to initialize the runner
mem *ml.BackendMemory // Memory allocations for this model
// llamaModel is an instance of the cgo llama.cpp model definition
// nil if this server is running the new engine
llamaModel *llama.Model
llamaModelLock *sync.Mutex
// textProcessor handles text encoding/decoding for the model in the Ollama engine
// nil if this server is running the llama.cpp based engine
textProcessor model.TextProcessor
totalLayers uint64
loadStart time.Time // Record how long it took the model to load
loadProgress float32
@@ -113,15 +109,13 @@ type llmServer struct {
type llamaServer struct {
llmServer
ggml *ggml.GGML
gpus []ml.DeviceInfo // The set of GPUs covered by the memory estimate
estimate MemoryEstimate
ggml *ggml.GGML
}
type ollamaServer struct {
llmServer
mem *ml.BackendMemory
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
}
// LoadModel will load a model from disk. The model must be in the GGML format.
@@ -245,8 +239,6 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
loadRequest: loadRequest,
llamaModel: llamaModel,
llamaModelLock: &sync.Mutex{},
textProcessor: textProcessor,
numParallel: numParallel,
sem: semaphore.NewWeighted(int64(numParallel)),
totalLayers: f.KV().BlockCount() + 1,
loadStart: time.Now(),
@@ -281,7 +273,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}()
if textProcessor != nil {
return &ollamaServer{llmServer: s}, nil
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
}
@@ -463,169 +455,226 @@ type LoadResponse struct {
var ErrLoadRequiredFull = errors.New("unable to load full model on GPU")
func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
systemTotalMemory := systemInfo.TotalMemory
systemFreeMemory := systemInfo.FreeMemory
systemSwapFreeMemory := systemInfo.FreeSwap
slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU)
if len(gpus) == 0 || s.options.NumGPU == 0 {
if !verifyCPUFit(s.ggml, s.modelPath, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, systemInfo, s.numParallel) {
slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate)
return nil, fmt.Errorf("model requires more system memory than is currently available %w", ErrLoadRequiredFull)
gpus := append(make([]ml.DeviceInfo, 0, len(systemGPUs)), systemGPUs...)
// Synthesize memory allocation information based on our estimates
s.mem = &ml.BackendMemory{CPU: ml.DeviceMemory{
Name: "CPU",
Weights: make([]uint64, s.totalLayers),
Cache: make([]uint64, s.totalLayers),
}, GPUs: make([]ml.DeviceMemory, len(gpus))}
for i := range s.mem.GPUs {
s.mem.GPUs[i].Name = gpus[i].Name
s.mem.GPUs[i].DeviceID = gpus[i].DeviceID
s.mem.GPUs[i].Weights = make([]uint64, s.totalLayers)
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
}
kv, graphPartialOffload, graphFullOffload := s.ggml.GraphSize(uint64(s.options.NumCtx), uint64(s.loadRequest.BatchSize),
s.loadRequest.Parallel, s.loadRequest.KvCacheType, s.loadRequest.FlashAttention)
// Use the size of one layer as a buffer
layers := s.ggml.Tensors().GroupLayers()
if blk0, ok := layers["blk.0"]; ok {
for i := range gpus {
gpus[i].FreeMemory -= blk0.Size() + kv[0]
}
} else {
g := pickBestFullFitByLibrary(s.ggml, s.modelPath, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
if g == nil {
if !requireFull {
g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
} else {
slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate)
return nil, ErrLoadRequiredFull
slog.Warn("model missing blk.0 layer size")
}
// Assign all the layers to the CPU for now, they will get reassigned later
for i := range s.ggml.KV().BlockCount() {
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
s.mem.CPU.Weights[i] = blk.Size()
s.mem.CPU.Cache[i] += kv[i]
}
}
// We historically haven't included InputWeights in the model size
var outputWeights uint64
if layer, ok := layers["output_norm"]; ok {
outputWeights += layer.Size()
}
if layer, ok := layers["output"]; ok {
outputWeights += layer.Size()
} else if layer, ok := layers["token_embd"]; ok {
outputWeights += layer.Size()
}
s.mem.CPU.Weights[s.totalLayers-1] = outputWeights
// The vision projector is always loaded on the first GPU if available.
// This can't be assigned by us, so just subtract it from free space
projectorGPU := -1
var projectorWeights uint64
if len(gpus) > 0 {
for _, projector := range s.loadRequest.LoraPath {
projectorWeights += projectorMemoryRequirements(projector)
}
// llama.cpp uses the first discrete GPU if available, otherwise the first iGPU
firstIntegrated := -1
for i := range gpus {
if !gpus[i].Integrated {
projectorGPU = i
break
}
if firstIntegrated == -1 {
firstIntegrated = i
}
}
gpus = g
}
s.estimate = estimateGPULayers(gpus, s.ggml, []string{s.loadRequest.ProjectorPath}, s.options, s.numParallel)
if len(gpus) >= 1 {
switch {
case s.options.NumGPU == 0:
gpus = []ml.DeviceInfo{}
case gpus[0].Library == "Metal" && s.estimate.VRAMSize > systemInfo.TotalMemory:
// disable partial offloading when model is greater than total system memory as this
// can lead to locking up the system
s.options.NumGPU = 0
gpus = []ml.DeviceInfo{}
case gpus[0].Library != "Metal" && s.estimate.Layers == 0:
// Don't bother loading into the GPU if no layers can fit
gpus = []ml.DeviceInfo{}
case s.options.NumGPU < 0 && s.estimate.Layers > 0:
s.options.NumGPU = s.estimate.Layers
if projectorGPU == -1 {
projectorGPU = firstIntegrated
}
} else {
s.options.NumGPU = 0
gpus[projectorGPU].FreeMemory -= projectorWeights
}
// On linux and windows, over-allocating CPU memory will almost always result in an error
// Darwin has fully dynamic swap so has no direct concept of free swap space
if runtime.GOOS != "darwin" {
systemMemoryRequired := s.estimate.TotalSize - s.estimate.VRAMSize
available := systemInfo.FreeMemory + systemInfo.FreeSwap
if systemMemoryRequired > available {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.TotalMemory), "free", format.HumanBytes2(systemInfo.FreeMemory), "swap", format.HumanBytes2(systemInfo.FreeSwap))
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
var kvTotal uint64
for _, kvLayer := range kv {
kvTotal += kvLayer
}
if graphPartialOffload == 0 {
headsKV := s.ggml.KV().HeadCountKVMin()
if headsKV == 0 {
headsKV = 1
}
gqa := s.ggml.KV().HeadCountMax() / headsKV
graphPartialOffload = gqa * kvTotal / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
// On Metal there's no partial offload overhead
if len(gpus) > 0 && gpus[0].Library == "Metal" {
graphPartialOffload = graphFullOffload
}
// Create a layout based on the memory data that we've built. The compute graph
// for GPUs is iteratively assigned based on the number of GPUs that are required.
var gpuLayers ml.GPULayersList
for {
prevGPULayers := gpuLayers
var err error
gpuLayers, err = s.createLayout(systemInfo, gpus, s.mem, requireFull, 0)
if err != nil {
return nil, err
}
if len(gpuLayers) > len(prevGPULayers) {
for _, gl := range gpuLayers {
for i := range s.mem.GPUs {
if gl.DeviceID == s.mem.GPUs[i].DeviceID {
s.mem.GPUs[i].Graph = max(graphPartialOffload, graphFullOffload)
break
}
}
}
} else {
break
}
}
slog.Info("offload", "", s.estimate)
// This maintains the historical assignment of graph sizes, though it isn't fully accurate
graphSize := graphFullOffload
if gpuLayers.Sum() < int(s.totalLayers) {
graphSize = graphPartialOffload
}
s.gpus = gpus
s.loadRequest.GPULayers = createGPULayers(s.estimate, s.ggml, gpus, s.options.NumGPU)
// For all layers that we have assigned to GPUs, move them in the memory data so
// that it is reported accurately
for _, gl := range gpuLayers {
for i := range s.mem.GPUs {
if gl.DeviceID == s.mem.GPUs[i].DeviceID {
for _, l := range gl.Layers {
s.mem.GPUs[i].Weights[l] = s.mem.CPU.Weights[l]
s.mem.GPUs[i].Cache[l] = s.mem.CPU.Cache[l]
// Mmap is only supported on the llama engine
if s.textProcessor == nil {
s.loadRequest.UseMmap = true
s.mem.CPU.Weights[l] = 0
s.mem.CPU.Cache[l] = 0
}
// mmap has issues with partial offloading on metal
for _, g := range gpus {
if g.Library == "Metal" &&
uint64(s.options.NumGPU) > 0 &&
uint64(s.options.NumGPU) < s.ggml.KV().BlockCount()+1 {
s.options.UseMMap = new(bool)
*s.options.UseMMap = false
s.mem.GPUs[i].Graph = graphSize
break
}
}
}
// Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing
// For CPU loads we want the memory to be allocated, not FS cache
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) ||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
s.loadRequest.UseMmap = false
if projectorGPU > 0 && len(s.mem.GPUs[projectorGPU].Weights) > 0 {
s.mem.GPUs[projectorGPU].Weights[s.totalLayers-1] += projectorWeights
}
slog.Debug("memory", "estimate", s.mem)
s.mem.Log(slog.LevelInfo)
// The llama engine uses mmap by default
s.loadRequest.UseMmap = true
// mmap has issues with partial offloading on metal
for _, g := range gpus {
if g.Library == "Metal" &&
uint64(s.options.NumGPU) > 0 &&
uint64(s.options.NumGPU) < s.totalLayers {
s.options.UseMMap = new(bool)
*s.options.UseMMap = false
}
}
// Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing
// For CPU loads we want the memory to be allocated, not FS cache
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
s.loadRequest.UseMmap = false
}
if err := s.waitUntilRunnerLaunched(ctx); err != nil {
return nil, err
}
s.loadRequest.GPULayers = gpuLayers
resp, err := s.initModel(ctx, s.loadRequest, LoadOperationCommit)
if err != nil {
return nil, err
}
// On the Ollama engine, we can print out a summary of the memory allocations.
// We don't have this for the llama engine but it does something similar itself.
if s.textProcessor != nil {
resp.Memory.Log(slog.LevelInfo)
}
if !resp.Success {
slog.Warn("failed to allocate memory for model", "memory", resp.Memory)
return nil, errors.New("failed to allocate memory for model")
}
// The llama engine does its memory allocations together with model loading, so we
// need to wait until it is done to ensure that we have accurate memory data before
// loading the next model
if s.textProcessor == nil {
return uniqueDeviceIDs(s.loadRequest.GPULayers), s.WaitUntilRunning(ctx)
} else {
return uniqueDeviceIDs(s.loadRequest.GPULayers), nil
}
return uniqueDeviceIDs(s.loadRequest.GPULayers), s.WaitUntilRunning(ctx)
}
// createGPULayers maps from the tensor splits assigned by the memory estimates to explicit assignment
// of particular layers onto GPUs
func createGPULayers(estimate MemoryEstimate, ggml *ggml.GGML, gpus []ml.DeviceInfo, numGPU int) ml.GPULayersList {
if numGPU <= 0 || len(gpus) == 0 {
return nil
func projectorMemoryRequirements(filename string) (weights uint64) {
file, err := os.Open(filename)
if err != nil {
return 0
}
defer file.Close()
ggml, err := ggml.Decode(file, 1024)
if err != nil {
return 0
}
gpuLayers := make(ml.GPULayersList, len(gpus))
for i := range gpuLayers {
gpuLayers[i].DeviceID = gpus[i].DeviceID
for _, layer := range ggml.Tensors().GroupLayers() {
weights += layer.Size()
}
var sum float32
splits := make([]float32, len(estimate.TensorSplit))
// cumulative sum of all splits
for i := range splits {
sum += float32(estimate.TensorSplit[i])
splits[i] = sum
}
if sum <= 0 {
return nil
}
// normalize splits
for i := range splits {
splits[i] /= sum
}
blocks := int(ggml.KV().BlockCount())
gpuRangeStart := max(0, blocks-numGPU)
gpuRangeStop := min(gpuRangeStart+numGPU, blocks+1)
for i := range blocks + 1 {
if i < gpuRangeStart || i >= gpuRangeStop {
continue
}
index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f })
if index < 0 || index >= len(gpus) {
continue
}
gpuLayers[index].Layers = append(gpuLayers[index].Layers, i)
}
return gpuLayers
return weights
}
// Load finds the optimal layout of layers to offload on GPUs based on no initial information about the size of the model
@@ -652,23 +701,6 @@ func (s *ollamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus
slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU)
systemTotalMemory := systemInfo.TotalMemory
systemFreeMemory := systemInfo.FreeMemory
systemSwapFreeMemory := systemInfo.FreeSwap
slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
for _, gpu := range gpus {
available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory()
if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory() {
available = 0
}
slog.Info("gpu memory", "id", gpu.ID, "library", gpu.Library,
"available", format.HumanBytes2(available),
"free", format.HumanBytes2(gpu.FreeMemory),
"minimum", format.HumanBytes2(gpu.MinimumMemory()),
"overhead", format.HumanBytes2(envconfig.GpuOverhead()))
}
pastAllocations := make(map[uint64]struct{})
var backoff float32
@@ -834,25 +866,22 @@ func uniqueDeviceIDs(gpuLayers ml.GPULayersList) []ml.DeviceID {
// - Calculating how much space each GPU has available for layers, based on free memory and space occupied by the graph
// - Assigning layers
// - Ensuring that we don't exceed limits, such as requirements about partial offloading or system memory
func (s *ollamaServer) createLayout(systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, error) {
func (s *llmServer) createLayout(systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, error) {
if memory == nil {
memory = &ml.BackendMemory{CPU: ml.DeviceMemory{
Weights: make([]uint64, s.totalLayers),
Cache: make([]uint64, s.totalLayers),
}}
}
gpuLayers, layers, err := s.buildLayout(systemGPUs, memory, requireFull, backoff)
if err != nil {
return nil, err
}
err = s.verifyLayout(systemInfo, memory, requireFull, gpuLayers, layers)
gpuLayers, layers := s.buildLayout(systemGPUs, memory, requireFull, backoff)
err := s.verifyLayout(systemInfo, memory, requireFull, gpuLayers, layers)
if err != nil {
return nil, err
}
return gpuLayers, nil
}
func (s *ollamaServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, []uint64, error) {
func (s *llmServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, []uint64) {
gpus := append(make([]ml.DeviceInfo, 0, len(systemGPUs)), systemGPUs...)
sort.Sort(sort.Reverse(ml.ByFreeMemory(gpus)))
@@ -910,11 +939,11 @@ func (s *ollamaServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.Backen
gpuLayers = libraryGpuLayers
}
}
return gpuLayers, layers, nil
return gpuLayers, layers
}
// verifyLayout ensures that we don't exceed limits, such as requirements about partial offloading or system memory
func (s *ollamaServer) verifyLayout(systemInfo ml.SystemInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
func (s *llmServer) verifyLayout(systemInfo ml.SystemInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
// These sizes will only increase as we go through additional iterations and get additional information.
cpuSize := memory.InputWeights + memory.CPU.Graph
var vramSize uint64
@@ -942,11 +971,13 @@ nextLayer:
if requireFull {
if gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
slog.Info("model requires more memory than is currently available, evicting a model to make space", "loaded layers", gpuLayers.Sum())
return ErrLoadRequiredFull
}
if cpuSize > systemInfo.FreeMemory {
return ErrLoadRequiredFull
slog.Info("model requires more system memory than is currently available, evicting a model to make space", "required", cpuSize, "free", systemInfo.FreeMemory)
return fmt.Errorf("model requires more system memory than is currently available %w", ErrLoadRequiredFull)
}
}
@@ -976,6 +1007,13 @@ nextLayer:
// assignLayers packs the maximum number of layers onto the smallest set of GPUs and comes up with a layer assignment
func assignLayers(layers []uint64, gpus []ml.DeviceInfo, requireFull bool, requestedLayers int, lastUsedGPU int) (gpuLayers ml.GPULayersList) {
// If the user is manually overriding parameters, treat all GPUs equally so they split according to VRAM
if requestedLayers >= 0 || envconfig.SchedSpread() {
for i := range gpus {
gpus[i].Integrated = false
}
}
// If we can't fit everything then prefer offloading layers other than the output layer
for range 2 {
// requestedLayers may be -1 if nothing was requested
@@ -1008,33 +1046,38 @@ func assignLayers(layers []uint64, gpus []ml.DeviceInfo, requireFull bool, reque
// findBestFit binary searches to find the smallest capacity factor that can fit
// the max number of layers. The capacity factor is multiplied by the free space on
// each GPU and a small one will force even balancing.
// each GPU and a small one will force even balancing. Higher performance GPUs are
// used first.
func findBestFit(layers []uint64, gpus []ml.DeviceInfo, requestedLayers int, forceRequest bool) (gpuLayers ml.GPULayersList) {
var high float32 = 1
var low float32 = 0
for _, gl := range ml.ByPerformance(gpus) {
var high float32 = 1
var low float32 = 0
// If we need to fulfill the requested number of layers, pretend we have almost infinite VRAM
if requestedLayers >= 0 && forceRequest {
high = 1000
}
bestAssignments := greedyFit(layers, gpus, high, requestedLayers)
maxNumGPU := bestAssignments.Sum()
if maxNumGPU == 0 {
return bestAssignments
}
for high-low > 1e-6 {
mid := (low + high) / 2
assignments := greedyFit(layers, gpus, mid, requestedLayers)
if assignments.Sum() == maxNumGPU {
high = mid
bestAssignments = assignments
} else {
low = mid
// If we need to fulfill the requested number of layers, pretend we have almost infinite VRAM
if requestedLayers >= 0 && forceRequest {
high = 1000
}
bestAssignments := greedyFit(layers, gl, high, requestedLayers)
maxNumGPU := bestAssignments.Sum()
for high-low > 1e-6 {
mid := (low + high) / 2
assignments := greedyFit(layers, gl, mid, requestedLayers)
if assignments.Sum() == maxNumGPU {
high = mid
bestAssignments = assignments
} else {
low = mid
}
}
layers = layers[:len(layers)-bestAssignments.Sum()]
requestedLayers -= bestAssignments.Sum()
gpuLayers = append(bestAssignments, gpuLayers...)
}
return bestAssignments
return gpuLayers
}
// greedyFit assigns layers incrementally to GPUs, spilling over as each runs out of free space
@@ -1362,6 +1405,12 @@ type CompletionRequest struct {
Grammar string // set before sending the request to the subprocess
Shift bool
Truncate bool
// Logprobs specifies whether to include log probabilities in the response
Logprobs bool
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
TopLogprobs int
}
// DoneReason represents the reason why a completion response is done
@@ -1387,6 +1436,18 @@ func (d DoneReason) String() string {
}
}
// TokenLogprob represents log probability information for a single token alternative.
type TokenLogprob struct {
Token string `json:"token"`
Logprob float64 `json:"logprob"`
}
// Logprob contains log probability information for a generated token.
type Logprob struct {
TokenLogprob
TopLogprobs []TokenLogprob `json:"top_logprobs,omitempty"`
}
type CompletionResponse struct {
Content string `json:"content"`
DoneReason DoneReason `json:"done_reason"`
@@ -1395,6 +1456,9 @@ type CompletionResponse struct {
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"`
// Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"`
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
@@ -1530,7 +1594,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if c.Content != "" {
fn(CompletionResponse{
Content: c.Content,
Content: c.Content,
Logprobs: c.Logprobs,
})
}
@@ -1623,68 +1688,59 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
return e.Embedding, nil
}
type TokenizeRequest struct {
Content string `json:"content"`
}
type TokenizeResponse struct {
Tokens []int `json:"tokens"`
}
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
s.llamaModelLock.Lock()
defer s.llamaModelLock.Unlock()
if s.llamaModel != nil {
return s.llamaModel.Tokenize(content, false, true)
if s.llamaModel == nil {
return nil, fmt.Errorf("no tokenizer configured")
}
if s.textProcessor != nil {
tokens, err := s.textProcessor.Encode(content, false)
if err != nil {
return nil, err
}
toks := make([]int, len(tokens))
for i, t := range tokens {
toks[i] = int(t)
}
return toks, nil
return s.llamaModel.Tokenize(content, false, true)
}
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
tokens, err := s.textProcessor.Encode(content, false)
if err != nil {
return nil, err
}
// not reached
return nil, fmt.Errorf("no tokenizer configured")
toks := make([]int, len(tokens))
for i, t := range tokens {
toks[i] = int(t)
}
return toks, nil
}
type DetokenizeRequest struct {
Tokens []int `json:"tokens"`
}
type DetokenizeResponse struct {
Content string `json:"content"`
}
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
func (s *llamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
s.llamaModelLock.Lock()
defer s.llamaModelLock.Unlock()
if s.llamaModel != nil {
var resp string
for _, token := range tokens {
resp += s.llamaModel.TokenToPiece(token)
}
return resp, nil
if s.llamaModel == nil {
return "", fmt.Errorf("no tokenizer configured")
}
if s.textProcessor != nil {
toks := make([]int32, len(tokens))
for i, t := range tokens {
toks[i] = int32(t)
}
content, err := s.textProcessor.Decode(toks)
if err != nil {
return "", err
}
return content, nil
var resp string
for _, token := range tokens {
resp += s.llamaModel.TokenToPiece(token)
}
// not reached
return "", fmt.Errorf("no tokenizer configured")
return resp, nil
}
func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
toks := make([]int32, len(tokens))
for i, t := range tokens {
toks[i] = int32(t)
}
content, err := s.textProcessor.Decode(toks)
if err != nil {
return "", err
}
return content, nil
}
func (s *llmServer) Close() error {
@@ -1712,31 +1768,12 @@ func (s *llmServer) Close() error {
return nil
}
func (s *llamaServer) VRAMSize() uint64 {
return s.estimate.VRAMSize
}
func (s *llamaServer) TotalSize() uint64 {
return s.estimate.TotalSize
}
func (s *llamaServer) VRAMByGPU(id ml.DeviceID) uint64 {
for i, gpu := range s.gpus {
if gpu.DeviceID == id {
if i < len(s.estimate.GPUSizes) {
return s.estimate.GPUSizes[i]
}
}
}
return 0
}
func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
slog.Debug("llamarunner free vram reporting not supported")
return nil
}
func (s *ollamaServer) VRAMSize() uint64 {
func (s *llmServer) VRAMSize() uint64 {
if s.mem == nil {
return 0
}
@@ -1764,7 +1801,7 @@ func (s *ollamaServer) VRAMSize() uint64 {
return mem
}
func (s *ollamaServer) TotalSize() uint64 {
func (s *llmServer) TotalSize() uint64 {
if s.mem == nil {
return 0
}
@@ -1778,7 +1815,7 @@ func (s *ollamaServer) TotalSize() uint64 {
return mem
}
func (s *ollamaServer) VRAMByGPU(id ml.DeviceID) uint64 {
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
if s.mem == nil {
return 0
}

View File

@@ -14,16 +14,11 @@ import (
)
func TestLLMServerFitGPU(t *testing.T) {
type gpu struct {
id ml.DeviceID
free int
}
minMemory := 457 * format.MebiByte
tests := []struct {
name string
gpus []gpu
gpus []ml.DeviceInfo
layers []int
numGPU int
requireFull bool
@@ -38,91 +33,91 @@ func TestLLMServerFitGPU(t *testing.T) {
},
{
name: "Full single GPU",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2}}},
},
{
name: "Partial single GPU",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
},
{
name: "Single GPU with numGPU 1",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: 1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}},
},
{
name: "Single GPU with numGPU 0",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: 0,
expected: ml.GPULayersList{},
},
{
name: "Single GPU with numGPU 999",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: 999,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2, 3}}},
},
{
name: "Multi GPU fits on one",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1, 2}}},
},
{
name: "Multi GPU split",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
},
{
name: "Multi GPU partial",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}},
},
{
name: "Multi GPU numGPU 1",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: 1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}},
},
{
name: "Multi GPU numGPU 2",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: 2,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}},
},
{
name: "Multi GPU numGPU 999",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte},
numGPU: 999,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}},
},
{
name: "Multi GPU different libraries",
gpus: []gpu{{id: ml.DeviceID{Library: "CUDA", ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{Library: "ROCm", ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{Library: "CUDA", ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{Library: "ROCm", ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{128 * format.MebiByte, 128 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1", Library: "ROCm"}, Layers: []int{0, 1}}},
},
{
name: "requireFull",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: -1,
requireFull: true,
@@ -130,12 +125,54 @@ func TestLLMServerFitGPU(t *testing.T) {
},
{
name: "requireFull numGPU",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}},
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256 * format.MebiByte)}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: 4,
requireFull: true,
expectedErr: ErrLoadRequiredFull,
},
{
name: "iGPU",
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2}}},
},
{
name: "iGPU + dGPU",
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
},
{
name: "iGPU + dGPU fits on one",
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1}}},
},
{
name: "iGPU + dGPU partial",
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}},
},
{
name: "iGPU + dGPU numGPU 1",
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: 1,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}},
},
{
name: "iGPU + dGPU numGPU 999",
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: 999,
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1, 2, 3}}},
},
}
for _, tt := range tests {
@@ -145,12 +182,6 @@ func TestLLMServerFitGPU(t *testing.T) {
systemInfo.FreeMemory = 512 * format.MebiByte
systemInfo.FreeSwap = 256 * format.MebiByte
gpus := make([]ml.DeviceInfo, len(tt.gpus))
for i := range tt.gpus {
gpus[i].DeviceID = tt.gpus[i].id
gpus[i].FreeMemory = uint64(tt.gpus[i].free)
}
s := &ollamaServer{
llmServer: llmServer{
totalLayers: uint64(len(tt.layers)),
@@ -165,19 +196,19 @@ func TestLLMServerFitGPU(t *testing.T) {
s.mem = &ml.BackendMemory{CPU: ml.DeviceMemory{
Weights: make([]uint64, s.totalLayers),
Cache: make([]uint64, s.totalLayers),
}, GPUs: make([]ml.DeviceMemory, len(gpus))}
}, GPUs: make([]ml.DeviceMemory, len(tt.gpus))}
for i := range tt.layers {
s.mem.CPU.Weights[i] = uint64(tt.layers[i])
}
for i := range s.mem.GPUs {
s.mem.GPUs[i].DeviceID = gpus[i].DeviceID
s.mem.GPUs[i].DeviceID = tt.gpus[i].DeviceID
s.mem.GPUs[i].Weights = make([]uint64, s.totalLayers)
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
}
gpuLayers, err := s.createLayout(systemInfo, gpus, s.mem, tt.requireFull, 0)
gpuLayers, err := s.createLayout(systemInfo, tt.gpus, s.mem, tt.requireFull, 0)
if err != tt.expectedErr {
t.Fatalf("fitGPU returned error: %v", err)
}

View File

@@ -1,16 +0,0 @@
{
"env": {
"browser": true,
"es6": true,
"node": true
},
"extends": [
"eslint:recommended",
"plugin:@typescript-eslint/eslint-recommended",
"plugin:@typescript-eslint/recommended",
"plugin:import/recommended",
"plugin:import/electron",
"plugin:import/typescript"
],
"parser": "@typescript-eslint/parser"
}

92
macapp/.gitignore vendored
View File

@@ -1,92 +0,0 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
lerna-debug.log*
# Diagnostic reports (https://nodejs.org/api/report.html)
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
# Runtime data
pids
*.pid
*.seed
*.pid.lock
.DS_Store
# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov
# Coverage directory used by tools like istanbul
coverage
*.lcov
# nyc test coverage
.nyc_output
# node-waf configuration
.lock-wscript
# Compiled binary addons (https://nodejs.org/api/addons.html)
build/Release
# Dependency directories
node_modules/
jspm_packages/
# TypeScript v1 declaration files
typings/
# TypeScript cache
*.tsbuildinfo
# Optional npm cache directory
.npm
# Optional eslint cache
.eslintcache
# Optional REPL history
.node_repl_history
# Output of 'npm pack'
*.tgz
# Yarn Integrity file
.yarn-integrity
# dotenv environment variables file
.env
.env.test
# parcel-bundler cache (https://parceljs.org/)
.cache
# next.js build output
.next
# nuxt.js build output
.nuxt
# vuepress build output
.vuepress/dist
# Serverless directories
.serverless/
# FuseBox cache
.fusebox/
# DynamoDB Local files
.dynamodb/
# Webpack
.webpack/
# Vite
.vite/
# Electron-Forge
out/

View File

@@ -1,21 +0,0 @@
# Desktop
This app builds upon Ollama to provide a desktop experience for running models.
## Developing
First, build the `ollama` binary:
```shell
cd ..
go build .
```
Then run the desktop app with `npm start`:
```shell
cd macapp
npm install
npm start
```

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 402 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 741 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 440 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 763 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 447 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 891 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 443 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 844 B

View File

@@ -1,79 +0,0 @@
import type { ForgeConfig } from '@electron-forge/shared-types'
import { MakerSquirrel } from '@electron-forge/maker-squirrel'
import { MakerZIP } from '@electron-forge/maker-zip'
import { PublisherGithub } from '@electron-forge/publisher-github'
import { AutoUnpackNativesPlugin } from '@electron-forge/plugin-auto-unpack-natives'
import { WebpackPlugin } from '@electron-forge/plugin-webpack'
import * as path from 'path'
import * as fs from 'fs'
import { mainConfig } from './webpack.main.config'
import { rendererConfig } from './webpack.renderer.config'
const packageJson = JSON.parse(fs.readFileSync(path.resolve(__dirname, './package.json'), 'utf8'))
const config: ForgeConfig = {
packagerConfig: {
appVersion: process.env.VERSION || packageJson.version,
asar: true,
icon: './assets/icon.icns',
extraResource: [
path.join(__dirname, '../dist/darwin/ollama'),
...fs.readdirSync(path.join(__dirname, '../dist/darwin-amd64/lib/ollama')).map(f => path.join(__dirname, '../dist/darwin-amd64/lib/ollama', f)),
path.join(__dirname, './assets/iconTemplate.png'),
path.join(__dirname, './assets/iconTemplate@2x.png'),
path.join(__dirname, './assets/iconUpdateTemplate.png'),
path.join(__dirname, './assets/iconUpdateTemplate@2x.png'),
path.join(__dirname, './assets/iconDarkTemplate.png'),
path.join(__dirname, './assets/iconDarkTemplate@2x.png'),
path.join(__dirname, './assets/iconDarkUpdateTemplate.png'),
path.join(__dirname, './assets/iconDarkUpdateTemplate@2x.png'),
],
...(process.env.SIGN
? {
osxSign: {
identity: process.env.APPLE_IDENTITY,
},
osxNotarize: {
tool: 'notarytool',
appleId: process.env.APPLE_ID || '',
appleIdPassword: process.env.APPLE_PASSWORD || '',
teamId: process.env.APPLE_TEAM_ID || '',
},
}
: {}),
osxUniversal: {
x64ArchFiles: '*',
},
},
rebuildConfig: {},
makers: [new MakerSquirrel({}), new MakerZIP({}, ['darwin'])],
hooks: {
readPackageJson: async (_, packageJson) => {
return { ...packageJson, version: process.env.VERSION || packageJson.version }
},
},
plugins: [
new AutoUnpackNativesPlugin({}),
new WebpackPlugin({
mainConfig,
devContentSecurityPolicy: `default-src * 'unsafe-eval' 'unsafe-inline'; img-src data: 'self'`,
renderer: {
config: rendererConfig,
nodeIntegration: true,
entryPoints: [
{
html: './src/index.html',
js: './src/renderer.tsx',
name: 'main_window',
preload: {
js: './src/preload.ts',
},
},
],
},
}),
],
}
export default config

16604
macapp/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,80 +0,0 @@
{
"name": "ollama",
"productName": "Ollama",
"version": "0.0.0",
"description": "ollama",
"main": ".webpack/main",
"scripts": {
"start": "electron-forge start",
"package": "electron-forge package --arch universal",
"package:sign": "SIGN=1 electron-forge package --arch universal",
"make": "electron-forge make --arch universal",
"make:sign": "SIGN=1 electron-forge make --arch universal",
"publish": "SIGN=1 electron-forge publish",
"lint": "eslint --ext .ts,.tsx ."
},
"keywords": [],
"author": {
"name": "Jeffrey Morgan",
"email": "jmorganca@gmail.com"
},
"license": "MIT",
"devDependencies": {
"@babel/core": "^7.22.5",
"@babel/preset-react": "^7.22.5",
"@electron-forge/cli": "^6.2.1",
"@electron-forge/maker-deb": "^6.2.1",
"@electron-forge/maker-rpm": "^6.2.1",
"@electron-forge/maker-squirrel": "^6.2.1",
"@electron-forge/maker-zip": "^6.2.1",
"@electron-forge/plugin-auto-unpack-natives": "^6.2.1",
"@electron-forge/plugin-webpack": "^6.2.1",
"@electron-forge/publisher-github": "^6.2.1",
"@electron/universal": "^1.4.1",
"@svgr/webpack": "^8.0.1",
"@types/chmodr": "^1.0.0",
"@types/node": "^20.4.0",
"@types/react": "^18.2.14",
"@types/react-dom": "^18.2.6",
"@types/uuid": "^9.0.2",
"@typescript-eslint/eslint-plugin": "^5.60.0",
"@typescript-eslint/parser": "^5.60.0",
"@vercel/webpack-asset-relocator-loader": "^1.7.3",
"babel-loader": "^9.1.2",
"chmodr": "^1.2.0",
"copy-webpack-plugin": "^11.0.0",
"css-loader": "^6.8.1",
"electron": "25.9.2",
"eslint": "^8.43.0",
"eslint-plugin-import": "^2.27.5",
"fork-ts-checker-webpack-plugin": "^7.3.0",
"node-loader": "^2.0.0",
"postcss": "^8.4.24",
"postcss-import": "^15.1.0",
"postcss-loader": "^7.3.3",
"postcss-preset-env": "^8.5.1",
"style-loader": "^3.3.3",
"svg-inline-loader": "^0.8.2",
"tailwindcss": "^3.3.2",
"ts-loader": "^9.4.3",
"ts-node": "^10.9.1",
"typescript": "~4.5.4",
"url-loader": "^4.1.1",
"webpack": "^5.88.0",
"webpack-cli": "^5.1.4",
"webpack-dev-server": "^4.15.1"
},
"dependencies": {
"@electron/remote": "^2.0.10",
"@heroicons/react": "^2.0.18",
"@segment/analytics-node": "^1.0.0",
"copy-to-clipboard": "^3.3.3",
"electron-squirrel-startup": "^1.0.0",
"electron-store": "^8.1.0",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"uuid": "^9.0.0",
"winston": "^3.10.0",
"winston-daily-rotate-file": "^4.7.1"
}
}

View File

@@ -1,7 +0,0 @@
module.exports = {
plugins: {
'postcss-import': {},
tailwindcss: {},
autoprefixer: {},
},
}

View File

@@ -1,34 +0,0 @@
@tailwind base;
@tailwind components;
@tailwind utilities;
html,
body {
background: transparent;
}
.drag {
-webkit-app-region: drag;
}
.no-drag {
-webkit-app-region: no-drag;
}
.blink {
-webkit-animation: 1s blink step-end infinite;
-moz-animation: 1s blink step-end infinite;
-ms-animation: 1s blink step-end infinite;
-o-animation: 1s blink step-end infinite;
animation: 1s blink step-end infinite;
}
@keyframes blink {
from,
to {
color: transparent;
}
50% {
color: black;
}
}

View File

@@ -1,122 +0,0 @@
import { useState } from 'react'
import copy from 'copy-to-clipboard'
import { CheckIcon, DocumentDuplicateIcon } from '@heroicons/react/24/outline'
import Store from 'electron-store'
import { getCurrentWindow, app } from '@electron/remote'
import { install } from './install'
import OllamaIcon from './ollama.svg'
const store = new Store()
enum Step {
WELCOME = 0,
CLI,
FINISH,
}
export default function () {
const [step, setStep] = useState<Step>(Step.WELCOME)
const [commandCopied, setCommandCopied] = useState<boolean>(false)
const command = 'ollama run llama3.2'
return (
<div className='drag'>
<div className='mx-auto flex min-h-screen w-full flex-col justify-between bg-white px-4 pt-16'>
{step === Step.WELCOME && (
<>
<div className='mx-auto text-center'>
<h1 className='mb-6 mt-4 text-2xl tracking-tight text-gray-900'>Welcome to Ollama</h1>
<p className='mx-auto w-[65%] text-sm text-gray-400'>
Let's get you up and running with your own large language models.
</p>
<button
onClick={() => setStep(Step.CLI)}
className='no-drag rounded-dm mx-auto my-8 w-[40%] rounded-md bg-black px-4 py-2 text-sm text-white hover:brightness-110'
>
Next
</button>
</div>
<div className='mx-auto'>
<OllamaIcon />
</div>
</>
)}
{step === Step.CLI && (
<>
<div className='mx-auto flex flex-col space-y-28 text-center'>
<h1 className='mt-4 text-2xl tracking-tight text-gray-900'>Install the command line</h1>
<pre className='mx-auto text-4xl text-gray-400'>&gt; ollama</pre>
<div className='mx-auto'>
<button
onClick={async () => {
try {
await install()
setStep(Step.FINISH)
} catch (e) {
console.error('could not install: ', e)
} finally {
getCurrentWindow().show()
getCurrentWindow().focus()
}
}}
className='no-drag rounded-dm mx-auto w-[60%] rounded-md bg-black px-4 py-2 text-sm text-white hover:brightness-110'
>
Install
</button>
<p className='mx-auto my-4 w-[70%] text-xs text-gray-400'>
You will be prompted for administrator access
</p>
</div>
</div>
</>
)}
{step === Step.FINISH && (
<>
<div className='mx-auto flex flex-col space-y-20 text-center'>
<h1 className='mt-4 text-2xl tracking-tight text-gray-900'>Run your first model</h1>
<div className='flex flex-col'>
<div className='group relative flex items-center'>
<pre className='language-none text-2xs w-full rounded-md bg-gray-100 px-4 py-3 text-start leading-normal'>
{command}
</pre>
<button
className={`no-drag absolute right-[5px] px-2 py-2 ${
commandCopied
? 'text-gray-900 opacity-100 hover:cursor-auto'
: 'text-gray-200 opacity-50 hover:cursor-pointer'
} hover:font-bold hover:text-gray-900 group-hover:opacity-100`}
onClick={() => {
copy(command)
setCommandCopied(true)
setTimeout(() => setCommandCopied(false), 3000)
}}
>
{commandCopied ? (
<CheckIcon className='h-4 w-4 font-bold text-gray-500' />
) : (
<DocumentDuplicateIcon className='h-4 w-4 text-gray-500' />
)}
</button>
</div>
<p className='mx-auto my-4 w-[70%] text-xs text-gray-400'>
Run this command in your favorite terminal.
</p>
</div>
<button
onClick={() => {
store.set('first-time-run', true)
window.close()
}}
className='no-drag rounded-dm mx-auto w-[60%] rounded-md bg-black px-4 py-2 text-sm text-white hover:brightness-110'
>
Finish
</button>
</div>
</>
)}
</div>
</div>
)
}

View File

@@ -1,4 +0,0 @@
declare module '*.svg' {
const content: string
export default content
}

View File

@@ -1,9 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8" />
</head>
<body>
<div id="app"></div>
</body>
</html>

View File

@@ -1,302 +0,0 @@
import { spawn, ChildProcess } from 'child_process'
import { app, autoUpdater, dialog, Tray, Menu, BrowserWindow, MenuItemConstructorOptions, nativeTheme } from 'electron'
import Store from 'electron-store'
import winston from 'winston'
import 'winston-daily-rotate-file'
import * as path from 'path'
import { v4 as uuidv4 } from 'uuid'
import { installed } from './install'
require('@electron/remote/main').initialize()
if (require('electron-squirrel-startup')) {
app.quit()
}
const store = new Store()
let welcomeWindow: BrowserWindow | null = null
declare const MAIN_WINDOW_WEBPACK_ENTRY: string
const logger = winston.createLogger({
transports: [
new winston.transports.Console(),
new winston.transports.File({
filename: path.join(app.getPath('home'), '.ollama', 'logs', 'server.log'),
maxsize: 1024 * 1024 * 20,
maxFiles: 5,
}),
],
format: winston.format.printf(info => info.message),
})
app.on('ready', () => {
const gotTheLock = app.requestSingleInstanceLock()
if (!gotTheLock) {
app.exit(0)
return
}
app.on('second-instance', () => {
if (app.hasSingleInstanceLock()) {
app.releaseSingleInstanceLock()
}
if (proc) {
proc.off('exit', restart)
proc.kill()
}
app.exit(0)
})
app.focus({ steal: true })
init()
})
function firstRunWindow() {
// Create the browser window.
welcomeWindow = new BrowserWindow({
width: 400,
height: 500,
frame: false,
fullscreenable: false,
resizable: false,
movable: true,
show: false,
webPreferences: {
nodeIntegration: true,
contextIsolation: false,
},
})
require('@electron/remote/main').enable(welcomeWindow.webContents)
welcomeWindow.loadURL(MAIN_WINDOW_WEBPACK_ENTRY)
welcomeWindow.on('ready-to-show', () => welcomeWindow.show())
welcomeWindow.on('closed', () => {
if (process.platform === 'darwin') {
app.dock.hide()
}
})
}
let tray: Tray | null = null
let updateAvailable = false
const assetPath = app.isPackaged ? process.resourcesPath : path.join(__dirname, '..', '..', 'assets')
function trayIconPath() {
return nativeTheme.shouldUseDarkColors
? updateAvailable
? path.join(assetPath, 'iconDarkUpdateTemplate.png')
: path.join(assetPath, 'iconDarkTemplate.png')
: updateAvailable
? path.join(assetPath, 'iconUpdateTemplate.png')
: path.join(assetPath, 'iconTemplate.png')
}
function updateTrayIcon() {
if (tray) {
tray.setImage(trayIconPath())
}
}
function updateTray() {
const updateItems: MenuItemConstructorOptions[] = [
{ label: 'An update is available', enabled: false },
{
label: 'Restart to update',
click: () => autoUpdater.quitAndInstall(),
},
{ type: 'separator' },
]
const menu = Menu.buildFromTemplate([
...(updateAvailable ? updateItems : []),
{ role: 'quit', label: 'Quit Ollama', accelerator: 'Command+Q' },
])
if (!tray) {
tray = new Tray(trayIconPath())
}
tray.setToolTip(updateAvailable ? 'An update is available' : 'Ollama')
tray.setContextMenu(menu)
tray.setImage(trayIconPath())
nativeTheme.off('updated', updateTrayIcon)
nativeTheme.on('updated', updateTrayIcon)
}
let proc: ChildProcess = null
function server() {
const binary = app.isPackaged
? path.join(process.resourcesPath, 'ollama')
: path.resolve(process.cwd(), '..', 'ollama')
proc = spawn(binary, ['serve'])
proc.stdout.on('data', data => {
logger.info(data.toString().trim())
})
proc.stderr.on('data', data => {
logger.error(data.toString().trim())
})
proc.on('exit', restart)
}
function restart() {
setTimeout(server, 1000)
}
app.on('before-quit', () => {
if (proc) {
proc.off('exit', restart)
proc.kill('SIGINT') // send SIGINT signal to the server, which also stops any loaded llms
}
})
const updateURL = `https://ollama.com/api/update?os=${process.platform}&arch=${
process.arch
}&version=${app.getVersion()}&id=${id()}`
let latest = ''
async function isNewReleaseAvailable() {
try {
const response = await fetch(updateURL)
if (!response.ok) {
return false
}
if (response.status === 204) {
return false
}
const data = await response.json()
const url = data?.url
if (!url) {
return false
}
if (latest === url) {
return false
}
latest = url
return true
} catch (error) {
logger.error(`update check failed - ${error}`)
return false
}
}
async function checkUpdate() {
const available = await isNewReleaseAvailable()
if (available) {
logger.info('checking for update')
autoUpdater.checkForUpdates()
}
}
function init() {
if (app.isPackaged) {
checkUpdate()
setInterval(() => {
checkUpdate()
}, 60 * 60 * 1000)
}
updateTray()
if (process.platform === 'darwin') {
if (app.isPackaged) {
if (!app.isInApplicationsFolder()) {
const chosen = dialog.showMessageBoxSync({
type: 'question',
buttons: ['Move to Applications', 'Do Not Move'],
message: 'Ollama works best when run from the Applications directory.',
defaultId: 0,
cancelId: 1,
})
if (chosen === 0) {
try {
app.moveToApplicationsFolder({
conflictHandler: conflictType => {
if (conflictType === 'existsAndRunning') {
dialog.showMessageBoxSync({
type: 'info',
message: 'Cannot move to Applications directory',
detail:
'Another version of Ollama is currently running from your Applications directory. Close it first and try again.',
})
}
return true
},
})
return
} catch (e) {
logger.error(`[Move to Applications] Failed to move to applications folder - ${e.message}}`)
}
}
}
}
}
server()
if (store.get('first-time-run') && installed()) {
if (process.platform === 'darwin') {
app.dock.hide()
}
app.setLoginItemSettings({ openAtLogin: app.getLoginItemSettings().openAtLogin })
return
}
// This is the first run or the CLI is no longer installed
app.setLoginItemSettings({ openAtLogin: true })
firstRunWindow()
}
// Quit when all windows are closed, except on macOS. There, it's common
// for applications and their menu bar to stay active until the user quits
// explicitly with Cmd + Q.
app.on('window-all-closed', () => {
if (process.platform !== 'darwin') {
app.quit()
}
})
function id(): string {
const id = store.get('id') as string
if (id) {
return id
}
const uuid = uuidv4()
store.set('id', uuid)
return uuid
}
autoUpdater.setFeedURL({ url: updateURL })
autoUpdater.on('error', e => {
logger.error(`update check failed - ${e.message}`)
console.error(`update check failed - ${e.message}`)
})
autoUpdater.on('update-downloaded', () => {
updateAvailable = true
updateTray()
})

View File

@@ -1,21 +0,0 @@
import * as fs from 'fs'
import { exec as cbExec } from 'child_process'
import * as path from 'path'
import { promisify } from 'util'
const app = process && process.type === 'renderer' ? require('@electron/remote').app : require('electron').app
const ollama = app.isPackaged ? path.join(process.resourcesPath, 'ollama') : path.resolve(process.cwd(), '..', 'ollama')
const exec = promisify(cbExec)
const symlinkPath = '/usr/local/bin/ollama'
export function installed() {
return fs.existsSync(symlinkPath) && fs.readlinkSync(symlinkPath) === ollama
}
export async function install() {
const command = `do shell script "mkdir -p ${path.dirname(
symlinkPath
)} && ln -F -s \\"${ollama}\\" \\"${symlinkPath}\\"" with administrator privileges`
await exec(`osascript -e '${command}'`)
}

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 17 KiB

View File

View File

@@ -1,7 +0,0 @@
import App from './app'
import './app.css'
import { createRoot } from 'react-dom/client'
const container = document.getElementById('app')
const root = createRoot(container)
root.render(<App />)

View File

@@ -1,6 +0,0 @@
/** @type {import('tailwindcss').Config} */
module.exports = {
content: ['./src/**/*.{js,ts,jsx,tsx,mdx}'],
theme: {},
plugins: [],
}

View File

@@ -1,20 +0,0 @@
{
"compilerOptions": {
"target": "ES6",
"allowJs": true,
"module": "commonjs",
"skipLibCheck": true,
"esModuleInterop": true,
"noImplicitAny": true,
"sourceMap": true,
"baseUrl": ".",
"outDir": "dist",
"moduleResolution": "node",
"resolveJsonModule": true,
"paths": {
"*": ["node_modules/*"]
},
"jsx": "react-jsx"
},
"include": ["src/**/*"]
}

View File

@@ -1,20 +0,0 @@
import type { Configuration } from 'webpack'
import { rules } from './webpack.rules'
import { plugins } from './webpack.plugins'
export const mainConfig: Configuration = {
/**
* This is the main entry point for your application, it's the first file
* that runs in the main process.
*/
entry: './src/index.ts',
// Put your normal webpack config below here
module: {
rules,
},
plugins,
resolve: {
extensions: ['.js', '.ts', '.jsx', '.tsx', '.css', '.json'],
},
}

View File

@@ -1,14 +0,0 @@
import type IForkTsCheckerWebpackPlugin from 'fork-ts-checker-webpack-plugin'
import { DefinePlugin } from 'webpack'
// eslint-disable-next-line @typescript-eslint/no-var-requires
const ForkTsCheckerWebpackPlugin: typeof IForkTsCheckerWebpackPlugin = require('fork-ts-checker-webpack-plugin')
export const plugins = [
new ForkTsCheckerWebpackPlugin({
logger: 'webpack-infrastructure',
}),
new DefinePlugin({
'process.env.TELEMETRY_WRITE_KEY': JSON.stringify(process.env.TELEMETRY_WRITE_KEY),
}),
]

View File

@@ -1,19 +0,0 @@
import type { Configuration } from 'webpack'
import { rules } from './webpack.rules'
import { plugins } from './webpack.plugins'
rules.push({
test: /\.css$/,
use: [{ loader: 'style-loader' }, { loader: 'css-loader' }, { loader: 'postcss-loader' }],
})
export const rendererConfig: Configuration = {
module: {
rules,
},
plugins,
resolve: {
extensions: ['.js', '.ts', '.jsx', '.tsx', '.css'],
},
}

View File

@@ -1,35 +0,0 @@
import type { ModuleOptions } from 'webpack'
export const rules: Required<ModuleOptions>['rules'] = [
// Add support for native node modules
{
// We're specifying native_modules in the test because the asset relocator loader generates a
// "fake" .node file which is really a cjs file.
test: /native_modules[/\\].+\.node$/,
use: 'node-loader',
},
{
test: /[/\\]node_modules[/\\].+\.(m?js|node)$/,
parser: { amd: false },
use: {
loader: '@vercel/webpack-asset-relocator-loader',
options: {
outputAssetBase: 'native_modules',
},
},
},
{
test: /\.tsx?$/,
exclude: /(node_modules|\.webpack)/,
use: {
loader: 'ts-loader',
options: {
transpileOnly: true,
},
},
},
{
test: /\.svg$/,
use: ['@svgr/webpack'],
},
]

View File

@@ -146,7 +146,6 @@ type Tensor interface {
FromFloats([]float32)
FromInts([]int32)
Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor
@@ -185,7 +184,6 @@ type Tensor interface {
View(ctx Context, offset int, shape ...int) Tensor
Permute(ctx Context, shape ...int) Tensor
Contiguous(ctx Context, shape ...int) Tensor
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
Pad(ctx Context, shape ...int) Tensor
@@ -198,6 +196,10 @@ type Tensor interface {
Copy(ctx Context, t2 Tensor) Tensor
Duplicate(ctx Context) Tensor
Slice(ctx Context, dim, low, high, step int) Tensor
Chunk(ctx Context, dim int, size int) []Tensor
ChunkSections(ctx Context, dim int, sections ...int) []Tensor
TopK(ctx Context, k int) Tensor
Argsort(ctx Context) Tensor
Mean(ctx Context) Tensor
@@ -205,7 +207,6 @@ type Tensor interface {
Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor
Clamp(ctx Context, min, max float32) Tensor
}
// ScaledDotProductAttention implements a fused attention

View File

@@ -1137,13 +1137,6 @@ func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor {
}
}
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
@@ -1632,20 +1625,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
}
}
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
var tt *C.struct_ggml_tensor
switch len(strides) {
case 0:
tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
case 1:
tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
default:
panic("unsupported number of dimensions")
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor
if mask != nil {
@@ -1732,9 +1711,65 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
// Slice returns a view of the tensor sliced along dim from low to high in step steps.
// Slice panics if the dimension is invalid or the slice parameters are out of range.
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.
func (t *Tensor) Slice(ctx ml.Context, dim int, low, high, step int) ml.Tensor {
if dim < 0 || dim >= C.GGML_MAX_DIMS {
panic("invalid dimension")
} else if low < 0 || high > t.Dim(dim) || low >= high || step < 1 {
panic("invalid slice parameters")
}
if dim == 0 && step > 1 {
// dim=0,step>1 is a special case so handle it here first
return t.View(ctx,
low*t.Stride(0), 1,
step*t.Stride(0), (high-low+1)/step,
t.Stride(1), t.Dim(1),
// preserve dim 3 by merging it into dim 2
t.Stride(2), t.Dim(2)*t.Dim(3),
).Contiguous(ctx, (high-low+1)/step, t.Dim(1), t.Dim(2), t.Dim(3))
}
args := []int{
low * t.Stride(dim), t.Dim(0),
t.Stride(1), t.Dim(1),
t.Stride(2), t.Dim(2),
t.Stride(3), t.Dim(3),
}
if step == 1 {
args[dim*2+1] = high - low
return t.View(ctx, args[0], args[1:]...)
} else {
args[dim*2] = step * t.Stride(dim)
args[dim*2+1] = (high - low + 1) / step
return t.View(ctx, args[0], args[1:]...)
}
}
// Chunk the tensor into chunk sized tensors along dim. Each sub-tensor is a view of
// the original.
func (t *Tensor) Chunk(ctx ml.Context, dim, chunk int) []ml.Tensor {
sections := make([]int, 0, t.Dim(dim)/chunk+1)
for rest := t.Dim(dim); rest > 0; rest -= chunk {
sections = append(sections, min(chunk, rest))
}
return t.ChunkSections(ctx, dim, sections...)
}
// ChunkSections split the tensor into section sized tensors along dim. Each sub-tensor is a
// view of the original. The size of the dim must equal the sum of sections.
func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Tensor {
var offset int
s := make([]ml.Tensor, len(sections))
for i, section := range sections {
s[i] = t.Slice(ctx, dim, offset, offset+section, 1)
offset += section
}
if offset != t.Dim(dim) {
panic("sections do not sum to tensor dimension")
}
return s
}

View File

@@ -693,6 +693,7 @@ GGML_API void ggml_dxgi_pdh_release();
#endif
#ifdef __cplusplus
#include <array>
#include <initializer_list>
#include <vector>
@@ -708,6 +709,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
}
// Return true if the edges in the graph match expectations.
inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
int start_idx,
std::initializer_list<std::array<int, 3>> edges) {
for (const auto & edge : edges) {
int dst_node = edge[0];
int src_idx = edge[1];
int src_node = edge[2];
if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
return false;
}
}
return true;
}
// expose GGUF internals for test code
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);

View File

@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
char name[256];
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_ne02=%d", base, ne02);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {

File diff suppressed because it is too large Load Diff

View File

@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
uint nrows;
uint order;
} p;
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
dst_row[idx1] = tmp;
}
void argsort(bool needs_bounds_check) {
void argsort(bool needs_bounds_check, const uint row) {
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
const uint row = gl_WorkGroupID.y;
const uint row_offset = row * p.ncols;
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
void main() {
if (p.ncols == BLOCK_SIZE) {
argsort(false);
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
argsort(false, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
} else {
argsort(true);
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
argsort(true, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
}
}

View File

@@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
#if defined(DATA_A_MXFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
vec2 v0 = dequantize(ib, iqs, a_offset);
@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
const uint scales = data_a[a_offset + ib].scales[scalesi];
const vec2 d = vec2(data_a[a_offset + ib].d);
const vec2 dm = vec2(data_a[a_offset + ib].dm);
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const vec2 loadd = vec2(data_a[a_offset + ib].d);
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint8_t hm = uint8_t(1 << (iqs / 16));
const vec2 loadd = vec2(data_a[a_offset + ib].d);
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);

View File

@@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
const f16vec2 d = bl.block.d;
const f16vec2 dm = bl.block.dm;
const uint idx = coordInBlock[1];
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
@@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
qs = unpack8(qs)[idx & 1];
const uint scales = bl.block.scales[scalesi];
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
return ret;
}
@@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
return ret;
}
#endif

View File

@@ -26,7 +26,7 @@ void main() {
const float d = e8m0_to_fp32(data_a[ib].e);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
}
}

View File

@@ -24,8 +24,8 @@ void main() {
const uint ql_idx = 32 * ip + il;
const uint8_t qs = data_a[i].qs[32 * ip + il];
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));

View File

@@ -20,8 +20,8 @@ void main() {
const uint is = 2 * il;
const uint n = 4;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
const uint qs_idx = 32*il + n * ir;

View File

@@ -19,8 +19,8 @@ void main() {
const uint ir = tid % 16;
const uint is = 2 * il;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
const uint qs_idx = 32*il + 2 * ir;

View File

@@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
@@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
}
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n]));
}
}
}

View File

@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
}
}
}

View File

@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
}
}
}

View File

@@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[BN];
uint _ne1;
#ifdef MUL_MAT_ID_USE_SUBGROUPS
shared uvec4 ballots_sh[NUM_WARPS];
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
uint ids[16];
uint iter = 0;
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_LocalInvocationIndex;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
ballots_sh[gl_SubgroupID] = ballot;
barrier();
uint subgroup_base = 0;
uint total = 0;
for (uint k = 0; k < gl_NumSubgroups; ++k) {
if (k == gl_SubgroupID) {
subgroup_base = total;
}
total += subgroupBallotBitCount(ballots_sh[k]);
}
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
}
_ne1 += total;
iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
break;
}
}
barrier();
}
#endif // MUL_MAT_ID_USE_SUBGROUPS
#endif // MUL_MAT_ID
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
#include "mul_mm_id_funcs.glsl"
#include "mul_mm_funcs.glsl"
void main() {

View File

@@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
const uint scales = data_a[ib].scales[scalesi];
const vec2 d = vec2(data_a[ib].d);
const vec2 dm = vec2(data_a[ib].dm);
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_Q3_K)
@@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const vec2 loadd = vec2(data_a[ib].d);
const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint8_t hm = uint8_t(1 << (iqs / 16));
const vec2 loadd = vec2(data_a[ib].d);
const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2;
const float d = e8m0_to_fp32(data_a[ib].e);
const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;
const uint vui = uint(data_a[ib].qs[iqs]);
const uint vui2 = uint(data_a[ib].qs[iqs+1]);

View File

@@ -0,0 +1,70 @@
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[BN];
uint _ne1;
#ifdef MUL_MAT_ID_USE_SUBGROUPS
shared uvec4 ballots_sh[NUM_WARPS];
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
uint ids[16];
uint iter = 0;
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_LocalInvocationIndex;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
ballots_sh[gl_SubgroupID] = ballot;
barrier();
uint subgroup_base = 0;
uint total = 0;
for (uint k = 0; k < gl_NumSubgroups; ++k) {
if (k == gl_SubgroupID) {
subgroup_base = total;
}
total += subgroupBallotBitCount(ballots_sh[k]);
}
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
}
_ne1 += total;
iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
break;
}
}
barrier();
}
#endif // MUL_MAT_ID_USE_SUBGROUPS
#endif // MUL_MAT_ID

View File

@@ -10,10 +10,9 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#if defined(MUL_MAT_ID_USE_SUBGROUPS)
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#endif
#ifdef MUL_MAT_ID
@@ -24,7 +23,10 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
@@ -76,40 +78,31 @@ layout (constant_id = 10) const uint WARP = 32;
#define BK 32
#ifdef COOPMAT
#define SHMEM_STRIDE (BK / 4 + 4)
#else
#define SHMEM_STRIDE (BK / 4 + 1)
#endif
#define MMQ_SHMEM
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
#ifndef COOPMAT
#if QUANT_AUXF == 1
shared FLOAT_TYPE buf_a_dm[BM];
#else
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
#endif
#endif
shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
#ifndef COOPMAT
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
#endif
#define LOAD_VEC_A (4 * QUANT_R)
#define LOAD_VEC_B 16
#include "mul_mmq_shmem_types.glsl"
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
#endif // MUL_MAT_ID
#define BK_STEP 1
#else
#ifndef BK_STEP
#define BK_STEP 4
#endif
#endif
// Shared memory cache
shared block_a_cache buf_a[BM * BK_STEP];
shared block_b_cache buf_b[BN * BK_STEP];
// Register cache
block_a_cache cache_a[WMITER * TM];
block_b_cache cache_b;
#define LOAD_VEC_A (4 * QUANT_R_MMQ)
#define LOAD_VEC_B 16
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
#include "mul_mm_id_funcs.glsl"
#include "mul_mmq_funcs.glsl"
void main() {
@@ -139,26 +132,12 @@ void main() {
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER;
#ifdef COOPMAT
const uint warp_i = gl_SubgroupID;
const uint tiw = gl_SubgroupInvocationID;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint storestride = WARP / TM;
const uint store_r = tiw % TM;
const uint store_c = tiw / TM;
#else
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM);
#endif
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
@@ -172,17 +151,27 @@ void main() {
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
uint _ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
#ifdef MUL_MAT_ID_USE_SUBGROUPS
if (bitCount(p.nei0) == 1) {
load_row_ids(expert_idx, true, ic);
} else {
load_row_ids(expert_idx, false, ic);
}
#else
_ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
row_ids[_ne1] = u16vec2(ii0, ii1);
if (_ne1 >= ic * BN) {
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
}
_ne1++;
}
}
}
barrier();
#endif
// Workgroup has no work
if (ic * BN >= _ne1) return;
@@ -209,159 +198,70 @@ void main() {
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
#endif
#ifdef COOPMAT
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
#else
int32_t cache_a_qs[WMITER * TM * BK / 4];
int32_t cache_b_qs[TN * BK / 4];
ACC_TYPE sums[WMITER * TM * WNITER * TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = ACC_TYPE(0.0f);
}
#endif
#if QUANT_AUXF == 1
FLOAT_TYPE cache_a_dm[WMITER * TM];
#else
FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
#endif
FLOAT_TYPE_VEC2 cache_b_ds[TN];
for (uint block = start_k; block < end_k; block += BK) {
for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
const uint iqs = loadr_a;
const uint buf_ib = loadc_a + l;
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
const uint iqs = loadr_a;
if (iqs == 0) {
#if QUANT_AUXF == 1
buf_a_dm[buf_ib] = get_d(ib);
#else
buf_a_dm[buf_ib] = get_dm(ib);
#endif
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
}
#if QUANT_R == 1
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
#else
const i32vec2 vals = repack(ib, iqs);
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
#endif
}
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
const uint ib = idx / 8;
const uint iqs = idx & 0x7;
#else
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
const uint iqs = loadr_b;
#endif
const uint buf_ib = loadc_b + l;
if (iqs == 0) {
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[buf_ib];
const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
#else
const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
#endif
const uint iqs = loadr_b;
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
}
barrier();
pos_a_ib += 1;
pos_b_ib += 1;
pos_a_ib += BK_STEP;
pos_b_ib += BK_STEP;
#ifdef COOPMAT
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
const uint ib_a = warp_r * WM + cm_row * TM;
for (uint k_step = 0; k_step < BK_STEP; k_step++) {
// Load from shared into cache
coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
// TODO: only cache values that are actually needed
[[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
}
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint ib_b = warp_c * WN + cm_col * TN;
coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
// TODO: only cache values that are actually needed
[[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
}
cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
}
coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
}
}
#else
// Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
}
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
cache_b_ds[cc] = buf_b_ds[ib];
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
}
}
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint cache_a_idx = wsir * TM + cr;
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
int32_t q_sum = 0;
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
cache_b_qs[cc * (BK / 4) + idx_k]);
}
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint reg_ib = wsir * TM + cr;
const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
block_a_to_registers(reg_ib, k_step * BM + buf_ib);
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
block_b_to_registers(ib);
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint cache_a_idx = wsir * TM + cr;
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
sums[sums_idx] += mmq_dot_product(cache_a_idx);
}
}
}
}
}
#endif
barrier();
}
@@ -373,54 +273,6 @@ void main() {
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif
#ifdef COOPMAT
#ifdef MUL_MAT_ID
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
#else
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
if (is_aligned && is_in_bounds) {
// Full coopMat is within bounds and stride_d is aligned with 16B
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
} else if (is_in_bounds) {
// Full coopMat is within bounds, but stride_d is not aligned
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
// Partial coopMat is within bounds
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
}
}
#endif // MUL_MAT_ID
#else
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
@@ -431,19 +283,21 @@ void main() {
const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
#ifdef MUL_MAT_ID
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
if (dr_warp + cr < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
}
#else
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
}
#endif // MUL_MAT_ID
}
}
}
}
#endif // COOPMAT
}

View File

@@ -6,41 +6,89 @@
// Each iqs value maps to a 32-bit integer
#if defined(DATA_A_Q4_0)
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
// 2-byte loads for Q4_0 blocks (18 bytes)
// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]);
#ifdef DATA_A_Q4_0
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
#else // DATA_A_Q4_1
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
#endif
}
#ifdef DATA_A_Q4_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
#endif
#if defined(DATA_A_Q4_1)
i32vec2 repack(uint ib, uint iqs) {
// Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
#else // DATA_A_Q4_1
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#if defined(DATA_A_Q5_0)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q4_0
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
}
#else // DATA_A_Q4_1
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
const uint32_t vui = cache_a[ib_a].qs[iqs];
const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
const int32_t qs_b0 = cache_b.qs[iqs];
const int32_t qs_b1 = cache_b.qs[iqs + 4];
q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
// 2-byte loads for Q5_0 blocks (22 bytes)
// 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2 repack(uint ib, uint iqs) {
// Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]);
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
#ifdef DATA_A_Q5_0
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
#else // DATA_A_Q5_1
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
#endif
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
@@ -50,40 +98,457 @@ i32vec2 repack(uint ib, uint iqs) {
return i32vec2(v0, v1);
}
#ifdef DATA_A_Q5_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
#endif
#if defined(DATA_A_Q5_1)
i32vec2 repack(uint ib, uint iqs) {
// Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
const uint32_t vui = data_a_packed32[ib].qs[iqs];
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
#else // DATA_A_Q5_1
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q5_0
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1]));
}
#else // DATA_A_Q5_1
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
buf_a[buf_ib].qh = data_a_packed32[ib].qh;
}
#endif
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
cache_a[reg_ib].qh = buf_a[buf_ib].qh;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
const uint32_t vui = cache_a[ib_a].qs[iqs];
const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
const int32_t qs_b0 = cache_b.qs[iqs];
const int32_t qs_b1 = cache_b.qs[iqs + 4];
q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
// Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]));
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]));
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * da * dsb.x);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
const int32_t qs_a = cache_a[ib_a].qs[iqs];
const int32_t qs_b = cache_b.qs[iqs];
q_sum += dotPacked4x8EXT(qs_a, qs_b);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
return i32vec2( quants & 0x0F0F0F0F,
(quants >> 4) & 0x0F0F0F0F);
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * dsb.x * float(q_sum));
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
if (iqs == 0) {
buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].d = buf_a[buf_ib].d;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
const int32_t qs_a = cache_a[ib_a].qs[iqs];
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
#if defined(DATA_A_Q2_K)
// 4-byte loads for Q2_K blocks (84 bytes)
int32_t repack(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
}
uint8_t get_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return data_a[ib_k].scales[iqs_k / 4];
}
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
// Repack 4x4 quants into one int
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
cache_a[reg_ib].scales = buf_a[buf_ib].scales;
[[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t sum_d = 0;
int32_t sum_m = 0;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
}
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q3_K)
// 2-byte loads for Q3_K blocks (110 bytes)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint hm_idx = iqs * QUANT_R_MMQ;
const uint iqs_k = (ib % 8) * 8 + hm_idx;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
const uint hm_shift = iqs_k / 8;
// Repack 2x4 quants into one int
// Add the 3rd bit instead of subtracting it to allow packing the quants
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
(pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
if (iqs == 0) {
const uint is = iqs_k / 4;
const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
float result = 0.0;
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
// Subtract 4 from the quants to correct the 3rd bit offset
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
q_sum = 0;
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
return ACC_TYPE(cache_b.ds.x * result);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
// Repack 2x4 quants into one int
#if defined(DATA_A_Q4_K)
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
#else // defined(DATA_A_Q5_K)
const uint qh_idx = iqs * QUANT_R_MMQ;
const uint qh_shift = iqs_k / 8;
buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
#endif
if (iqs == 0) {
// Scale index
const uint is = iqs_k / 8;
u8vec2 scale_dm;
if (is < 4) {
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
} else {
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
}
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
[[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
#if defined(DATA_A_Q4_K)
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
#else // defined(DATA_A_Q5_K)
const int32_t qs_a = cache_a[ib_a].qs[iqs];
#endif
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#ifdef MMQ_SHMEM
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
}
void block_b_to_registers(const uint ib) {
cache_b.ds = buf_b[ib].ds;
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
}
#endif
#if defined(DATA_A_Q6_K)
// 2-byte loads for Q6_K blocks (210 bytes)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
if (iqs == 0) {
const uint is = iqs_k / 4;
const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
float result = 0.0;
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
const int32_t qs_a = cache_a[ib_a].qs[iqs];
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
q_sum = 0;
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
const int32_t qs_a = cache_a[ib_a].qs[iqs];
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
return ACC_TYPE(cache_b.ds.x * result);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
@@ -103,3 +568,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
#if defined(DATA_A_Q2_K)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
const uint ib_k = ib / 8;
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
}
#endif

View File

@@ -0,0 +1,78 @@
#if defined(DATA_A_Q4_0)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];
FLOAT_TYPE dm;
};
#elif defined(DATA_A_Q4_1)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q5_0)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];
uint32_t qh;
FLOAT_TYPE dm;
};
#elif defined(DATA_A_Q5_1)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];
uint32_t qh;
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q8_0)
#define QUANT_R_MMQ 1
// AMD likes 4, Intel likes 1 and Nvidia likes 2
// #define BK_STEP 1
struct block_a_cache {
int32_t qs[32/4];
FLOAT_TYPE dm;
};
#elif defined(DATA_A_MXFP4)
#define QUANT_R_MMQ 2
struct block_a_cache {
int32_t qs[8];
FLOAT_TYPE d;
};
#elif defined(DATA_A_Q2_K)
#define QUANT_R_MMQ 4
struct block_a_cache {
uint32_t qs[2];
u8vec2 scales;
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q3_K)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[4];
FLOAT_TYPE_VEC2 d_scales;
};
#elif defined(DATA_A_Q4_K)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[4];
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q5_K)
#define QUANT_R_MMQ 1
struct block_a_cache {
int32_t qs[8];
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q6_K)
#define QUANT_R_MMQ 1
struct block_a_cache {
int32_t qs[8];
FLOAT_TYPE_VEC2 d_scales;
};
#endif
struct block_b_cache
{
int32_t qs[8];
FLOAT_TYPE_VEC2 ds;
};

View File

@@ -10,6 +10,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {int data_pos[];};
layout (binding = 2) readonly buffer Z {float data_ff[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows
layout (push_constant) uniform parameter {
uint ncols;
@@ -27,6 +28,7 @@ layout (push_constant) uniform parameter {
uint s2;
int sections[4];
uint is_back;
uint set_rows_stride;
} p;
float rope_yarn_ramp(const float low, const float high, const uint i0) {

View File

@@ -16,12 +16,19 @@ void main() {
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0/2;
uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
// Fusion optimization: ROPE + VIEW + SET_ROWS..
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
if (p.set_rows_stride != 0) {
idst = row_x*ne0 + i0/2;
idst += data_i[channel_x].x * p.set_rows_stride;
}
if (i0 >= p.n_dims) {
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]);
data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]);
return;
}

View File

@@ -16,12 +16,19 @@ void main() {
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0;
uint idst = row_dst*ne0 + i0;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
// Fusion optimization: ROPE + VIEW + SET_ROWS..
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
if (p.set_rows_stride != 0) {
idst = row_x*ne0 + i0;
idst += data_i[channel_x].x * p.set_rows_stride;
}
if (i0 >= p.n_dims) {
data_d[idst + 0] = data_a[ix + 0];
data_d[idst + 1] = data_a[ix + 1];
data_d[idst + 0] = D_TYPE(data_a[ix + 0]);
data_d[idst + 1] = D_TYPE(data_a[ix + 1]);
return;
}

Some files were not shown because too many files have changed in this diff Show More