Merge branch 'ollama:main' into main
40
.github/workflows/release.yaml
vendored
@@ -104,6 +104,13 @@ jobs:
|
||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||
rocm-version: '6.2'
|
||||
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
runner_dir: 'rocm'
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: Vulkan
|
||||
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
|
||||
flags: ''
|
||||
runner_dir: 'vulkan'
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
env:
|
||||
@@ -113,13 +120,14 @@ jobs:
|
||||
run: |
|
||||
choco install -y --no-progress ccache ninja
|
||||
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ')
|
||||
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
|
||||
id: cache-install
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
- if: startsWith(matrix.preset, 'CUDA ')
|
||||
name: Install CUDA ${{ matrix.cuda-version }}
|
||||
@@ -149,6 +157,18 @@ jobs:
|
||||
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- if: matrix.preset == 'Vulkan'
|
||||
name: Install Vulkan ${{ matrix.rocm-version }}
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
|
||||
}
|
||||
|
||||
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
|
||||
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
|
||||
- if: matrix.preset == 'CPU'
|
||||
run: |
|
||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
@@ -159,6 +179,7 @@ jobs:
|
||||
path: |
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/cache@v4
|
||||
@@ -171,7 +192,7 @@ jobs:
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
|
||||
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
|
||||
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||
env:
|
||||
CMAKE_GENERATOR: Ninja
|
||||
@@ -312,13 +333,13 @@ jobs:
|
||||
include:
|
||||
- os: linux
|
||||
arch: amd64
|
||||
target: archive_novulkan
|
||||
target: archive
|
||||
- os: linux
|
||||
arch: amd64
|
||||
target: rocm
|
||||
- os: linux
|
||||
arch: arm64
|
||||
target: archive_novulkan
|
||||
target: archive
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
needs: setup-environment
|
||||
@@ -374,14 +395,12 @@ jobs:
|
||||
include:
|
||||
- os: linux
|
||||
arch: arm64
|
||||
target: novulkan
|
||||
build-args: |
|
||||
CGO_CFLAGS
|
||||
CGO_CXXFLAGS
|
||||
GOFLAGS
|
||||
- os: linux
|
||||
arch: amd64
|
||||
target: novulkan
|
||||
build-args: |
|
||||
CGO_CFLAGS
|
||||
CGO_CXXFLAGS
|
||||
@@ -394,14 +413,6 @@ jobs:
|
||||
CGO_CXXFLAGS
|
||||
GOFLAGS
|
||||
FLAVOR=rocm
|
||||
- os: linux
|
||||
arch: amd64
|
||||
suffix: '-vulkan'
|
||||
target: default
|
||||
build-args: |
|
||||
CGO_CFLAGS
|
||||
CGO_CXXFLAGS
|
||||
GOFLAGS
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
needs: setup-environment
|
||||
@@ -419,7 +430,6 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.os }}/${{ matrix.arch }}
|
||||
target: ${{ matrix.preset }}
|
||||
build-args: ${{ matrix.build-args }}
|
||||
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||
|
||||
1
.github/workflows/test.yaml
vendored
@@ -172,6 +172,7 @@ jobs:
|
||||
path: |
|
||||
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||
C:\Program Files\AMD\ROCm
|
||||
C:\VulkanSDK
|
||||
key: ${{ matrix.install }}
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/cache@v4
|
||||
|
||||
27
Dockerfile
@@ -159,32 +159,7 @@ ARG VULKANVERSION
|
||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||
COPY --from=build /bin/ollama /bin/ollama
|
||||
|
||||
# Temporary opt-out stages for Vulkan
|
||||
FROM --platform=linux/amd64 scratch AS amd64_novulkan
|
||||
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
|
||||
FROM arm64 AS arm64_novulkan
|
||||
FROM ${FLAVOR}_novulkan AS archive_novulkan
|
||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||
COPY --from=build /bin/ollama /bin/ollama
|
||||
FROM ubuntu:24.04 AS novulkan
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y ca-certificates \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=archive_novulkan /bin /usr/bin
|
||||
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
COPY --from=archive_novulkan /lib/ollama /usr/lib/ollama
|
||||
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||
ENV NVIDIA_VISIBLE_DEVICES=all
|
||||
ENV OLLAMA_HOST=0.0.0.0:11434
|
||||
EXPOSE 11434
|
||||
ENTRYPOINT ["/bin/ollama"]
|
||||
CMD ["serve"]
|
||||
|
||||
FROM ubuntu:24.04 AS default
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y ca-certificates libvulkan1 \
|
||||
&& apt-get clean \
|
||||
|
||||
@@ -321,6 +321,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LibreChat](https://github.com/danny-avila/LibreChat)
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
- [AI-UI](https://github.com/bajahaw/ai-ui)
|
||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||
@@ -662,5 +663,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
|
||||
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||
|
||||
## Security
|
||||
### Security
|
||||
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)
|
||||
|
||||
45
api/types.go
@@ -117,6 +117,14 @@ type GenerateRequest struct {
|
||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||
// template instead of calling the model.
|
||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||
|
||||
// Logprobs specifies whether to return log probabilities of the output tokens.
|
||||
Logprobs bool `json:"logprobs,omitempty"`
|
||||
|
||||
// TopLogprobs is the number of most likely tokens to return at each token position,
|
||||
// each with an associated log probability. Only applies when Logprobs is true.
|
||||
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -159,6 +167,14 @@ type ChatRequest struct {
|
||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||
// template instead of calling the model.
|
||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||
|
||||
// Logprobs specifies whether to return log probabilities of the output tokens.
|
||||
Logprobs bool `json:"logprobs,omitempty"`
|
||||
|
||||
// TopLogprobs is the number of most likely tokens to return at each token position,
|
||||
// each with an associated log probability. Only applies when Logprobs is true.
|
||||
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@@ -343,6 +359,27 @@ func (t *ToolFunction) String() string {
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// TokenLogprob represents log probability information for a single token alternative.
|
||||
type TokenLogprob struct {
|
||||
// Token is the text representation of the token.
|
||||
Token string `json:"token"`
|
||||
|
||||
// Logprob is the log probability of this token.
|
||||
Logprob float64 `json:"logprob"`
|
||||
|
||||
// Bytes contains the raw byte representation of the token
|
||||
Bytes []int `json:"bytes,omitempty"`
|
||||
}
|
||||
|
||||
// Logprob contains log probability information for a generated token.
|
||||
type Logprob struct {
|
||||
TokenLogprob
|
||||
|
||||
// TopLogprobs contains the most likely tokens and their log probabilities
|
||||
// at this position, if requested via TopLogprobs parameter.
|
||||
TopLogprobs []TokenLogprob `json:"top_logprobs,omitempty"`
|
||||
}
|
||||
|
||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||
// similar to [GenerateResponse].
|
||||
type ChatResponse struct {
|
||||
@@ -369,6 +406,10 @@ type ChatResponse struct {
|
||||
|
||||
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||
|
||||
// Logprobs contains log probability information for the generated tokens,
|
||||
// if requested via the Logprobs parameter.
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
|
||||
Metrics
|
||||
}
|
||||
|
||||
@@ -677,6 +718,10 @@ type GenerateResponse struct {
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
|
||||
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||
|
||||
// Logprobs contains log probability information for the generated tokens,
|
||||
// if requested via the Logprobs parameter.
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
|
||||
@@ -48,16 +48,6 @@ The `-dev` flag enables:
|
||||
- CORS headers for cross-origin requests
|
||||
- Hot-reload support for UI development
|
||||
|
||||
#### Run Storybook
|
||||
|
||||
Inside the `ui/app` directory, run:
|
||||
|
||||
```bash
|
||||
npm run storybook
|
||||
```
|
||||
|
||||
For now we're writing stories as siblings of the component they're testing. So for example, `src/components/Message.stories.tsx` is the story for `src/components/Message.tsx`.
|
||||
|
||||
## Build
|
||||
|
||||
|
||||
|
||||
1512
app/ui/app/package-lock.json
generated
@@ -34,6 +34,7 @@
|
||||
"rehype-raw": "^7.0.0",
|
||||
"rehype-sanitize": "^6.0.0",
|
||||
"remark-math": "^6.0.0",
|
||||
"streamdown": "^1.4.0",
|
||||
"unist-builder": "^4.0.0",
|
||||
"unist-util-parents": "^3.0.0"
|
||||
},
|
||||
|
||||
@@ -205,6 +205,13 @@ export async function* sendMessage(
|
||||
data: uint8ArrayToBase64(att.data),
|
||||
}));
|
||||
|
||||
// Only send think parameter when actually requesting thinking
|
||||
// Don't send false as it causes issues with some providers
|
||||
const shouldSendThink =
|
||||
think !== undefined &&
|
||||
((typeof think === "boolean" && think) ||
|
||||
(typeof think === "string" && think !== ""));
|
||||
|
||||
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
@@ -222,7 +229,7 @@ export async function* sendMessage(
|
||||
web_search: webSearch ?? false,
|
||||
file_tools: fileTools ?? false,
|
||||
...(forceUpdate !== undefined ? { forceUpdate } : {}),
|
||||
...(think !== undefined ? { think } : {}),
|
||||
...(shouldSendThink ? { think } : {}),
|
||||
}),
|
||||
),
|
||||
signal,
|
||||
|
||||
@@ -1,522 +0,0 @@
|
||||
import { expect, test, suite } from "vitest";
|
||||
import { processStreamingMarkdown } from "@/utils/processStreamingMarkdown";
|
||||
|
||||
suite("common llm outputs that cause issues", () => {
|
||||
test("prefix of bolded list item shouldn't make a horizontal line", () => {
|
||||
// we're going to go in order of incrementally adding characters. This
|
||||
// happens really commonly with LLMs that like to make lists like so:
|
||||
//
|
||||
// * **point 1**: explanatory text
|
||||
// * **point 2**: more explanatory text
|
||||
//
|
||||
// Partial rendering of `*` (A), followed by `* *` (B), followed by `* **`
|
||||
// (C) is a total mess. (A) renders as a single bullet point in an
|
||||
// otherwise empty list, (B) renders as two nested lists (and therefore
|
||||
// two bullet points, styled differently by default in html), and (C)
|
||||
// renders as a horizontal line because in markdown apparently `***` or `*
|
||||
// * *` horizontal rules don't have as strict whitespace rules as I
|
||||
// expected them to
|
||||
|
||||
// these are alone (i.e., they would be the first list item)
|
||||
expect(processStreamingMarkdown("*")).toBe("");
|
||||
expect(processStreamingMarkdown("* *")).toBe("");
|
||||
expect(processStreamingMarkdown("* **")).toBe("");
|
||||
// expect(processStreamingMarkdown("* **b")).toBe("* **b**");
|
||||
|
||||
// with a list item before them
|
||||
expect(
|
||||
processStreamingMarkdown(
|
||||
// prettier-ignore
|
||||
[
|
||||
"* abc",
|
||||
"*"
|
||||
].join("\n"),
|
||||
),
|
||||
).toBe("* abc");
|
||||
|
||||
expect(
|
||||
processStreamingMarkdown(
|
||||
// prettier-ignore
|
||||
[
|
||||
"* abc",
|
||||
"* *"
|
||||
].join("\n"),
|
||||
),
|
||||
).toBe("* abc");
|
||||
|
||||
expect(
|
||||
processStreamingMarkdown(
|
||||
// prettier-ignore
|
||||
[
|
||||
"* abc",
|
||||
"* **"
|
||||
].join("\n"),
|
||||
),
|
||||
).toBe("* abc");
|
||||
});
|
||||
|
||||
test("bolded list items with text should be rendered properly", () => {
|
||||
expect(processStreamingMarkdown("* **abc**")).toBe("* **abc**");
|
||||
});
|
||||
|
||||
test("partially bolded list items should be autoclosed", () => {
|
||||
expect(processStreamingMarkdown("* **abc")).toBe("* **abc**");
|
||||
});
|
||||
|
||||
suite(
|
||||
"partially bolded list items should be autoclosed, even if the last node isn't a text node",
|
||||
() => {
|
||||
test("inline code", () => {
|
||||
expect(
|
||||
processStreamingMarkdown("* **Asynchronous Function `async`*"),
|
||||
).toBe("* **Asynchronous Function `async`**");
|
||||
});
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
suite("autoclosing bold", () => {
|
||||
suite("endings with no asterisks", () => {
|
||||
test("should autoclose bold", () => {
|
||||
expect(processStreamingMarkdown("**abc")).toBe("**abc**");
|
||||
expect(processStreamingMarkdown("abc **abc")).toBe("abc **abc**");
|
||||
});
|
||||
|
||||
suite("should autoclose, even if the last node isn't a text node", () => {
|
||||
test("inline code", () => {
|
||||
expect(
|
||||
processStreamingMarkdown("* **Asynchronous Function `async`"),
|
||||
).toBe("* **Asynchronous Function `async`**");
|
||||
});
|
||||
|
||||
test("opening ** is at the end of the text", () => {
|
||||
expect(processStreamingMarkdown("abc **`def` jhk [lmn](opq)")).toBe(
|
||||
"abc **`def` jhk [lmn](opq)**",
|
||||
);
|
||||
});
|
||||
|
||||
test("if there's a space after the **, it should NOT be autoclosed", () => {
|
||||
expect(processStreamingMarkdown("abc ** `def` jhk [lmn](opq)")).toBe(
|
||||
"abc \\*\\* `def` jhk [lmn](opq)",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test("should autoclose bold, even if the last node isn't a text node", () => {
|
||||
expect(
|
||||
processStreamingMarkdown("* **Asynchronous Function ( `async`"),
|
||||
).toBe("* **Asynchronous Function ( `async`**");
|
||||
});
|
||||
|
||||
test("whitespace fakeouts should not be modified", () => {
|
||||
expect(processStreamingMarkdown("** abc")).toBe("\\*\\* abc");
|
||||
});
|
||||
|
||||
// TODO(drifkin): arguably this should just be removed entirely, but empty
|
||||
// isn't so bad
|
||||
test("should handle empty bolded items", () => {
|
||||
expect(processStreamingMarkdown("**")).toBe("");
|
||||
});
|
||||
});
|
||||
|
||||
suite("partially closed bolded items", () => {
|
||||
test("simple partial", () => {
|
||||
expect(processStreamingMarkdown("**abc*")).toBe("**abc**");
|
||||
});
|
||||
|
||||
test("partial with non-text node at end", () => {
|
||||
expect(processStreamingMarkdown("**abc`def`*")).toBe("**abc`def`**");
|
||||
});
|
||||
|
||||
test("partial with multiply nested ending nodes", () => {
|
||||
expect(processStreamingMarkdown("**abc[abc](`def`)*")).toBe(
|
||||
"**abc[abc](`def`)**",
|
||||
);
|
||||
});
|
||||
|
||||
test("normal emphasis should not be affected", () => {
|
||||
expect(processStreamingMarkdown("*abc*")).toBe("*abc*");
|
||||
});
|
||||
|
||||
test("normal emphasis with nested code should not be affected", () => {
|
||||
expect(processStreamingMarkdown("*`abc`*")).toBe("*`abc`*");
|
||||
});
|
||||
});
|
||||
|
||||
test.skip("shouldn't autoclose immediately if there's a space before the closing *", () => {
|
||||
expect(processStreamingMarkdown("**abc *")).toBe("**abc**");
|
||||
});
|
||||
|
||||
// skipping for now because this requires partial link completion as well
|
||||
suite.skip("nested blocks that each need autoclosing", () => {
|
||||
test("emph nested in link nested in strong nested in list item", () => {
|
||||
expect(processStreamingMarkdown("* **[abc **def")).toBe(
|
||||
"* **[abc **def**]()**",
|
||||
);
|
||||
});
|
||||
|
||||
test("* **[ab *`def`", () => {
|
||||
expect(processStreamingMarkdown("* **[ab *`def`")).toBe(
|
||||
"* **[ab *`def`*]()**",
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
suite("numbered list items", () => {
|
||||
test("should remove trailing numbers", () => {
|
||||
expect(processStreamingMarkdown("1. First\n2")).toBe("1. First");
|
||||
});
|
||||
|
||||
test("should remove trailing numbers with breaks before", () => {
|
||||
expect(processStreamingMarkdown("1. First \n2")).toBe("1. First");
|
||||
});
|
||||
|
||||
test("should remove trailing numbers that form a new paragraph", () => {
|
||||
expect(processStreamingMarkdown("1. First\n\n2")).toBe("1. First");
|
||||
});
|
||||
|
||||
test("but should leave list items separated by two newlines", () => {
|
||||
expect(processStreamingMarkdown("1. First\n\n2. S")).toBe(
|
||||
"1. First\n\n2. S",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// TODO(drifkin):slop tests ahead, some are decent, but need to manually go
|
||||
// through them as I implement
|
||||
/*
|
||||
describe("StreamingMarkdownContent - processStreamingMarkdown", () => {
|
||||
describe("Ambiguous endings removal", () => {
|
||||
it("should remove list markers at the end", () => {
|
||||
expect(processStreamingMarkdown("Some text\n* ")).toBe("Some text");
|
||||
expect(processStreamingMarkdown("Some text\n*")).toBe("Some text");
|
||||
expect(processStreamingMarkdown("* Item 1\n- ")).toBe("* Item 1");
|
||||
expect(processStreamingMarkdown("* Item 1\n-")).toBe("* Item 1");
|
||||
expect(processStreamingMarkdown("Text\n+ ")).toBe("Text");
|
||||
expect(processStreamingMarkdown("Text\n+")).toBe("Text");
|
||||
expect(processStreamingMarkdown("1. First\n2. ")).toBe("1. First");
|
||||
});
|
||||
|
||||
it("should remove heading markers at the end", () => {
|
||||
expect(processStreamingMarkdown("Some text\n# ")).toBe("Some text");
|
||||
expect(processStreamingMarkdown("Some text\n#")).toBe("Some text\n#"); // # without space is not removed
|
||||
expect(processStreamingMarkdown("# Title\n## ")).toBe("# Title");
|
||||
expect(processStreamingMarkdown("# Title\n##")).toBe("# Title\n##"); // ## without space is not removed
|
||||
});
|
||||
|
||||
it("should remove ambiguous bold markers at the end", () => {
|
||||
expect(processStreamingMarkdown("Text **")).toBe("Text ");
|
||||
expect(processStreamingMarkdown("Some text\n**")).toBe("Some text");
|
||||
});
|
||||
|
||||
it("should remove code block markers at the end", () => {
|
||||
expect(processStreamingMarkdown("Text\n```")).toBe("Text");
|
||||
expect(processStreamingMarkdown("```")).toBe("");
|
||||
});
|
||||
|
||||
it("should remove single backtick at the end", () => {
|
||||
expect(processStreamingMarkdown("Text `")).toBe("Text ");
|
||||
expect(processStreamingMarkdown("`")).toBe("");
|
||||
});
|
||||
|
||||
it("should remove single asterisk at the end", () => {
|
||||
expect(processStreamingMarkdown("Text *")).toBe("Text ");
|
||||
expect(processStreamingMarkdown("*")).toBe("");
|
||||
});
|
||||
|
||||
it("should handle empty content", () => {
|
||||
expect(processStreamingMarkdown("")).toBe("");
|
||||
});
|
||||
|
||||
it("should handle single line removals correctly", () => {
|
||||
expect(processStreamingMarkdown("* ")).toBe("");
|
||||
expect(processStreamingMarkdown("# ")).toBe("");
|
||||
expect(processStreamingMarkdown("**")).toBe("");
|
||||
expect(processStreamingMarkdown("`")).toBe("");
|
||||
});
|
||||
|
||||
it("shouldn't have this regexp capture group bug", () => {
|
||||
expect(
|
||||
processStreamingMarkdown("Here's a shopping list:\n*"),
|
||||
).not.toContain("0*");
|
||||
expect(processStreamingMarkdown("Here's a shopping list:\n*")).toBe(
|
||||
"Here's a shopping list:",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("List markers", () => {
|
||||
it("should preserve complete list items", () => {
|
||||
expect(processStreamingMarkdown("* Complete item")).toBe(
|
||||
"* Complete item",
|
||||
);
|
||||
expect(processStreamingMarkdown("- Another item")).toBe("- Another item");
|
||||
expect(processStreamingMarkdown("+ Plus item")).toBe("+ Plus item");
|
||||
expect(processStreamingMarkdown("1. Numbered item")).toBe(
|
||||
"1. Numbered item",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle indented list markers", () => {
|
||||
expect(processStreamingMarkdown(" * ")).toBe(" ");
|
||||
expect(processStreamingMarkdown(" - ")).toBe(" ");
|
||||
expect(processStreamingMarkdown("\t+ ")).toBe("\t");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Heading markers", () => {
|
||||
it("should preserve complete headings", () => {
|
||||
expect(processStreamingMarkdown("# Complete Heading")).toBe(
|
||||
"# Complete Heading",
|
||||
);
|
||||
expect(processStreamingMarkdown("## Subheading")).toBe("## Subheading");
|
||||
expect(processStreamingMarkdown("### H3 Title")).toBe("### H3 Title");
|
||||
});
|
||||
|
||||
it("should not affect # in other contexts", () => {
|
||||
expect(processStreamingMarkdown("C# programming")).toBe("C# programming");
|
||||
expect(processStreamingMarkdown("Issue #123")).toBe("Issue #123");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Bold text", () => {
|
||||
it("should close incomplete bold text", () => {
|
||||
expect(processStreamingMarkdown("This is **bold text")).toBe(
|
||||
"This is **bold text**",
|
||||
);
|
||||
expect(processStreamingMarkdown("Start **bold and more")).toBe(
|
||||
"Start **bold and more**",
|
||||
);
|
||||
expect(processStreamingMarkdown("**just bold")).toBe("**just bold**");
|
||||
});
|
||||
|
||||
it("should not affect complete bold text", () => {
|
||||
expect(processStreamingMarkdown("**complete bold**")).toBe(
|
||||
"**complete bold**",
|
||||
);
|
||||
expect(processStreamingMarkdown("Text **bold** more")).toBe(
|
||||
"Text **bold** more",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle nested bold correctly", () => {
|
||||
expect(processStreamingMarkdown("**bold** and **another")).toBe(
|
||||
"**bold** and **another**",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Italic text", () => {
|
||||
it("should close incomplete italic text", () => {
|
||||
expect(processStreamingMarkdown("This is *italic text")).toBe(
|
||||
"This is *italic text*",
|
||||
);
|
||||
expect(processStreamingMarkdown("Start *italic and more")).toBe(
|
||||
"Start *italic and more*",
|
||||
);
|
||||
});
|
||||
|
||||
it("should differentiate between list markers and italic", () => {
|
||||
expect(processStreamingMarkdown("* Item\n* ")).toBe("* Item");
|
||||
expect(processStreamingMarkdown("Some *italic text")).toBe(
|
||||
"Some *italic text*",
|
||||
);
|
||||
expect(processStreamingMarkdown("*just italic")).toBe("*just italic*");
|
||||
});
|
||||
|
||||
it("should not affect complete italic text", () => {
|
||||
expect(processStreamingMarkdown("*complete italic*")).toBe(
|
||||
"*complete italic*",
|
||||
);
|
||||
expect(processStreamingMarkdown("Text *italic* more")).toBe(
|
||||
"Text *italic* more",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Code blocks", () => {
|
||||
it("should close incomplete code blocks", () => {
|
||||
expect(processStreamingMarkdown("```javascript\nconst x = 42;")).toBe(
|
||||
"```javascript\nconst x = 42;\n```",
|
||||
);
|
||||
expect(processStreamingMarkdown("```\ncode here")).toBe(
|
||||
"```\ncode here\n```",
|
||||
);
|
||||
});
|
||||
|
||||
it("should not affect complete code blocks", () => {
|
||||
expect(processStreamingMarkdown("```\ncode\n```")).toBe("```\ncode\n```");
|
||||
expect(processStreamingMarkdown("```js\nconst x = 1;\n```")).toBe(
|
||||
"```js\nconst x = 1;\n```",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle nested code blocks correctly", () => {
|
||||
expect(processStreamingMarkdown("```\ncode\n```\n```python")).toBe(
|
||||
"```\ncode\n```\n```python\n```",
|
||||
);
|
||||
});
|
||||
|
||||
it("should not process markdown inside code blocks", () => {
|
||||
expect(processStreamingMarkdown("```\n* not a list\n**not bold**")).toBe(
|
||||
"```\n* not a list\n**not bold**\n```",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Inline code", () => {
|
||||
it("should close incomplete inline code", () => {
|
||||
expect(processStreamingMarkdown("This is `inline code")).toBe(
|
||||
"This is `inline code`",
|
||||
);
|
||||
expect(processStreamingMarkdown("Use `console.log")).toBe(
|
||||
"Use `console.log`",
|
||||
);
|
||||
});
|
||||
|
||||
it("should not affect complete inline code", () => {
|
||||
expect(processStreamingMarkdown("`complete code`")).toBe(
|
||||
"`complete code`",
|
||||
);
|
||||
expect(processStreamingMarkdown("Use `code` here")).toBe(
|
||||
"Use `code` here",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle multiple inline codes correctly", () => {
|
||||
expect(processStreamingMarkdown("`code` and `more")).toBe(
|
||||
"`code` and `more`",
|
||||
);
|
||||
});
|
||||
|
||||
it("should not confuse inline code with code blocks", () => {
|
||||
expect(processStreamingMarkdown("```\nblock\n```\n`inline")).toBe(
|
||||
"```\nblock\n```\n`inline`",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Complex streaming scenarios", () => {
|
||||
it("should handle progressive streaming of a heading", () => {
|
||||
const steps = [
|
||||
{ input: "#", expected: "#" }, // # alone is not removed (needs space)
|
||||
{ input: "# ", expected: "" },
|
||||
{ input: "# H", expected: "# H" },
|
||||
{ input: "# Hello", expected: "# Hello" },
|
||||
];
|
||||
steps.forEach(({ input, expected }) => {
|
||||
expect(processStreamingMarkdown(input)).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle progressive streaming of bold text", () => {
|
||||
const steps = [
|
||||
{ input: "*", expected: "" },
|
||||
{ input: "**", expected: "" },
|
||||
{ input: "**b", expected: "**b**" },
|
||||
{ input: "**bold", expected: "**bold**" },
|
||||
{ input: "**bold**", expected: "**bold**" },
|
||||
];
|
||||
steps.forEach(({ input, expected }) => {
|
||||
expect(processStreamingMarkdown(input)).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle multiline content with various patterns", () => {
|
||||
const multiline = `# Title
|
||||
|
||||
This is a paragraph with **bold text** and *italic text*.
|
||||
|
||||
* Item 1
|
||||
* Item 2
|
||||
* `;
|
||||
|
||||
const expected = `# Title
|
||||
|
||||
This is a paragraph with **bold text** and *italic text*.
|
||||
|
||||
* Item 1
|
||||
* Item 2`;
|
||||
|
||||
expect(processStreamingMarkdown(multiline)).toBe(expected);
|
||||
});
|
||||
|
||||
it("should only fix the last line", () => {
|
||||
expect(processStreamingMarkdown("# Complete\n# Another\n# ")).toBe(
|
||||
"# Complete\n# Another",
|
||||
);
|
||||
expect(processStreamingMarkdown("* Item 1\n* Item 2\n* ")).toBe(
|
||||
"* Item 1\n* Item 2",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle mixed content correctly", () => {
|
||||
const input = `# Header
|
||||
|
||||
This has **bold** text and *italic* text.
|
||||
|
||||
\`\`\`js
|
||||
const x = 42;
|
||||
\`\`\`
|
||||
|
||||
Now some \`inline code\` and **unclosed bold`;
|
||||
|
||||
const expected = `# Header
|
||||
|
||||
This has **bold** text and *italic* text.
|
||||
|
||||
\`\`\`js
|
||||
const x = 42;
|
||||
\`\`\`
|
||||
|
||||
Now some \`inline code\` and **unclosed bold**`;
|
||||
|
||||
expect(processStreamingMarkdown(input)).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Edge cases with escaping", () => {
|
||||
it("should handle escaped asterisks (future enhancement)", () => {
|
||||
// Note: Current implementation doesn't handle escaping
|
||||
// This is a known limitation - escaped characters still trigger closing
|
||||
expect(processStreamingMarkdown("Text \\*not italic")).toBe(
|
||||
"Text \\*not italic*",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle escaped backticks (future enhancement)", () => {
|
||||
// Note: Current implementation doesn't handle escaping
|
||||
// This is a known limitation - escaped characters still trigger closing
|
||||
expect(processStreamingMarkdown("Text \\`not code")).toBe(
|
||||
"Text \\`not code`",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Code block edge cases", () => {
|
||||
it("should handle triple backticks in the middle of lines", () => {
|
||||
expect(processStreamingMarkdown("Text ``` in middle")).toBe(
|
||||
"Text ``` in middle\n```",
|
||||
);
|
||||
expect(processStreamingMarkdown("```\nText ``` in code\nmore")).toBe(
|
||||
"```\nText ``` in code\nmore\n```",
|
||||
);
|
||||
});
|
||||
|
||||
it("should properly close code blocks with language specifiers", () => {
|
||||
expect(processStreamingMarkdown("```typescript")).toBe(
|
||||
"```typescript\n```",
|
||||
);
|
||||
expect(processStreamingMarkdown("```typescript\nconst x = 1")).toBe(
|
||||
"```typescript\nconst x = 1\n```",
|
||||
);
|
||||
});
|
||||
|
||||
it("should remove a completely empty partial code block", () => {
|
||||
expect(processStreamingMarkdown("```\n")).toBe("");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
*/
|
||||
@@ -1,66 +1,123 @@
|
||||
import React from "react";
|
||||
import Markdown from "react-markdown";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkMath from "remark-math";
|
||||
import rehypeRaw from "rehype-raw";
|
||||
import rehypeSanitize, { defaultSchema } from "rehype-sanitize";
|
||||
import rehypePrismPlus from "rehype-prism-plus";
|
||||
import rehypeKatex from "rehype-katex";
|
||||
import remarkStreamingMarkdown, {
|
||||
type LastNodeInfo,
|
||||
} from "@/utils/remarkStreamingMarkdown";
|
||||
import type { PluggableList } from "unified";
|
||||
import { Streamdown, defaultRemarkPlugins } from "streamdown";
|
||||
import remarkCitationParser from "@/utils/remarkCitationParser";
|
||||
import CopyButton from "./CopyButton";
|
||||
import type { BundledLanguage } from "shiki";
|
||||
import { highlighter } from "@/lib/highlighter";
|
||||
|
||||
interface StreamingMarkdownContentProps {
|
||||
content: string;
|
||||
isStreaming?: boolean;
|
||||
size?: "sm" | "md" | "lg";
|
||||
onLastNode?: (info: LastNodeInfo) => void;
|
||||
browserToolResult?: any; // TODO: proper type
|
||||
}
|
||||
|
||||
const CodeBlock = React.memo(
|
||||
({ children, className, ...props }: React.HTMLAttributes<HTMLPreElement>) => {
|
||||
const extractText = React.useCallback((node: React.ReactNode): string => {
|
||||
// Helper to extract text from React nodes
|
||||
const extractText = (node: React.ReactNode): string => {
|
||||
if (typeof node === "string") return node;
|
||||
if (typeof node === "number") return String(node);
|
||||
if (!node) return "";
|
||||
|
||||
if (React.isValidElement(node)) {
|
||||
if (
|
||||
node.props &&
|
||||
typeof node.props === "object" &&
|
||||
"children" in node.props
|
||||
) {
|
||||
return extractText(node.props.children as React.ReactNode);
|
||||
const props = node.props as any;
|
||||
if (props?.children) {
|
||||
return extractText(props.children as React.ReactNode);
|
||||
}
|
||||
}
|
||||
|
||||
if (Array.isArray(node)) {
|
||||
return node.map(extractText).join("");
|
||||
}
|
||||
|
||||
return "";
|
||||
}, []);
|
||||
};
|
||||
|
||||
const language = className?.replace(/language-/, "") || "";
|
||||
const CodeBlock = React.memo(
|
||||
({ children }: React.HTMLAttributes<HTMLPreElement>) => {
|
||||
// Extract code and language from children
|
||||
const codeElement = children as React.ReactElement<{
|
||||
className?: string;
|
||||
children: React.ReactNode;
|
||||
}>;
|
||||
const language =
|
||||
codeElement.props.className?.replace(/language-/, "") || "";
|
||||
const codeText = extractText(codeElement.props.children);
|
||||
|
||||
// Synchronously highlight code using the pre-loaded highlighter
|
||||
const tokens = React.useMemo(() => {
|
||||
if (!highlighter) return null;
|
||||
|
||||
try {
|
||||
return {
|
||||
light: highlighter.codeToTokensBase(codeText, {
|
||||
lang: language as BundledLanguage,
|
||||
theme: "one-light" as any,
|
||||
}),
|
||||
dark: highlighter.codeToTokensBase(codeText, {
|
||||
lang: language as BundledLanguage,
|
||||
theme: "one-dark" as any,
|
||||
}),
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Failed to highlight code:", error);
|
||||
return null;
|
||||
}
|
||||
}, [codeText, language]);
|
||||
|
||||
return (
|
||||
<div className="relative bg-neutral-100 dark:bg-neutral-800 rounded-2xl overflow-hidden my-6">
|
||||
<div className="flex justify-between select-none">
|
||||
<div className="flex select-none">
|
||||
{language && (
|
||||
<div className="text-[13px] text-neutral-500 dark:text-neutral-400 font-mono px-4 py-2">
|
||||
{language}
|
||||
</div>
|
||||
)}
|
||||
<CopyButton
|
||||
content={extractText(children)}
|
||||
content={codeText}
|
||||
showLabels={true}
|
||||
className="copy-button text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800"
|
||||
className="copy-button text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800 ml-auto"
|
||||
/>
|
||||
</div>
|
||||
<pre className={className} {...props}>
|
||||
{children}
|
||||
{/* Light mode */}
|
||||
<pre className="dark:hidden m-0 bg-neutral-100 text-sm overflow-x-auto p-4">
|
||||
<code className="font-mono text-sm">
|
||||
{tokens?.light
|
||||
? tokens.light.map((line: any, i: number) => (
|
||||
<React.Fragment key={i}>
|
||||
{line.map((token: any, j: number) => (
|
||||
<span
|
||||
key={j}
|
||||
style={{
|
||||
color: token.color,
|
||||
}}
|
||||
>
|
||||
{token.content}
|
||||
</span>
|
||||
))}
|
||||
{i < tokens.light.length - 1 && "\n"}
|
||||
</React.Fragment>
|
||||
))
|
||||
: codeText}
|
||||
</code>
|
||||
</pre>
|
||||
{/* Dark mode */}
|
||||
<pre className="hidden dark:block m-0 bg-neutral-800 text-sm overflow-x-auto p-4">
|
||||
<code className="font-mono text-sm">
|
||||
{tokens?.dark
|
||||
? tokens.dark.map((line: any, i: number) => (
|
||||
<React.Fragment key={i}>
|
||||
{line.map((token: any, j: number) => (
|
||||
<span
|
||||
key={j}
|
||||
style={{
|
||||
color: token.color,
|
||||
}}
|
||||
>
|
||||
{token.content}
|
||||
</span>
|
||||
))}
|
||||
{i < tokens.dark.length - 1 && "\n"}
|
||||
</React.Fragment>
|
||||
))
|
||||
: codeText}
|
||||
</code>
|
||||
</pre>
|
||||
</div>
|
||||
);
|
||||
@@ -68,60 +125,14 @@ const CodeBlock = React.memo(
|
||||
);
|
||||
|
||||
const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
||||
React.memo(
|
||||
({ content, isStreaming = false, size, onLastNode, browserToolResult }) => {
|
||||
// Build the remark plugins array
|
||||
React.memo(({ content, isStreaming = false, size, browserToolResult }) => {
|
||||
// Build the remark plugins array - keep default GFM and Math, add citations
|
||||
const remarkPlugins = React.useMemo(() => {
|
||||
const plugins: PluggableList = [
|
||||
remarkGfm,
|
||||
[remarkMath, { singleDollarTextMath: false }],
|
||||
return [
|
||||
defaultRemarkPlugins.gfm,
|
||||
defaultRemarkPlugins.math,
|
||||
remarkCitationParser,
|
||||
];
|
||||
|
||||
// Add streaming plugin when in streaming mode
|
||||
if (isStreaming) {
|
||||
plugins.push([remarkStreamingMarkdown, { debug: true, onLastNode }]);
|
||||
}
|
||||
|
||||
return plugins;
|
||||
}, [isStreaming, onLastNode]);
|
||||
|
||||
// Create a custom sanitization schema that allows math elements
|
||||
const sanitizeSchema = React.useMemo(() => {
|
||||
return {
|
||||
...defaultSchema,
|
||||
attributes: {
|
||||
...defaultSchema.attributes,
|
||||
span: [
|
||||
...(defaultSchema.attributes?.span || []),
|
||||
["className", /^katex/],
|
||||
],
|
||||
div: [
|
||||
...(defaultSchema.attributes?.div || []),
|
||||
["className", /^katex/],
|
||||
],
|
||||
"ol-citation": ["cursor", "start", "end"],
|
||||
},
|
||||
tagNames: [
|
||||
...(defaultSchema.tagNames || []),
|
||||
"math",
|
||||
"mrow",
|
||||
"mi",
|
||||
"mo",
|
||||
"mn",
|
||||
"msup",
|
||||
"msub",
|
||||
"mfrac",
|
||||
"mover",
|
||||
"munder",
|
||||
"msqrt",
|
||||
"mroot",
|
||||
"merror",
|
||||
"mspace",
|
||||
"mpadded",
|
||||
"ol-citation",
|
||||
],
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
@@ -144,6 +155,26 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
||||
prose-pre:my-0
|
||||
prose-pre:max-w-full
|
||||
prose-pre:pt-1
|
||||
[&_table]:border-collapse
|
||||
[&_table]:w-full
|
||||
[&_table]:border
|
||||
[&_table]:border-neutral-200
|
||||
[&_table]:rounded-lg
|
||||
[&_table]:overflow-hidden
|
||||
[&_th]:px-3
|
||||
[&_th]:py-2
|
||||
[&_th]:text-left
|
||||
[&_th]:font-semibold
|
||||
[&_th]:border-b
|
||||
[&_th]:border-r
|
||||
[&_th]:border-neutral-200
|
||||
[&_th:last-child]:border-r-0
|
||||
[&_td]:px-3
|
||||
[&_td]:py-2
|
||||
[&_td]:border-r
|
||||
[&_td]:border-neutral-200
|
||||
[&_td:last-child]:border-r-0
|
||||
[&_tbody_tr:not(:last-child)_td]:border-b
|
||||
[&_code:not(pre_code)]:text-neutral-700
|
||||
[&_code:not(pre_code)]:bg-neutral-100
|
||||
[&_code:not(pre_code)]:font-normal
|
||||
@@ -160,6 +191,10 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
||||
dark:prose-strong:text-neutral-200
|
||||
dark:prose-pre:text-neutral-200
|
||||
dark:prose:pre:text-neutral-200
|
||||
dark:[&_table]:border-neutral-700
|
||||
dark:[&_thead]:bg-neutral-800
|
||||
dark:[&_th]:border-neutral-700
|
||||
dark:[&_td]:border-neutral-700
|
||||
dark:[&_code:not(pre_code)]:text-neutral-200
|
||||
dark:[&_code:not(pre_code)]:bg-neutral-800
|
||||
dark:[&_code:not(pre_code)]:font-normal
|
||||
@@ -172,23 +207,11 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
||||
content={content}
|
||||
isStreaming={isStreaming}
|
||||
>
|
||||
<Markdown
|
||||
<Streamdown
|
||||
parseIncompleteMarkdown={isStreaming}
|
||||
isAnimating={isStreaming}
|
||||
remarkPlugins={remarkPlugins}
|
||||
rehypePlugins={
|
||||
[
|
||||
[rehypeRaw, { allowDangerousHtml: true }],
|
||||
[rehypeSanitize, sanitizeSchema],
|
||||
[rehypePrismPlus, { ignoreMissing: true }],
|
||||
[
|
||||
rehypeKatex,
|
||||
{
|
||||
errorColor: "#000000", // Black instead of red for errors
|
||||
strict: false, // Be more lenient with parsing
|
||||
throwOnError: false,
|
||||
},
|
||||
],
|
||||
] as PluggableList
|
||||
}
|
||||
controls={false}
|
||||
components={{
|
||||
pre: CodeBlock,
|
||||
table: ({
|
||||
@@ -196,38 +219,35 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
||||
...props
|
||||
}: React.HTMLAttributes<HTMLTableElement>) => (
|
||||
<div className="overflow-x-auto max-w-full">
|
||||
<table {...props}>{children}</table>
|
||||
<table
|
||||
{...props}
|
||||
className="border-collapse w-full border border-neutral-200 dark:border-neutral-700 rounded-lg overflow-hidden"
|
||||
>
|
||||
{children}
|
||||
</table>
|
||||
</div>
|
||||
),
|
||||
// @ts-expect-error: custom type
|
||||
// @ts-expect-error: custom citation type
|
||||
"ol-citation": ({
|
||||
cursor,
|
||||
// start,
|
||||
// end,
|
||||
}: {
|
||||
cursor: number;
|
||||
start: number;
|
||||
end: number;
|
||||
}) => {
|
||||
// Check if we have a page_stack and if the cursor is valid
|
||||
const pageStack = browserToolResult?.page_stack;
|
||||
const hasValidPage = pageStack && cursor < pageStack.length;
|
||||
const pageUrl = hasValidPage ? pageStack[cursor] : null;
|
||||
|
||||
// Extract a readable title from the URL if possible
|
||||
const getPageTitle = (url: string) => {
|
||||
if (url.startsWith("search_results_")) {
|
||||
const searchTerm = url.substring(
|
||||
"search_results_".length,
|
||||
);
|
||||
const searchTerm = url.substring("search_results_".length);
|
||||
return `Search: ${searchTerm}`;
|
||||
}
|
||||
// For regular URLs, try to extract domain or use full URL
|
||||
try {
|
||||
const urlObj = new URL(url);
|
||||
return urlObj.hostname;
|
||||
} catch {
|
||||
// If not a valid URL, return as is
|
||||
return url;
|
||||
}
|
||||
};
|
||||
@@ -238,7 +258,6 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
||||
</span>
|
||||
);
|
||||
|
||||
// If we have a valid page URL, wrap in a link
|
||||
if (pageUrl && pageUrl.startsWith("http")) {
|
||||
return (
|
||||
<a
|
||||
@@ -253,18 +272,16 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
|
||||
);
|
||||
}
|
||||
|
||||
// Otherwise, just return the citation without a link
|
||||
return citationElement;
|
||||
},
|
||||
}}
|
||||
>
|
||||
{content}
|
||||
</Markdown>
|
||||
</Streamdown>
|
||||
</StreamingMarkdownErrorBoundary>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
interface StreamingMarkdownErrorBoundaryProps {
|
||||
content: string;
|
||||
|
||||
@@ -73,8 +73,9 @@ export default function Thinking({
|
||||
// Calculate max height for smooth animations
|
||||
const getMaxHeight = () => {
|
||||
if (isCollapsed) {
|
||||
return finishedThinking ? "0px" : "12rem"; // 8rem = 128px (same as max-h-32)
|
||||
return finishedThinking ? "0px" : "12rem";
|
||||
}
|
||||
// When expanded, use the content height or grow naturally
|
||||
return contentHeight ? `${contentHeight}px` : "none";
|
||||
};
|
||||
|
||||
@@ -131,10 +132,11 @@ export default function Thinking({
|
||||
</div>
|
||||
<div
|
||||
ref={wrapperRef}
|
||||
className={`text-xs text-neutral-500 dark:text-neutral-500 rounded-md overflow-hidden
|
||||
transition-[max-height,opacity] duration-300 ease-in-out relative ml-6 mt-2`}
|
||||
className={`text-xs text-neutral-500 dark:text-neutral-500 rounded-md
|
||||
transition-[max-height,opacity] duration-300 ease-in-out relative ml-6 mt-2
|
||||
${isCollapsed ? "overflow-hidden" : "overflow-y-auto"}`}
|
||||
style={{
|
||||
maxHeight: getMaxHeight(),
|
||||
maxHeight: isCollapsed ? getMaxHeight() : undefined,
|
||||
opacity: isCollapsed && finishedThinking ? 0 : 1,
|
||||
}}
|
||||
>
|
||||
|
||||
@@ -16,793 +16,6 @@
|
||||
--text-color: #ffffff;
|
||||
}
|
||||
}
|
||||
@media (prefers-color-scheme: light) {
|
||||
.prose {
|
||||
/**
|
||||
* One Light theme for prism.js
|
||||
* Based on Atom's One Light theme: https://github.com/atom/atom/tree/master/packages/one-light-syntax
|
||||
*/
|
||||
|
||||
/**
|
||||
* One Light colours (accurate as of commit eb064bf on 19 Feb 2021)
|
||||
* From colors.less
|
||||
* --mono-1: hsl(230, 8%, 24%);
|
||||
* --mono-2: hsl(230, 6%, 44%);
|
||||
* --mono-3: hsl(230, 4%, 64%)
|
||||
* --hue-1: hsl(198, 99%, 37%);
|
||||
* --hue-2: hsl(221, 87%, 60%);
|
||||
* --hue-3: hsl(301, 63%, 40%);
|
||||
* --hue-4: hsl(119, 34%, 47%);
|
||||
* --hue-5: hsl(5, 74%, 59%);
|
||||
* --hue-5-2: hsl(344, 84%, 43%);
|
||||
* --hue-6: hsl(35, 99%, 36%);
|
||||
* --hue-6-2: hsl(35, 99%, 40%);
|
||||
* --syntax-fg: hsl(230, 8%, 24%);
|
||||
* --syntax-bg: hsl(230, 1%, 98%);
|
||||
* --syntax-gutter: hsl(230, 1%, 62%);
|
||||
* --syntax-guide: hsla(230, 8%, 24%, 0.2);
|
||||
* --syntax-accent: hsl(230, 100%, 66%);
|
||||
* From syntax-variables.less
|
||||
* --syntax-selection-color: hsl(230, 1%, 90%);
|
||||
* --syntax-gutter-background-color-selected: hsl(230, 1%, 90%);
|
||||
* --syntax-cursor-line: hsla(230, 8%, 24%, 0.05);
|
||||
*/
|
||||
|
||||
.token.comment,
|
||||
.token.prolog,
|
||||
.token.cdata {
|
||||
color: hsl(230, 4%, 64%);
|
||||
}
|
||||
|
||||
.token.doctype,
|
||||
.token.punctuation,
|
||||
.token.entity {
|
||||
color: hsl(230, 8%, 24%);
|
||||
}
|
||||
|
||||
.token.attr-name,
|
||||
.token.class-name,
|
||||
.token.boolean,
|
||||
.token.constant,
|
||||
.token.number,
|
||||
.token.atrule {
|
||||
color: hsl(35, 99%, 36%);
|
||||
}
|
||||
|
||||
.token.keyword {
|
||||
color: hsl(301, 63%, 40%);
|
||||
}
|
||||
|
||||
.token.property,
|
||||
.token.tag,
|
||||
.token.symbol,
|
||||
.token.deleted,
|
||||
.token.important {
|
||||
color: hsl(5, 74%, 59%);
|
||||
}
|
||||
|
||||
.token.selector,
|
||||
.token.string,
|
||||
.token.char,
|
||||
.token.builtin,
|
||||
.token.inserted,
|
||||
.token.regex,
|
||||
.token.attr-value,
|
||||
.token.attr-value > .token.punctuation {
|
||||
color: hsl(119, 34%, 47%);
|
||||
}
|
||||
|
||||
.token.variable,
|
||||
.token.operator,
|
||||
.token.function {
|
||||
color: hsl(221, 87%, 60%);
|
||||
}
|
||||
|
||||
.token.url {
|
||||
color: hsl(198, 99%, 37%);
|
||||
}
|
||||
|
||||
/* HTML overrides */
|
||||
.token.attr-value > .token.punctuation.attr-equals,
|
||||
.token.special-attr > .token.attr-value > .token.value.css {
|
||||
color: hsl(230, 8%, 24%);
|
||||
}
|
||||
|
||||
/* CSS overrides */
|
||||
.language-css .token.selector {
|
||||
color: hsl(5, 74%, 59%);
|
||||
}
|
||||
|
||||
.language-css .token.property {
|
||||
color: hsl(230, 8%, 24%);
|
||||
}
|
||||
|
||||
.language-css .token.function,
|
||||
.language-css .token.url > .token.function {
|
||||
color: hsl(198, 99%, 37%);
|
||||
}
|
||||
|
||||
.language-css .token.url > .token.string.url {
|
||||
color: hsl(119, 34%, 47%);
|
||||
}
|
||||
|
||||
.language-css .token.important,
|
||||
.language-css .token.atrule .token.rule {
|
||||
color: hsl(301, 63%, 40%);
|
||||
}
|
||||
|
||||
/* JS overrides */
|
||||
.language-javascript .token.operator {
|
||||
color: hsl(301, 63%, 40%);
|
||||
}
|
||||
|
||||
.language-javascript
|
||||
.token.template-string
|
||||
> .token.interpolation
|
||||
> .token.interpolation-punctuation.punctuation {
|
||||
color: hsl(344, 84%, 43%);
|
||||
}
|
||||
|
||||
/* JSON overrides */
|
||||
.language-json .token.operator {
|
||||
color: hsl(230, 8%, 24%);
|
||||
}
|
||||
|
||||
.language-json .token.null.keyword {
|
||||
color: hsl(35, 99%, 36%);
|
||||
}
|
||||
|
||||
/* MD overrides */
|
||||
.language-markdown .token.url,
|
||||
.language-markdown .token.url > .token.operator,
|
||||
.language-markdown .token.url-reference.url > .token.string {
|
||||
color: hsl(230, 8%, 24%);
|
||||
}
|
||||
|
||||
.language-markdown .token.url > .token.content {
|
||||
color: hsl(221, 87%, 60%);
|
||||
}
|
||||
|
||||
.language-markdown .token.url > .token.url,
|
||||
.language-markdown .token.url-reference.url {
|
||||
color: hsl(198, 99%, 37%);
|
||||
}
|
||||
|
||||
.language-markdown .token.blockquote.punctuation,
|
||||
.language-markdown .token.hr.punctuation {
|
||||
color: hsl(230, 4%, 64%);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.language-markdown .token.code-snippet {
|
||||
color: hsl(119, 34%, 47%);
|
||||
}
|
||||
|
||||
.language-markdown .token.bold .token.content {
|
||||
color: hsl(35, 99%, 36%);
|
||||
}
|
||||
|
||||
.language-markdown .token.italic .token.content {
|
||||
color: hsl(301, 63%, 40%);
|
||||
}
|
||||
|
||||
.language-markdown .token.strike .token.content,
|
||||
.language-markdown .token.strike .token.punctuation,
|
||||
.language-markdown .token.list.punctuation,
|
||||
.language-markdown .token.title.important > .token.punctuation {
|
||||
color: hsl(5, 74%, 59%);
|
||||
}
|
||||
|
||||
/* General */
|
||||
.token.bold {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.token.comment,
|
||||
.token.italic {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.token.entity {
|
||||
cursor: help;
|
||||
}
|
||||
|
||||
.token.namespace {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
/* Plugin overrides */
|
||||
/* Selectors should have higher specificity than those in the plugins' default stylesheets */
|
||||
|
||||
/* Show Invisibles plugin overrides */
|
||||
.token.token.tab:not(:empty):before,
|
||||
.token.token.cr:before,
|
||||
.token.token.lf:before,
|
||||
.token.token.space:before {
|
||||
color: hsla(230, 8%, 24%, 0.2);
|
||||
}
|
||||
|
||||
/* Toolbar plugin overrides */
|
||||
/* Space out all buttons and move them away from the right edge of the code block */
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item {
|
||||
margin-right: 0.4em;
|
||||
}
|
||||
|
||||
/* Styling the buttons */
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button,
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a,
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span {
|
||||
background: hsl(230, 1%, 90%);
|
||||
color: hsl(230, 6%, 44%);
|
||||
padding: 0.1em 0.4em;
|
||||
border-radius: 0.3em;
|
||||
}
|
||||
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:hover,
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:focus,
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:hover,
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:focus,
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:hover,
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:focus {
|
||||
background: hsl(230, 1%, 78%); /* custom: darken(--syntax-bg, 20%) */
|
||||
color: hsl(230, 8%, 24%);
|
||||
}
|
||||
|
||||
/* Line Highlight plugin overrides */
|
||||
/* The highlighted line itself */
|
||||
.line-highlight.line-highlight {
|
||||
background: hsla(230, 8%, 24%, 0.05);
|
||||
}
|
||||
|
||||
/* Default line numbers in Line Highlight plugin */
|
||||
.line-highlight.line-highlight:before,
|
||||
.line-highlight.line-highlight[data-end]:after {
|
||||
background: hsl(230, 1%, 90%);
|
||||
color: hsl(230, 8%, 24%);
|
||||
padding: 0.1em 0.6em;
|
||||
border-radius: 0.3em;
|
||||
box-shadow: 0 2px 0 0 rgba(0, 0, 0, 0.2); /* same as Toolbar plugin default */
|
||||
}
|
||||
|
||||
/* Hovering over a linkable line number (in the gutter area) */
|
||||
/* Requires Line Numbers plugin as well */
|
||||
pre[id].linkable-line-numbers.linkable-line-numbers
|
||||
span.line-numbers-rows
|
||||
> span:hover:before {
|
||||
background-color: hsla(230, 8%, 24%, 0.05);
|
||||
}
|
||||
|
||||
/* Line Numbers and Command Line plugins overrides */
|
||||
/* Line separating gutter from coding area */
|
||||
.line-numbers.line-numbers .line-numbers-rows,
|
||||
.command-line .command-line-prompt {
|
||||
border-right-color: hsla(230, 8%, 24%, 0.2);
|
||||
}
|
||||
|
||||
/* Stuff in the gutter */
|
||||
.line-numbers .line-numbers-rows > span:before,
|
||||
.command-line .command-line-prompt > span:before {
|
||||
color: hsl(230, 1%, 62%);
|
||||
}
|
||||
|
||||
/* Match Braces plugin overrides */
|
||||
/* Note: Outline colour is inherited from the braces */
|
||||
.rainbow-braces .token.token.punctuation.brace-level-1,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-5,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-9 {
|
||||
color: hsl(5, 74%, 59%);
|
||||
}
|
||||
|
||||
.rainbow-braces .token.token.punctuation.brace-level-2,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-6,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-10 {
|
||||
color: hsl(119, 34%, 47%);
|
||||
}
|
||||
|
||||
.rainbow-braces .token.token.punctuation.brace-level-3,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-7,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-11 {
|
||||
color: hsl(221, 87%, 60%);
|
||||
}
|
||||
|
||||
.rainbow-braces .token.token.punctuation.brace-level-4,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-8,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-12 {
|
||||
color: hsl(301, 63%, 40%);
|
||||
}
|
||||
|
||||
/* Diff Highlight plugin overrides */
|
||||
/* Taken from https://github.com/atom/github/blob/master/styles/variables.less */
|
||||
pre.diff-highlight > code .token.token.deleted:not(.prefix),
|
||||
pre > code.diff-highlight .token.token.deleted:not(.prefix) {
|
||||
background-color: hsla(353, 100%, 66%, 0.15);
|
||||
}
|
||||
|
||||
pre.diff-highlight > code .token.token.deleted:not(.prefix)::-moz-selection,
|
||||
pre.diff-highlight
|
||||
> code
|
||||
.token.token.deleted:not(.prefix)
|
||||
*::-moz-selection,
|
||||
pre > code.diff-highlight .token.token.deleted:not(.prefix)::-moz-selection,
|
||||
pre
|
||||
> code.diff-highlight
|
||||
.token.token.deleted:not(.prefix)
|
||||
*::-moz-selection {
|
||||
background-color: hsla(353, 95%, 66%, 0.25);
|
||||
}
|
||||
|
||||
pre.diff-highlight > code .token.token.deleted:not(.prefix)::selection,
|
||||
pre.diff-highlight > code .token.token.deleted:not(.prefix) *::selection,
|
||||
pre > code.diff-highlight .token.token.deleted:not(.prefix)::selection,
|
||||
pre > code.diff-highlight .token.token.deleted:not(.prefix) *::selection {
|
||||
background-color: hsla(353, 95%, 66%, 0.25);
|
||||
}
|
||||
|
||||
pre.diff-highlight > code .token.token.inserted:not(.prefix),
|
||||
pre > code.diff-highlight .token.token.inserted:not(.prefix) {
|
||||
background-color: hsla(137, 100%, 55%, 0.15);
|
||||
}
|
||||
|
||||
pre.diff-highlight
|
||||
> code
|
||||
.token.token.inserted:not(.prefix)::-moz-selection,
|
||||
pre.diff-highlight
|
||||
> code
|
||||
.token.token.inserted:not(.prefix)
|
||||
*::-moz-selection,
|
||||
pre
|
||||
> code.diff-highlight
|
||||
.token.token.inserted:not(.prefix)::-moz-selection,
|
||||
pre
|
||||
> code.diff-highlight
|
||||
.token.token.inserted:not(.prefix)
|
||||
*::-moz-selection {
|
||||
background-color: hsla(135, 73%, 55%, 0.25);
|
||||
}
|
||||
|
||||
pre.diff-highlight > code .token.token.inserted:not(.prefix)::selection,
|
||||
pre.diff-highlight > code .token.token.inserted:not(.prefix) *::selection,
|
||||
pre > code.diff-highlight .token.token.inserted:not(.prefix)::selection,
|
||||
pre > code.diff-highlight .token.token.inserted:not(.prefix) *::selection {
|
||||
background-color: hsla(135, 73%, 55%, 0.25);
|
||||
}
|
||||
|
||||
/* Previewers plugin overrides */
|
||||
/* Based on https://github.com/atom-community/atom-ide-datatip/blob/master/styles/atom-ide-datatips.less and https://github.com/atom/atom/blob/master/packages/one-light-ui */
|
||||
/* Border around popup */
|
||||
.prism-previewer.prism-previewer:before,
|
||||
.prism-previewer-gradient.prism-previewer-gradient div {
|
||||
border-color: hsl(0, 0, 95%);
|
||||
}
|
||||
|
||||
/* Angle and time should remain as circles and are hence not included */
|
||||
.prism-previewer-color.prism-previewer-color:before,
|
||||
.prism-previewer-gradient.prism-previewer-gradient div,
|
||||
.prism-previewer-easing.prism-previewer-easing:before {
|
||||
border-radius: 0.3em;
|
||||
}
|
||||
|
||||
/* Triangles pointing to the code */
|
||||
.prism-previewer.prism-previewer:after {
|
||||
border-top-color: hsl(0, 0, 95%);
|
||||
}
|
||||
|
||||
.prism-previewer-flipped.prism-previewer-flipped.after {
|
||||
border-bottom-color: hsl(0, 0, 95%);
|
||||
}
|
||||
|
||||
/* Background colour within the popup */
|
||||
.prism-previewer-angle.prism-previewer-angle:before,
|
||||
.prism-previewer-time.prism-previewer-time:before,
|
||||
.prism-previewer-easing.prism-previewer-easing {
|
||||
background: hsl(0, 0%, 100%);
|
||||
}
|
||||
|
||||
/* For angle, this is the positive area (eg. 90deg will display one quadrant in this colour) */
|
||||
/* For time, this is the alternate colour */
|
||||
.prism-previewer-angle.prism-previewer-angle circle,
|
||||
.prism-previewer-time.prism-previewer-time circle {
|
||||
stroke: hsl(230, 8%, 24%);
|
||||
stroke-opacity: 1;
|
||||
}
|
||||
|
||||
/* Stroke colours of the handle, direction point, and vector itself */
|
||||
.prism-previewer-easing.prism-previewer-easing circle,
|
||||
.prism-previewer-easing.prism-previewer-easing path,
|
||||
.prism-previewer-easing.prism-previewer-easing line {
|
||||
stroke: hsl(230, 8%, 24%);
|
||||
}
|
||||
|
||||
/* Fill colour of the handle */
|
||||
.prism-previewer-easing.prism-previewer-easing circle {
|
||||
fill: transparent;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
.prose {
|
||||
.token.comment,
|
||||
.token.prolog,
|
||||
.token.cdata {
|
||||
color: hsl(220, 10%, 40%);
|
||||
}
|
||||
|
||||
.token.doctype,
|
||||
.token.punctuation,
|
||||
.token.entity {
|
||||
color: hsl(220, 14%, 71%);
|
||||
}
|
||||
|
||||
.token.attr-name,
|
||||
.token.class-name,
|
||||
.token.boolean,
|
||||
.token.constant,
|
||||
.token.number,
|
||||
.token.atrule {
|
||||
color: hsl(29, 54%, 61%);
|
||||
}
|
||||
|
||||
.token.keyword {
|
||||
color: hsl(286, 60%, 67%);
|
||||
}
|
||||
|
||||
.token.property,
|
||||
.token.tag,
|
||||
.token.symbol,
|
||||
.token.deleted,
|
||||
.token.important {
|
||||
color: hsl(355, 65%, 65%);
|
||||
}
|
||||
|
||||
.token.selector,
|
||||
.token.string,
|
||||
.token.char,
|
||||
.token.builtin,
|
||||
.token.inserted,
|
||||
.token.regex,
|
||||
.token.attr-value,
|
||||
.token.attr-value > .token.punctuation {
|
||||
color: hsl(95, 38%, 62%);
|
||||
}
|
||||
|
||||
.token.variable,
|
||||
.token.operator,
|
||||
.token.function {
|
||||
color: hsl(207, 82%, 66%);
|
||||
}
|
||||
|
||||
.token.url {
|
||||
color: hsl(187, 47%, 55%);
|
||||
}
|
||||
|
||||
/* HTML overrides */
|
||||
.token.attr-value > .token.punctuation.attr-equals,
|
||||
.token.special-attr > .token.attr-value > .token.value.css {
|
||||
color: hsl(220, 14%, 71%);
|
||||
}
|
||||
|
||||
/* CSS overrides */
|
||||
.language-css .token.selector {
|
||||
color: hsl(355, 65%, 65%);
|
||||
}
|
||||
|
||||
.language-css .token.property {
|
||||
color: hsl(220, 14%, 71%);
|
||||
}
|
||||
|
||||
.language-css .token.function,
|
||||
.language-css .token.url > .token.function {
|
||||
color: hsl(187, 47%, 55%);
|
||||
}
|
||||
|
||||
.language-css .token.url > .token.string.url {
|
||||
color: hsl(95, 38%, 62%);
|
||||
}
|
||||
|
||||
.language-css .token.important,
|
||||
.language-css .token.atrule .token.rule {
|
||||
color: hsl(286, 60%, 67%);
|
||||
}
|
||||
|
||||
/* JS overrides */
|
||||
.language-javascript .token.operator {
|
||||
color: hsl(286, 60%, 67%);
|
||||
}
|
||||
|
||||
.language-javascript
|
||||
.token.template-string
|
||||
> .token.interpolation
|
||||
> .token.interpolation-punctuation.punctuation {
|
||||
color: hsl(5, 48%, 51%);
|
||||
}
|
||||
|
||||
/* JSON overrides */
|
||||
.language-json .token.operator {
|
||||
color: hsl(220, 14%, 71%);
|
||||
}
|
||||
|
||||
.language-json .token.null.keyword {
|
||||
color: hsl(29, 54%, 61%);
|
||||
}
|
||||
|
||||
/* MD overrides */
|
||||
.language-markdown .token.url,
|
||||
.language-markdown .token.url > .token.operator,
|
||||
.language-markdown .token.url-reference.url > .token.string {
|
||||
color: hsl(220, 14%, 71%);
|
||||
}
|
||||
|
||||
.language-markdown .token.url > .token.content {
|
||||
color: hsl(207, 82%, 66%);
|
||||
}
|
||||
|
||||
.language-markdown .token.url > .token.url,
|
||||
.language-markdown .token.url-reference.url {
|
||||
color: hsl(187, 47%, 55%);
|
||||
}
|
||||
|
||||
.language-markdown .token.blockquote.punctuation,
|
||||
.language-markdown .token.hr.punctuation {
|
||||
color: hsl(220, 10%, 40%);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.language-markdown .token.code-snippet {
|
||||
color: hsl(95, 38%, 62%);
|
||||
}
|
||||
|
||||
.language-markdown .token.bold .token.content {
|
||||
color: hsl(29, 54%, 61%);
|
||||
}
|
||||
|
||||
.language-markdown .token.italic .token.content {
|
||||
color: hsl(286, 60%, 67%);
|
||||
}
|
||||
|
||||
.language-markdown .token.strike .token.content,
|
||||
.language-markdown .token.strike .token.punctuation,
|
||||
.language-markdown .token.list.punctuation,
|
||||
.language-markdown .token.title.important > .token.punctuation {
|
||||
color: hsl(355, 65%, 65%);
|
||||
}
|
||||
|
||||
/* General */
|
||||
.token.bold {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.token.comment,
|
||||
.token.italic {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.token.entity {
|
||||
cursor: help;
|
||||
}
|
||||
|
||||
.token.namespace {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
/* Plugin overrides */
|
||||
/* Selectors should have higher specificity than those in the plugins' default stylesheets */
|
||||
|
||||
/* Show Invisibles plugin overrides */
|
||||
.token.token.tab:not(:empty):before,
|
||||
.token.token.cr:before,
|
||||
.token.token.lf:before,
|
||||
.token.token.space:before {
|
||||
color: hsla(220, 14%, 71%, 0.15);
|
||||
text-shadow: none;
|
||||
}
|
||||
|
||||
/* Toolbar plugin overrides */
|
||||
/* Space out all buttons and move them away from the right edge of the code block */
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item {
|
||||
margin-right: 0.4em;
|
||||
}
|
||||
|
||||
/* Styling the buttons */
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button,
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a,
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span {
|
||||
background: hsl(220, 13%, 26%);
|
||||
color: hsl(220, 9%, 55%);
|
||||
padding: 0.1em 0.4em;
|
||||
border-radius: 0.3em;
|
||||
}
|
||||
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:hover,
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:focus,
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:hover,
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:focus,
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:hover,
|
||||
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:focus {
|
||||
background: hsl(220, 13%, 28%);
|
||||
color: hsl(220, 14%, 71%);
|
||||
}
|
||||
|
||||
/* Line Highlight plugin overrides */
|
||||
/* The highlighted line itself */
|
||||
.line-highlight.line-highlight {
|
||||
background: hsla(220, 100%, 80%, 0.04);
|
||||
}
|
||||
|
||||
/* Default line numbers in Line Highlight plugin */
|
||||
.line-highlight.line-highlight:before,
|
||||
.line-highlight.line-highlight[data-end]:after {
|
||||
background: hsl(220, 13%, 26%);
|
||||
color: hsl(220, 14%, 71%);
|
||||
padding: 0.1em 0.6em;
|
||||
border-radius: 0.3em;
|
||||
box-shadow: 0 2px 0 0 rgba(0, 0, 0, 0.2); /* same as Toolbar plugin default */
|
||||
}
|
||||
|
||||
/* Hovering over a linkable line number (in the gutter area) */
|
||||
/* Requires Line Numbers plugin as well */
|
||||
pre[id].linkable-line-numbers.linkable-line-numbers
|
||||
span.line-numbers-rows
|
||||
> span:hover:before {
|
||||
background-color: hsla(220, 100%, 80%, 0.04);
|
||||
}
|
||||
|
||||
/* Line Numbers and Command Line plugins overrides */
|
||||
/* Line separating gutter from coding area */
|
||||
.line-numbers.line-numbers .line-numbers-rows,
|
||||
.command-line .command-line-prompt {
|
||||
border-right-color: hsla(220, 14%, 71%, 0.15);
|
||||
}
|
||||
|
||||
/* Stuff in the gutter */
|
||||
.line-numbers .line-numbers-rows > span:before,
|
||||
.command-line .command-line-prompt > span:before {
|
||||
color: hsl(220, 14%, 45%);
|
||||
}
|
||||
|
||||
/* Match Braces plugin overrides */
|
||||
/* Note: Outline colour is inherited from the braces */
|
||||
.rainbow-braces .token.token.punctuation.brace-level-1,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-5,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-9 {
|
||||
color: hsl(355, 65%, 65%);
|
||||
}
|
||||
|
||||
.rainbow-braces .token.token.punctuation.brace-level-2,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-6,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-10 {
|
||||
color: hsl(95, 38%, 62%);
|
||||
}
|
||||
|
||||
.rainbow-braces .token.token.punctuation.brace-level-3,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-7,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-11 {
|
||||
color: hsl(207, 82%, 66%);
|
||||
}
|
||||
|
||||
.rainbow-braces .token.token.punctuation.brace-level-4,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-8,
|
||||
.rainbow-braces .token.token.punctuation.brace-level-12 {
|
||||
color: hsl(286, 60%, 67%);
|
||||
}
|
||||
|
||||
/* Diff Highlight plugin overrides */
|
||||
/* Taken from https://github.com/atom/github/blob/master/styles/variables.less */
|
||||
pre.diff-highlight > code .token.token.deleted:not(.prefix),
|
||||
pre > code.diff-highlight .token.token.deleted:not(.prefix) {
|
||||
background-color: hsla(353, 100%, 66%, 0.15);
|
||||
}
|
||||
|
||||
pre.diff-highlight > code .token.token.deleted:not(.prefix)::-moz-selection,
|
||||
pre.diff-highlight
|
||||
> code
|
||||
.token.token.deleted:not(.prefix)
|
||||
*::-moz-selection,
|
||||
pre > code.diff-highlight .token.token.deleted:not(.prefix)::-moz-selection,
|
||||
pre
|
||||
> code.diff-highlight
|
||||
.token.token.deleted:not(.prefix)
|
||||
*::-moz-selection {
|
||||
background-color: hsla(353, 95%, 66%, 0.25);
|
||||
}
|
||||
|
||||
pre.diff-highlight > code .token.token.deleted:not(.prefix)::selection,
|
||||
pre.diff-highlight > code .token.token.deleted:not(.prefix) *::selection,
|
||||
pre > code.diff-highlight .token.token.deleted:not(.prefix)::selection,
|
||||
pre > code.diff-highlight .token.token.deleted:not(.prefix) *::selection {
|
||||
background-color: hsla(353, 95%, 66%, 0.25);
|
||||
}
|
||||
|
||||
pre.diff-highlight > code .token.token.inserted:not(.prefix),
|
||||
pre > code.diff-highlight .token.token.inserted:not(.prefix) {
|
||||
background-color: hsla(137, 100%, 55%, 0.15);
|
||||
}
|
||||
|
||||
pre.diff-highlight
|
||||
> code
|
||||
.token.token.inserted:not(.prefix)::-moz-selection,
|
||||
pre.diff-highlight
|
||||
> code
|
||||
.token.token.inserted:not(.prefix)
|
||||
*::-moz-selection,
|
||||
pre
|
||||
> code.diff-highlight
|
||||
.token.token.inserted:not(.prefix)::-moz-selection,
|
||||
pre
|
||||
> code.diff-highlight
|
||||
.token.token.inserted:not(.prefix)
|
||||
*::-moz-selection {
|
||||
background-color: hsla(135, 73%, 55%, 0.25);
|
||||
}
|
||||
|
||||
pre.diff-highlight > code .token.token.inserted:not(.prefix)::selection,
|
||||
pre.diff-highlight > code .token.token.inserted:not(.prefix) *::selection,
|
||||
pre > code.diff-highlight .token.token.inserted:not(.prefix)::selection,
|
||||
pre > code.diff-highlight .token.token.inserted:not(.prefix) *::selection {
|
||||
background-color: hsla(135, 73%, 55%, 0.25);
|
||||
}
|
||||
|
||||
/* Previewers plugin overrides */
|
||||
/* Based on https://github.com/atom-community/atom-ide-datatip/blob/master/styles/atom-ide-datatips.less and https://github.com/atom/atom/blob/master/packages/one-dark-ui */
|
||||
/* Border around popup */
|
||||
.prism-previewer.prism-previewer:before,
|
||||
.prism-previewer-gradient.prism-previewer-gradient div {
|
||||
border-color: hsl(224, 13%, 17%);
|
||||
}
|
||||
|
||||
/* Angle and time should remain as circles and are hence not included */
|
||||
.prism-previewer-color.prism-previewer-color:before,
|
||||
.prism-previewer-gradient.prism-previewer-gradient div,
|
||||
.prism-previewer-easing.prism-previewer-easing:before {
|
||||
border-radius: 0.3em;
|
||||
}
|
||||
|
||||
/* Triangles pointing to the code */
|
||||
.prism-previewer.prism-previewer:after {
|
||||
border-top-color: hsl(224, 13%, 17%);
|
||||
}
|
||||
|
||||
.prism-previewer-flipped.prism-previewer-flipped.after {
|
||||
border-bottom-color: hsl(224, 13%, 17%);
|
||||
}
|
||||
|
||||
/* Background colour within the popup */
|
||||
.prism-previewer-angle.prism-previewer-angle:before,
|
||||
.prism-previewer-time.prism-previewer-time:before,
|
||||
.prism-previewer-easing.prism-previewer-easing {
|
||||
background: hsl(219, 13%, 22%);
|
||||
}
|
||||
|
||||
/* For angle, this is the positive area (eg. 90deg will display one quadrant in this colour) */
|
||||
/* For time, this is the alternate colour */
|
||||
.prism-previewer-angle.prism-previewer-angle circle,
|
||||
.prism-previewer-time.prism-previewer-time circle {
|
||||
stroke: hsl(220, 14%, 71%);
|
||||
stroke-opacity: 1;
|
||||
}
|
||||
|
||||
/* Stroke colours of the handle, direction point, and vector itself */
|
||||
.prism-previewer-easing.prism-previewer-easing circle,
|
||||
.prism-previewer-easing.prism-previewer-easing path,
|
||||
.prism-previewer-easing.prism-previewer-easing line {
|
||||
stroke: hsl(220, 14%, 71%);
|
||||
}
|
||||
|
||||
/* Fill colour of the handle */
|
||||
.prism-previewer-easing.prism-previewer-easing circle {
|
||||
fill: transparent;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.prose pre {
|
||||
contain: layout style;
|
||||
}
|
||||
|
||||
/* Or more aggressively */
|
||||
.prose pre code {
|
||||
contain: layout style paint;
|
||||
}
|
||||
|
||||
/* messaging-style typing indicator animation */
|
||||
@keyframes typing {
|
||||
|
||||
156
app/ui/app/src/lib/highlighter.ts
Normal file
@@ -0,0 +1,156 @@
|
||||
import { createHighlighter } from "shiki";
|
||||
import type { ThemeRegistration } from "shiki";
|
||||
|
||||
const oneLightTheme: ThemeRegistration = {
|
||||
name: "one-light",
|
||||
type: "light",
|
||||
colors: {
|
||||
"editor.background": "#fafafa",
|
||||
"editor.foreground": "#383a42",
|
||||
},
|
||||
tokenColors: [
|
||||
{
|
||||
scope: ["comment", "punctuation.definition.comment"],
|
||||
settings: { foreground: "#a0a1a7" },
|
||||
},
|
||||
{
|
||||
scope: ["keyword", "storage.type", "storage.modifier"],
|
||||
settings: { foreground: "#a626a4" },
|
||||
},
|
||||
{ scope: ["string", "string.quoted"], settings: { foreground: "#50a14f" } },
|
||||
{
|
||||
scope: ["function", "entity.name.function", "support.function"],
|
||||
settings: { foreground: "#4078f2" },
|
||||
},
|
||||
{
|
||||
scope: [
|
||||
"constant.numeric",
|
||||
"constant.language",
|
||||
"constant.character",
|
||||
"number",
|
||||
],
|
||||
settings: { foreground: "#c18401" },
|
||||
},
|
||||
{
|
||||
scope: ["variable", "support.variable"],
|
||||
settings: { foreground: "#e45649" },
|
||||
},
|
||||
{
|
||||
scope: ["entity.name.tag", "entity.name.type", "entity.name.class"],
|
||||
settings: { foreground: "#e45649" },
|
||||
},
|
||||
{
|
||||
scope: ["entity.other.attribute-name"],
|
||||
settings: { foreground: "#c18401" },
|
||||
},
|
||||
{
|
||||
scope: ["keyword.operator", "operator"],
|
||||
settings: { foreground: "#a626a4" },
|
||||
},
|
||||
{ scope: ["punctuation"], settings: { foreground: "#383a42" } },
|
||||
{
|
||||
scope: ["markup.heading"],
|
||||
settings: { foreground: "#e45649", fontStyle: "bold" },
|
||||
},
|
||||
{
|
||||
scope: ["markup.bold"],
|
||||
settings: { foreground: "#c18401", fontStyle: "bold" },
|
||||
},
|
||||
{
|
||||
scope: ["markup.italic"],
|
||||
settings: { foreground: "#a626a4", fontStyle: "italic" },
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const oneDarkTheme: ThemeRegistration = {
|
||||
name: "one-dark",
|
||||
type: "dark",
|
||||
colors: {
|
||||
"editor.background": "#282c34",
|
||||
"editor.foreground": "#abb2bf",
|
||||
},
|
||||
tokenColors: [
|
||||
{
|
||||
scope: ["comment", "punctuation.definition.comment"],
|
||||
settings: { foreground: "#5c6370" },
|
||||
},
|
||||
{
|
||||
scope: ["keyword", "storage.type", "storage.modifier"],
|
||||
settings: { foreground: "#c678dd" },
|
||||
},
|
||||
{ scope: ["string", "string.quoted"], settings: { foreground: "#98c379" } },
|
||||
{
|
||||
scope: ["function", "entity.name.function", "support.function"],
|
||||
settings: { foreground: "#61afef" },
|
||||
},
|
||||
{
|
||||
scope: [
|
||||
"constant.numeric",
|
||||
"constant.language",
|
||||
"constant.character",
|
||||
"number",
|
||||
],
|
||||
settings: { foreground: "#d19a66" },
|
||||
},
|
||||
{
|
||||
scope: ["variable", "support.variable"],
|
||||
settings: { foreground: "#e06c75" },
|
||||
},
|
||||
{
|
||||
scope: ["entity.name.tag", "entity.name.type", "entity.name.class"],
|
||||
settings: { foreground: "#e06c75" },
|
||||
},
|
||||
{
|
||||
scope: ["entity.other.attribute-name"],
|
||||
settings: { foreground: "#d19a66" },
|
||||
},
|
||||
{
|
||||
scope: ["keyword.operator", "operator"],
|
||||
settings: { foreground: "#c678dd" },
|
||||
},
|
||||
{ scope: ["punctuation"], settings: { foreground: "#abb2bf" } },
|
||||
{
|
||||
scope: ["markup.heading"],
|
||||
settings: { foreground: "#e06c75", fontStyle: "bold" },
|
||||
},
|
||||
{
|
||||
scope: ["markup.bold"],
|
||||
settings: { foreground: "#d19a66", fontStyle: "bold" },
|
||||
},
|
||||
{
|
||||
scope: ["markup.italic"],
|
||||
settings: { foreground: "#c678dd", fontStyle: "italic" },
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
export let highlighter: Awaited<ReturnType<typeof createHighlighter>> | null =
|
||||
null;
|
||||
|
||||
export const highlighterPromise = createHighlighter({
|
||||
themes: [oneLightTheme, oneDarkTheme],
|
||||
langs: [
|
||||
"javascript",
|
||||
"typescript",
|
||||
"python",
|
||||
"bash",
|
||||
"shell",
|
||||
"json",
|
||||
"html",
|
||||
"css",
|
||||
"tsx",
|
||||
"jsx",
|
||||
"go",
|
||||
"rust",
|
||||
"java",
|
||||
"c",
|
||||
"cpp",
|
||||
"sql",
|
||||
"yaml",
|
||||
"markdown",
|
||||
],
|
||||
}).then((h) => {
|
||||
highlighter = h;
|
||||
return h;
|
||||
});
|
||||
@@ -1,24 +0,0 @@
|
||||
import { remark } from "remark";
|
||||
import remarkStringify from "remark-stringify";
|
||||
import remarkStreamingMarkdown from "./remarkStreamingMarkdown";
|
||||
|
||||
/**
|
||||
* Process markdown content for streaming display using the remark plugin.
|
||||
* This is primarily used for testing the remark plugin with string inputs/outputs.
|
||||
*/
|
||||
export function processStreamingMarkdown(content: string): string {
|
||||
if (!content) return content;
|
||||
|
||||
const result = remark()
|
||||
.use(remarkStreamingMarkdown, { debug: false })
|
||||
.use(remarkStringify)
|
||||
.processSync(content);
|
||||
|
||||
// remove trailing newline to keep tests cleaner
|
||||
let output = result.toString();
|
||||
if (output.endsWith("\n")) {
|
||||
output = output.slice(0, -1);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
@@ -1,447 +0,0 @@
|
||||
import { parents, type Proxy } from "unist-util-parents";
|
||||
import type { Plugin } from "unified";
|
||||
import type {
|
||||
Emphasis,
|
||||
Node,
|
||||
Parent,
|
||||
Root,
|
||||
RootContent,
|
||||
Text,
|
||||
Strong,
|
||||
PhrasingContent,
|
||||
Paragraph,
|
||||
} from "mdast";
|
||||
import { u } from "unist-builder";
|
||||
|
||||
declare module "unist" {
|
||||
interface Node {
|
||||
/** Added by `unist-util-parents` (or your own walk). */
|
||||
parent?: Proxy & Parent;
|
||||
}
|
||||
}
|
||||
|
||||
// interface SimpleTextRule {
|
||||
// pattern: RegExp;
|
||||
// transform: (matches: RegExpExecArray[], lastNode: Proxy) => void;
|
||||
// }
|
||||
|
||||
// const simpleTextRules: SimpleTextRule[] = [
|
||||
// // TODO(drifkin): generalize this for `__`/`_`/`~~`/`~` etc.
|
||||
// {
|
||||
// pattern: /(\*\*)(?=\S|$)/g,
|
||||
// transform: (matchesIterator, lastNode) => {
|
||||
// const textNode = lastNode.node as Text;
|
||||
|
||||
// const matches = [...matchesIterator];
|
||||
// const lastMatch = matches[matches.length - 1];
|
||||
// const origValue = textNode.value;
|
||||
// const start = lastMatch.index;
|
||||
// const sep = lastMatch[1];
|
||||
|
||||
// const before = origValue.slice(0, start);
|
||||
// const after = origValue.slice(start + sep.length);
|
||||
|
||||
// if (lastNode.parent) {
|
||||
// const index = (lastNode.parent.node as Parent).children.indexOf(
|
||||
// lastNode.node as RootContent,
|
||||
// );
|
||||
// const shouldRemove = before.length === 0;
|
||||
// if (!shouldRemove) {
|
||||
// textNode.value = before;
|
||||
// }
|
||||
|
||||
// const newNode = u("strong", {
|
||||
// children: [u("text", { value: after })],
|
||||
// });
|
||||
// (lastNode.parent.node as Parent).children.splice(
|
||||
// index + (shouldRemove ? 0 : 1),
|
||||
// shouldRemove ? 1 : 0,
|
||||
// newNode,
|
||||
// );
|
||||
// }
|
||||
// },
|
||||
// },
|
||||
// ];
|
||||
|
||||
interface Options {
|
||||
debug?: boolean;
|
||||
onLastNode?: (info: LastNodeInfo) => void;
|
||||
}
|
||||
|
||||
export interface LastNodeInfo {
|
||||
path: string[];
|
||||
type: string;
|
||||
value?: string;
|
||||
lastChars?: string;
|
||||
fullNode: Node;
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes `child` from `parent` in-place.
|
||||
* @returns `true` if the child was found and removed; `false` otherwise.
|
||||
*/
|
||||
export function removeChildFromParent(
|
||||
child: RootContent,
|
||||
parent: Node,
|
||||
): boolean {
|
||||
if (!isParent(parent)) return false; // parent isn’t a Parent → nothing to do
|
||||
|
||||
const idx = parent.children.indexOf(child);
|
||||
if (idx < 0) return false; // not a child → nothing to remove
|
||||
|
||||
parent.children.splice(idx, 1);
|
||||
return true; // removal successful
|
||||
}
|
||||
|
||||
/** Narrow a generic `Node` to a `Parent` (i.e. one that really has children). */
|
||||
function isParent(node: Node): node is Parent {
|
||||
// A `Parent` always has a `children` array; make sure it's an array first.
|
||||
return Array.isArray((node as Partial<Parent>).children);
|
||||
}
|
||||
|
||||
/**
|
||||
* Follow “last-child” pointers until you reach a leaf.
|
||||
* Returns the right-most, deepest node in source order.
|
||||
*/
|
||||
export function findRightmostDeepestNode(root: Node): Node {
|
||||
let current: Node = root;
|
||||
|
||||
// While the current node *is* a Parent and has at least one child…
|
||||
while (isParent(current) && current.children.length > 0) {
|
||||
const lastIndex = current.children.length - 1;
|
||||
current = current.children[lastIndex];
|
||||
}
|
||||
|
||||
return current; // Leaf: no further children
|
||||
}
|
||||
|
||||
const remarkStreamingMarkdown: Plugin<[Options?], Root> = () => {
|
||||
return (tree) => {
|
||||
const treeWithParents = parents(tree);
|
||||
const lastNode = findRightmostDeepestNode(treeWithParents) as Proxy;
|
||||
|
||||
const parentNode = lastNode.parent;
|
||||
const grandparentNode = parentNode?.parent;
|
||||
|
||||
let ruleMatched = false;
|
||||
|
||||
// handling `* *` -> ``
|
||||
//
|
||||
// if the last node is part of a <list item (otherwise empty)> ->
|
||||
// <list (otherwise empty)> -> <list item (last node, empty)>, then we need to
|
||||
// remove everything up to and including the first list item. This happens
|
||||
// when we have `* *`, which can become a bolded list item OR a horizontal
|
||||
// line
|
||||
if (
|
||||
lastNode.type === "listItem" &&
|
||||
parentNode &&
|
||||
grandparentNode &&
|
||||
parentNode.type === "list" &&
|
||||
grandparentNode.type === "listItem" &&
|
||||
parentNode.children.length === 1 &&
|
||||
grandparentNode.children.length === 1
|
||||
) {
|
||||
ruleMatched = true;
|
||||
if (grandparentNode.parent) {
|
||||
removeChildFromParent(
|
||||
grandparentNode.node as RootContent,
|
||||
grandparentNode.parent.node,
|
||||
);
|
||||
}
|
||||
// Handle `*` -> ``:
|
||||
//
|
||||
// if the last node is just an empty list item, we need to remove it
|
||||
// because it could become something else (e.g., a horizontal line)
|
||||
} else if (
|
||||
lastNode.type === "listItem" &&
|
||||
parentNode &&
|
||||
parentNode.type === "list"
|
||||
) {
|
||||
ruleMatched = true;
|
||||
removeChildFromParent(lastNode.node as RootContent, parentNode.node);
|
||||
} else if (lastNode.type === "thematicBreak") {
|
||||
ruleMatched = true;
|
||||
const parent = lastNode.parent;
|
||||
if (parent) {
|
||||
removeChildFromParent(lastNode.node as RootContent, parent.node);
|
||||
}
|
||||
} else if (lastNode.type === "text") {
|
||||
const textNode = lastNode.node as Text;
|
||||
if (textNode.value.endsWith("**")) {
|
||||
ruleMatched = true;
|
||||
textNode.value = textNode.value.slice(0, -2);
|
||||
// if there's a newline then a number, this is very very likely a
|
||||
// numbered list item. Let's just hide it until the period comes (or
|
||||
// other text disambiguates it)
|
||||
} else {
|
||||
const match = textNode.value.match(/^([0-9]+)$/m);
|
||||
if (match) {
|
||||
const number = match[1];
|
||||
textNode.value = textNode.value.slice(0, -number.length - 1);
|
||||
ruleMatched = true;
|
||||
// if the text node is now empty, then we might want to remove other
|
||||
// elements, like a now-empty containing paragraph, or a break that
|
||||
// might disappear once more tokens come in
|
||||
if (textNode.value.length === 0) {
|
||||
if (
|
||||
lastNode.parent?.type === "paragraph" &&
|
||||
lastNode.parent.children.length === 1
|
||||
) {
|
||||
// remove the whole paragraph if it's now empty (otherwise it'll
|
||||
// cause an extra newline that might not last)
|
||||
removeChildFromParent(
|
||||
lastNode.parent.node as Paragraph,
|
||||
lastNode.parent.parent?.node as Node,
|
||||
);
|
||||
} else {
|
||||
const prev = prevSibling(lastNode);
|
||||
if (prev?.type === "break") {
|
||||
removeChildFromParent(
|
||||
prev.node as RootContent,
|
||||
lastNode.parent?.node as Node,
|
||||
);
|
||||
removeChildFromParent(
|
||||
lastNode.node as RootContent,
|
||||
lastNode.parent?.node as Node,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ruleMatched) {
|
||||
return tree;
|
||||
}
|
||||
|
||||
// we need to
|
||||
// a case like
|
||||
// - *def `abc` [abc **def**](abc)*
|
||||
// is pretty tricky, because if we land just after def, then we actually
|
||||
// have two separate tags to process at two different parents. Maybe we
|
||||
// need to keep iterating up until we find a paragraph, but process each
|
||||
// parent on the way up. Hmm, well actually after `def` we won't even be a proper link yet
|
||||
// TODO(drifkin): it's really if the last node's parent is a paragraph, for which the following is a sub-cas where the lastNode is a text node.
|
||||
// And instead of just processing simple text rules, they need to operate on the whole paragraph
|
||||
// like `**[abc](def)` needs to become `**[abc](def)**`
|
||||
|
||||
// if we're just text at the end, then we should remove some ambiguous characters
|
||||
|
||||
if (lastNode.parent) {
|
||||
const didChange = processParent(lastNode.parent as Parent & Proxy);
|
||||
if (didChange) {
|
||||
// TODO(drifkin): need to fix up the tree, but not sure lastNode will still exist? Check all the transforms to see if it's safe to find the last node again
|
||||
//
|
||||
// need to regen the tree w/ parents since reparenting could've happened
|
||||
// treeWithParents = parents(tree);
|
||||
}
|
||||
}
|
||||
|
||||
const grandparent = lastNode.parent?.parent;
|
||||
// TODO(drifkin): let's go arbitrarily high up the tree, but limiting it
|
||||
// to 2 levels for now until I think more about the stop condition
|
||||
if (grandparent) {
|
||||
processParent(grandparent as Parent & Proxy);
|
||||
}
|
||||
|
||||
// console.log("ruleMatched", ruleMatched);
|
||||
|
||||
// } else if (lastNode.parent?.type === "paragraph") {
|
||||
// console.log("!!! paragraph");
|
||||
// console.log("lastNode.parent", lastNode.parent);
|
||||
|
||||
// // Handle `**abc*` -> `**abc**`:
|
||||
// // We detect this when the last child is an emphasis node, and it's preceded by a text node that ends with `*`
|
||||
// const paragraph = lastNode.parent as Proxy & Paragraph;
|
||||
// if (paragraph.children.length >= 2) {
|
||||
// const lastChild = paragraph.children[paragraph.children.length - 1];
|
||||
// if (lastChild.type === "emphasis") {
|
||||
// const sibling = paragraph.children[paragraph.children.length - 2];
|
||||
// if (sibling.type === "text") {
|
||||
// const siblingText = sibling as Text & Proxy;
|
||||
// if (siblingText.value.endsWith("*")) {
|
||||
// ruleMatched = true;
|
||||
// const textNode = (lastNode as Proxy).node as Text;
|
||||
// textNode.value = textNode.value.slice(0, -1);
|
||||
// paragraph.node.type = "strong";
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// } else if (lastNode.type === "text") {
|
||||
// // Handle `**abc*` -> `**abc**`:
|
||||
// //
|
||||
// // this gets parsed as a text node ending in `*` followed by an emphasis
|
||||
// // node. So if we're in text, we need to check if our parent is emphasis,
|
||||
// // and then get our parent's sibling before it and check if it ends with
|
||||
// // `*`
|
||||
// const parent = lastNode.parent;
|
||||
// if (parent && parent.type === "emphasis") {
|
||||
// const grandparent = parent.parent;
|
||||
// if (grandparent) {
|
||||
// const index = (grandparent.node as Parent).children.indexOf(
|
||||
// parent.node as RootContent,
|
||||
// );
|
||||
// if (index > 0) {
|
||||
// const prevNode = grandparent.children[index - 1];
|
||||
// if (
|
||||
// prevNode.type === "text" &&
|
||||
// (prevNode as Text).value.endsWith("*")
|
||||
// ) {
|
||||
// ruleMatched = true;
|
||||
// const textNode = (prevNode as Proxy).node as Text;
|
||||
// textNode.value = textNode.value.slice(0, -1);
|
||||
// parent.node.type = "strong";
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// if (!ruleMatched) {
|
||||
// // if the last node is just text, then we process it in order to fix up certain unclosed items
|
||||
// // e.g., `**abc` -> `**abc**`
|
||||
// const textNode = lastNode.node as Text;
|
||||
// for (const rule of simpleTextRules) {
|
||||
// const matchesIterator = textNode.value.matchAll(rule.pattern);
|
||||
// const matches = [...matchesIterator];
|
||||
// if (matches.length > 0) {
|
||||
// rule.transform(matches, lastNode);
|
||||
// ruleMatched = true;
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// } else if (!ruleMatched) {
|
||||
// // console.log("no rule matched", lastNode);
|
||||
// }
|
||||
|
||||
return tree;
|
||||
};
|
||||
};
|
||||
|
||||
function processParent(parent: Parent & Proxy): boolean {
|
||||
if (parent.type === "emphasis") {
|
||||
// Handle `**abc*` -> `**abc**`:
|
||||
// We detect this when we end with an emphasis node, and it's preceded by
|
||||
// a text node that ends with `*`
|
||||
// TODO(drifkin): the last node can be more deeply nested (e.g., a code
|
||||
// literal in a link), so we probably need to walk up the tree until we
|
||||
// find an emphasis node or a block? For now we'll just go up one layer to
|
||||
// catch the most common cases
|
||||
const emphasisNode = parent as Emphasis & Proxy;
|
||||
const grandparent = emphasisNode.parent;
|
||||
if (grandparent) {
|
||||
const indexOfEmphasisNode = (grandparent.node as Parent).children.indexOf(
|
||||
emphasisNode.node as RootContent,
|
||||
);
|
||||
if (indexOfEmphasisNode >= 0) {
|
||||
const nodeBefore = grandparent.children[indexOfEmphasisNode - 1] as
|
||||
| (Node & Proxy)
|
||||
| undefined;
|
||||
if (nodeBefore?.type === "text") {
|
||||
const textNode = nodeBefore.node as Text;
|
||||
if (textNode.value.endsWith("*")) {
|
||||
const strBefore = textNode.value.slice(0, -1);
|
||||
textNode.value = strBefore;
|
||||
const strongNode = u("strong", {
|
||||
children: emphasisNode.children,
|
||||
});
|
||||
(grandparent.node as Parent).children.splice(
|
||||
indexOfEmphasisNode,
|
||||
1,
|
||||
strongNode,
|
||||
);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Let's check if we have any bold items to close
|
||||
for (let i = parent.children.length - 1; i >= 0; i--) {
|
||||
const child = parent.children[i];
|
||||
if (child.type === "text") {
|
||||
const textNode = child as Text & Proxy;
|
||||
const sep = "**";
|
||||
const index = textNode.value.lastIndexOf(sep);
|
||||
if (index >= 0) {
|
||||
let isValidOpening = false;
|
||||
if (index + sep.length < textNode.value.length) {
|
||||
const charAfter = textNode.value[index + sep.length];
|
||||
if (!isWhitespace(charAfter)) {
|
||||
isValidOpening = true;
|
||||
}
|
||||
} else {
|
||||
if (i < parent.children.length - 1) {
|
||||
// TODO(drifkin): I'm not sure that this check is strict enough.
|
||||
// We're trying to detect cases like `**[abc]()` where the char
|
||||
// after the opening ** is indeed a non-whitespace character. We're
|
||||
// using the heuristic that there's another item after the current
|
||||
// one, but I'm not sure if that is good enough. In a well
|
||||
// constructed tree, there aren't two text nodes in a row, so this
|
||||
// _seems_ good, but I should think through it more
|
||||
isValidOpening = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (isValidOpening) {
|
||||
// TODO(drifkin): close the bold
|
||||
const strBefore = textNode.value.slice(0, index);
|
||||
const strAfter = textNode.value.slice(index + sep.length);
|
||||
(textNode.node as Text).value = strBefore;
|
||||
// TODO(drifkin): the node above could be empty in which case we probably want to delete it
|
||||
const children: PhrasingContent[] = [
|
||||
...(strAfter.length > 0 ? [u("text", { value: strAfter })] : []),
|
||||
];
|
||||
const strongNode: Strong = u("strong", {
|
||||
children,
|
||||
});
|
||||
const nodesAfter = (parent.node as Parent).children.splice(
|
||||
i + 1,
|
||||
parent.children.length - i - 1,
|
||||
strongNode,
|
||||
);
|
||||
// TODO(drifkin): this cast seems iffy, should see if we can cast the
|
||||
// parent instead, which would also help us check some of our
|
||||
// assumptions
|
||||
strongNode.children.push(...(nodesAfter as PhrasingContent[]));
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
function prevSibling(node: Node & Proxy): (Node & Proxy) | null {
|
||||
const parent = node.parent;
|
||||
if (parent) {
|
||||
const index = parent.children.indexOf(node);
|
||||
return parent.children[index - 1] as Node & Proxy;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function isWhitespace(str: string) {
|
||||
return str.trim() === "";
|
||||
}
|
||||
|
||||
// function debugPrintTreeNoPos(tree: Node) {
|
||||
// console.log(
|
||||
// JSON.stringify(
|
||||
// tree,
|
||||
// (key, value) => {
|
||||
// if (key === "position") {
|
||||
// return undefined;
|
||||
// }
|
||||
// return value;
|
||||
// },
|
||||
// 2,
|
||||
// ),
|
||||
// );
|
||||
// }
|
||||
|
||||
export default remarkStreamingMarkdown;
|
||||
@@ -1794,13 +1794,14 @@ func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, ava
|
||||
|
||||
var thinkValue *api.ThinkValue
|
||||
if think != nil {
|
||||
// Only set Think if it's actually requesting thinking
|
||||
if boolValue, ok := think.(bool); ok {
|
||||
thinkValue = &api.ThinkValue{
|
||||
Value: boolValue,
|
||||
if boolValue {
|
||||
thinkValue = &api.ThinkValue{Value: boolValue}
|
||||
}
|
||||
} else if stringValue, ok := think.(string); ok {
|
||||
thinkValue = &api.ThinkValue{
|
||||
Value: stringValue,
|
||||
if stringValue != "" && stringValue != "none" {
|
||||
thinkValue = &api.ThinkValue{Value: stringValue}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,9 +110,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
|
||||
for name, mxfp4 := range mxfp4s {
|
||||
dims := mxfp4.blocks.Shape()
|
||||
if !strings.HasSuffix(name, ".weight") {
|
||||
name = name + ".weight"
|
||||
}
|
||||
if strings.Contains(name, "ffn_down_exps") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name + ".weight",
|
||||
Name: name,
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4,
|
||||
@@ -121,12 +124,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
|
||||
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
|
||||
Name: strings.Replace(name, "gate_up", "gate", 1),
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
|
||||
}, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
|
||||
Name: strings.Replace(name, "gate_up", "up", 1),
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),
|
||||
|
||||
@@ -2,10 +2,12 @@ package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"io"
|
||||
"iter"
|
||||
"path"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
@@ -94,6 +96,26 @@ func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []
|
||||
return matched
|
||||
})
|
||||
|
||||
slices.SortStableFunc(matched, func(a, b Tensor) int {
|
||||
x := strings.Split(a.Name(), ".")
|
||||
y := strings.Split(b.Name(), ".")
|
||||
if len(x) != len(y) {
|
||||
return cmp.Compare(len(x), len(y))
|
||||
}
|
||||
|
||||
vals := make([]int, len(x))
|
||||
for i := range x {
|
||||
vals[i] = strings.Compare(x[i], y[i])
|
||||
m, err := strconv.ParseInt(x[i], 0, 0)
|
||||
n, err2 := strconv.ParseInt(y[i], 0, 0)
|
||||
if errors.Join(err, err2) == nil {
|
||||
vals[i] = cmp.Compare(m, n)
|
||||
}
|
||||
}
|
||||
|
||||
return cmp.Or(vals...)
|
||||
})
|
||||
|
||||
if len(matched) > 0 {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: merges[i].name,
|
||||
|
||||
@@ -3,8 +3,10 @@ package convert
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -951,3 +953,45 @@ func TestMerge(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMergeOrder(t *testing.T) {
|
||||
for range 8 {
|
||||
t.Run("", func(t *testing.T) {
|
||||
tensors := make([]Tensor, 16)
|
||||
for i := range tensors {
|
||||
tensors[i] = &fakeTensor{
|
||||
name: fmt.Sprintf("layer.%d.weight", i),
|
||||
shape: []uint64{1},
|
||||
data: []float32{float32(i)},
|
||||
}
|
||||
}
|
||||
|
||||
rand.Shuffle(len(tensors), func(i, j int) {
|
||||
tensors[i], tensors[j] = tensors[j], tensors[i]
|
||||
})
|
||||
|
||||
matched, unmatched := mergeTensors(tensors, merge{"layer.*.weight", "layer.weight"})
|
||||
if len(unmatched) != 0 {
|
||||
t.Error("expected no remaining tensors, got", len(unmatched))
|
||||
}
|
||||
|
||||
if len(matched) != 1 {
|
||||
t.Error("expected 1 merged tensor, got", len(matched))
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := matched[0].WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var f32s [16]float32
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.IsSorted(f32s[:]) {
|
||||
t.Errorf("merged tensor data is not in order: %+v", f32s)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,6 +94,9 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
continue
|
||||
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
|
||||
continue
|
||||
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
|
||||
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
|
||||
continue
|
||||
}
|
||||
dirs = []string{ml.LibOllamaPath, dir}
|
||||
} else {
|
||||
|
||||
@@ -13,9 +13,23 @@ Embeddings turn text into numeric vectors you can store in a vector database, se
|
||||
|
||||
## Generate embeddings
|
||||
|
||||
Use `/api/embed` with a single string.
|
||||
|
||||
<Tabs>
|
||||
<Tab title="CLI">
|
||||
Generate embeddings directly from the command line:
|
||||
|
||||
```shell
|
||||
ollama run embeddinggemma "Hello world"
|
||||
```
|
||||
|
||||
You can also pipe text to generate embeddings:
|
||||
|
||||
```shell
|
||||
echo "Hello world" | ollama run embeddinggemma
|
||||
```
|
||||
|
||||
Output is a JSON array.
|
||||
|
||||
</Tab>
|
||||
<Tab title="cURL">
|
||||
```shell
|
||||
curl -X POST http://localhost:11434/api/embed \
|
||||
|
||||
@@ -68,6 +68,15 @@ To run Ollama using Docker with AMD GPUs, use the `rocm` tag and the following c
|
||||
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama:rocm
|
||||
```
|
||||
|
||||
## Vulkan Support
|
||||
|
||||
Vulkan is bundled into the `ollama/ollama` image.
|
||||
|
||||
```shell
|
||||
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 -e OLLAMA_VULKAN=1 --name ollama ollama/ollama
|
||||
```
|
||||
|
||||
|
||||
## Run model locally
|
||||
|
||||
Now you can run a model:
|
||||
@@ -79,3 +88,4 @@ docker exec -it ollama ollama run llama3.2
|
||||
## Try different models
|
||||
|
||||
More models can be found on the [Ollama library](https://ollama.com/library).
|
||||
|
||||
|
||||
@@ -63,6 +63,10 @@
|
||||
{
|
||||
"source": "/api/openai",
|
||||
"destination": "/api/openai-compatibility"
|
||||
},
|
||||
{
|
||||
"source": "/api",
|
||||
"destination": "/api/introduction"
|
||||
}
|
||||
],
|
||||
"navigation": {
|
||||
@@ -130,7 +134,7 @@
|
||||
{
|
||||
"group": "API Reference",
|
||||
"pages": [
|
||||
"/api/index",
|
||||
"/api/introduction",
|
||||
"/api/authentication",
|
||||
"/api/streaming",
|
||||
"/api/usage",
|
||||
|
||||
44
docs/gpu.mdx
@@ -52,7 +52,11 @@ sudo modprobe nvidia_uvm`
|
||||
|
||||
## AMD Radeon
|
||||
|
||||
Ollama supports the following AMD GPUs:
|
||||
Ollama supports the following AMD GPUs via the ROCm library:
|
||||
|
||||
> [!NOTE]
|
||||
> Additional AMD GPU support is provided by the Vulkan Library - see below.
|
||||
|
||||
|
||||
### Linux Support
|
||||
|
||||
@@ -121,6 +125,42 @@ In some Linux distributions, SELinux can prevent containers from
|
||||
accessing the AMD GPU devices. On the host system you can run
|
||||
`sudo setsebool container_use_devices=1` to allow containers to use devices.
|
||||
|
||||
### Metal (Apple GPUs)
|
||||
## Metal (Apple GPUs)
|
||||
|
||||
Ollama supports GPU acceleration on Apple devices via the Metal API.
|
||||
|
||||
|
||||
## Vulkan GPU Support
|
||||
|
||||
> [!NOTE]
|
||||
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
|
||||
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
|
||||
|
||||
Additional GPU support on Windows and Linux is provided via
|
||||
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
|
||||
bundled with Vulkan support and require no additional setup steps. Most Linux
|
||||
distributions require installing additional components, and you may have
|
||||
multiple options for Vulkan drivers between Mesa and GPU Vendor specific packages
|
||||
|
||||
- Linux Intel GPU Instructions - https://dgpu-docs.intel.com/driver/client/overview.html
|
||||
- Linux AMD GPU Instructions - https://amdgpu-install.readthedocs.io/en/latest/install-script.html#specifying-a-vulkan-implementation
|
||||
|
||||
For AMD GPUs on some Linux distributions, you may need to add the `ollama` user to the `render` group.
|
||||
|
||||
The Ollama scheduler leverages available VRAM data reported by the GPU libraries to
|
||||
make optimal scheduling decisions. Vulkan requires additional capabilities or
|
||||
running as root to expose this available VRAM data. If neither root access or this
|
||||
capability are granted, Ollama will use approximate sizes of the models
|
||||
to make best effort scheduling decisions.
|
||||
|
||||
```bash
|
||||
sudo setcap cap_perfmon+ep /usr/local/bin/ollama
|
||||
```
|
||||
|
||||
### GPU Selection
|
||||
|
||||
To select specific Vulkan GPU(s), you can set the environment variable
|
||||
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
|
||||
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
|
||||
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
|
||||
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
||||
@@ -2,12 +2,15 @@ openapi: 3.1.0
|
||||
info:
|
||||
title: Ollama API
|
||||
version: 0.1.0
|
||||
license:
|
||||
name: MIT
|
||||
url: https://opensource.org/licenses/MIT
|
||||
description: |
|
||||
OpenAPI specification for the Ollama HTTP API
|
||||
|
||||
servers:
|
||||
- url: http://localhost:11434
|
||||
description: Local Ollama instance
|
||||
description: Ollama
|
||||
security: []
|
||||
components:
|
||||
securitySchemes:
|
||||
bearerAuth:
|
||||
@@ -93,8 +96,11 @@ components:
|
||||
type: boolean
|
||||
default: true
|
||||
think:
|
||||
type: boolean
|
||||
description: When true, returns separate thinking output in addition to content
|
||||
oneOf:
|
||||
- type: boolean
|
||||
- type: string
|
||||
enum: [high, medium, low]
|
||||
description: When true, returns separate thinking output in addition to content. Can be a boolean (true/false) or a string ("high", "medium", "low") for supported models.
|
||||
raw:
|
||||
type: boolean
|
||||
description: When true, returns the raw response from the model without any prompt templating
|
||||
@@ -271,8 +277,11 @@ components:
|
||||
type: boolean
|
||||
default: true
|
||||
think:
|
||||
type: boolean
|
||||
description: When true, returns separate thinking output in addition to content
|
||||
oneOf:
|
||||
- type: boolean
|
||||
- type: string
|
||||
enum: [high, medium, low]
|
||||
description: When true, returns separate thinking output in addition to content. Can be a boolean (true/false) or a string ("high", "medium", "low") for supported models.
|
||||
keep_alive:
|
||||
oneOf:
|
||||
- type: string
|
||||
@@ -310,7 +319,6 @@ components:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
nullable: true
|
||||
description: Optional base64-encoded images in the response
|
||||
done:
|
||||
type: boolean
|
||||
@@ -367,7 +375,6 @@ components:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
nullable: true
|
||||
description: Partial base64-encoded images, when present
|
||||
done:
|
||||
type: boolean
|
||||
@@ -543,6 +550,9 @@ components:
|
||||
license:
|
||||
type: string
|
||||
description: The license of the model
|
||||
modified_at:
|
||||
type: string
|
||||
description: Last modified timestamp in ISO 8601 format
|
||||
details:
|
||||
type: object
|
||||
description: High-level model details
|
||||
@@ -622,6 +632,9 @@ components:
|
||||
size_vram:
|
||||
type: integer
|
||||
description: VRAM usage in bytes
|
||||
context_length:
|
||||
type: integer
|
||||
description: Context length for the running model
|
||||
PsResponse:
|
||||
type: object
|
||||
properties:
|
||||
@@ -1275,6 +1288,9 @@ paths:
|
||||
example:
|
||||
source: gemma3
|
||||
destination: gemma3-backup
|
||||
responses:
|
||||
"200":
|
||||
description: Model successfully copied
|
||||
/api/pull:
|
||||
post:
|
||||
summary: Pull a model
|
||||
@@ -1382,16 +1398,7 @@ paths:
|
||||
model: gemma3
|
||||
responses:
|
||||
"200":
|
||||
description: Deletion status updates.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/StatusResponse"
|
||||
example:
|
||||
status: "success"
|
||||
application/x-ndjson:
|
||||
schema:
|
||||
$ref: "#/components/schemas/StatusEvent"
|
||||
description: Model successfully deleted
|
||||
/api/version:
|
||||
get:
|
||||
summary: Get version
|
||||
|
||||
@@ -196,8 +196,6 @@ var (
|
||||
NoPrune = Bool("OLLAMA_NOPRUNE")
|
||||
// SchedSpread allows scheduling models across all GPUs.
|
||||
SchedSpread = Bool("OLLAMA_SCHED_SPREAD")
|
||||
// IntelGPU enables experimental Intel GPU detection.
|
||||
IntelGPU = Bool("OLLAMA_INTEL_GPU")
|
||||
// MultiUserCache optimizes prompt caching for multi-user scenarios
|
||||
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
|
||||
// Enable the new Ollama engine
|
||||
@@ -206,6 +204,8 @@ var (
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||
// Auth enables authentication between the Ollama client and server
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
// Enable Vulkan backend
|
||||
EnableVulkan = Bool("OLLAMA_VULKAN")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
@@ -314,7 +314,7 @@ func AsMap() map[string]EnvVar {
|
||||
ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which Vulkan devices are visible by numeric ID"}
|
||||
ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"}
|
||||
ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"}
|
||||
ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"}
|
||||
ret["OLLAMA_VULKAN"] = EnvVar{"OLLAMA_VULKAN", EnableVulkan(), "Enable experimental Vulkan support"}
|
||||
}
|
||||
|
||||
return ret
|
||||
|
||||
@@ -797,73 +797,6 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
return
|
||||
}
|
||||
|
||||
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||
if llm.KV().Uint("vision.block_count") == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for name, layer := range llm.Tensors().GroupLayers() {
|
||||
if name == "v" || strings.HasPrefix(name, "v.") {
|
||||
for _, tensor := range layer {
|
||||
weights += tensor.Size()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
imageSize := uint64(llm.KV().Uint("vision.image_size"))
|
||||
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
|
||||
if patchSize == 0 {
|
||||
slog.Warn("unknown patch size for vision model")
|
||||
return
|
||||
}
|
||||
|
||||
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
|
||||
|
||||
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
|
||||
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||
numPatches++
|
||||
}
|
||||
|
||||
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
|
||||
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
case "mllama":
|
||||
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
||||
|
||||
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
|
||||
|
||||
graphSize = 4 * (8 +
|
||||
imageSize*imageSize*numChannels*maxNumTiles +
|
||||
embeddingLength*numPatches*maxNumTiles +
|
||||
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
||||
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
||||
case "gemma3", "mistral3":
|
||||
graphSize = 4 * (imageSize*imageSize*numChannels +
|
||||
embeddingLength*patchSize +
|
||||
numPatches*numPatches*headCount)
|
||||
case "qwen25vl":
|
||||
maxPixels := uint64(llm.KV().Uint("vision.max_pixels", 28*28*1280))
|
||||
|
||||
numPatches := maxPixels / (patchSize * patchSize)
|
||||
|
||||
graphSize = 4 * (maxPixels*numChannels + // Original image storage
|
||||
// Normalized pixels
|
||||
maxPixels*numChannels +
|
||||
// Patches storage (numPatches * channels * patchSize^2)
|
||||
numPatches*numChannels*patchSize*patchSize +
|
||||
// Self-attention calculations
|
||||
numPatches*numPatches*headCount +
|
||||
// Additional buffer for processing
|
||||
embeddingLength*numPatches)
|
||||
case "llama4":
|
||||
// vision graph is computed independently in the same schedule
|
||||
// and is negligible compared to the worst case text graph
|
||||
}
|
||||
|
||||
return weights, graphSize
|
||||
}
|
||||
|
||||
// SupportsKVCacheType checks if the requested cache type is supported
|
||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
if cacheType == "" || cacheType == "f16" {
|
||||
|
||||
@@ -14,6 +14,23 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func assertBytesMatchToken(t *testing.T, label, token string, ints []int) {
|
||||
t.Helper()
|
||||
|
||||
raw := []byte(token)
|
||||
if len(ints) != len(raw) {
|
||||
t.Errorf("%s expected %d bytes for token %q, got %d (%v)", label, len(raw), token, len(ints), ints)
|
||||
return
|
||||
}
|
||||
|
||||
for i, b := range raw {
|
||||
if ints[i] != int(b) {
|
||||
t.Errorf("%s byte[%d] mismatch for token %q: got %d want %d", label, i, token, ints[i], int(b))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIGenerate(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 30 * time.Second
|
||||
@@ -381,3 +398,182 @@ func TestAPIShowModel(t *testing.T) {
|
||||
t.Errorf("%s missing modified_at: %#v", modelName, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIGenerateLogprobs(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if err := PullIfMissing(ctx, client, smol); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
enableLogprobs := true
|
||||
noStream := false
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logprobs *bool
|
||||
topLogprobs int
|
||||
expectCount int
|
||||
}{
|
||||
{
|
||||
name: "no_logprobs",
|
||||
logprobs: nil,
|
||||
topLogprobs: 0,
|
||||
expectCount: 0,
|
||||
},
|
||||
{
|
||||
name: "logprobs_only",
|
||||
logprobs: &enableLogprobs,
|
||||
topLogprobs: 0,
|
||||
expectCount: 1,
|
||||
},
|
||||
{
|
||||
name: "logprobs_with_top_5",
|
||||
logprobs: &enableLogprobs,
|
||||
topLogprobs: 5,
|
||||
expectCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "Why is the sky blue?",
|
||||
Stream: &noStream,
|
||||
Logprobs: test.logprobs != nil && *test.logprobs,
|
||||
TopLogprobs: test.topLogprobs,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_predict": 10,
|
||||
},
|
||||
}
|
||||
|
||||
var response api.GenerateResponse
|
||||
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
|
||||
if resp.Done {
|
||||
response = resp
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("generate failed: %s", err)
|
||||
}
|
||||
|
||||
// Check logprobs based on expectation
|
||||
if test.expectCount == 0 {
|
||||
if len(response.Logprobs) > 0 {
|
||||
t.Errorf("expected no logprobs but got %d", len(response.Logprobs))
|
||||
}
|
||||
} else {
|
||||
if len(response.Logprobs) == 0 {
|
||||
t.Errorf("expected logprobs but got none")
|
||||
}
|
||||
|
||||
// Validate each logprob entry
|
||||
for i, lp := range response.Logprobs {
|
||||
if lp.Token == "" {
|
||||
t.Errorf("logprob[%d] has empty token", i)
|
||||
}
|
||||
if lp.Logprob > 0 {
|
||||
t.Errorf("logprob[%d] has positive logprob %f (should be <= 0)", i, lp.Logprob)
|
||||
}
|
||||
assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d]", i), lp.Token, lp.Bytes)
|
||||
|
||||
// Check top_logprobs if requested
|
||||
if test.topLogprobs > 0 {
|
||||
if len(lp.TopLogprobs) == 0 {
|
||||
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
|
||||
}
|
||||
if len(lp.TopLogprobs) > test.topLogprobs {
|
||||
t.Errorf("logprob[%d] has %d top_logprobs, expected max %d", i, len(lp.TopLogprobs), test.topLogprobs)
|
||||
}
|
||||
|
||||
// Verify top_logprobs are sorted by probability (descending)
|
||||
for j := 1; j < len(lp.TopLogprobs); j++ {
|
||||
if lp.TopLogprobs[j-1].Logprob < lp.TopLogprobs[j].Logprob {
|
||||
t.Errorf("logprob[%d].top_logprobs not sorted: %f < %f", i, lp.TopLogprobs[j-1].Logprob, lp.TopLogprobs[j].Logprob)
|
||||
}
|
||||
}
|
||||
for j, top := range lp.TopLogprobs {
|
||||
assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d].top[%d]", i, j), top.Token, top.Bytes)
|
||||
}
|
||||
} else if len(lp.TopLogprobs) > 0 {
|
||||
t.Errorf("logprob[%d] has top_logprobs but none were requested", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIChatLogprobs(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if err := PullIfMissing(ctx, client, smol); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
enableLogprobs := true
|
||||
noStream := false
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Say hello in one word"},
|
||||
},
|
||||
Stream: &noStream,
|
||||
Logprobs: enableLogprobs,
|
||||
TopLogprobs: 3,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_predict": 5,
|
||||
},
|
||||
}
|
||||
|
||||
var response api.ChatResponse
|
||||
err := client.Chat(ctx, &req, func(resp api.ChatResponse) error {
|
||||
if resp.Done {
|
||||
response = resp
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("chat failed: %s", err)
|
||||
}
|
||||
|
||||
if len(response.Logprobs) == 0 {
|
||||
t.Fatal("expected logprobs in response but got none")
|
||||
}
|
||||
|
||||
t.Logf("received %d logprobs for chat response", len(response.Logprobs))
|
||||
|
||||
for i, lp := range response.Logprobs {
|
||||
if lp.Token == "" {
|
||||
t.Errorf("logprob[%d] has empty token", i)
|
||||
}
|
||||
if lp.Logprob > 0 {
|
||||
t.Errorf("logprob[%d] has positive logprob %f", i, lp.Logprob)
|
||||
}
|
||||
assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d]", i), lp.Token, lp.Bytes)
|
||||
if len(lp.TopLogprobs) == 0 {
|
||||
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
|
||||
}
|
||||
if len(lp.TopLogprobs) > 3 {
|
||||
t.Errorf("logprob[%d] has %d top_logprobs, expected max 3", i, len(lp.TopLogprobs))
|
||||
}
|
||||
for j, top := range lp.TopLogprobs {
|
||||
assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d].top[%d]", i, j), top.Token, top.Bytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,8 +63,13 @@ func BackendInit() {
|
||||
C.llama_backend_init()
|
||||
}
|
||||
|
||||
func EnumerateGPUs() []ml.DeviceID {
|
||||
var ids []ml.DeviceID
|
||||
type Devices struct {
|
||||
ml.DeviceID
|
||||
LlamaID uint64
|
||||
}
|
||||
|
||||
func EnumerateGPUs() []Devices {
|
||||
var ids []Devices
|
||||
|
||||
for i := range C.ggml_backend_dev_count() {
|
||||
device := C.ggml_backend_dev_get(i)
|
||||
@@ -74,9 +79,12 @@ func EnumerateGPUs() []ml.DeviceID {
|
||||
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
|
||||
var props C.struct_ggml_backend_dev_props
|
||||
C.ggml_backend_dev_get_props(device, &props)
|
||||
ids = append(ids, ml.DeviceID{
|
||||
ids = append(ids, Devices{
|
||||
DeviceID: ml.DeviceID{
|
||||
ID: C.GoString(props.id),
|
||||
Library: C.GoString(props.library),
|
||||
},
|
||||
LlamaID: uint64(i),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -217,7 +225,21 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
|
||||
return embeddings
|
||||
}
|
||||
|
||||
// GetLogitsIth gets the logits for the ith token
|
||||
func (c *Context) GetLogitsIth(i int) []float32 {
|
||||
logits := unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int32_t(i)))
|
||||
if logits == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
vocabSize := c.Model().NumVocab()
|
||||
result := make([]float32, vocabSize)
|
||||
_ = copy(result, unsafe.Slice((*float32)(logits), vocabSize))
|
||||
return result
|
||||
}
|
||||
|
||||
type ModelParams struct {
|
||||
Devices []uint64
|
||||
NumGpuLayers int
|
||||
MainGpu int
|
||||
UseMmap bool
|
||||
@@ -241,6 +263,21 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
|
||||
cparams.use_mmap = C.bool(params.UseMmap)
|
||||
cparams.vocab_only = C.bool(params.VocabOnly)
|
||||
|
||||
var devices []C.ggml_backend_dev_t
|
||||
for _, llamaID := range params.Devices {
|
||||
devices = append(devices, C.ggml_backend_dev_get(C.size_t(llamaID)))
|
||||
}
|
||||
if len(devices) > 0 {
|
||||
devices = append(devices, C.ggml_backend_dev_t(C.NULL))
|
||||
devicesData := &devices[0]
|
||||
|
||||
var devicesPin runtime.Pinner
|
||||
devicesPin.Pin(devicesData)
|
||||
defer devicesPin.Unpin()
|
||||
|
||||
cparams.devices = devicesData
|
||||
}
|
||||
|
||||
if len(params.TensorSplit) > 0 {
|
||||
tensorSplitData := ¶ms.TensorSplit[0]
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Jeff Bolz <jbolz@nvidia.com>
|
||||
Date: Wed, 29 Oct 2025 03:53:04 -0500
|
||||
Subject: [PATCH] vulkan: Call ggml_vk_buffer_write_2d from ggml_vk_buffer_copy
|
||||
(#16793)
|
||||
|
||||
This lets the copy to the destination device use the host-visible
|
||||
vidmem optimization.
|
||||
---
|
||||
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 5 +----
|
||||
1 file changed, 1 insertion(+), 4 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index 221e29509..18b7cbccf 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -5654,14 +5654,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
|
||||
VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
|
||||
// Copy device to device
|
||||
ggml_vk_ensure_sync_staging_buffer(src->device, size);
|
||||
- ggml_vk_ensure_sync_staging_buffer(dst->device, size);
|
||||
|
||||
// Copy to src staging buffer
|
||||
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
|
||||
- // memcpy to dst staging buffer
|
||||
- memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size);
|
||||
// Copy to dst buffer
|
||||
- ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size);
|
||||
+ ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,657 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Jeff Bolz <jbolz@nvidia.com>
|
||||
Date: Wed, 29 Oct 2025 08:44:29 -0500
|
||||
Subject: [PATCH] vulkan: Update topk_moe fusion to handle gpt's late softmax
|
||||
(#16656)
|
||||
|
||||
* vulkan: Update topk_moe fusion to handle gpt's late softmax
|
||||
|
||||
Based on #16649.
|
||||
|
||||
* Add ggml_check_edges
|
||||
|
||||
* Add sync logging to show fusion effects
|
||||
|
||||
* handle clamp added in #16655
|
||||
|
||||
* Update ggml/src/ggml-impl.h
|
||||
|
||||
Co-authored-by: Diego Devesa <slarengh@gmail.com>
|
||||
---
|
||||
ggml/src/ggml-impl.h | 16 +
|
||||
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 304 +++++++++++-------
|
||||
.../ggml-vulkan/vulkan-shaders/topk_moe.comp | 90 ++++--
|
||||
3 files changed, 272 insertions(+), 138 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
|
||||
index 639d551a2..e5c446d1d 100644
|
||||
--- a/ggml/src/ggml-impl.h
|
||||
+++ b/ggml/src/ggml-impl.h
|
||||
@@ -693,6 +693,7 @@ GGML_API void ggml_dxgi_pdh_release();
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
+#include <array>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
@@ -708,6 +709,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
|
||||
}
|
||||
|
||||
+// Return true if the edges in the graph match expectations.
|
||||
+inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
|
||||
+ int start_idx,
|
||||
+ std::initializer_list<std::array<int, 3>> edges) {
|
||||
+ for (const auto & edge : edges) {
|
||||
+ int dst_node = edge[0];
|
||||
+ int src_idx = edge[1];
|
||||
+ int src_node = edge[2];
|
||||
+ if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ }
|
||||
+ return true;
|
||||
+}
|
||||
+
|
||||
// expose GGUF internals for test code
|
||||
GGML_API size_t gguf_type_size(enum gguf_type type);
|
||||
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index 53b57c179..b2855b078 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -387,12 +387,76 @@ static constexpr uint32_t num_argsort_pipelines = 11;
|
||||
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
||||
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
||||
|
||||
-static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
- GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
- GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
||||
-static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
- GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
+ GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
+ GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
|
||||
+ GGML_OP_RESHAPE };
|
||||
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
+ GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||
+static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||
+ GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
+ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
||||
+
|
||||
+//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
|
||||
+//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
|
||||
+//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
|
||||
+//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ]
|
||||
+//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ]
|
||||
+//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ]
|
||||
+//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ]
|
||||
+//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
|
||||
+//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
|
||||
+//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ]
|
||||
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
|
||||
+ { 1, 0, 0 }, // reshape->src[0] == softmax
|
||||
+ { 2, 0, 0 }, // argsort->src[0] == softmax
|
||||
+ { 3, 0, 2 }, // view->src[0] == argsort
|
||||
+ { 4, 0, 1 }, // get_rows->src[0] == reshape
|
||||
+ { 4, 1, 3 }, // get_rows->src[1] == view
|
||||
+ { 5, 0, 4 }, // reshape->src[0] == get_rows
|
||||
+ { 6, 0, 5 }, // sum_rows->src[0] == reshape
|
||||
+ { 7, 0, 6 }, // clamp->src[0] == sum_rows
|
||||
+ { 8, 0, 5 }, // div->src[0] == reshape
|
||||
+ { 8, 1, 7 }, // div->src[1] == clamp
|
||||
+ { 9, 0, 8 }, // reshape->src[0] == div
|
||||
+};
|
||||
+
|
||||
+// same as early_softmax_norm but ending after the get_rows
|
||||
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
|
||||
+ { 1, 0, 0 }, // reshape->src[0] == softmax
|
||||
+ { 2, 0, 0 }, // argsort->src[0] == softmax
|
||||
+ { 3, 0, 2 }, // view->src[0] == argsort
|
||||
+ { 4, 0, 1 }, // get_rows->src[0] == reshape
|
||||
+ { 4, 1, 3 }, // get_rows->src[1] == view
|
||||
+};
|
||||
|
||||
+//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
|
||||
+//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
|
||||
+//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
|
||||
+//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
|
||||
+//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
|
||||
+//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
|
||||
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
|
||||
+ { 1, 0, 0 }, // view->src[0] == argsort
|
||||
+ { 2, 1, 1 }, // get_rows->src[1] == view
|
||||
+ { 3, 0, 2 }, // reshape->src[0] == get_rows
|
||||
+ { 4, 0, 3 }, // soft_max->src[0] == reshape
|
||||
+ { 5, 0, 4 }, // reshape->src[0] == soft_max
|
||||
+};
|
||||
+
|
||||
+enum topk_moe_mode {
|
||||
+ TOPK_MOE_EARLY_SOFTMAX,
|
||||
+ TOPK_MOE_EARLY_SOFTMAX_NORM,
|
||||
+ TOPK_MOE_LATE_SOFTMAX,
|
||||
+ TOPK_MOE_COUNT,
|
||||
+};
|
||||
+
|
||||
+static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
|
||||
+ topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
|
||||
+ num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
|
||||
+ TOPK_MOE_LATE_SOFTMAX;
|
||||
+ return mode;
|
||||
+}
|
||||
|
||||
struct vk_device_struct {
|
||||
std::recursive_mutex mutex;
|
||||
@@ -607,8 +671,7 @@ struct vk_device_struct {
|
||||
|
||||
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||
|
||||
- // [2] is {!norm, norm}
|
||||
- vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
|
||||
+ vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
|
||||
|
||||
std::vector<vk_pipeline_ref> all_pipelines;
|
||||
|
||||
@@ -956,6 +1019,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
|
||||
struct vk_op_topk_moe_push_constants {
|
||||
uint32_t n_rows;
|
||||
uint32_t n_expert_used;
|
||||
+ float clamp_min;
|
||||
+ float clamp_max;
|
||||
};
|
||||
|
||||
struct vk_op_add_id_push_constants {
|
||||
@@ -3806,8 +3871,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
||||
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
|
||||
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
|
||||
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
|
||||
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
|
||||
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
|
||||
}
|
||||
|
||||
for (auto &c : compiles) {
|
||||
@@ -8085,8 +8151,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||
- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
||||
- return ctx->device->pipeline_topk_moe[idx][with_norm];
|
||||
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||
+ return ctx->device->pipeline_topk_moe[idx][mode];
|
||||
}
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||
@@ -8141,6 +8207,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return nullptr;
|
||||
}
|
||||
case GGML_OP_ARGSORT:
|
||||
+ if (ctx->num_additional_fused_ops) {
|
||||
+ uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
+ GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||
+ return ctx->device->pipeline_topk_moe[idx][mode];
|
||||
+ }
|
||||
+
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
return ctx->device->pipeline_argsort_f32[idx];
|
||||
@@ -9676,10 +9749,12 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
|
||||
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
|
||||
|
||||
- bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
||||
+ topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
|
||||
- ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
||||
- ggml_tensor * ids = cgraph->nodes[node_idx + 3];
|
||||
+ ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
|
||||
+ (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
|
||||
+ cgraph->nodes[node_idx + 5];
|
||||
+ ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
|
||||
|
||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
@@ -9738,9 +9813,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
GGML_ASSERT(d_ids != nullptr);
|
||||
}
|
||||
|
||||
- vk_op_topk_moe_push_constants pc;
|
||||
+ vk_op_topk_moe_push_constants pc {};
|
||||
pc.n_rows = n_rows;
|
||||
pc.n_expert_used = n_expert_used;
|
||||
+ if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
|
||||
+ ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
|
||||
+ pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
|
||||
+ pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
|
||||
+ }
|
||||
|
||||
GGML_ASSERT(n_expert_used <= n_experts);
|
||||
|
||||
@@ -11335,7 +11415,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
}
|
||||
}
|
||||
}
|
||||
+
|
||||
+#define ENABLE_SYNC_LOGGING 0
|
||||
+
|
||||
if (need_sync) {
|
||||
+#if ENABLE_SYNC_LOGGING
|
||||
+ std::cerr << "sync" << std::endl;
|
||||
+#endif
|
||||
ctx->unsynced_nodes_written.clear();
|
||||
ctx->unsynced_nodes_read.clear();
|
||||
ggml_vk_sync_buffers(ctx, compute_ctx);
|
||||
@@ -11353,6 +11439,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
}
|
||||
}
|
||||
}
|
||||
+#if ENABLE_SYNC_LOGGING
|
||||
+ if (!dryrun) {
|
||||
+ for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
|
||||
+ auto *n = cgraph->nodes[node_idx + i];
|
||||
+ std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
|
||||
+ if (n->op == GGML_OP_GLU) {
|
||||
+ std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
|
||||
+ }
|
||||
+ std::cerr << std::endl;
|
||||
+ }
|
||||
+ }
|
||||
+#endif
|
||||
|
||||
switch (node->op) {
|
||||
case GGML_OP_REPEAT:
|
||||
@@ -11531,7 +11629,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
- ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
|
||||
+ if (ctx->num_additional_fused_ops) {
|
||||
+ ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
|
||||
+ } else {
|
||||
+ ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
|
||||
+ }
|
||||
|
||||
break;
|
||||
case GGML_OP_SUM:
|
||||
@@ -12329,30 +12431,27 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
|
||||
}
|
||||
|
||||
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
||||
- int node_idx, bool with_norm) {
|
||||
+ int node_idx, topk_moe_mode mode) {
|
||||
|
||||
- if (with_norm) {
|
||||
- if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
|
||||
- return false;
|
||||
- }
|
||||
- for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
|
||||
- if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
|
||||
- return false;
|
||||
- }
|
||||
- }
|
||||
- } else {
|
||||
- if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
|
||||
- return false;
|
||||
- }
|
||||
- for (size_t i = 0; i < topk_moe.size(); ++i) {
|
||||
- if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
|
||||
- return false;
|
||||
- }
|
||||
- }
|
||||
- }
|
||||
+ const ggml_tensor * softmax;
|
||||
+ const ggml_tensor * weights;
|
||||
|
||||
- const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
|
||||
- const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
||||
+ switch (mode) {
|
||||
+ case TOPK_MOE_EARLY_SOFTMAX_NORM:
|
||||
+ softmax = cgraph->nodes[node_idx + 0];
|
||||
+ weights = cgraph->nodes[node_idx + 9];
|
||||
+ break;
|
||||
+ case TOPK_MOE_EARLY_SOFTMAX:
|
||||
+ softmax = cgraph->nodes[node_idx + 0];
|
||||
+ weights = cgraph->nodes[node_idx + 4];
|
||||
+ break;
|
||||
+ case TOPK_MOE_LATE_SOFTMAX:
|
||||
+ softmax = cgraph->nodes[node_idx + 4];
|
||||
+ weights = cgraph->nodes[node_idx + 5];
|
||||
+ break;
|
||||
+ default:
|
||||
+ return false;
|
||||
+ }
|
||||
|
||||
const float * op_params = (const float *)softmax->op_params;
|
||||
|
||||
@@ -12378,60 +12477,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
|
||||
return false;
|
||||
}
|
||||
|
||||
- // Check that the nodes don't have any unexpected uses
|
||||
- const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
|
||||
- const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
|
||||
- const ggml_tensor * view = cgraph->nodes[node_idx + 3];
|
||||
- const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
|
||||
- const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
|
||||
- const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
|
||||
- const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
|
||||
- const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
|
||||
-
|
||||
- // softmax is used by reshape and argsort
|
||||
- if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
|
||||
- reshape1->src[0] != softmax ||
|
||||
- argsort->src[0] != softmax) {
|
||||
- return false;
|
||||
- }
|
||||
- // reshape is used by get_rows
|
||||
- if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
|
||||
- get_rows->src[0] != reshape1) {
|
||||
- return false;
|
||||
- }
|
||||
- // argsort is used by view
|
||||
- if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
|
||||
- view->src[0] != argsort) {
|
||||
- return false;
|
||||
- }
|
||||
- // view is written (via argsort), we can skip checking it
|
||||
-
|
||||
- if (with_norm) {
|
||||
- // get_rows is used by reshape
|
||||
- if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
|
||||
- reshape5->src[0] != get_rows) {
|
||||
- return false;
|
||||
- }
|
||||
-
|
||||
- // reshape is used by sum_rows and div
|
||||
- if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
|
||||
- sum_rows->src[0] != reshape5 ||
|
||||
- div->src[0] != reshape5) {
|
||||
- return false;
|
||||
- }
|
||||
-
|
||||
- // sum_rows is used by div
|
||||
- if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
|
||||
- div->src[1] != sum_rows) {
|
||||
- return false;
|
||||
- }
|
||||
-
|
||||
- // div/reshape are written
|
||||
- if (reshape8->src[0] != div) {
|
||||
- return false;
|
||||
- }
|
||||
- }
|
||||
-
|
||||
if (!ctx->device->subgroup_arithmetic ||
|
||||
!ctx->device->subgroup_shuffle ||
|
||||
!ctx->device->subgroup_require_full_support ||
|
||||
@@ -12517,10 +12562,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
ctx->num_additional_fused_ops = num_adds - 1;
|
||||
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
||||
- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
||||
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
||||
- ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
||||
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
|
||||
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
|
||||
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
|
||||
+ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
|
||||
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
||||
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
||||
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
||||
+ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
|
||||
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
||||
+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
||||
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
|
||||
+ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
|
||||
}
|
||||
}
|
||||
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
||||
@@ -12618,10 +12671,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
ctx->num_additional_fused_ops = num_adds - 1;
|
||||
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
||||
- ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
||||
- } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
||||
- ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
||||
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
|
||||
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
|
||||
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
|
||||
+ ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
|
||||
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
||||
+ ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
||||
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
||||
+ ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
|
||||
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
||||
+ ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
||||
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
|
||||
+ ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12754,25 +12815,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
||||
while (first_unused < graph->n_nodes) {
|
||||
std::vector<int> current_set;
|
||||
|
||||
- // Avoid reordering topk_moe_norm
|
||||
- if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
|
||||
- bool is_topk_moe_norm = true;
|
||||
- for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
||||
- if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
|
||||
- is_topk_moe_norm = false;
|
||||
+ // Check for fusion patterns and avoid reordering them
|
||||
+ auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
|
||||
+ if (start + (int)pattern.size() <= graph->n_nodes) {
|
||||
+ bool is_pattern = true;
|
||||
+ for (size_t j = 0; j < pattern.size(); ++j) {
|
||||
+ if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
|
||||
+ is_pattern = false;
|
||||
+ }
|
||||
}
|
||||
+ return is_pattern;
|
||||
}
|
||||
- if (is_topk_moe_norm) {
|
||||
- for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
||||
+ return false;
|
||||
+ };
|
||||
+
|
||||
+ auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
|
||||
+ if (match_pattern(pattern, first_unused)) {
|
||||
+ for (size_t j = 0; j < pattern.size(); ++j) {
|
||||
new_order.push_back(graph->nodes[first_unused + j]);
|
||||
used[first_unused + j] = true;
|
||||
}
|
||||
while (first_unused < graph->n_nodes && used[first_unused]) {
|
||||
first_unused++;
|
||||
}
|
||||
- continue;
|
||||
+ return true;
|
||||
}
|
||||
+ return false;
|
||||
+ };
|
||||
+
|
||||
+ if (keep_pattern(topk_moe_early_softmax_norm)) {
|
||||
+ continue;
|
||||
+ }
|
||||
+ if (keep_pattern(topk_moe_early_softmax)) {
|
||||
+ continue;
|
||||
}
|
||||
+ if (keep_pattern(topk_moe_late_softmax)) {
|
||||
+ continue;
|
||||
+ }
|
||||
+
|
||||
// First, grab the next unused node.
|
||||
current_set.push_back(first_unused);
|
||||
|
||||
@@ -12790,6 +12870,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
||||
if (is_empty(graph->nodes[j])) {
|
||||
continue;
|
||||
}
|
||||
+ // Don't pull forward nodes from fusion patterns
|
||||
+ if (match_pattern(topk_moe_early_softmax_norm, j) ||
|
||||
+ match_pattern(topk_moe_early_softmax, j) ||
|
||||
+ match_pattern(topk_moe_late_softmax, j)) {
|
||||
+ continue;
|
||||
+ }
|
||||
bool ok = true;
|
||||
for (int c = first_unused; c < j; ++c) {
|
||||
if (!used[c] &&
|
||||
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
|
||||
index 9e56d5f8a..bc1c278bf 100644
|
||||
--- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
|
||||
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
|
||||
@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_rows;
|
||||
uint n_expert_used;
|
||||
+ float clamp_min;
|
||||
+ float clamp_max;
|
||||
};
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||
layout(constant_id = 1) const uint n_experts = 512;
|
||||
layout(constant_id = 2) const bool with_norm = true;
|
||||
+layout(constant_id = 3) const bool late_softmax = false;
|
||||
|
||||
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
||||
|
||||
@@ -25,53 +28,72 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];};
|
||||
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
||||
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
|
||||
|
||||
-void main() {
|
||||
- const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
||||
- if (row >= n_rows) {
|
||||
- return;
|
||||
- }
|
||||
+const float INFINITY = 1.0 / 0.0;
|
||||
|
||||
- const uint logits_offset = n_experts * row;
|
||||
- const uint weights_offset = n_expert_used * row;
|
||||
- const uint ids_offset = n_experts * row;
|
||||
-
|
||||
- float logits_r[experts_per_thread];
|
||||
-
|
||||
- const float INFINITY = 1.0 / 0.0;
|
||||
+// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
||||
+void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
|
||||
+ float max_val = -INFINITY;
|
||||
|
||||
[[unroll]]
|
||||
- for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
- const uint expert = i + gl_LocalInvocationID.x;
|
||||
- logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
|
||||
+ for (int i = 0; i < experts_per_thread; i++) {
|
||||
+ const uint idx = lane + i * WARP_SIZE;
|
||||
+ const bool is_active = !use_limit || (idx < limit);
|
||||
+ if (is_active) {
|
||||
+ max_val = max(max_val, vals[i]);
|
||||
+ }
|
||||
}
|
||||
|
||||
- float max_val = logits_r[0];
|
||||
+ max_val = subgroupMax(max_val);
|
||||
+
|
||||
+ float sum = 0.f;
|
||||
|
||||
[[unroll]]
|
||||
- for (int i = 1; i < experts_per_thread; i++) {
|
||||
- const float val = logits_r[i];
|
||||
- max_val = max(val, max_val);
|
||||
+ for (int i = 0; i < experts_per_thread; i++) {
|
||||
+ const uint idx = lane + i * WARP_SIZE;
|
||||
+ const bool is_active = !use_limit || (idx < limit);
|
||||
+ if (is_active) {
|
||||
+ const float val = exp(vals[i] - max_val);
|
||||
+ vals[i] = val;
|
||||
+ sum += val;
|
||||
+ } else {
|
||||
+ vals[i] = 0.f;
|
||||
+ }
|
||||
}
|
||||
|
||||
- max_val = subgroupMax(max_val);
|
||||
+ sum = subgroupAdd(sum);
|
||||
|
||||
- float wt[experts_per_thread];
|
||||
- float tmp = 0.f;
|
||||
+ const float inv_sum = 1.0f / sum;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
- const float val = logits_r[i];
|
||||
- wt[i] = exp(val - max_val);
|
||||
- tmp += wt[i];
|
||||
+ const uint idx = lane + i * WARP_SIZE;
|
||||
+ const bool is_active = !use_limit || (idx < limit);
|
||||
+ if (is_active) {
|
||||
+ vals[i] *= inv_sum;
|
||||
+ }
|
||||
}
|
||||
+}
|
||||
|
||||
- tmp = subgroupAdd(tmp);
|
||||
+void main() {
|
||||
+ const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
||||
+ if (row >= n_rows) {
|
||||
+ return;
|
||||
+ }
|
||||
|
||||
- const float inv_sum = 1.0f / tmp;
|
||||
+ const uint logits_offset = n_experts * row;
|
||||
+ const uint weights_offset = n_expert_used * row;
|
||||
+ const uint ids_offset = n_experts * row;
|
||||
+
|
||||
+ float wt[experts_per_thread];
|
||||
|
||||
[[unroll]]
|
||||
- for (int i = 0; i < experts_per_thread; i++) {
|
||||
- wt[i] = wt[i] * inv_sum;
|
||||
+ for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
+ const uint expert = i + gl_LocalInvocationID.x;
|
||||
+ wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
|
||||
+ }
|
||||
+
|
||||
+ if (!late_softmax) {
|
||||
+ softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
|
||||
}
|
||||
|
||||
// at this point, each thread holds a portion of softmax,
|
||||
@@ -82,6 +104,11 @@ void main() {
|
||||
|
||||
float output_weights[experts_per_thread];
|
||||
|
||||
+ [[unroll]]
|
||||
+ for (int i = 0; i < experts_per_thread; i++) {
|
||||
+ output_weights[i] = 0.f;
|
||||
+ }
|
||||
+
|
||||
for (int k = 0; k < n_expert_used; k++) {
|
||||
float max_val = wt[0];
|
||||
uint max_expert = gl_LocalInvocationID.x;
|
||||
@@ -121,6 +148,7 @@ void main() {
|
||||
|
||||
if (with_norm) {
|
||||
wt_sum = subgroupAdd(wt_sum);
|
||||
+ wt_sum = clamp(wt_sum, clamp_min, clamp_max);
|
||||
const float inv_sum = 1.0f / wt_sum;
|
||||
|
||||
[[unroll]]
|
||||
@@ -129,6 +157,10 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
+ if (late_softmax) {
|
||||
+ softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
|
||||
+ }
|
||||
+
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
|
||||
1242
llama/patches/0032-vulkan-Fuse-rope-set_rows-16769.patch
Normal file
@@ -0,0 +1,85 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Jeff Bolz <jbolz@nvidia.com>
|
||||
Date: Thu, 30 Oct 2025 01:27:41 -0500
|
||||
Subject: [PATCH] vulkan: Handle argsort with a large number of rows (#16851)
|
||||
|
||||
---
|
||||
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++++
|
||||
ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp | 16 ++++++++++++----
|
||||
2 files changed, 16 insertions(+), 4 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index aaf4334b5..3604ceb04 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -1084,6 +1084,7 @@ struct vk_op_soft_max_push_constants {
|
||||
|
||||
struct vk_op_argsort_push_constants {
|
||||
uint32_t ncols;
|
||||
+ uint32_t nrows;
|
||||
int32_t order;
|
||||
};
|
||||
|
||||
@@ -8710,6 +8711,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
|
||||
+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
break;
|
||||
case GGML_OP_IM2COL:
|
||||
{
|
||||
@@ -9952,9 +9954,11 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
||||
int32_t * op_params = (int32_t *)dst->op_params;
|
||||
|
||||
uint32_t ncols = src0->ne[0];
|
||||
+ uint32_t nrows = ggml_nrows(src0);
|
||||
|
||||
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
|
||||
ncols,
|
||||
+ nrows,
|
||||
op_params[0],
|
||||
}, dryrun);
|
||||
}
|
||||
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
|
||||
index c81b84452..c4e68bc02 100644
|
||||
--- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
|
||||
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
|
||||
@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint ncols;
|
||||
+ uint nrows;
|
||||
uint order;
|
||||
} p;
|
||||
|
||||
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
|
||||
dst_row[idx1] = tmp;
|
||||
}
|
||||
|
||||
-void argsort(bool needs_bounds_check) {
|
||||
+void argsort(bool needs_bounds_check, const uint row) {
|
||||
// bitonic sort
|
||||
const int col = int(gl_LocalInvocationID.x);
|
||||
- const uint row = gl_WorkGroupID.y;
|
||||
|
||||
const uint row_offset = row * p.ncols;
|
||||
|
||||
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
|
||||
|
||||
void main() {
|
||||
if (p.ncols == BLOCK_SIZE) {
|
||||
- argsort(false);
|
||||
+ uint row = gl_WorkGroupID.y;
|
||||
+ while (row < p.nrows) {
|
||||
+ argsort(false, row);
|
||||
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||
+ }
|
||||
} else {
|
||||
- argsort(true);
|
||||
+ uint row = gl_WorkGroupID.y;
|
||||
+ while (row < p.nrows) {
|
||||
+ argsort(true, row);
|
||||
+ row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||
+ }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Ruben Ortlam <picard12@live.de>
|
||||
Date: Fri, 31 Oct 2025 08:14:49 +0100
|
||||
Subject: [PATCH] vulkan: fix shmem overrun in mmq id shader (#16873)
|
||||
|
||||
* vulkan: fix shmem overrun in mmq id shader
|
||||
|
||||
* metal : fix mul_mm_id
|
||||
|
||||
---------
|
||||
|
||||
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal-device.cpp | 2 +-
|
||||
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp | 4 ++++
|
||||
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl | 2 +-
|
||||
tests/test-backend-ops.cpp | 3 +++
|
||||
4 files changed, 9 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
index 758116342..c78082ac3 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
|
||||
- snprintf(name, 256, "%s", base);
|
||||
+ snprintf(name, 256, "%s_ne02=%d", base, ne02);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
|
||||
index 8b238ac4b..d955b4fc7 100644
|
||||
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
|
||||
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
|
||||
@@ -82,9 +82,13 @@ layout (constant_id = 10) const uint WARP = 32;
|
||||
|
||||
#include "mul_mmq_shmem_types.glsl"
|
||||
|
||||
+#ifdef MUL_MAT_ID
|
||||
+#define BK_STEP 1
|
||||
+#else
|
||||
#ifndef BK_STEP
|
||||
#define BK_STEP 4
|
||||
#endif
|
||||
+#endif
|
||||
|
||||
// Shared memory cache
|
||||
shared block_a_cache buf_a[BM * BK_STEP];
|
||||
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
|
||||
index 72fec4404..1c0f5306f 100644
|
||||
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
|
||||
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
|
||||
@@ -27,7 +27,7 @@ struct block_a_cache {
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
#define QUANT_R_MMQ 1
|
||||
// AMD likes 4, Intel likes 1 and Nvidia likes 2
|
||||
-#define BK_STEP 1
|
||||
+// #define BK_STEP 1
|
||||
struct block_a_cache {
|
||||
int32_t qs[32/4];
|
||||
FLOAT_TYPE dm;
|
||||
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
|
||||
index 657b6cc2f..1f8dda383 100644
|
||||
--- a/tests/test-backend-ops.cpp
|
||||
+++ b/tests/test-backend-ops.cpp
|
||||
@@ -6722,6 +6722,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
|
||||
|
||||
+ // gpt-oss issue with Vulkan mmq_id
|
||||
+ test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
|
||||
+
|
||||
for (ggml_type type_a : base_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
|
||||
for (int n_mats : {4, 8}) {
|
||||
@@ -0,0 +1,80 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Masato Nakasaka <masato.nakasaka@intel.com>
|
||||
Date: Fri, 31 Oct 2025 16:18:59 +0900
|
||||
Subject: [PATCH] vulkan: Fix crash when FP16 mul_mat accumulation is not
|
||||
supported (#16796)
|
||||
|
||||
* Experimenting crash fix
|
||||
|
||||
* added assert for aborting and fixed comment
|
||||
|
||||
* changed to check if a pipeline is empty or not
|
||||
|
||||
* Moved function in class definition
|
||||
|
||||
* replaced with is_empty
|
||||
|
||||
* Modified is_empty to check only unaligned pipelines
|
||||
---
|
||||
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 +++++++++++++-------
|
||||
1 file changed, 13 insertions(+), 7 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index 3604ceb04..80185d9f0 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -146,8 +146,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
|
||||
struct vk_matmul_pipeline_struct {
|
||||
vk_pipeline l, m, s;
|
||||
vk_pipeline a_l, a_m, a_s;
|
||||
+ // Returns true when all unaligned pipelines are null.
|
||||
+ // We only check for unaligned variants since one of the unaligned pipelines must exist
|
||||
+ // while aligned pipelines are optional
|
||||
+ bool is_empty() const {
|
||||
+ return l == nullptr && m == nullptr && s == nullptr;
|
||||
+ }
|
||||
};
|
||||
-
|
||||
typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
|
||||
|
||||
struct vk_matmul_pipeline2 {
|
||||
@@ -5080,7 +5085,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
||||
if (src1_type == GGML_TYPE_Q8_1) {
|
||||
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
|
||||
|
||||
- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
|
||||
+ if (pipelines->is_empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -5229,7 +5234,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
||||
if (src1_type == GGML_TYPE_Q8_1) {
|
||||
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
|
||||
|
||||
- if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
|
||||
+ if (pipelines->is_empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -5264,16 +5269,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
+ vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
|
||||
// XXX TODO 'prec' is not actually allowed in mul_mat_id.
|
||||
bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
|
||||
- bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr;
|
||||
- bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr;
|
||||
+ bool support_fp16acc = !mmp.f16acc->is_empty();
|
||||
+ bool support_fp32acc = !mmp.f32acc->is_empty();
|
||||
|
||||
if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
|
||||
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc;
|
||||
+ return mmp.f16acc;
|
||||
} else {
|
||||
GGML_ASSERT(support_fp32acc);
|
||||
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
|
||||
+ return mmp.f32acc;
|
||||
}
|
||||
}
|
||||
|
||||
516
llm/memory.go
@@ -1,516 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// pickBestFullFitByLibrary will try to find the optimal placement of the model in the available GPUs where the model fully fits
|
||||
// The list of GPUs returned will always be the same brand (library)
|
||||
// If the model can not be fit fully within the available GPU(s) nil is returned
|
||||
func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []string, adapters []string, opts api.Options, gpus []ml.DeviceInfo, numParallel int) []ml.DeviceInfo {
|
||||
for _, gl := range ml.ByLibrary(gpus) {
|
||||
sgl := append(make([]ml.DeviceInfo, 0, len(gl)), gl...)
|
||||
|
||||
// TODO - potentially sort by performance capability, existing models loaded, etc.
|
||||
// TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them
|
||||
// Note: at present, this will favor most current available VRAM descending and ignoring faster GPU speed in mixed setups
|
||||
sort.Sort(sort.Reverse(ml.ByFreeMemory(sgl)))
|
||||
|
||||
if !envconfig.SchedSpread() {
|
||||
// Try to pack into as few GPUs as possible, starting from 1 GPU
|
||||
for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ {
|
||||
gpuSubset := sgl[:numGPUs]
|
||||
ok, estimatedVRAM := predictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel)
|
||||
|
||||
if ok {
|
||||
slog.Info("new model will fit in available VRAM across minimum required GPUs, loading",
|
||||
"model", modelPath,
|
||||
"library", sgl[0].Library,
|
||||
"parallel", numParallel,
|
||||
"required", format.HumanBytes2(estimatedVRAM),
|
||||
"gpus", numGPUs)
|
||||
return gpuSubset
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// TODO future refinements
|
||||
// - if multiple Libraries, see if any single GPU in any Library will fit
|
||||
// - try subsets of GPUs instead of just falling back to 1 or all in a family
|
||||
|
||||
// Now try all the GPUS (OLLAMA_SCHED_SPREAD is set)
|
||||
if ok, estimatedVRAM := predictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok {
|
||||
slog.Info("new model will fit in available VRAM, loading",
|
||||
"model", modelPath,
|
||||
"library", sgl[0].Library,
|
||||
"parallel", numParallel,
|
||||
"required", format.HumanBytes2(estimatedVRAM),
|
||||
"gpus", len(sgl))
|
||||
return sgl
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If multiple Libraries are detected, pick the Library which loads the most layers for the model
|
||||
func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []string, opts api.Options, gpus []ml.DeviceInfo, numParallel int) []ml.DeviceInfo {
|
||||
byLibrary := ml.ByLibrary(gpus)
|
||||
if len(byLibrary) <= 1 {
|
||||
return gpus
|
||||
}
|
||||
var bestEstimate uint64
|
||||
var bestFit int
|
||||
for i, gl := range byLibrary {
|
||||
_, estimatedVRAM := predictServerFit(gl, f, adapters, projectors, opts, numParallel)
|
||||
if estimatedVRAM > bestEstimate {
|
||||
bestEstimate = estimatedVRAM
|
||||
bestFit = i
|
||||
}
|
||||
}
|
||||
return byLibrary[bestFit]
|
||||
}
|
||||
|
||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||
func predictServerFit(allGpus []ml.DeviceInfo, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
||||
// Split up the GPUs by type and try them
|
||||
var estimatedVRAM uint64
|
||||
for _, gpus := range ml.ByLibrary(allGpus) {
|
||||
var layerCount int
|
||||
estimate := estimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
||||
if opts.NumGPU < 0 {
|
||||
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
|
||||
return true, estimatedVRAM
|
||||
}
|
||||
} else {
|
||||
if layerCount > 0 && layerCount >= opts.NumGPU {
|
||||
return true, estimatedVRAM
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, estimatedVRAM
|
||||
}
|
||||
|
||||
func verifyCPUFit(f *ggml.GGML, modelPath string, projectors []string, adapters []string, opts api.Options, systemInfo ml.SystemInfo, numParallel int) bool {
|
||||
estimate := estimateGPULayers(nil, f, projectors, opts, numParallel)
|
||||
if estimate.TotalSize > systemInfo.FreeMemory {
|
||||
return false
|
||||
}
|
||||
slog.Info("new model will fit in available system memory for CPU inference, loading",
|
||||
"model", modelPath,
|
||||
"parallel", numParallel,
|
||||
"required", format.HumanBytes2(estimate.TotalSize),
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
type MemoryEstimate struct {
|
||||
// How many layers we predict we can load
|
||||
Layers int
|
||||
|
||||
// The size of the graph which occupies the main GPU
|
||||
Graph uint64
|
||||
|
||||
// How much VRAM will be allocated given the number of layers we predict
|
||||
VRAMSize uint64
|
||||
|
||||
// The total size of the model if loaded into VRAM. If all layers are loaded, VRAMSize == TotalSize
|
||||
TotalSize uint64
|
||||
|
||||
// For multi-GPU scenarios, this provides the tensor split parameter
|
||||
TensorSplit []int
|
||||
|
||||
// For multi-GPU scenarios, this is the size in bytes per GPU
|
||||
GPUSizes []uint64
|
||||
|
||||
// internal fields for logging purposes
|
||||
inferenceLibrary string
|
||||
layersRequested int
|
||||
layersModel int
|
||||
availableList []string
|
||||
kv uint64
|
||||
allocationsList []string
|
||||
memoryWeights uint64
|
||||
memoryLayerOutput uint64
|
||||
graphFullOffload uint64
|
||||
graphPartialOffload uint64
|
||||
|
||||
projectorWeights, projectorGraph uint64
|
||||
}
|
||||
|
||||
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
||||
// The GPUs provided must all be the same Library
|
||||
func estimateGPULayers(gpus []ml.DeviceInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
|
||||
// Graph size for a partial offload, applies to all GPUs
|
||||
var graphPartialOffload uint64
|
||||
|
||||
// Graph size when all layers are offloaded, applies to all GPUs
|
||||
var graphFullOffload uint64
|
||||
|
||||
// Final graph offload once we know full or partial
|
||||
var graphOffload uint64
|
||||
|
||||
// Projectors loaded into GPU0 only
|
||||
var llamaEngineProjectorWeights uint64
|
||||
|
||||
// Projectors loaded with output layer
|
||||
var ollamaEngineProjectorWeights uint64
|
||||
var ollamaEngineProjectorGraph uint64
|
||||
|
||||
// Conditional output size on GPU 0
|
||||
var memoryLayerOutput uint64
|
||||
|
||||
// The sizes of a layer
|
||||
var layerSize uint64
|
||||
|
||||
// The sum of all the layer sizes (just for logging)
|
||||
var memoryWeights uint64
|
||||
|
||||
// True if all the layers are loaded
|
||||
var fullyLoaded bool
|
||||
|
||||
// Overflow that didn't fit into the GPU
|
||||
var overflow uint64
|
||||
|
||||
overhead := envconfig.GpuOverhead()
|
||||
availableList := make([]string, len(gpus))
|
||||
libraries := []string{}
|
||||
for i, gpu := range gpus {
|
||||
availableList[i] = format.HumanBytes2(gpu.FreeMemory)
|
||||
if !slices.Contains(libraries, gpu.Library) {
|
||||
libraries = append(libraries, gpu.Library)
|
||||
}
|
||||
}
|
||||
if len(libraries) == 0 {
|
||||
libraries = []string{"cpu"}
|
||||
}
|
||||
slog.Debug("evaluating", "library", strings.Join(libraries, ","), "gpu_count", len(gpus), "available", availableList)
|
||||
|
||||
for _, projector := range projectors {
|
||||
llamaEngineProjectorWeights += projectorMemoryRequirements(projector)
|
||||
}
|
||||
if llamaEngineProjectorWeights == 0 {
|
||||
ollamaEngineProjectorWeights, ollamaEngineProjectorGraph = f.VisionGraphSize()
|
||||
}
|
||||
|
||||
layers := f.Tensors().GroupLayers()
|
||||
// add one layer worth of memory as a buffer
|
||||
if blk0, ok := layers["blk.0"]; ok {
|
||||
layerSize = blk0.Size()
|
||||
} else {
|
||||
slog.Warn("model missing blk.0 layer size")
|
||||
}
|
||||
|
||||
useFlashAttention := envconfig.FlashAttention(f.FlashAttention()) &&
|
||||
ml.FlashAttentionSupported(gpus) &&
|
||||
f.SupportsFlashAttention()
|
||||
|
||||
var kvct string
|
||||
if useFlashAttention {
|
||||
requested := strings.ToLower(envconfig.KvCacheType())
|
||||
if f.SupportsKVCacheType(requested) {
|
||||
kvct = requested
|
||||
}
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention)
|
||||
|
||||
if len(kv) > 0 {
|
||||
layerSize += kv[0]
|
||||
}
|
||||
|
||||
var kvTotal uint64
|
||||
for _, kvLayer := range kv {
|
||||
kvTotal += kvLayer
|
||||
}
|
||||
|
||||
if graphPartialOffload == 0 {
|
||||
headsKV := f.KV().HeadCountKVMin()
|
||||
if headsKV == 0 {
|
||||
headsKV = 1
|
||||
}
|
||||
gqa := f.KV().HeadCountMax() / headsKV
|
||||
graphPartialOffload = gqa * kvTotal / 6
|
||||
}
|
||||
if graphFullOffload == 0 {
|
||||
graphFullOffload = graphPartialOffload
|
||||
}
|
||||
|
||||
// on metal there's no partial offload overhead
|
||||
if len(gpus) > 0 && gpus[0].Library == "Metal" {
|
||||
graphPartialOffload = graphFullOffload
|
||||
} else if len(gpus) > 1 {
|
||||
// multigpu should always use the partial graph size
|
||||
graphFullOffload = graphPartialOffload
|
||||
}
|
||||
|
||||
// Output layer handled at the end if we have space
|
||||
if layer, ok := layers["output_norm"]; ok {
|
||||
memoryLayerOutput += layer.Size()
|
||||
}
|
||||
if layer, ok := layers["output"]; ok {
|
||||
memoryLayerOutput += layer.Size()
|
||||
} else if layer, ok := layers["token_embd"]; ok {
|
||||
memoryLayerOutput += layer.Size()
|
||||
}
|
||||
|
||||
gpuZeroOverhead := llamaEngineProjectorWeights
|
||||
|
||||
// Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer
|
||||
var layerCount int
|
||||
tensorSplit := make([]int, len(gpus))
|
||||
gpuAllocations := make([]uint64, len(gpus))
|
||||
type gs struct {
|
||||
i int
|
||||
g *ml.DeviceInfo
|
||||
}
|
||||
gpusWithSpace := []gs{}
|
||||
for i := range gpus {
|
||||
var gzo uint64
|
||||
if len(gpusWithSpace) == 0 {
|
||||
gzo = gpuZeroOverhead
|
||||
}
|
||||
// Only include GPUs that can fit the graph, gpu minimum, the layer buffer and at least more layer
|
||||
if gpus[i].FreeMemory < overhead+gzo+max(graphPartialOffload, graphFullOffload)+gpus[i].MinimumMemory()+2*layerSize {
|
||||
slog.Debug("gpu has too little memory to allocate any layers",
|
||||
"id", gpus[i].ID,
|
||||
"library", gpus[i].Library,
|
||||
"compute", gpus[i].Compute(),
|
||||
"driver", fmt.Sprintf("%d.%d", gpus[i].DriverMajor, gpus[i].DriverMinor),
|
||||
"name", gpus[i].Name,
|
||||
"total", format.HumanBytes2(gpus[i].TotalMemory),
|
||||
"available", format.HumanBytes2(gpus[i].FreeMemory),
|
||||
"minimum_memory", gpus[i].MinimumMemory,
|
||||
"layer_size", format.HumanBytes2(layerSize),
|
||||
"gpu_zer_overhead", format.HumanBytes2(gzo),
|
||||
"partial_offload", format.HumanBytes2(graphPartialOffload),
|
||||
"full_offload", format.HumanBytes2(graphFullOffload),
|
||||
)
|
||||
continue
|
||||
}
|
||||
gpusWithSpace = append(gpusWithSpace, gs{i, &gpus[i]})
|
||||
gpuAllocations[i] += gpus[i].MinimumMemory() + layerSize // We hold off on graph until we know partial vs. full
|
||||
}
|
||||
|
||||
var gpuZeroID int
|
||||
if len(gpusWithSpace) > 0 {
|
||||
gpuZeroID = gpusWithSpace[0].i
|
||||
gpuAllocations[gpuZeroID] += gpuZeroOverhead
|
||||
} else {
|
||||
overflow += gpuZeroOverhead
|
||||
}
|
||||
|
||||
// For all the layers, find where they can fit on the GPU(s)
|
||||
for i := int(f.KV().BlockCount()) - 1; i >= 0; i-- {
|
||||
// Some models have inconsistent layer sizes
|
||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||
layerSize = blk.Size()
|
||||
layerSize += kv[i]
|
||||
memoryWeights += blk.Size()
|
||||
}
|
||||
|
||||
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
|
||||
// Stop allocating on GPU(s) once we hit the users target NumGPU
|
||||
overflow += layerSize
|
||||
continue
|
||||
}
|
||||
|
||||
// distribute the layers across the GPU(s) that have space
|
||||
for j := len(gpusWithSpace); j > 0; j-- {
|
||||
g := gpusWithSpace[i%j]
|
||||
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
|
||||
if g.g.FreeMemory > overhead+used+layerSize {
|
||||
gpuAllocations[g.i] += layerSize
|
||||
tensorSplit[g.i]++
|
||||
layerCount++
|
||||
break
|
||||
} else {
|
||||
gpusWithSpace = append(gpusWithSpace[:i%j], gpusWithSpace[i%j+1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(gpusWithSpace) == 0 {
|
||||
overflow += layerSize
|
||||
}
|
||||
}
|
||||
if layerCount >= int(f.KV().BlockCount()) {
|
||||
fullyLoaded = true
|
||||
}
|
||||
|
||||
// Determine if we need to consider output then find where it fits
|
||||
memoryLastLayer := memoryLayerOutput + ollamaEngineProjectorWeights + ollamaEngineProjectorGraph
|
||||
if memoryLastLayer > 0 {
|
||||
if opts.NumGPU < 0 || layerCount < opts.NumGPU {
|
||||
for j := len(gpusWithSpace); j > 0; j-- {
|
||||
g := gpusWithSpace[layerCount%j]
|
||||
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
|
||||
if g.g.FreeMemory > overhead+used+memoryLastLayer {
|
||||
gpuAllocations[g.i] += memoryLastLayer
|
||||
tensorSplit[g.i]++
|
||||
layerCount++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if layerCount < int(f.KV().BlockCount())+1 {
|
||||
fullyLoaded = false
|
||||
overflow += memoryLastLayer
|
||||
}
|
||||
}
|
||||
|
||||
// Add the applicable (full or partial) graph allocations
|
||||
for i := range gpus {
|
||||
if tensorSplit[i] <= 0 {
|
||||
continue
|
||||
}
|
||||
if fullyLoaded {
|
||||
gpuAllocations[i] += graphFullOffload
|
||||
} else {
|
||||
gpuAllocations[i] += graphPartialOffload
|
||||
}
|
||||
}
|
||||
if fullyLoaded {
|
||||
graphOffload = graphFullOffload
|
||||
} else {
|
||||
graphOffload = graphPartialOffload
|
||||
}
|
||||
|
||||
// Summaries for the log
|
||||
var memoryRequiredPartial, memoryRequiredTotal uint64
|
||||
for i := range gpuAllocations {
|
||||
memoryRequiredPartial += gpuAllocations[i]
|
||||
}
|
||||
memoryRequiredTotal = memoryRequiredPartial + overflow
|
||||
|
||||
allocationsList := []string{}
|
||||
for _, a := range gpuAllocations {
|
||||
allocationsList = append(allocationsList, format.HumanBytes2(a))
|
||||
}
|
||||
|
||||
estimate := MemoryEstimate{
|
||||
TotalSize: memoryRequiredTotal,
|
||||
Layers: 0,
|
||||
Graph: 0,
|
||||
VRAMSize: 0,
|
||||
GPUSizes: []uint64{},
|
||||
|
||||
inferenceLibrary: strings.Join(libraries, ","),
|
||||
layersRequested: opts.NumGPU,
|
||||
layersModel: int(f.KV().BlockCount()) + 1,
|
||||
availableList: availableList,
|
||||
kv: kvTotal,
|
||||
allocationsList: allocationsList,
|
||||
memoryWeights: memoryWeights,
|
||||
memoryLayerOutput: memoryLayerOutput,
|
||||
graphFullOffload: graphFullOffload,
|
||||
graphPartialOffload: graphPartialOffload,
|
||||
projectorWeights: llamaEngineProjectorWeights + ollamaEngineProjectorWeights,
|
||||
projectorGraph: ollamaEngineProjectorGraph,
|
||||
}
|
||||
|
||||
if len(gpus) == 0 {
|
||||
return estimate
|
||||
}
|
||||
if layerCount == 0 {
|
||||
slog.Debug("insufficient VRAM to load any model layers")
|
||||
return estimate
|
||||
}
|
||||
estimate.Layers = layerCount
|
||||
estimate.Graph = graphOffload
|
||||
estimate.VRAMSize = memoryRequiredPartial
|
||||
estimate.TotalSize = memoryRequiredTotal
|
||||
estimate.TensorSplit = tensorSplit
|
||||
estimate.GPUSizes = gpuAllocations
|
||||
return estimate
|
||||
}
|
||||
|
||||
func (m MemoryEstimate) LogValue() slog.Value {
|
||||
attrs := []slog.Attr{
|
||||
slog.String("library", m.inferenceLibrary),
|
||||
slog.Group(
|
||||
"layers",
|
||||
// requested number of layers to offload
|
||||
"requested", m.layersRequested,
|
||||
// The number of layers the model has (including output)
|
||||
"model", m.layersModel,
|
||||
// estimated number of layers that can be offloaded
|
||||
"offload", m.Layers,
|
||||
// multi-gpu split for tensors
|
||||
"split", m.TensorSplit,
|
||||
),
|
||||
slog.Group(
|
||||
"memory",
|
||||
// memory available by GPU for offloading
|
||||
"available", m.availableList,
|
||||
"gpu_overhead", format.HumanBytes2(envconfig.GpuOverhead()),
|
||||
slog.Group(
|
||||
"required",
|
||||
// memory required for full offloading
|
||||
"full", format.HumanBytes2(m.TotalSize),
|
||||
// memory required to offload layers.estimate layers
|
||||
"partial", format.HumanBytes2(m.VRAMSize),
|
||||
// memory of KV cache
|
||||
"kv", format.HumanBytes2(m.kv),
|
||||
// Allocations across the GPUs
|
||||
"allocations", m.allocationsList,
|
||||
),
|
||||
slog.Group(
|
||||
"weights",
|
||||
// memory of the weights
|
||||
"total", format.HumanBytes2(m.memoryWeights+m.memoryLayerOutput),
|
||||
// memory of repeating layers
|
||||
"repeating", format.HumanBytes2(m.memoryWeights),
|
||||
// memory of non-repeating layers
|
||||
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
|
||||
),
|
||||
slog.Group(
|
||||
"graph",
|
||||
// memory of graph when fully offloaded
|
||||
"full", format.HumanBytes2(m.graphFullOffload),
|
||||
// memory of graph when not fully offloaded
|
||||
"partial", format.HumanBytes2(m.graphPartialOffload),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
if m.projectorWeights > 0 {
|
||||
attrs = append(attrs, slog.Group(
|
||||
"projector",
|
||||
"weights", format.HumanBytes2(m.projectorWeights),
|
||||
"graph", format.HumanBytes2(m.projectorGraph),
|
||||
))
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
func projectorMemoryRequirements(filename string) (weights uint64) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
ggml, err := ggml.Decode(file, 1024)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
for _, layer := range ggml.Tensors().GroupLayers() {
|
||||
weights += layer.Size()
|
||||
}
|
||||
|
||||
return weights
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func TestEstimateGPULayers(t *testing.T) {
|
||||
t.Setenv("OLLAMA_DEBUG", "1")
|
||||
t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "2048")
|
||||
|
||||
modelName := "dummy"
|
||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||
require.NoError(t, err)
|
||||
defer f.Close()
|
||||
inputLayerCount := 5
|
||||
|
||||
tensors := []*ggml.Tensor{
|
||||
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
{Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
{Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
}
|
||||
assert.Len(t, tensors, inputLayerCount+1)
|
||||
err = ggml.WriteGGUF(f, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.context_length": uint32(32),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.block_count": uint32(inputLayerCount),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(32),
|
||||
"tokenizer.ggml.tokens": []string{" "},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, tensors)
|
||||
require.NoError(t, err)
|
||||
|
||||
ggml, err := LoadModel(f.Name(), 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Simple CPU scenario
|
||||
gpus := []ml.DeviceInfo{}
|
||||
projectors := []string{}
|
||||
opts := api.DefaultOptions()
|
||||
t.Run("cpu", func(t *testing.T) {
|
||||
estimate := estimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||
assert.Equal(t, 0, estimate.Layers)
|
||||
assert.Equal(t, uint64(0), estimate.Graph)
|
||||
})
|
||||
|
||||
// derived from the dummy ggml file above
|
||||
graphPartialOffload := uint64(202377216)
|
||||
graphFullOffload := uint64(171968512)
|
||||
layerSize := uint64(33554436)
|
||||
projectorSize := uint64(0)
|
||||
memoryLayerOutput := uint64(4)
|
||||
|
||||
// Dual CUDA scenario with asymmetry
|
||||
gpuMinimumMemory := uint64(457 * format.MebiByte)
|
||||
gpus = []ml.DeviceInfo{
|
||||
{
|
||||
DeviceID: ml.DeviceID{
|
||||
Library: "CUDA",
|
||||
},
|
||||
},
|
||||
{
|
||||
DeviceID: ml.DeviceID{
|
||||
Library: "CUDA",
|
||||
},
|
||||
},
|
||||
}
|
||||
// Nested array: GPU0 layer space, GPU1 layer space, expected gpu0, expected gpu1
|
||||
for i, s := range []struct {
|
||||
layer0, layer1 uint64
|
||||
expect0, expect1 int
|
||||
}{
|
||||
{1, 1, 1, 1},
|
||||
{2, 1, 2, 1},
|
||||
{2, 2, 2, 2},
|
||||
{1, 2, 1, 2},
|
||||
{3, 3, 3, 3},
|
||||
{4, 4, 3, 3},
|
||||
{6, 6, 3, 3},
|
||||
{0, 3, 0, 3},
|
||||
} {
|
||||
t.Run(fmt.Sprintf("%v", s), func(t *testing.T) {
|
||||
gpus[0].FreeMemory = 0
|
||||
gpus[1].FreeMemory = 0
|
||||
gpus[0].FreeMemory += projectorSize
|
||||
if s.layer0 > 0 {
|
||||
gpus[0].FreeMemory += memoryLayerOutput
|
||||
} else {
|
||||
gpus[1].FreeMemory += memoryLayerOutput
|
||||
}
|
||||
gpus[0].FreeMemory += gpuMinimumMemory + layerSize + s.layer0*layerSize + 1
|
||||
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
|
||||
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||
estimate := estimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||
assert.Equal(t, s.expect0+s.expect1, estimate.Layers, "scenario %d: %v", i, s)
|
||||
assert.Equal(t, []int{s.expect0, s.expect1}, estimate.TensorSplit, "scenario %d: %v", i, s)
|
||||
var layerSums uint64
|
||||
for _, b := range estimate.GPUSizes {
|
||||
layerSums += b
|
||||
}
|
||||
if estimate.Layers < inputLayerCount+1 {
|
||||
assert.Less(t, estimate.VRAMSize, estimate.TotalSize, "scenario %d: %v %+v", i, s, estimate)
|
||||
assert.Equal(t, estimate.VRAMSize, layerSums, "scenario %d: %v %+v", i, s, estimate)
|
||||
} else {
|
||||
assert.Equal(t, estimate.VRAMSize, estimate.TotalSize, "scenario %d: %v %+v", i, s, estimate)
|
||||
assert.Equal(t, estimate.TotalSize, layerSums, "scenario %d: %v %+v", i, s, estimate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
439
llm/server.go
@@ -89,20 +89,16 @@ type llmServer struct {
|
||||
done chan error // Channel to signal when the process exits
|
||||
status *StatusWriter
|
||||
options api.Options
|
||||
numParallel int
|
||||
modelPath string
|
||||
|
||||
loadRequest LoadRequest // Parameters used to initialize the runner
|
||||
mem *ml.BackendMemory // Memory allocations for this model
|
||||
|
||||
// llamaModel is an instance of the cgo llama.cpp model definition
|
||||
// nil if this server is running the new engine
|
||||
llamaModel *llama.Model
|
||||
llamaModelLock *sync.Mutex
|
||||
|
||||
// textProcessor handles text encoding/decoding for the model in the Ollama engine
|
||||
// nil if this server is running the llama.cpp based engine
|
||||
textProcessor model.TextProcessor
|
||||
|
||||
totalLayers uint64
|
||||
loadStart time.Time // Record how long it took the model to load
|
||||
loadProgress float32
|
||||
@@ -114,14 +110,12 @@ type llamaServer struct {
|
||||
llmServer
|
||||
|
||||
ggml *ggml.GGML
|
||||
gpus []ml.DeviceInfo // The set of GPUs covered by the memory estimate
|
||||
estimate MemoryEstimate
|
||||
}
|
||||
|
||||
type ollamaServer struct {
|
||||
llmServer
|
||||
|
||||
mem *ml.BackendMemory
|
||||
textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
|
||||
}
|
||||
|
||||
// LoadModel will load a model from disk. The model must be in the GGML format.
|
||||
@@ -245,8 +239,6 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
loadRequest: loadRequest,
|
||||
llamaModel: llamaModel,
|
||||
llamaModelLock: &sync.Mutex{},
|
||||
textProcessor: textProcessor,
|
||||
numParallel: numParallel,
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
totalLayers: f.KV().BlockCount() + 1,
|
||||
loadStart: time.Now(),
|
||||
@@ -281,7 +273,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
}()
|
||||
|
||||
if textProcessor != nil {
|
||||
return &ollamaServer{llmServer: s}, nil
|
||||
return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
|
||||
} else {
|
||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||
}
|
||||
@@ -463,76 +455,173 @@ type LoadResponse struct {
|
||||
|
||||
var ErrLoadRequiredFull = errors.New("unable to load full model on GPU")
|
||||
|
||||
func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
systemTotalMemory := systemInfo.TotalMemory
|
||||
systemFreeMemory := systemInfo.FreeMemory
|
||||
systemSwapFreeMemory := systemInfo.FreeSwap
|
||||
slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
|
||||
func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU)
|
||||
|
||||
if len(gpus) == 0 || s.options.NumGPU == 0 {
|
||||
if !verifyCPUFit(s.ggml, s.modelPath, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, systemInfo, s.numParallel) {
|
||||
slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate)
|
||||
return nil, fmt.Errorf("model requires more system memory than is currently available %w", ErrLoadRequiredFull)
|
||||
gpus := append(make([]ml.DeviceInfo, 0, len(systemGPUs)), systemGPUs...)
|
||||
|
||||
// Synthesize memory allocation information based on our estimates
|
||||
s.mem = &ml.BackendMemory{CPU: ml.DeviceMemory{
|
||||
Name: "CPU",
|
||||
Weights: make([]uint64, s.totalLayers),
|
||||
Cache: make([]uint64, s.totalLayers),
|
||||
}, GPUs: make([]ml.DeviceMemory, len(gpus))}
|
||||
|
||||
for i := range s.mem.GPUs {
|
||||
s.mem.GPUs[i].Name = gpus[i].Name
|
||||
s.mem.GPUs[i].DeviceID = gpus[i].DeviceID
|
||||
s.mem.GPUs[i].Weights = make([]uint64, s.totalLayers)
|
||||
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := s.ggml.GraphSize(uint64(s.options.NumCtx), uint64(s.loadRequest.BatchSize),
|
||||
s.loadRequest.Parallel, s.loadRequest.KvCacheType, s.loadRequest.FlashAttention)
|
||||
|
||||
// Use the size of one layer as a buffer
|
||||
layers := s.ggml.Tensors().GroupLayers()
|
||||
if blk0, ok := layers["blk.0"]; ok {
|
||||
for i := range gpus {
|
||||
gpus[i].FreeMemory -= blk0.Size() + kv[0]
|
||||
}
|
||||
} else {
|
||||
g := pickBestFullFitByLibrary(s.ggml, s.modelPath, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
|
||||
if g == nil {
|
||||
if !requireFull {
|
||||
g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
|
||||
} else {
|
||||
slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate)
|
||||
return nil, ErrLoadRequiredFull
|
||||
}
|
||||
}
|
||||
gpus = g
|
||||
slog.Warn("model missing blk.0 layer size")
|
||||
}
|
||||
|
||||
s.estimate = estimateGPULayers(gpus, s.ggml, []string{s.loadRequest.ProjectorPath}, s.options, s.numParallel)
|
||||
// Assign all the layers to the CPU for now, they will get reassigned later
|
||||
for i := range s.ggml.KV().BlockCount() {
|
||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||
s.mem.CPU.Weights[i] = blk.Size()
|
||||
s.mem.CPU.Cache[i] += kv[i]
|
||||
}
|
||||
}
|
||||
|
||||
if len(gpus) >= 1 {
|
||||
switch {
|
||||
case s.options.NumGPU == 0:
|
||||
gpus = []ml.DeviceInfo{}
|
||||
case gpus[0].Library == "Metal" && s.estimate.VRAMSize > systemInfo.TotalMemory:
|
||||
// disable partial offloading when model is greater than total system memory as this
|
||||
// can lead to locking up the system
|
||||
s.options.NumGPU = 0
|
||||
gpus = []ml.DeviceInfo{}
|
||||
case gpus[0].Library != "Metal" && s.estimate.Layers == 0:
|
||||
// Don't bother loading into the GPU if no layers can fit
|
||||
gpus = []ml.DeviceInfo{}
|
||||
case s.options.NumGPU < 0 && s.estimate.Layers > 0:
|
||||
s.options.NumGPU = s.estimate.Layers
|
||||
// We historically haven't included InputWeights in the model size
|
||||
var outputWeights uint64
|
||||
if layer, ok := layers["output_norm"]; ok {
|
||||
outputWeights += layer.Size()
|
||||
}
|
||||
if layer, ok := layers["output"]; ok {
|
||||
outputWeights += layer.Size()
|
||||
} else if layer, ok := layers["token_embd"]; ok {
|
||||
outputWeights += layer.Size()
|
||||
}
|
||||
s.mem.CPU.Weights[s.totalLayers-1] = outputWeights
|
||||
|
||||
// The vision projector is always loaded on the first GPU if available.
|
||||
// This can't be assigned by us, so just subtract it from free space
|
||||
projectorGPU := -1
|
||||
var projectorWeights uint64
|
||||
if len(gpus) > 0 {
|
||||
for _, projector := range s.loadRequest.LoraPath {
|
||||
projectorWeights += projectorMemoryRequirements(projector)
|
||||
}
|
||||
|
||||
// llama.cpp uses the first discrete GPU if available, otherwise the first iGPU
|
||||
firstIntegrated := -1
|
||||
for i := range gpus {
|
||||
if !gpus[i].Integrated {
|
||||
projectorGPU = i
|
||||
break
|
||||
}
|
||||
if firstIntegrated == -1 {
|
||||
firstIntegrated = i
|
||||
}
|
||||
}
|
||||
if projectorGPU == -1 {
|
||||
projectorGPU = firstIntegrated
|
||||
}
|
||||
|
||||
gpus[projectorGPU].FreeMemory -= projectorWeights
|
||||
}
|
||||
|
||||
var kvTotal uint64
|
||||
for _, kvLayer := range kv {
|
||||
kvTotal += kvLayer
|
||||
}
|
||||
|
||||
if graphPartialOffload == 0 {
|
||||
headsKV := s.ggml.KV().HeadCountKVMin()
|
||||
if headsKV == 0 {
|
||||
headsKV = 1
|
||||
}
|
||||
gqa := s.ggml.KV().HeadCountMax() / headsKV
|
||||
graphPartialOffload = gqa * kvTotal / 6
|
||||
}
|
||||
if graphFullOffload == 0 {
|
||||
graphFullOffload = graphPartialOffload
|
||||
}
|
||||
|
||||
// On Metal there's no partial offload overhead
|
||||
if len(gpus) > 0 && gpus[0].Library == "Metal" {
|
||||
graphPartialOffload = graphFullOffload
|
||||
}
|
||||
|
||||
// Create a layout based on the memory data that we've built. The compute graph
|
||||
// for GPUs is iteratively assigned based on the number of GPUs that are required.
|
||||
var gpuLayers ml.GPULayersList
|
||||
for {
|
||||
prevGPULayers := gpuLayers
|
||||
|
||||
var err error
|
||||
gpuLayers, err = s.createLayout(systemInfo, gpus, s.mem, requireFull, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(gpuLayers) > len(prevGPULayers) {
|
||||
for _, gl := range gpuLayers {
|
||||
for i := range s.mem.GPUs {
|
||||
if gl.DeviceID == s.mem.GPUs[i].DeviceID {
|
||||
s.mem.GPUs[i].Graph = max(graphPartialOffload, graphFullOffload)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.options.NumGPU = 0
|
||||
}
|
||||
|
||||
// On linux and windows, over-allocating CPU memory will almost always result in an error
|
||||
// Darwin has fully dynamic swap so has no direct concept of free swap space
|
||||
if runtime.GOOS != "darwin" {
|
||||
systemMemoryRequired := s.estimate.TotalSize - s.estimate.VRAMSize
|
||||
available := systemInfo.FreeMemory + systemInfo.FreeSwap
|
||||
if systemMemoryRequired > available {
|
||||
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.TotalMemory), "free", format.HumanBytes2(systemInfo.FreeMemory), "swap", format.HumanBytes2(systemInfo.FreeSwap))
|
||||
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("offload", "", s.estimate)
|
||||
// This maintains the historical assignment of graph sizes, though it isn't fully accurate
|
||||
graphSize := graphFullOffload
|
||||
if gpuLayers.Sum() < int(s.totalLayers) {
|
||||
graphSize = graphPartialOffload
|
||||
}
|
||||
|
||||
s.gpus = gpus
|
||||
s.loadRequest.GPULayers = createGPULayers(s.estimate, s.ggml, gpus, s.options.NumGPU)
|
||||
// For all layers that we have assigned to GPUs, move them in the memory data so
|
||||
// that it is reported accurately
|
||||
for _, gl := range gpuLayers {
|
||||
for i := range s.mem.GPUs {
|
||||
if gl.DeviceID == s.mem.GPUs[i].DeviceID {
|
||||
for _, l := range gl.Layers {
|
||||
s.mem.GPUs[i].Weights[l] = s.mem.CPU.Weights[l]
|
||||
s.mem.GPUs[i].Cache[l] = s.mem.CPU.Cache[l]
|
||||
|
||||
// Mmap is only supported on the llama engine
|
||||
if s.textProcessor == nil {
|
||||
s.mem.CPU.Weights[l] = 0
|
||||
s.mem.CPU.Cache[l] = 0
|
||||
}
|
||||
|
||||
s.mem.GPUs[i].Graph = graphSize
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if projectorGPU > 0 && len(s.mem.GPUs[projectorGPU].Weights) > 0 {
|
||||
s.mem.GPUs[projectorGPU].Weights[s.totalLayers-1] += projectorWeights
|
||||
}
|
||||
|
||||
slog.Debug("memory", "estimate", s.mem)
|
||||
s.mem.Log(slog.LevelInfo)
|
||||
|
||||
// The llama engine uses mmap by default
|
||||
s.loadRequest.UseMmap = true
|
||||
|
||||
// mmap has issues with partial offloading on metal
|
||||
for _, g := range gpus {
|
||||
if g.Library == "Metal" &&
|
||||
uint64(s.options.NumGPU) > 0 &&
|
||||
uint64(s.options.NumGPU) < s.ggml.KV().BlockCount()+1 {
|
||||
uint64(s.options.NumGPU) < s.totalLayers {
|
||||
s.options.UseMMap = new(bool)
|
||||
*s.options.UseMMap = false
|
||||
}
|
||||
@@ -542,90 +631,50 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus [
|
||||
// Linux with a model larger than free space, mmap leads to thrashing
|
||||
// For CPU loads we want the memory to be allocated, not FS cache
|
||||
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) ||
|
||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
|
||||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
||||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
||||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
||||
s.loadRequest.UseMmap = false
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.waitUntilRunnerLaunched(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.loadRequest.GPULayers = gpuLayers
|
||||
resp, err := s.initModel(ctx, s.loadRequest, LoadOperationCommit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// On the Ollama engine, we can print out a summary of the memory allocations.
|
||||
// We don't have this for the llama engine but it does something similar itself.
|
||||
if s.textProcessor != nil {
|
||||
resp.Memory.Log(slog.LevelInfo)
|
||||
}
|
||||
|
||||
if !resp.Success {
|
||||
slog.Warn("failed to allocate memory for model", "memory", resp.Memory)
|
||||
return nil, errors.New("failed to allocate memory for model")
|
||||
}
|
||||
|
||||
// The llama engine does its memory allocations together with model loading, so we
|
||||
// need to wait until it is done to ensure that we have accurate memory data before
|
||||
// loading the next model
|
||||
if s.textProcessor == nil {
|
||||
return uniqueDeviceIDs(s.loadRequest.GPULayers), s.WaitUntilRunning(ctx)
|
||||
} else {
|
||||
return uniqueDeviceIDs(s.loadRequest.GPULayers), nil
|
||||
}
|
||||
}
|
||||
|
||||
// createGPULayers maps from the tensor splits assigned by the memory estimates to explicit assignment
|
||||
// of particular layers onto GPUs
|
||||
func createGPULayers(estimate MemoryEstimate, ggml *ggml.GGML, gpus []ml.DeviceInfo, numGPU int) ml.GPULayersList {
|
||||
if numGPU <= 0 || len(gpus) == 0 {
|
||||
return nil
|
||||
func projectorMemoryRequirements(filename string) (weights uint64) {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
ggml, err := ggml.Decode(file, 1024)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
gpuLayers := make(ml.GPULayersList, len(gpus))
|
||||
for i := range gpuLayers {
|
||||
gpuLayers[i].DeviceID = gpus[i].DeviceID
|
||||
for _, layer := range ggml.Tensors().GroupLayers() {
|
||||
weights += layer.Size()
|
||||
}
|
||||
|
||||
var sum float32
|
||||
splits := make([]float32, len(estimate.TensorSplit))
|
||||
// cumulative sum of all splits
|
||||
for i := range splits {
|
||||
sum += float32(estimate.TensorSplit[i])
|
||||
splits[i] = sum
|
||||
}
|
||||
|
||||
if sum <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalize splits
|
||||
for i := range splits {
|
||||
splits[i] /= sum
|
||||
}
|
||||
|
||||
blocks := int(ggml.KV().BlockCount())
|
||||
gpuRangeStart := max(0, blocks-numGPU)
|
||||
gpuRangeStop := min(gpuRangeStart+numGPU, blocks+1)
|
||||
for i := range blocks + 1 {
|
||||
if i < gpuRangeStart || i >= gpuRangeStop {
|
||||
continue
|
||||
}
|
||||
|
||||
index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f })
|
||||
if index < 0 || index >= len(gpus) {
|
||||
continue
|
||||
}
|
||||
|
||||
gpuLayers[index].Layers = append(gpuLayers[index].Layers, i)
|
||||
}
|
||||
|
||||
return gpuLayers
|
||||
return weights
|
||||
}
|
||||
|
||||
// Load finds the optimal layout of layers to offload on GPUs based on no initial information about the size of the model
|
||||
@@ -652,23 +701,6 @@ func (s *ollamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus
|
||||
|
||||
slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU)
|
||||
|
||||
systemTotalMemory := systemInfo.TotalMemory
|
||||
systemFreeMemory := systemInfo.FreeMemory
|
||||
systemSwapFreeMemory := systemInfo.FreeSwap
|
||||
slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
|
||||
|
||||
for _, gpu := range gpus {
|
||||
available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory()
|
||||
if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory() {
|
||||
available = 0
|
||||
}
|
||||
slog.Info("gpu memory", "id", gpu.ID, "library", gpu.Library,
|
||||
"available", format.HumanBytes2(available),
|
||||
"free", format.HumanBytes2(gpu.FreeMemory),
|
||||
"minimum", format.HumanBytes2(gpu.MinimumMemory()),
|
||||
"overhead", format.HumanBytes2(envconfig.GpuOverhead()))
|
||||
}
|
||||
|
||||
pastAllocations := make(map[uint64]struct{})
|
||||
var backoff float32
|
||||
|
||||
@@ -834,25 +866,22 @@ func uniqueDeviceIDs(gpuLayers ml.GPULayersList) []ml.DeviceID {
|
||||
// - Calculating how much space each GPU has available for layers, based on free memory and space occupied by the graph
|
||||
// - Assigning layers
|
||||
// - Ensuring that we don't exceed limits, such as requirements about partial offloading or system memory
|
||||
func (s *ollamaServer) createLayout(systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, error) {
|
||||
func (s *llmServer) createLayout(systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, error) {
|
||||
if memory == nil {
|
||||
memory = &ml.BackendMemory{CPU: ml.DeviceMemory{
|
||||
Weights: make([]uint64, s.totalLayers),
|
||||
Cache: make([]uint64, s.totalLayers),
|
||||
}}
|
||||
}
|
||||
gpuLayers, layers, err := s.buildLayout(systemGPUs, memory, requireFull, backoff)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = s.verifyLayout(systemInfo, memory, requireFull, gpuLayers, layers)
|
||||
gpuLayers, layers := s.buildLayout(systemGPUs, memory, requireFull, backoff)
|
||||
err := s.verifyLayout(systemInfo, memory, requireFull, gpuLayers, layers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return gpuLayers, nil
|
||||
}
|
||||
|
||||
func (s *ollamaServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, []uint64, error) {
|
||||
func (s *llmServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, []uint64) {
|
||||
gpus := append(make([]ml.DeviceInfo, 0, len(systemGPUs)), systemGPUs...)
|
||||
sort.Sort(sort.Reverse(ml.ByFreeMemory(gpus)))
|
||||
|
||||
@@ -910,11 +939,11 @@ func (s *ollamaServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.Backen
|
||||
gpuLayers = libraryGpuLayers
|
||||
}
|
||||
}
|
||||
return gpuLayers, layers, nil
|
||||
return gpuLayers, layers
|
||||
}
|
||||
|
||||
// verifyLayout ensures that we don't exceed limits, such as requirements about partial offloading or system memory
|
||||
func (s *ollamaServer) verifyLayout(systemInfo ml.SystemInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
|
||||
func (s *llmServer) verifyLayout(systemInfo ml.SystemInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
|
||||
// These sizes will only increase as we go through additional iterations and get additional information.
|
||||
cpuSize := memory.InputWeights + memory.CPU.Graph
|
||||
var vramSize uint64
|
||||
@@ -942,11 +971,13 @@ nextLayer:
|
||||
|
||||
if requireFull {
|
||||
if gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
|
||||
slog.Info("model requires more memory than is currently available, evicting a model to make space", "loaded layers", gpuLayers.Sum())
|
||||
return ErrLoadRequiredFull
|
||||
}
|
||||
|
||||
if cpuSize > systemInfo.FreeMemory {
|
||||
return ErrLoadRequiredFull
|
||||
slog.Info("model requires more system memory than is currently available, evicting a model to make space", "required", cpuSize, "free", systemInfo.FreeMemory)
|
||||
return fmt.Errorf("model requires more system memory than is currently available %w", ErrLoadRequiredFull)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -976,6 +1007,13 @@ nextLayer:
|
||||
|
||||
// assignLayers packs the maximum number of layers onto the smallest set of GPUs and comes up with a layer assignment
|
||||
func assignLayers(layers []uint64, gpus []ml.DeviceInfo, requireFull bool, requestedLayers int, lastUsedGPU int) (gpuLayers ml.GPULayersList) {
|
||||
// If the user is manually overriding parameters, treat all GPUs equally so they split according to VRAM
|
||||
if requestedLayers >= 0 || envconfig.SchedSpread() {
|
||||
for i := range gpus {
|
||||
gpus[i].Integrated = false
|
||||
}
|
||||
}
|
||||
|
||||
// If we can't fit everything then prefer offloading layers other than the output layer
|
||||
for range 2 {
|
||||
// requestedLayers may be -1 if nothing was requested
|
||||
@@ -1008,8 +1046,10 @@ func assignLayers(layers []uint64, gpus []ml.DeviceInfo, requireFull bool, reque
|
||||
|
||||
// findBestFit binary searches to find the smallest capacity factor that can fit
|
||||
// the max number of layers. The capacity factor is multiplied by the free space on
|
||||
// each GPU and a small one will force even balancing.
|
||||
// each GPU and a small one will force even balancing. Higher performance GPUs are
|
||||
// used first.
|
||||
func findBestFit(layers []uint64, gpus []ml.DeviceInfo, requestedLayers int, forceRequest bool) (gpuLayers ml.GPULayersList) {
|
||||
for _, gl := range ml.ByPerformance(gpus) {
|
||||
var high float32 = 1
|
||||
var low float32 = 0
|
||||
|
||||
@@ -1018,15 +1058,12 @@ func findBestFit(layers []uint64, gpus []ml.DeviceInfo, requestedLayers int, for
|
||||
high = 1000
|
||||
}
|
||||
|
||||
bestAssignments := greedyFit(layers, gpus, high, requestedLayers)
|
||||
bestAssignments := greedyFit(layers, gl, high, requestedLayers)
|
||||
maxNumGPU := bestAssignments.Sum()
|
||||
if maxNumGPU == 0 {
|
||||
return bestAssignments
|
||||
}
|
||||
|
||||
for high-low > 1e-6 {
|
||||
mid := (low + high) / 2
|
||||
assignments := greedyFit(layers, gpus, mid, requestedLayers)
|
||||
assignments := greedyFit(layers, gl, mid, requestedLayers)
|
||||
if assignments.Sum() == maxNumGPU {
|
||||
high = mid
|
||||
bestAssignments = assignments
|
||||
@@ -1034,7 +1071,13 @@ func findBestFit(layers []uint64, gpus []ml.DeviceInfo, requestedLayers int, for
|
||||
low = mid
|
||||
}
|
||||
}
|
||||
return bestAssignments
|
||||
|
||||
layers = layers[:len(layers)-bestAssignments.Sum()]
|
||||
requestedLayers -= bestAssignments.Sum()
|
||||
gpuLayers = append(bestAssignments, gpuLayers...)
|
||||
}
|
||||
|
||||
return gpuLayers
|
||||
}
|
||||
|
||||
// greedyFit assigns layers incrementally to GPUs, spilling over as each runs out of free space
|
||||
@@ -1362,6 +1405,12 @@ type CompletionRequest struct {
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
Shift bool
|
||||
Truncate bool
|
||||
|
||||
// Logprobs specifies whether to include log probabilities in the response
|
||||
Logprobs bool
|
||||
|
||||
// TopLogprobs specifies the number of most likely alternative tokens to return (0-20)
|
||||
TopLogprobs int
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
@@ -1387,6 +1436,18 @@ func (d DoneReason) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
// TokenLogprob represents log probability information for a single token alternative.
|
||||
type TokenLogprob struct {
|
||||
Token string `json:"token"`
|
||||
Logprob float64 `json:"logprob"`
|
||||
}
|
||||
|
||||
// Logprob contains log probability information for a generated token.
|
||||
type Logprob struct {
|
||||
TokenLogprob
|
||||
TopLogprobs []TokenLogprob `json:"top_logprobs,omitempty"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
DoneReason DoneReason `json:"done_reason"`
|
||||
@@ -1395,6 +1456,9 @@ type CompletionResponse struct {
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
|
||||
// Logprobs contains log probability information if requested
|
||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
@@ -1531,6 +1595,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
if c.Content != "" {
|
||||
fn(CompletionResponse{
|
||||
Content: c.Content,
|
||||
Logprobs: c.Logprobs,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1623,69 +1688,60 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
|
||||
return e.Embedding, nil
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type TokenizeResponse struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
s.llamaModelLock.Lock()
|
||||
defer s.llamaModelLock.Unlock()
|
||||
|
||||
if s.llamaModel != nil {
|
||||
if s.llamaModel == nil {
|
||||
return nil, fmt.Errorf("no tokenizer configured")
|
||||
}
|
||||
|
||||
return s.llamaModel.Tokenize(content, false, true)
|
||||
}
|
||||
if s.textProcessor != nil {
|
||||
|
||||
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
tokens, err := s.textProcessor.Encode(content, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toks := make([]int, len(tokens))
|
||||
for i, t := range tokens {
|
||||
toks[i] = int(t)
|
||||
}
|
||||
|
||||
return toks, nil
|
||||
}
|
||||
// not reached
|
||||
return nil, fmt.Errorf("no tokenizer configured")
|
||||
}
|
||||
|
||||
type DetokenizeRequest struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
type DetokenizeResponse struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
func (s *llamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
s.llamaModelLock.Lock()
|
||||
defer s.llamaModelLock.Unlock()
|
||||
|
||||
if s.llamaModel != nil {
|
||||
if s.llamaModel == nil {
|
||||
return "", fmt.Errorf("no tokenizer configured")
|
||||
}
|
||||
|
||||
var resp string
|
||||
for _, token := range tokens {
|
||||
resp += s.llamaModel.TokenToPiece(token)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
if s.textProcessor != nil {
|
||||
|
||||
func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
toks := make([]int32, len(tokens))
|
||||
for i, t := range tokens {
|
||||
toks[i] = int32(t)
|
||||
}
|
||||
|
||||
content, err := s.textProcessor.Decode(toks)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
// not reached
|
||||
return "", fmt.Errorf("no tokenizer configured")
|
||||
}
|
||||
|
||||
func (s *llmServer) Close() error {
|
||||
s.llamaModelLock.Lock()
|
||||
@@ -1712,31 +1768,12 @@ func (s *llmServer) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *llamaServer) VRAMSize() uint64 {
|
||||
return s.estimate.VRAMSize
|
||||
}
|
||||
|
||||
func (s *llamaServer) TotalSize() uint64 {
|
||||
return s.estimate.TotalSize
|
||||
}
|
||||
|
||||
func (s *llamaServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
for i, gpu := range s.gpus {
|
||||
if gpu.DeviceID == id {
|
||||
if i < len(s.estimate.GPUSizes) {
|
||||
return s.estimate.GPUSizes[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
slog.Debug("llamarunner free vram reporting not supported")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ollamaServer) VRAMSize() uint64 {
|
||||
func (s *llmServer) VRAMSize() uint64 {
|
||||
if s.mem == nil {
|
||||
return 0
|
||||
}
|
||||
@@ -1764,7 +1801,7 @@ func (s *ollamaServer) VRAMSize() uint64 {
|
||||
return mem
|
||||
}
|
||||
|
||||
func (s *ollamaServer) TotalSize() uint64 {
|
||||
func (s *llmServer) TotalSize() uint64 {
|
||||
if s.mem == nil {
|
||||
return 0
|
||||
}
|
||||
@@ -1778,7 +1815,7 @@ func (s *ollamaServer) TotalSize() uint64 {
|
||||
return mem
|
||||
}
|
||||
|
||||
func (s *ollamaServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
if s.mem == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -14,16 +14,11 @@ import (
|
||||
)
|
||||
|
||||
func TestLLMServerFitGPU(t *testing.T) {
|
||||
type gpu struct {
|
||||
id ml.DeviceID
|
||||
free int
|
||||
}
|
||||
|
||||
minMemory := 457 * format.MebiByte
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
gpus []gpu
|
||||
gpus []ml.DeviceInfo
|
||||
layers []int
|
||||
numGPU int
|
||||
requireFull bool
|
||||
@@ -38,91 +33,91 @@ func TestLLMServerFitGPU(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Full single GPU",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2}}},
|
||||
},
|
||||
{
|
||||
name: "Partial single GPU",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
|
||||
},
|
||||
{
|
||||
name: "Single GPU with numGPU 1",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}},
|
||||
},
|
||||
{
|
||||
name: "Single GPU with numGPU 0",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 0,
|
||||
expected: ml.GPULayersList{},
|
||||
},
|
||||
{
|
||||
name: "Single GPU with numGPU 999",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||
numGPU: 999,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2, 3}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU fits on one",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1, 2}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU split",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU partial",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU numGPU 1",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU numGPU 2",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 2,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU numGPU 999",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 999,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU different libraries",
|
||||
gpus: []gpu{{id: ml.DeviceID{Library: "CUDA", ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{Library: "ROCm", ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{Library: "CUDA", ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{Library: "ROCm", ID: "gpu1"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{128 * format.MebiByte, 128 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1", Library: "ROCm"}, Layers: []int{0, 1}}},
|
||||
},
|
||||
{
|
||||
name: "requireFull",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
requireFull: true,
|
||||
@@ -130,12 +125,54 @@ func TestLLMServerFitGPU(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "requireFull numGPU",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}},
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(256 * format.MebiByte)}},
|
||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||
numGPU: 4,
|
||||
requireFull: true,
|
||||
expectedErr: ErrLoadRequiredFull,
|
||||
},
|
||||
{
|
||||
name: "iGPU",
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2}}},
|
||||
},
|
||||
{
|
||||
name: "iGPU + dGPU",
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
|
||||
},
|
||||
{
|
||||
name: "iGPU + dGPU fits on one",
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1}}},
|
||||
},
|
||||
{
|
||||
name: "iGPU + dGPU partial",
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}},
|
||||
},
|
||||
{
|
||||
name: "iGPU + dGPU numGPU 1",
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||
numGPU: 1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}},
|
||||
},
|
||||
{
|
||||
name: "iGPU + dGPU numGPU 999",
|
||||
gpus: []ml.DeviceInfo{{DeviceID: ml.DeviceID{ID: "gpu0"}, FreeMemory: uint64(128*format.MebiByte + minMemory)}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Integrated: true, FreeMemory: uint64(256*format.MebiByte + minMemory)}},
|
||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||
numGPU: 999,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1, 2, 3}}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -145,12 +182,6 @@ func TestLLMServerFitGPU(t *testing.T) {
|
||||
systemInfo.FreeMemory = 512 * format.MebiByte
|
||||
systemInfo.FreeSwap = 256 * format.MebiByte
|
||||
|
||||
gpus := make([]ml.DeviceInfo, len(tt.gpus))
|
||||
for i := range tt.gpus {
|
||||
gpus[i].DeviceID = tt.gpus[i].id
|
||||
gpus[i].FreeMemory = uint64(tt.gpus[i].free)
|
||||
}
|
||||
|
||||
s := &ollamaServer{
|
||||
llmServer: llmServer{
|
||||
totalLayers: uint64(len(tt.layers)),
|
||||
@@ -165,19 +196,19 @@ func TestLLMServerFitGPU(t *testing.T) {
|
||||
s.mem = &ml.BackendMemory{CPU: ml.DeviceMemory{
|
||||
Weights: make([]uint64, s.totalLayers),
|
||||
Cache: make([]uint64, s.totalLayers),
|
||||
}, GPUs: make([]ml.DeviceMemory, len(gpus))}
|
||||
}, GPUs: make([]ml.DeviceMemory, len(tt.gpus))}
|
||||
|
||||
for i := range tt.layers {
|
||||
s.mem.CPU.Weights[i] = uint64(tt.layers[i])
|
||||
}
|
||||
|
||||
for i := range s.mem.GPUs {
|
||||
s.mem.GPUs[i].DeviceID = gpus[i].DeviceID
|
||||
s.mem.GPUs[i].DeviceID = tt.gpus[i].DeviceID
|
||||
s.mem.GPUs[i].Weights = make([]uint64, s.totalLayers)
|
||||
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
|
||||
}
|
||||
|
||||
gpuLayers, err := s.createLayout(systemInfo, gpus, s.mem, tt.requireFull, 0)
|
||||
gpuLayers, err := s.createLayout(systemInfo, tt.gpus, s.mem, tt.requireFull, 0)
|
||||
if err != tt.expectedErr {
|
||||
t.Fatalf("fitGPU returned error: %v", err)
|
||||
}
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"env": {
|
||||
"browser": true,
|
||||
"es6": true,
|
||||
"node": true
|
||||
},
|
||||
"extends": [
|
||||
"eslint:recommended",
|
||||
"plugin:@typescript-eslint/eslint-recommended",
|
||||
"plugin:@typescript-eslint/recommended",
|
||||
"plugin:import/recommended",
|
||||
"plugin:import/electron",
|
||||
"plugin:import/typescript"
|
||||
],
|
||||
"parser": "@typescript-eslint/parser"
|
||||
}
|
||||
92
macapp/.gitignore
vendored
@@ -1,92 +0,0 @@
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
lerna-debug.log*
|
||||
|
||||
# Diagnostic reports (https://nodejs.org/api/report.html)
|
||||
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
|
||||
|
||||
# Runtime data
|
||||
pids
|
||||
*.pid
|
||||
*.seed
|
||||
*.pid.lock
|
||||
.DS_Store
|
||||
|
||||
# Directory for instrumented libs generated by jscoverage/JSCover
|
||||
lib-cov
|
||||
|
||||
# Coverage directory used by tools like istanbul
|
||||
coverage
|
||||
*.lcov
|
||||
|
||||
# nyc test coverage
|
||||
.nyc_output
|
||||
|
||||
# node-waf configuration
|
||||
.lock-wscript
|
||||
|
||||
# Compiled binary addons (https://nodejs.org/api/addons.html)
|
||||
build/Release
|
||||
|
||||
# Dependency directories
|
||||
node_modules/
|
||||
jspm_packages/
|
||||
|
||||
# TypeScript v1 declaration files
|
||||
typings/
|
||||
|
||||
# TypeScript cache
|
||||
*.tsbuildinfo
|
||||
|
||||
# Optional npm cache directory
|
||||
.npm
|
||||
|
||||
# Optional eslint cache
|
||||
.eslintcache
|
||||
|
||||
# Optional REPL history
|
||||
.node_repl_history
|
||||
|
||||
# Output of 'npm pack'
|
||||
*.tgz
|
||||
|
||||
# Yarn Integrity file
|
||||
.yarn-integrity
|
||||
|
||||
# dotenv environment variables file
|
||||
.env
|
||||
.env.test
|
||||
|
||||
# parcel-bundler cache (https://parceljs.org/)
|
||||
.cache
|
||||
|
||||
# next.js build output
|
||||
.next
|
||||
|
||||
# nuxt.js build output
|
||||
.nuxt
|
||||
|
||||
# vuepress build output
|
||||
.vuepress/dist
|
||||
|
||||
# Serverless directories
|
||||
.serverless/
|
||||
|
||||
# FuseBox cache
|
||||
.fusebox/
|
||||
|
||||
# DynamoDB Local files
|
||||
.dynamodb/
|
||||
|
||||
# Webpack
|
||||
.webpack/
|
||||
|
||||
# Vite
|
||||
.vite/
|
||||
|
||||
# Electron-Forge
|
||||
out/
|
||||
@@ -1,21 +0,0 @@
|
||||
# Desktop
|
||||
|
||||
This app builds upon Ollama to provide a desktop experience for running models.
|
||||
|
||||
## Developing
|
||||
|
||||
First, build the `ollama` binary:
|
||||
|
||||
```shell
|
||||
cd ..
|
||||
go build .
|
||||
```
|
||||
|
||||
Then run the desktop app with `npm start`:
|
||||
|
||||
```shell
|
||||
cd macapp
|
||||
npm install
|
||||
npm start
|
||||
```
|
||||
|
||||
|
Before Width: | Height: | Size: 402 B |
|
Before Width: | Height: | Size: 741 B |
|
Before Width: | Height: | Size: 440 B |
|
Before Width: | Height: | Size: 763 B |
|
Before Width: | Height: | Size: 447 B |
|
Before Width: | Height: | Size: 891 B |
|
Before Width: | Height: | Size: 443 B |
|
Before Width: | Height: | Size: 844 B |
@@ -1,79 +0,0 @@
|
||||
import type { ForgeConfig } from '@electron-forge/shared-types'
|
||||
import { MakerSquirrel } from '@electron-forge/maker-squirrel'
|
||||
import { MakerZIP } from '@electron-forge/maker-zip'
|
||||
import { PublisherGithub } from '@electron-forge/publisher-github'
|
||||
import { AutoUnpackNativesPlugin } from '@electron-forge/plugin-auto-unpack-natives'
|
||||
import { WebpackPlugin } from '@electron-forge/plugin-webpack'
|
||||
import * as path from 'path'
|
||||
import * as fs from 'fs'
|
||||
|
||||
import { mainConfig } from './webpack.main.config'
|
||||
import { rendererConfig } from './webpack.renderer.config'
|
||||
|
||||
const packageJson = JSON.parse(fs.readFileSync(path.resolve(__dirname, './package.json'), 'utf8'))
|
||||
|
||||
const config: ForgeConfig = {
|
||||
packagerConfig: {
|
||||
appVersion: process.env.VERSION || packageJson.version,
|
||||
asar: true,
|
||||
icon: './assets/icon.icns',
|
||||
extraResource: [
|
||||
path.join(__dirname, '../dist/darwin/ollama'),
|
||||
...fs.readdirSync(path.join(__dirname, '../dist/darwin-amd64/lib/ollama')).map(f => path.join(__dirname, '../dist/darwin-amd64/lib/ollama', f)),
|
||||
path.join(__dirname, './assets/iconTemplate.png'),
|
||||
path.join(__dirname, './assets/iconTemplate@2x.png'),
|
||||
path.join(__dirname, './assets/iconUpdateTemplate.png'),
|
||||
path.join(__dirname, './assets/iconUpdateTemplate@2x.png'),
|
||||
path.join(__dirname, './assets/iconDarkTemplate.png'),
|
||||
path.join(__dirname, './assets/iconDarkTemplate@2x.png'),
|
||||
path.join(__dirname, './assets/iconDarkUpdateTemplate.png'),
|
||||
path.join(__dirname, './assets/iconDarkUpdateTemplate@2x.png'),
|
||||
],
|
||||
...(process.env.SIGN
|
||||
? {
|
||||
osxSign: {
|
||||
identity: process.env.APPLE_IDENTITY,
|
||||
},
|
||||
osxNotarize: {
|
||||
tool: 'notarytool',
|
||||
appleId: process.env.APPLE_ID || '',
|
||||
appleIdPassword: process.env.APPLE_PASSWORD || '',
|
||||
teamId: process.env.APPLE_TEAM_ID || '',
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
osxUniversal: {
|
||||
x64ArchFiles: '*',
|
||||
},
|
||||
},
|
||||
rebuildConfig: {},
|
||||
makers: [new MakerSquirrel({}), new MakerZIP({}, ['darwin'])],
|
||||
hooks: {
|
||||
readPackageJson: async (_, packageJson) => {
|
||||
return { ...packageJson, version: process.env.VERSION || packageJson.version }
|
||||
},
|
||||
},
|
||||
plugins: [
|
||||
new AutoUnpackNativesPlugin({}),
|
||||
new WebpackPlugin({
|
||||
mainConfig,
|
||||
devContentSecurityPolicy: `default-src * 'unsafe-eval' 'unsafe-inline'; img-src data: 'self'`,
|
||||
renderer: {
|
||||
config: rendererConfig,
|
||||
nodeIntegration: true,
|
||||
entryPoints: [
|
||||
{
|
||||
html: './src/index.html',
|
||||
js: './src/renderer.tsx',
|
||||
name: 'main_window',
|
||||
preload: {
|
||||
js: './src/preload.ts',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
}),
|
||||
],
|
||||
}
|
||||
|
||||
export default config
|
||||
16604
macapp/package-lock.json
generated
@@ -1,80 +0,0 @@
|
||||
{
|
||||
"name": "ollama",
|
||||
"productName": "Ollama",
|
||||
"version": "0.0.0",
|
||||
"description": "ollama",
|
||||
"main": ".webpack/main",
|
||||
"scripts": {
|
||||
"start": "electron-forge start",
|
||||
"package": "electron-forge package --arch universal",
|
||||
"package:sign": "SIGN=1 electron-forge package --arch universal",
|
||||
"make": "electron-forge make --arch universal",
|
||||
"make:sign": "SIGN=1 electron-forge make --arch universal",
|
||||
"publish": "SIGN=1 electron-forge publish",
|
||||
"lint": "eslint --ext .ts,.tsx ."
|
||||
},
|
||||
"keywords": [],
|
||||
"author": {
|
||||
"name": "Jeffrey Morgan",
|
||||
"email": "jmorganca@gmail.com"
|
||||
},
|
||||
"license": "MIT",
|
||||
"devDependencies": {
|
||||
"@babel/core": "^7.22.5",
|
||||
"@babel/preset-react": "^7.22.5",
|
||||
"@electron-forge/cli": "^6.2.1",
|
||||
"@electron-forge/maker-deb": "^6.2.1",
|
||||
"@electron-forge/maker-rpm": "^6.2.1",
|
||||
"@electron-forge/maker-squirrel": "^6.2.1",
|
||||
"@electron-forge/maker-zip": "^6.2.1",
|
||||
"@electron-forge/plugin-auto-unpack-natives": "^6.2.1",
|
||||
"@electron-forge/plugin-webpack": "^6.2.1",
|
||||
"@electron-forge/publisher-github": "^6.2.1",
|
||||
"@electron/universal": "^1.4.1",
|
||||
"@svgr/webpack": "^8.0.1",
|
||||
"@types/chmodr": "^1.0.0",
|
||||
"@types/node": "^20.4.0",
|
||||
"@types/react": "^18.2.14",
|
||||
"@types/react-dom": "^18.2.6",
|
||||
"@types/uuid": "^9.0.2",
|
||||
"@typescript-eslint/eslint-plugin": "^5.60.0",
|
||||
"@typescript-eslint/parser": "^5.60.0",
|
||||
"@vercel/webpack-asset-relocator-loader": "^1.7.3",
|
||||
"babel-loader": "^9.1.2",
|
||||
"chmodr": "^1.2.0",
|
||||
"copy-webpack-plugin": "^11.0.0",
|
||||
"css-loader": "^6.8.1",
|
||||
"electron": "25.9.2",
|
||||
"eslint": "^8.43.0",
|
||||
"eslint-plugin-import": "^2.27.5",
|
||||
"fork-ts-checker-webpack-plugin": "^7.3.0",
|
||||
"node-loader": "^2.0.0",
|
||||
"postcss": "^8.4.24",
|
||||
"postcss-import": "^15.1.0",
|
||||
"postcss-loader": "^7.3.3",
|
||||
"postcss-preset-env": "^8.5.1",
|
||||
"style-loader": "^3.3.3",
|
||||
"svg-inline-loader": "^0.8.2",
|
||||
"tailwindcss": "^3.3.2",
|
||||
"ts-loader": "^9.4.3",
|
||||
"ts-node": "^10.9.1",
|
||||
"typescript": "~4.5.4",
|
||||
"url-loader": "^4.1.1",
|
||||
"webpack": "^5.88.0",
|
||||
"webpack-cli": "^5.1.4",
|
||||
"webpack-dev-server": "^4.15.1"
|
||||
},
|
||||
"dependencies": {
|
||||
"@electron/remote": "^2.0.10",
|
||||
"@heroicons/react": "^2.0.18",
|
||||
"@segment/analytics-node": "^1.0.0",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"electron-squirrel-startup": "^1.0.0",
|
||||
"electron-store": "^8.1.0",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"uuid": "^9.0.0",
|
||||
"winston": "^3.10.0",
|
||||
"winston-daily-rotate-file": "^4.7.1"
|
||||
}
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
module.exports = {
|
||||
plugins: {
|
||||
'postcss-import': {},
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
||||
|
||||
html,
|
||||
body {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.drag {
|
||||
-webkit-app-region: drag;
|
||||
}
|
||||
|
||||
.no-drag {
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
.blink {
|
||||
-webkit-animation: 1s blink step-end infinite;
|
||||
-moz-animation: 1s blink step-end infinite;
|
||||
-ms-animation: 1s blink step-end infinite;
|
||||
-o-animation: 1s blink step-end infinite;
|
||||
animation: 1s blink step-end infinite;
|
||||
}
|
||||
|
||||
@keyframes blink {
|
||||
from,
|
||||
to {
|
||||
color: transparent;
|
||||
}
|
||||
50% {
|
||||
color: black;
|
||||
}
|
||||
}
|
||||
@@ -1,122 +0,0 @@
|
||||
import { useState } from 'react'
|
||||
import copy from 'copy-to-clipboard'
|
||||
import { CheckIcon, DocumentDuplicateIcon } from '@heroicons/react/24/outline'
|
||||
import Store from 'electron-store'
|
||||
import { getCurrentWindow, app } from '@electron/remote'
|
||||
|
||||
import { install } from './install'
|
||||
import OllamaIcon from './ollama.svg'
|
||||
|
||||
const store = new Store()
|
||||
|
||||
enum Step {
|
||||
WELCOME = 0,
|
||||
CLI,
|
||||
FINISH,
|
||||
}
|
||||
|
||||
export default function () {
|
||||
const [step, setStep] = useState<Step>(Step.WELCOME)
|
||||
const [commandCopied, setCommandCopied] = useState<boolean>(false)
|
||||
|
||||
const command = 'ollama run llama3.2'
|
||||
|
||||
return (
|
||||
<div className='drag'>
|
||||
<div className='mx-auto flex min-h-screen w-full flex-col justify-between bg-white px-4 pt-16'>
|
||||
{step === Step.WELCOME && (
|
||||
<>
|
||||
<div className='mx-auto text-center'>
|
||||
<h1 className='mb-6 mt-4 text-2xl tracking-tight text-gray-900'>Welcome to Ollama</h1>
|
||||
<p className='mx-auto w-[65%] text-sm text-gray-400'>
|
||||
Let's get you up and running with your own large language models.
|
||||
</p>
|
||||
<button
|
||||
onClick={() => setStep(Step.CLI)}
|
||||
className='no-drag rounded-dm mx-auto my-8 w-[40%] rounded-md bg-black px-4 py-2 text-sm text-white hover:brightness-110'
|
||||
>
|
||||
Next
|
||||
</button>
|
||||
</div>
|
||||
<div className='mx-auto'>
|
||||
<OllamaIcon />
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{step === Step.CLI && (
|
||||
<>
|
||||
<div className='mx-auto flex flex-col space-y-28 text-center'>
|
||||
<h1 className='mt-4 text-2xl tracking-tight text-gray-900'>Install the command line</h1>
|
||||
<pre className='mx-auto text-4xl text-gray-400'>> ollama</pre>
|
||||
<div className='mx-auto'>
|
||||
<button
|
||||
onClick={async () => {
|
||||
try {
|
||||
await install()
|
||||
setStep(Step.FINISH)
|
||||
} catch (e) {
|
||||
console.error('could not install: ', e)
|
||||
} finally {
|
||||
getCurrentWindow().show()
|
||||
getCurrentWindow().focus()
|
||||
}
|
||||
}}
|
||||
className='no-drag rounded-dm mx-auto w-[60%] rounded-md bg-black px-4 py-2 text-sm text-white hover:brightness-110'
|
||||
>
|
||||
Install
|
||||
</button>
|
||||
<p className='mx-auto my-4 w-[70%] text-xs text-gray-400'>
|
||||
You will be prompted for administrator access
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{step === Step.FINISH && (
|
||||
<>
|
||||
<div className='mx-auto flex flex-col space-y-20 text-center'>
|
||||
<h1 className='mt-4 text-2xl tracking-tight text-gray-900'>Run your first model</h1>
|
||||
<div className='flex flex-col'>
|
||||
<div className='group relative flex items-center'>
|
||||
<pre className='language-none text-2xs w-full rounded-md bg-gray-100 px-4 py-3 text-start leading-normal'>
|
||||
{command}
|
||||
</pre>
|
||||
<button
|
||||
className={`no-drag absolute right-[5px] px-2 py-2 ${
|
||||
commandCopied
|
||||
? 'text-gray-900 opacity-100 hover:cursor-auto'
|
||||
: 'text-gray-200 opacity-50 hover:cursor-pointer'
|
||||
} hover:font-bold hover:text-gray-900 group-hover:opacity-100`}
|
||||
onClick={() => {
|
||||
copy(command)
|
||||
setCommandCopied(true)
|
||||
setTimeout(() => setCommandCopied(false), 3000)
|
||||
}}
|
||||
>
|
||||
{commandCopied ? (
|
||||
<CheckIcon className='h-4 w-4 font-bold text-gray-500' />
|
||||
) : (
|
||||
<DocumentDuplicateIcon className='h-4 w-4 text-gray-500' />
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
<p className='mx-auto my-4 w-[70%] text-xs text-gray-400'>
|
||||
Run this command in your favorite terminal.
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
onClick={() => {
|
||||
store.set('first-time-run', true)
|
||||
window.close()
|
||||
}}
|
||||
className='no-drag rounded-dm mx-auto w-[60%] rounded-md bg-black px-4 py-2 text-sm text-white hover:brightness-110'
|
||||
>
|
||||
Finish
|
||||
</button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
4
macapp/src/declarations.d.ts
vendored
@@ -1,4 +0,0 @@
|
||||
declare module '*.svg' {
|
||||
const content: string
|
||||
export default content
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="app"></div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,302 +0,0 @@
|
||||
import { spawn, ChildProcess } from 'child_process'
|
||||
import { app, autoUpdater, dialog, Tray, Menu, BrowserWindow, MenuItemConstructorOptions, nativeTheme } from 'electron'
|
||||
import Store from 'electron-store'
|
||||
import winston from 'winston'
|
||||
import 'winston-daily-rotate-file'
|
||||
import * as path from 'path'
|
||||
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import { installed } from './install'
|
||||
|
||||
require('@electron/remote/main').initialize()
|
||||
|
||||
if (require('electron-squirrel-startup')) {
|
||||
app.quit()
|
||||
}
|
||||
|
||||
const store = new Store()
|
||||
|
||||
let welcomeWindow: BrowserWindow | null = null
|
||||
|
||||
declare const MAIN_WINDOW_WEBPACK_ENTRY: string
|
||||
|
||||
const logger = winston.createLogger({
|
||||
transports: [
|
||||
new winston.transports.Console(),
|
||||
new winston.transports.File({
|
||||
filename: path.join(app.getPath('home'), '.ollama', 'logs', 'server.log'),
|
||||
maxsize: 1024 * 1024 * 20,
|
||||
maxFiles: 5,
|
||||
}),
|
||||
],
|
||||
format: winston.format.printf(info => info.message),
|
||||
})
|
||||
|
||||
app.on('ready', () => {
|
||||
const gotTheLock = app.requestSingleInstanceLock()
|
||||
if (!gotTheLock) {
|
||||
app.exit(0)
|
||||
return
|
||||
}
|
||||
|
||||
app.on('second-instance', () => {
|
||||
if (app.hasSingleInstanceLock()) {
|
||||
app.releaseSingleInstanceLock()
|
||||
}
|
||||
|
||||
if (proc) {
|
||||
proc.off('exit', restart)
|
||||
proc.kill()
|
||||
}
|
||||
|
||||
app.exit(0)
|
||||
})
|
||||
|
||||
app.focus({ steal: true })
|
||||
|
||||
init()
|
||||
})
|
||||
|
||||
function firstRunWindow() {
|
||||
// Create the browser window.
|
||||
welcomeWindow = new BrowserWindow({
|
||||
width: 400,
|
||||
height: 500,
|
||||
frame: false,
|
||||
fullscreenable: false,
|
||||
resizable: false,
|
||||
movable: true,
|
||||
show: false,
|
||||
webPreferences: {
|
||||
nodeIntegration: true,
|
||||
contextIsolation: false,
|
||||
},
|
||||
})
|
||||
|
||||
require('@electron/remote/main').enable(welcomeWindow.webContents)
|
||||
|
||||
welcomeWindow.loadURL(MAIN_WINDOW_WEBPACK_ENTRY)
|
||||
welcomeWindow.on('ready-to-show', () => welcomeWindow.show())
|
||||
welcomeWindow.on('closed', () => {
|
||||
if (process.platform === 'darwin') {
|
||||
app.dock.hide()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
let tray: Tray | null = null
|
||||
let updateAvailable = false
|
||||
const assetPath = app.isPackaged ? process.resourcesPath : path.join(__dirname, '..', '..', 'assets')
|
||||
|
||||
function trayIconPath() {
|
||||
return nativeTheme.shouldUseDarkColors
|
||||
? updateAvailable
|
||||
? path.join(assetPath, 'iconDarkUpdateTemplate.png')
|
||||
: path.join(assetPath, 'iconDarkTemplate.png')
|
||||
: updateAvailable
|
||||
? path.join(assetPath, 'iconUpdateTemplate.png')
|
||||
: path.join(assetPath, 'iconTemplate.png')
|
||||
}
|
||||
|
||||
function updateTrayIcon() {
|
||||
if (tray) {
|
||||
tray.setImage(trayIconPath())
|
||||
}
|
||||
}
|
||||
|
||||
function updateTray() {
|
||||
const updateItems: MenuItemConstructorOptions[] = [
|
||||
{ label: 'An update is available', enabled: false },
|
||||
{
|
||||
label: 'Restart to update',
|
||||
click: () => autoUpdater.quitAndInstall(),
|
||||
},
|
||||
{ type: 'separator' },
|
||||
]
|
||||
|
||||
const menu = Menu.buildFromTemplate([
|
||||
...(updateAvailable ? updateItems : []),
|
||||
{ role: 'quit', label: 'Quit Ollama', accelerator: 'Command+Q' },
|
||||
])
|
||||
|
||||
if (!tray) {
|
||||
tray = new Tray(trayIconPath())
|
||||
}
|
||||
|
||||
tray.setToolTip(updateAvailable ? 'An update is available' : 'Ollama')
|
||||
tray.setContextMenu(menu)
|
||||
tray.setImage(trayIconPath())
|
||||
|
||||
nativeTheme.off('updated', updateTrayIcon)
|
||||
nativeTheme.on('updated', updateTrayIcon)
|
||||
}
|
||||
|
||||
let proc: ChildProcess = null
|
||||
|
||||
function server() {
|
||||
const binary = app.isPackaged
|
||||
? path.join(process.resourcesPath, 'ollama')
|
||||
: path.resolve(process.cwd(), '..', 'ollama')
|
||||
|
||||
proc = spawn(binary, ['serve'])
|
||||
|
||||
proc.stdout.on('data', data => {
|
||||
logger.info(data.toString().trim())
|
||||
})
|
||||
|
||||
proc.stderr.on('data', data => {
|
||||
logger.error(data.toString().trim())
|
||||
})
|
||||
|
||||
proc.on('exit', restart)
|
||||
}
|
||||
|
||||
function restart() {
|
||||
setTimeout(server, 1000)
|
||||
}
|
||||
|
||||
app.on('before-quit', () => {
|
||||
if (proc) {
|
||||
proc.off('exit', restart)
|
||||
proc.kill('SIGINT') // send SIGINT signal to the server, which also stops any loaded llms
|
||||
}
|
||||
})
|
||||
|
||||
const updateURL = `https://ollama.com/api/update?os=${process.platform}&arch=${
|
||||
process.arch
|
||||
}&version=${app.getVersion()}&id=${id()}`
|
||||
|
||||
let latest = ''
|
||||
async function isNewReleaseAvailable() {
|
||||
try {
|
||||
const response = await fetch(updateURL)
|
||||
|
||||
if (!response.ok) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (response.status === 204) {
|
||||
return false
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
|
||||
const url = data?.url
|
||||
if (!url) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (latest === url) {
|
||||
return false
|
||||
}
|
||||
|
||||
latest = url
|
||||
|
||||
return true
|
||||
} catch (error) {
|
||||
logger.error(`update check failed - ${error}`)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
async function checkUpdate() {
|
||||
const available = await isNewReleaseAvailable()
|
||||
if (available) {
|
||||
logger.info('checking for update')
|
||||
autoUpdater.checkForUpdates()
|
||||
}
|
||||
}
|
||||
|
||||
function init() {
|
||||
if (app.isPackaged) {
|
||||
checkUpdate()
|
||||
setInterval(() => {
|
||||
checkUpdate()
|
||||
}, 60 * 60 * 1000)
|
||||
}
|
||||
|
||||
updateTray()
|
||||
|
||||
if (process.platform === 'darwin') {
|
||||
if (app.isPackaged) {
|
||||
if (!app.isInApplicationsFolder()) {
|
||||
const chosen = dialog.showMessageBoxSync({
|
||||
type: 'question',
|
||||
buttons: ['Move to Applications', 'Do Not Move'],
|
||||
message: 'Ollama works best when run from the Applications directory.',
|
||||
defaultId: 0,
|
||||
cancelId: 1,
|
||||
})
|
||||
|
||||
if (chosen === 0) {
|
||||
try {
|
||||
app.moveToApplicationsFolder({
|
||||
conflictHandler: conflictType => {
|
||||
if (conflictType === 'existsAndRunning') {
|
||||
dialog.showMessageBoxSync({
|
||||
type: 'info',
|
||||
message: 'Cannot move to Applications directory',
|
||||
detail:
|
||||
'Another version of Ollama is currently running from your Applications directory. Close it first and try again.',
|
||||
})
|
||||
}
|
||||
return true
|
||||
},
|
||||
})
|
||||
return
|
||||
} catch (e) {
|
||||
logger.error(`[Move to Applications] Failed to move to applications folder - ${e.message}}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
server()
|
||||
|
||||
if (store.get('first-time-run') && installed()) {
|
||||
if (process.platform === 'darwin') {
|
||||
app.dock.hide()
|
||||
}
|
||||
|
||||
app.setLoginItemSettings({ openAtLogin: app.getLoginItemSettings().openAtLogin })
|
||||
return
|
||||
}
|
||||
|
||||
// This is the first run or the CLI is no longer installed
|
||||
app.setLoginItemSettings({ openAtLogin: true })
|
||||
firstRunWindow()
|
||||
}
|
||||
|
||||
// Quit when all windows are closed, except on macOS. There, it's common
|
||||
// for applications and their menu bar to stay active until the user quits
|
||||
// explicitly with Cmd + Q.
|
||||
app.on('window-all-closed', () => {
|
||||
if (process.platform !== 'darwin') {
|
||||
app.quit()
|
||||
}
|
||||
})
|
||||
|
||||
function id(): string {
|
||||
const id = store.get('id') as string
|
||||
|
||||
if (id) {
|
||||
return id
|
||||
}
|
||||
|
||||
const uuid = uuidv4()
|
||||
store.set('id', uuid)
|
||||
return uuid
|
||||
}
|
||||
|
||||
autoUpdater.setFeedURL({ url: updateURL })
|
||||
|
||||
autoUpdater.on('error', e => {
|
||||
logger.error(`update check failed - ${e.message}`)
|
||||
console.error(`update check failed - ${e.message}`)
|
||||
})
|
||||
|
||||
autoUpdater.on('update-downloaded', () => {
|
||||
updateAvailable = true
|
||||
updateTray()
|
||||
})
|
||||
@@ -1,21 +0,0 @@
|
||||
import * as fs from 'fs'
|
||||
import { exec as cbExec } from 'child_process'
|
||||
import * as path from 'path'
|
||||
import { promisify } from 'util'
|
||||
|
||||
const app = process && process.type === 'renderer' ? require('@electron/remote').app : require('electron').app
|
||||
const ollama = app.isPackaged ? path.join(process.resourcesPath, 'ollama') : path.resolve(process.cwd(), '..', 'ollama')
|
||||
const exec = promisify(cbExec)
|
||||
const symlinkPath = '/usr/local/bin/ollama'
|
||||
|
||||
export function installed() {
|
||||
return fs.existsSync(symlinkPath) && fs.readlinkSync(symlinkPath) === ollama
|
||||
}
|
||||
|
||||
export async function install() {
|
||||
const command = `do shell script "mkdir -p ${path.dirname(
|
||||
symlinkPath
|
||||
)} && ln -F -s \\"${ollama}\\" \\"${symlinkPath}\\"" with administrator privileges`
|
||||
|
||||
await exec(`osascript -e '${command}'`)
|
||||
}
|
||||
|
Before Width: | Height: | Size: 17 KiB |
@@ -1,7 +0,0 @@
|
||||
import App from './app'
|
||||
import './app.css'
|
||||
import { createRoot } from 'react-dom/client'
|
||||
|
||||
const container = document.getElementById('app')
|
||||
const root = createRoot(container)
|
||||
root.render(<App />)
|
||||
@@ -1,6 +0,0 @@
|
||||
/** @type {import('tailwindcss').Config} */
|
||||
module.exports = {
|
||||
content: ['./src/**/*.{js,ts,jsx,tsx,mdx}'],
|
||||
theme: {},
|
||||
plugins: [],
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES6",
|
||||
"allowJs": true,
|
||||
"module": "commonjs",
|
||||
"skipLibCheck": true,
|
||||
"esModuleInterop": true,
|
||||
"noImplicitAny": true,
|
||||
"sourceMap": true,
|
||||
"baseUrl": ".",
|
||||
"outDir": "dist",
|
||||
"moduleResolution": "node",
|
||||
"resolveJsonModule": true,
|
||||
"paths": {
|
||||
"*": ["node_modules/*"]
|
||||
},
|
||||
"jsx": "react-jsx"
|
||||
},
|
||||
"include": ["src/**/*"]
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
import type { Configuration } from 'webpack'
|
||||
|
||||
import { rules } from './webpack.rules'
|
||||
import { plugins } from './webpack.plugins'
|
||||
|
||||
export const mainConfig: Configuration = {
|
||||
/**
|
||||
* This is the main entry point for your application, it's the first file
|
||||
* that runs in the main process.
|
||||
*/
|
||||
entry: './src/index.ts',
|
||||
// Put your normal webpack config below here
|
||||
module: {
|
||||
rules,
|
||||
},
|
||||
plugins,
|
||||
resolve: {
|
||||
extensions: ['.js', '.ts', '.jsx', '.tsx', '.css', '.json'],
|
||||
},
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
import type IForkTsCheckerWebpackPlugin from 'fork-ts-checker-webpack-plugin'
|
||||
import { DefinePlugin } from 'webpack'
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||
const ForkTsCheckerWebpackPlugin: typeof IForkTsCheckerWebpackPlugin = require('fork-ts-checker-webpack-plugin')
|
||||
|
||||
export const plugins = [
|
||||
new ForkTsCheckerWebpackPlugin({
|
||||
logger: 'webpack-infrastructure',
|
||||
}),
|
||||
new DefinePlugin({
|
||||
'process.env.TELEMETRY_WRITE_KEY': JSON.stringify(process.env.TELEMETRY_WRITE_KEY),
|
||||
}),
|
||||
]
|
||||
@@ -1,19 +0,0 @@
|
||||
import type { Configuration } from 'webpack'
|
||||
|
||||
import { rules } from './webpack.rules'
|
||||
import { plugins } from './webpack.plugins'
|
||||
|
||||
rules.push({
|
||||
test: /\.css$/,
|
||||
use: [{ loader: 'style-loader' }, { loader: 'css-loader' }, { loader: 'postcss-loader' }],
|
||||
})
|
||||
|
||||
export const rendererConfig: Configuration = {
|
||||
module: {
|
||||
rules,
|
||||
},
|
||||
plugins,
|
||||
resolve: {
|
||||
extensions: ['.js', '.ts', '.jsx', '.tsx', '.css'],
|
||||
},
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
import type { ModuleOptions } from 'webpack'
|
||||
|
||||
export const rules: Required<ModuleOptions>['rules'] = [
|
||||
// Add support for native node modules
|
||||
{
|
||||
// We're specifying native_modules in the test because the asset relocator loader generates a
|
||||
// "fake" .node file which is really a cjs file.
|
||||
test: /native_modules[/\\].+\.node$/,
|
||||
use: 'node-loader',
|
||||
},
|
||||
{
|
||||
test: /[/\\]node_modules[/\\].+\.(m?js|node)$/,
|
||||
parser: { amd: false },
|
||||
use: {
|
||||
loader: '@vercel/webpack-asset-relocator-loader',
|
||||
options: {
|
||||
outputAssetBase: 'native_modules',
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
test: /\.tsx?$/,
|
||||
exclude: /(node_modules|\.webpack)/,
|
||||
use: {
|
||||
loader: 'ts-loader',
|
||||
options: {
|
||||
transpileOnly: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
test: /\.svg$/,
|
||||
use: ['@svgr/webpack'],
|
||||
},
|
||||
]
|
||||
@@ -146,7 +146,6 @@ type Tensor interface {
|
||||
FromFloats([]float32)
|
||||
FromInts([]int32)
|
||||
|
||||
Neg(ctx Context) Tensor
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Sub(ctx Context, t2 Tensor) Tensor
|
||||
Mul(ctx Context, t2 Tensor) Tensor
|
||||
@@ -185,7 +184,6 @@ type Tensor interface {
|
||||
View(ctx Context, offset int, shape ...int) Tensor
|
||||
Permute(ctx Context, shape ...int) Tensor
|
||||
Contiguous(ctx Context, shape ...int) Tensor
|
||||
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
|
||||
|
||||
Pad(ctx Context, shape ...int) Tensor
|
||||
|
||||
@@ -198,6 +196,10 @@ type Tensor interface {
|
||||
Copy(ctx Context, t2 Tensor) Tensor
|
||||
Duplicate(ctx Context) Tensor
|
||||
|
||||
Slice(ctx Context, dim, low, high, step int) Tensor
|
||||
Chunk(ctx Context, dim int, size int) []Tensor
|
||||
ChunkSections(ctx Context, dim int, sections ...int) []Tensor
|
||||
|
||||
TopK(ctx Context, k int) Tensor
|
||||
Argsort(ctx Context) Tensor
|
||||
Mean(ctx Context) Tensor
|
||||
@@ -205,7 +207,6 @@ type Tensor interface {
|
||||
Stddev(ctx Context) Tensor
|
||||
Sqr(ctx Context) Tensor
|
||||
Sqrt(ctx Context) Tensor
|
||||
Clamp(ctx Context, min, max float32) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
|
||||
@@ -1137,13 +1137,6 @@ func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
@@ -1632,20 +1625,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
|
||||
var tt *C.struct_ggml_tensor
|
||||
switch len(strides) {
|
||||
case 0:
|
||||
tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
|
||||
case 1:
|
||||
tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
|
||||
default:
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
|
||||
return &Tensor{b: t.b, t: tt}
|
||||
}
|
||||
|
||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
|
||||
var kqMask *C.struct_ggml_tensor
|
||||
if mask != nil {
|
||||
@@ -1732,9 +1711,65 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
|
||||
// Slice returns a view of the tensor sliced along dim from low to high in step steps.
|
||||
// Slice panics if the dimension is invalid or the slice parameters are out of range.
|
||||
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.
|
||||
func (t *Tensor) Slice(ctx ml.Context, dim int, low, high, step int) ml.Tensor {
|
||||
if dim < 0 || dim >= C.GGML_MAX_DIMS {
|
||||
panic("invalid dimension")
|
||||
} else if low < 0 || high > t.Dim(dim) || low >= high || step < 1 {
|
||||
panic("invalid slice parameters")
|
||||
}
|
||||
|
||||
if dim == 0 && step > 1 {
|
||||
// dim=0,step>1 is a special case so handle it here first
|
||||
return t.View(ctx,
|
||||
low*t.Stride(0), 1,
|
||||
step*t.Stride(0), (high-low+1)/step,
|
||||
t.Stride(1), t.Dim(1),
|
||||
// preserve dim 3 by merging it into dim 2
|
||||
t.Stride(2), t.Dim(2)*t.Dim(3),
|
||||
).Contiguous(ctx, (high-low+1)/step, t.Dim(1), t.Dim(2), t.Dim(3))
|
||||
}
|
||||
|
||||
args := []int{
|
||||
low * t.Stride(dim), t.Dim(0),
|
||||
t.Stride(1), t.Dim(1),
|
||||
t.Stride(2), t.Dim(2),
|
||||
t.Stride(3), t.Dim(3),
|
||||
}
|
||||
|
||||
if step == 1 {
|
||||
args[dim*2+1] = high - low
|
||||
return t.View(ctx, args[0], args[1:]...)
|
||||
} else {
|
||||
args[dim*2] = step * t.Stride(dim)
|
||||
args[dim*2+1] = (high - low + 1) / step
|
||||
return t.View(ctx, args[0], args[1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
// Chunk the tensor into chunk sized tensors along dim. Each sub-tensor is a view of
|
||||
// the original.
|
||||
func (t *Tensor) Chunk(ctx ml.Context, dim, chunk int) []ml.Tensor {
|
||||
sections := make([]int, 0, t.Dim(dim)/chunk+1)
|
||||
for rest := t.Dim(dim); rest > 0; rest -= chunk {
|
||||
sections = append(sections, min(chunk, rest))
|
||||
}
|
||||
return t.ChunkSections(ctx, dim, sections...)
|
||||
}
|
||||
|
||||
// ChunkSections split the tensor into section sized tensors along dim. Each sub-tensor is a
|
||||
// view of the original. The size of the dim must equal the sum of sections.
|
||||
func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Tensor {
|
||||
var offset int
|
||||
s := make([]ml.Tensor, len(sections))
|
||||
for i, section := range sections {
|
||||
s[i] = t.Slice(ctx, dim, offset, offset+section, 1)
|
||||
offset += section
|
||||
}
|
||||
if offset != t.Dim(dim) {
|
||||
panic("sections do not sum to tensor dimension")
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
16
ml/backend/ggml/ggml/src/ggml-impl.h
vendored
@@ -693,6 +693,7 @@ GGML_API void ggml_dxgi_pdh_release();
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <array>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
@@ -708,6 +709,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
|
||||
}
|
||||
|
||||
// Return true if the edges in the graph match expectations.
|
||||
inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
|
||||
int start_idx,
|
||||
std::initializer_list<std::array<int, 3>> edges) {
|
||||
for (const auto & edge : edges) {
|
||||
int dst_node = edge[0];
|
||||
int src_idx = edge[1];
|
||||
int src_node = edge[2];
|
||||
if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// expose GGUF internals for test code
|
||||
GGML_API size_t gguf_type_size(enum gguf_type type);
|
||||
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
||||
|
||||
@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_ne02=%d", base, ne02);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
|
||||
814
ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp
vendored
@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint ncols;
|
||||
uint nrows;
|
||||
uint order;
|
||||
} p;
|
||||
|
||||
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
|
||||
dst_row[idx1] = tmp;
|
||||
}
|
||||
|
||||
void argsort(bool needs_bounds_check) {
|
||||
void argsort(bool needs_bounds_check, const uint row) {
|
||||
// bitonic sort
|
||||
const int col = int(gl_LocalInvocationID.x);
|
||||
const uint row = gl_WorkGroupID.y;
|
||||
|
||||
const uint row_offset = row * p.ncols;
|
||||
|
||||
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
|
||||
|
||||
void main() {
|
||||
if (p.ncols == BLOCK_SIZE) {
|
||||
argsort(false);
|
||||
uint row = gl_WorkGroupID.y;
|
||||
while (row < p.nrows) {
|
||||
argsort(false, row);
|
||||
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||
}
|
||||
} else {
|
||||
argsort(true);
|
||||
uint row = gl_WorkGroupID.y;
|
||||
while (row < p.nrows) {
|
||||
argsort(true, row);
|
||||
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
#if defined(DATA_A_MXFP4)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
|
||||
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
vec2 v0 = dequantize(ib, iqs, a_offset);
|
||||
@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
|
||||
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
|
||||
const uint scales = data_a[a_offset + ib].scales[scalesi];
|
||||
const vec2 d = vec2(data_a[a_offset + ib].d);
|
||||
const vec2 dm = vec2(data_a[a_offset + ib].dm);
|
||||
|
||||
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
||||
return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
|
||||
}
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(1, 0);
|
||||
@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint is = 2 * n + b; // 0..7
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||
|
||||
const vec2 loadd = vec2(data_a[a_offset + ib].d);
|
||||
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
@@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
|
||||
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
||||
|
||||
const vec2 loadd = vec2(data_a[a_offset + ib].d);
|
||||
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
|
||||
@@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
|
||||
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
|
||||
const f16vec2 d = bl.block.d;
|
||||
const f16vec2 dm = bl.block.dm;
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
|
||||
@@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
|
||||
qs = unpack8(qs)[idx & 1];
|
||||
|
||||
const uint scales = bl.block.scales[scalesi];
|
||||
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
|
||||
float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
|
||||
uint32_t qs = bl.block.qs[iqs];
|
||||
qs >>= shift;
|
||||
qs &= 0xF;
|
||||
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
|
||||
float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -26,7 +26,7 @@ void main() {
|
||||
const float d = e8m0_to_fp32(data_a[ib].e);
|
||||
|
||||
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
||||
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
|
||||
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
|
||||
data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
|
||||
data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,8 +24,8 @@ void main() {
|
||||
const uint ql_idx = 32 * ip + il;
|
||||
const uint8_t qs = data_a[i].qs[32 * ip + il];
|
||||
|
||||
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
|
||||
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
|
||||
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
|
||||
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
|
||||
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
|
||||
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
|
||||
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
|
||||
|
||||
@@ -20,8 +20,8 @@ void main() {
|
||||
const uint is = 2 * il;
|
||||
const uint n = 4;
|
||||
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
|
||||
|
||||
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
|
||||
const uint qs_idx = 32*il + n * ir;
|
||||
|
||||
@@ -19,8 +19,8 @@ void main() {
|
||||
const uint ir = tid % 16;
|
||||
const uint is = 2 * il;
|
||||
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
|
||||
|
||||
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
|
||||
const uint qs_idx = 32*il + 2 * ir;
|
||||
|
||||
@@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
|
||||
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
||||
|
||||
vec2 d = vec2(data_a[ib0 + i].d);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
|
||||
@@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
|
||||
}
|
||||
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
|
||||
temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
vec2 d = vec2(data_a[ib0 + i].d);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
|
||||
|
||||
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||
@@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
|
||||
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
|
||||
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
|
||||
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
|
||||
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
vec2 d = vec2(data_a[ib0 + i].d);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
|
||||
|
||||
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||
@@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
|
||||
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
|
||||
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
|
||||
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
|
||||
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[BN];
|
||||
uint _ne1;
|
||||
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
shared uvec4 ballots_sh[NUM_WARPS];
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_LocalInvocationIndex;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
uint total = 0;
|
||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||
if (k == gl_SubgroupID) {
|
||||
subgroup_base = total;
|
||||
}
|
||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif // MUL_MAT_ID_USE_SUBGROUPS
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#ifdef COOPMAT
|
||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
|
||||
#include "mul_mm_id_funcs.glsl"
|
||||
#include "mul_mm_funcs.glsl"
|
||||
|
||||
void main() {
|
||||
|
||||
@@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
|
||||
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
|
||||
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
|
||||
const uint scalesi = iqs / 8; // 0..15
|
||||
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||
|
||||
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
|
||||
const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
|
||||
const uint scales = data_a[ib].scales[scalesi];
|
||||
const vec2 d = vec2(data_a[ib].d);
|
||||
const vec2 dm = vec2(data_a[ib].dm);
|
||||
|
||||
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
||||
const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
@@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint is = 2 * n + b; // 0..7
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||
|
||||
const vec2 loadd = vec2(data_a[ib].d);
|
||||
const vec2 loadd = vec2(data_a[ib].dm);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
@@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
|
||||
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
||||
|
||||
const vec2 loadd = vec2(data_a[ib].d);
|
||||
const vec2 loadd = vec2(data_a[ib].dm);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
@@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = (idx & 0x07) * 2;
|
||||
|
||||
const float d = e8m0_to_fp32(data_a[ib].e);
|
||||
const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;
|
||||
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
||||
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[BN];
|
||||
uint _ne1;
|
||||
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
shared uvec4 ballots_sh[NUM_WARPS];
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_LocalInvocationIndex;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
uint total = 0;
|
||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||
if (k == gl_SubgroupID) {
|
||||
subgroup_base = total;
|
||||
}
|
||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif // MUL_MAT_ID_USE_SUBGROUPS
|
||||
#endif // MUL_MAT_ID
|
||||
@@ -10,10 +10,9 @@
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#if defined(MUL_MAT_ID_USE_SUBGROUPS)
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -24,7 +23,10 @@
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||
#endif
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||
#endif
|
||||
@@ -76,40 +78,31 @@ layout (constant_id = 10) const uint WARP = 32;
|
||||
|
||||
#define BK 32
|
||||
|
||||
#ifdef COOPMAT
|
||||
#define SHMEM_STRIDE (BK / 4 + 4)
|
||||
#else
|
||||
#define SHMEM_STRIDE (BK / 4 + 1)
|
||||
#endif
|
||||
#define MMQ_SHMEM
|
||||
|
||||
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
|
||||
|
||||
#ifndef COOPMAT
|
||||
#if QUANT_AUXF == 1
|
||||
shared FLOAT_TYPE buf_a_dm[BM];
|
||||
#else
|
||||
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
|
||||
#endif
|
||||
#endif
|
||||
|
||||
shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
|
||||
#ifndef COOPMAT
|
||||
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
|
||||
#endif
|
||||
|
||||
#define LOAD_VEC_A (4 * QUANT_R)
|
||||
#define LOAD_VEC_B 16
|
||||
#include "mul_mmq_shmem_types.glsl"
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[4096];
|
||||
#endif // MUL_MAT_ID
|
||||
#define BK_STEP 1
|
||||
#else
|
||||
#ifndef BK_STEP
|
||||
#define BK_STEP 4
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Shared memory cache
|
||||
shared block_a_cache buf_a[BM * BK_STEP];
|
||||
shared block_b_cache buf_b[BN * BK_STEP];
|
||||
// Register cache
|
||||
block_a_cache cache_a[WMITER * TM];
|
||||
block_b_cache cache_b;
|
||||
|
||||
#define LOAD_VEC_A (4 * QUANT_R_MMQ)
|
||||
#define LOAD_VEC_B 16
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
#ifdef COOPMAT
|
||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
|
||||
#include "mul_mm_id_funcs.glsl"
|
||||
#include "mul_mmq_funcs.glsl"
|
||||
|
||||
void main() {
|
||||
@@ -139,26 +132,12 @@ void main() {
|
||||
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const uint WSUBM = WM / WMITER;
|
||||
const uint WSUBN = WN / WNITER;
|
||||
|
||||
#ifdef COOPMAT
|
||||
const uint warp_i = gl_SubgroupID;
|
||||
|
||||
const uint tiw = gl_SubgroupInvocationID;
|
||||
|
||||
const uint cms_per_row = WM / TM;
|
||||
const uint cms_per_col = WN / TN;
|
||||
|
||||
const uint storestride = WARP / TM;
|
||||
const uint store_r = tiw % TM;
|
||||
const uint store_c = tiw / TM;
|
||||
#else
|
||||
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
||||
|
||||
const uint tiw = gl_LocalInvocationID.x % WARP;
|
||||
|
||||
const uint tiwr = tiw % (WSUBM / TM);
|
||||
const uint tiwc = tiw / (WSUBM / TM);
|
||||
#endif
|
||||
|
||||
const uint warp_r = warp_i % (BM / WM);
|
||||
const uint warp_c = warp_i / (BM / WM);
|
||||
@@ -172,17 +151,27 @@ void main() {
|
||||
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint _ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
if (bitCount(p.nei0) == 1) {
|
||||
load_row_ids(expert_idx, true, ic);
|
||||
} else {
|
||||
load_row_ids(expert_idx, false, ic);
|
||||
}
|
||||
#else
|
||||
_ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
|
||||
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||
row_ids[_ne1] = u16vec2(ii0, ii1);
|
||||
if (_ne1 >= ic * BN) {
|
||||
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
#endif
|
||||
|
||||
// Workgroup has no work
|
||||
if (ic * BN >= _ne1) return;
|
||||
@@ -209,159 +198,70 @@ void main() {
|
||||
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
||||
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
||||
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
||||
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
||||
}
|
||||
#else
|
||||
int32_t cache_a_qs[WMITER * TM * BK / 4];
|
||||
|
||||
int32_t cache_b_qs[TN * BK / 4];
|
||||
|
||||
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = ACC_TYPE(0.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
FLOAT_TYPE cache_a_dm[WMITER * TM];
|
||||
#else
|
||||
FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
|
||||
#endif
|
||||
|
||||
FLOAT_TYPE_VEC2 cache_b_ds[TN];
|
||||
|
||||
for (uint block = start_k; block < end_k; block += BK) {
|
||||
for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
|
||||
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
|
||||
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
|
||||
const uint iqs = loadr_a;
|
||||
const uint buf_ib = loadc_a + l;
|
||||
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
|
||||
const uint iqs = loadr_a;
|
||||
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
buf_a_dm[buf_ib] = get_d(ib);
|
||||
#else
|
||||
buf_a_dm[buf_ib] = get_dm(ib);
|
||||
#endif
|
||||
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
|
||||
block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
|
||||
}
|
||||
#if QUANT_R == 1
|
||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
|
||||
#else
|
||||
const i32vec2 vals = repack(ib, iqs);
|
||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
|
||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
||||
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x7;
|
||||
#else
|
||||
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
|
||||
const uint ib_outer = ib / 4;
|
||||
const uint ib_inner = ib % 4;
|
||||
|
||||
const uint iqs = loadr_b;
|
||||
#endif
|
||||
|
||||
const uint buf_ib = loadc_b + l;
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[buf_ib];
|
||||
const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
|
||||
#else
|
||||
const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
|
||||
#endif
|
||||
const uint iqs = loadr_b;
|
||||
|
||||
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
|
||||
block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
|
||||
}
|
||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a_ib += 1;
|
||||
pos_b_ib += 1;
|
||||
pos_a_ib += BK_STEP;
|
||||
pos_b_ib += BK_STEP;
|
||||
|
||||
#ifdef COOPMAT
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
const uint ib_a = warp_r * WM + cm_row * TM;
|
||||
// Load from shared into cache
|
||||
coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
||||
|
||||
// TODO: only cache values that are actually needed
|
||||
[[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
|
||||
cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const uint ib_b = warp_c * WN + cm_col * TN;
|
||||
coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
// TODO: only cache values that are actually needed
|
||||
[[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
|
||||
cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
|
||||
}
|
||||
|
||||
cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
|
||||
cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
|
||||
}
|
||||
|
||||
coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (uint k_step = 0; k_step < BK_STEP; k_step++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
|
||||
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
|
||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
||||
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
|
||||
}
|
||||
const uint reg_ib = wsir * TM + cr;
|
||||
const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
|
||||
|
||||
block_a_to_registers(reg_ib, k_step * BM + buf_ib);
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
|
||||
cache_b_ds[cc] = buf_b_ds[ib];
|
||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
||||
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
|
||||
}
|
||||
}
|
||||
const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
|
||||
block_b_to_registers(ib);
|
||||
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint cache_a_idx = wsir * TM + cr;
|
||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
||||
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
|
||||
cache_b_qs[cc * (BK / 4) + idx_k]);
|
||||
}
|
||||
|
||||
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
|
||||
sums[sums_idx] += mmq_dot_product(cache_a_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
barrier();
|
||||
}
|
||||
@@ -373,54 +273,6 @@ void main() {
|
||||
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#ifdef MUL_MAT_ID
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
|
||||
const uint row_i = dc + cm_col * TN + col + store_c;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
|
||||
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
|
||||
|
||||
if (is_aligned && is_in_bounds) {
|
||||
// Full coopMat is within bounds and stride_d is aligned with 16B
|
||||
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
|
||||
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
} else if (is_in_bounds) {
|
||||
// Full coopMat is within bounds, but stride_d is not aligned
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
|
||||
// Partial coopMat is within bounds
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // MUL_MAT_ID
|
||||
#else
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
@@ -431,19 +283,21 @@ void main() {
|
||||
const uint row_i = dc_warp + cc;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
#endif // MUL_MAT_ID
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
|
||||
#ifdef MUL_MAT_ID
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
if (dr_warp + cr < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
|
||||
}
|
||||
#else
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
|
||||
}
|
||||
#endif // MUL_MAT_ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // COOPMAT
|
||||
}
|
||||
|
||||
@@ -6,41 +6,89 @@
|
||||
|
||||
// Each iqs value maps to a 32-bit integer
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
// 2-byte loads for Q4_0 blocks (18 bytes)
|
||||
// 4-byte loads for Q4_1 blocks (20 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
|
||||
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
|
||||
data_a[ib].qs[iqs * 2 + 1]);
|
||||
#ifdef DATA_A_Q4_0
|
||||
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||
const uint32_t vui = pack32(quants);
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
#else // DATA_A_Q4_1
|
||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef DATA_A_Q4_0
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_1)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
|
||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
#else // DATA_A_Q4_1
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0)
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
#ifdef DATA_A_Q4_0
|
||||
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
}
|
||||
#else // DATA_A_Q4_1
|
||||
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
const uint32_t vui = cache_a[ib_a].qs[iqs];
|
||||
const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
|
||||
const int32_t qs_b0 = cache_b.qs[iqs];
|
||||
const int32_t qs_b1 = cache_b.qs[iqs + 4];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
|
||||
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
|
||||
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
// 2-byte loads for Q5_0 blocks (22 bytes)
|
||||
// 4-byte loads for Q5_1 blocks (24 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
|
||||
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
|
||||
data_a[ib].qs[iqs * 2 + 1]);
|
||||
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||
const uint32_t vui = pack32(quants);
|
||||
const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
|
||||
#ifdef DATA_A_Q5_0
|
||||
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
|
||||
#else // DATA_A_Q5_1
|
||||
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
|
||||
#endif
|
||||
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||
|
||||
@@ -50,40 +98,457 @@ i32vec2 repack(uint ib, uint iqs) {
|
||||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
#ifdef DATA_A_Q5_0
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_1)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
|
||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
|
||||
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||
|
||||
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
||||
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
||||
|
||||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
#else // DATA_A_Q5_1
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
#ifdef DATA_A_Q5_0
|
||||
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1]));
|
||||
}
|
||||
#else // DATA_A_Q5_1
|
||||
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
buf_a[buf_ib].qh = data_a_packed32[ib].qh;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
cache_a[reg_ib].qh = buf_a[buf_ib].qh;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
const uint32_t vui = cache_a[ib_a].qs[iqs];
|
||||
const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
|
||||
const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
|
||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||
const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
||||
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
||||
|
||||
const int32_t qs_b0 = cache_b.qs[iqs];
|
||||
const int32_t qs_b1 = cache_b.qs[iqs + 4];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
|
||||
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
// 2-byte loads for Q8_0 blocks (34 bytes)
|
||||
int32_t repack(uint ib, uint iqs) {
|
||||
// Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
|
||||
return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
|
||||
data_a[ib].qs[iqs * 2 + 1]));
|
||||
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * da * dsb.x);
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
const int32_t qs_b = cache_b.qs[iqs];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, qs_b);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
// 1-byte loads for mxfp4 blocks (17 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
||||
data_a[ib].qs[iqs * 4 + 1],
|
||||
data_a[ib].qs[iqs * 4 + 2],
|
||||
data_a[ib].qs[iqs * 4 + 3]));
|
||||
|
||||
return i32vec2( quants & 0x0F0F0F0F,
|
||||
(quants >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * dsb.x * float(q_sum));
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
||||
data_a[ib].qs[iqs * 4 + 1],
|
||||
data_a[ib].qs[iqs * 4 + 2],
|
||||
data_a[ib].qs[iqs * 4 + 3]));
|
||||
|
||||
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
|
||||
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
|
||||
|
||||
buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
|
||||
buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].d = buf_a[buf_ib].d;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
|
||||
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
|
||||
#if defined(DATA_A_Q2_K)
|
||||
// 4-byte loads for Q2_K blocks (84 bytes)
|
||||
int32_t repack(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||
|
||||
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
|
||||
}
|
||||
|
||||
uint8_t get_scale(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
return data_a[ib_k].scales[iqs_k / 4];
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
||||
|
||||
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||
|
||||
// Repack 4x4 quants into one int
|
||||
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
|
||||
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
|
||||
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
|
||||
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
|
||||
|
||||
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
cache_a[reg_ib].scales = buf_a[buf_ib].scales;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t sum_d = 0;
|
||||
int32_t sum_m = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
|
||||
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
|
||||
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
|
||||
|
||||
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
|
||||
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
|
||||
}
|
||||
|
||||
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q3_K)
|
||||
// 2-byte loads for Q3_K blocks (110 bytes)
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint hm_idx = iqs * QUANT_R_MMQ;
|
||||
const uint iqs_k = (ib % 8) * 8 + hm_idx;
|
||||
|
||||
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||
const uint hm_shift = iqs_k / 8;
|
||||
|
||||
// Repack 2x4 quants into one int
|
||||
// Add the 3rd bit instead of subtracting it to allow packing the quants
|
||||
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
|
||||
(pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
|
||||
|
||||
if (iqs == 0) {
|
||||
const uint is = iqs_k / 4;
|
||||
const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
|
||||
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
|
||||
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
float result = 0.0;
|
||||
int32_t q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
// Subtract 4 from the quants to correct the 3rd bit offset
|
||||
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
|
||||
q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
|
||||
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
||||
|
||||
return ACC_TYPE(cache_b.ds.x * result);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
|
||||
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
||||
|
||||
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
|
||||
|
||||
// Repack 2x4 quants into one int
|
||||
#if defined(DATA_A_Q4_K)
|
||||
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
|
||||
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
|
||||
|
||||
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
|
||||
#else // defined(DATA_A_Q5_K)
|
||||
const uint qh_idx = iqs * QUANT_R_MMQ;
|
||||
const uint qh_shift = iqs_k / 8;
|
||||
|
||||
buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
|
||||
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
|
||||
#endif
|
||||
|
||||
|
||||
if (iqs == 0) {
|
||||
// Scale index
|
||||
const uint is = iqs_k / 8;
|
||||
u8vec2 scale_dm;
|
||||
if (is < 4) {
|
||||
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
|
||||
} else {
|
||||
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
|
||||
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
|
||||
}
|
||||
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
#if defined(DATA_A_Q4_K)
|
||||
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
|
||||
#else // defined(DATA_A_Q5_K)
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
#endif
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_outer = ib / 4;
|
||||
const uint ib_inner = ib % 4;
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
||||
}
|
||||
|
||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
|
||||
}
|
||||
|
||||
void block_b_to_registers(const uint ib) {
|
||||
cache_b.ds = buf_b[ib].ds;
|
||||
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
|
||||
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q6_K)
|
||||
// 2-byte loads for Q6_K blocks (210 bytes)
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
|
||||
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
|
||||
|
||||
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
|
||||
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
|
||||
|
||||
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
|
||||
|
||||
if (iqs == 0) {
|
||||
const uint is = iqs_k / 4;
|
||||
const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
|
||||
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
float result = 0.0;
|
||||
int32_t q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
|
||||
q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
||||
|
||||
return ACC_TYPE(cache_b.ds.x * result);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
@@ -103,3 +568,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q2_K)
|
||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
const uint ib_k = ib / 8;
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
uint32_t qh;
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
uint32_t qh;
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
#define QUANT_R_MMQ 1
|
||||
// AMD likes 4, Intel likes 1 and Nvidia likes 2
|
||||
// #define BK_STEP 1
|
||||
struct block_a_cache {
|
||||
int32_t qs[32/4];
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE d;
|
||||
};
|
||||
#elif defined(DATA_A_Q2_K)
|
||||
#define QUANT_R_MMQ 4
|
||||
struct block_a_cache {
|
||||
uint32_t qs[2];
|
||||
u8vec2 scales;
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[4];
|
||||
FLOAT_TYPE_VEC2 d_scales;
|
||||
};
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[4];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
#define QUANT_R_MMQ 1
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
#define QUANT_R_MMQ 1
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 d_scales;
|
||||
};
|
||||
#endif
|
||||
|
||||
struct block_b_cache
|
||||
{
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 ds;
|
||||
};
|
||||
@@ -10,6 +10,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer Y {int data_pos[];};
|
||||
layout (binding = 2) readonly buffer Z {float data_ff[];};
|
||||
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
|
||||
layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint ncols;
|
||||
@@ -27,6 +28,7 @@ layout (push_constant) uniform parameter {
|
||||
uint s2;
|
||||
int sections[4];
|
||||
uint is_back;
|
||||
uint set_rows_stride;
|
||||
} p;
|
||||
|
||||
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
||||
|
||||
@@ -16,12 +16,19 @@ void main() {
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0/2;
|
||||
uint idst = row_dst*ne0 + i0/2;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS..
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
|
||||
if (p.set_rows_stride != 0) {
|
||||
idst = row_x*ne0 + i0/2;
|
||||
idst += data_i[channel_x].x * p.set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
||||
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
||||
data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]);
|
||||
data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -16,12 +16,19 @@ void main() {
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0;
|
||||
uint idst = row_dst*ne0 + i0;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS..
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
|
||||
if (p.set_rows_stride != 0) {
|
||||
idst = row_x*ne0 + i0;
|
||||
idst += data_i[channel_x].x * p.set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + 0] = data_a[ix + 0];
|
||||
data_d[idst + 1] = data_a[ix + 1];
|
||||
data_d[idst + 0] = D_TYPE(data_a[ix + 0]);
|
||||
data_d[idst + 1] = D_TYPE(data_a[ix + 1]);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_rows;
|
||||
uint n_expert_used;
|
||||
float clamp_min;
|
||||
float clamp_max;
|
||||
};
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||
layout(constant_id = 1) const uint n_experts = 512;
|
||||
layout(constant_id = 2) const bool with_norm = true;
|
||||
layout(constant_id = 3) const bool late_softmax = false;
|
||||
|
||||
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
||||
|
||||
@@ -25,6 +28,52 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];};
|
||||
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
||||
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
|
||||
|
||||
const float INFINITY = 1.0 / 0.0;
|
||||
|
||||
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
||||
void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
|
||||
float max_val = -INFINITY;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const uint idx = lane + i * WARP_SIZE;
|
||||
const bool is_active = !use_limit || (idx < limit);
|
||||
if (is_active) {
|
||||
max_val = max(max_val, vals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
max_val = subgroupMax(max_val);
|
||||
|
||||
float sum = 0.f;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const uint idx = lane + i * WARP_SIZE;
|
||||
const bool is_active = !use_limit || (idx < limit);
|
||||
if (is_active) {
|
||||
const float val = exp(vals[i] - max_val);
|
||||
vals[i] = val;
|
||||
sum += val;
|
||||
} else {
|
||||
vals[i] = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
sum = subgroupAdd(sum);
|
||||
|
||||
const float inv_sum = 1.0f / sum;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const uint idx = lane + i * WARP_SIZE;
|
||||
const bool is_active = !use_limit || (idx < limit);
|
||||
if (is_active) {
|
||||
vals[i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
||||
if (row >= n_rows) {
|
||||
@@ -35,43 +84,16 @@ void main() {
|
||||
const uint weights_offset = n_expert_used * row;
|
||||
const uint ids_offset = n_experts * row;
|
||||
|
||||
float logits_r[experts_per_thread];
|
||||
|
||||
const float INFINITY = 1.0 / 0.0;
|
||||
float wt[experts_per_thread];
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const uint expert = i + gl_LocalInvocationID.x;
|
||||
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
|
||||
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
|
||||
}
|
||||
|
||||
float max_val = logits_r[0];
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
max_val = max(val, max_val);
|
||||
}
|
||||
|
||||
max_val = subgroupMax(max_val);
|
||||
|
||||
float wt[experts_per_thread];
|
||||
float tmp = 0.f;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
wt[i] = exp(val - max_val);
|
||||
tmp += wt[i];
|
||||
}
|
||||
|
||||
tmp = subgroupAdd(tmp);
|
||||
|
||||
const float inv_sum = 1.0f / tmp;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
wt[i] = wt[i] * inv_sum;
|
||||
if (!late_softmax) {
|
||||
softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
|
||||
}
|
||||
|
||||
// at this point, each thread holds a portion of softmax,
|
||||
@@ -82,6 +104,11 @@ void main() {
|
||||
|
||||
float output_weights[experts_per_thread];
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
output_weights[i] = 0.f;
|
||||
}
|
||||
|
||||
for (int k = 0; k < n_expert_used; k++) {
|
||||
float max_val = wt[0];
|
||||
uint max_expert = gl_LocalInvocationID.x;
|
||||
@@ -121,6 +148,7 @@ void main() {
|
||||
|
||||
if (with_norm) {
|
||||
wt_sum = subgroupAdd(wt_sum);
|
||||
wt_sum = clamp(wt_sum, clamp_min, clamp_max);
|
||||
const float inv_sum = 1.0f / wt_sum;
|
||||
|
||||
[[unroll]]
|
||||
@@ -129,6 +157,10 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
if (late_softmax) {
|
||||
softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
|
||||
|
||||
@@ -66,6 +66,7 @@ struct block_q4_0_packed16
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_q4_0
|
||||
#define A_TYPE_PACKED16 block_q4_0_packed16
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q4_1 32
|
||||
@@ -98,6 +99,7 @@ struct block_q4_1_packed32
|
||||
#define A_TYPE block_q4_1
|
||||
#define A_TYPE_PACKED16 block_q4_1_packed16
|
||||
#define A_TYPE_PACKED32 block_q4_1_packed32
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q5_0 32
|
||||
@@ -123,6 +125,7 @@ struct block_q5_0_packed16
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_q5_0
|
||||
#define A_TYPE_PACKED16 block_q5_0_packed16
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q5_1 32
|
||||
@@ -158,6 +161,7 @@ struct block_q5_1_packed32
|
||||
#define A_TYPE block_q5_1
|
||||
#define A_TYPE_PACKED16 block_q5_1_packed16
|
||||
#define A_TYPE_PACKED32 block_q5_1_packed32
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q8_0 32
|
||||
@@ -186,6 +190,7 @@ struct block_q8_0_packed32
|
||||
#define A_TYPE block_q8_0
|
||||
#define A_TYPE_PACKED16 block_q8_0_packed16
|
||||
#define A_TYPE_PACKED32 block_q8_0_packed32
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q8_1 32
|
||||
@@ -226,21 +231,21 @@ struct block_q2_K
|
||||
{
|
||||
uint8_t scales[QUANT_K_Q2_K/16];
|
||||
uint8_t qs[QUANT_K_Q2_K/4];
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
};
|
||||
|
||||
struct block_q2_K_packed16
|
||||
{
|
||||
uint16_t scales[QUANT_K_Q2_K/16/2];
|
||||
uint16_t qs[QUANT_K_Q2_K/4/2];
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
};
|
||||
|
||||
struct block_q2_K_packed32
|
||||
{
|
||||
uint32_t scales[QUANT_K_Q2_K/16/4];
|
||||
uint32_t qs[QUANT_K_Q2_K/4/4];
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
};
|
||||
|
||||
#if defined(DATA_A_Q2_K)
|
||||
@@ -249,6 +254,8 @@ struct block_q2_K_packed32
|
||||
#define A_TYPE block_q2_K
|
||||
#define A_TYPE_PACKED16 block_q2_K_packed16
|
||||
#define A_TYPE_PACKED32 block_q2_K_packed32
|
||||
#define SCALES_PER_32 2
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q3_K 256
|
||||
@@ -274,27 +281,28 @@ struct block_q3_K_packed16
|
||||
#define QUANT_R 1
|
||||
#define A_TYPE block_q3_K
|
||||
#define A_TYPE_PACKED16 block_q3_K_packed16
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q4_K 256
|
||||
|
||||
struct block_q4_K
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint8_t scales[3*QUANT_K_Q4_K/64];
|
||||
uint8_t qs[QUANT_K_Q4_K/2];
|
||||
};
|
||||
|
||||
struct block_q4_K_packed16
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint16_t scales[3*QUANT_K_Q4_K/64/2];
|
||||
uint16_t qs[QUANT_K_Q4_K/2/2];
|
||||
};
|
||||
|
||||
struct block_q4_K_packed32
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint32_t scales[3*QUANT_K_Q4_K/64/4];
|
||||
uint32_t qs[QUANT_K_Q4_K/2/4];
|
||||
};
|
||||
@@ -310,13 +318,14 @@ struct block_q4_K_packed128
|
||||
#define A_TYPE block_q4_K
|
||||
#define A_TYPE_PACKED16 block_q4_K_packed16
|
||||
#define A_TYPE_PACKED32 block_q4_K_packed32
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q5_K 256
|
||||
|
||||
struct block_q5_K
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint8_t scales[12];
|
||||
uint8_t qh[QUANT_K_Q5_K/8];
|
||||
uint8_t qs[QUANT_K_Q5_K/2];
|
||||
@@ -324,12 +333,20 @@ struct block_q5_K
|
||||
|
||||
struct block_q5_K_packed16
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint16_t scales[12/2];
|
||||
uint16_t qh[QUANT_K_Q5_K/8/2];
|
||||
uint16_t qs[QUANT_K_Q5_K/2/2];
|
||||
};
|
||||
|
||||
struct block_q5_K_packed32
|
||||
{
|
||||
f16vec2 dm;
|
||||
uint32_t scales[12/4];
|
||||
uint32_t qh[QUANT_K_Q5_K/8/4];
|
||||
uint32_t qs[QUANT_K_Q5_K/2/4];
|
||||
};
|
||||
|
||||
struct block_q5_K_packed128
|
||||
{
|
||||
uvec4 q5k[11];
|
||||
@@ -340,6 +357,8 @@ struct block_q5_K_packed128
|
||||
#define QUANT_R 1
|
||||
#define A_TYPE block_q5_K
|
||||
#define A_TYPE_PACKED16 block_q5_K_packed16
|
||||
#define A_TYPE_PACKED32 block_q5_K_packed32
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q6_K 256
|
||||
@@ -356,7 +375,7 @@ struct block_q6_K_packed16
|
||||
{
|
||||
uint16_t ql[QUANT_K_Q6_K/2/2];
|
||||
uint16_t qh[QUANT_K_Q6_K/4/2];
|
||||
int8_t scales[QUANT_K_Q6_K/16];
|
||||
int16_t scales[QUANT_K_Q6_K/16/2];
|
||||
float16_t d;
|
||||
};
|
||||
|
||||
@@ -365,6 +384,7 @@ struct block_q6_K_packed16
|
||||
#define QUANT_R 1
|
||||
#define A_TYPE block_q6_K
|
||||
#define A_TYPE_PACKED16 block_q6_K_packed16
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
// IQuants
|
||||
@@ -1363,18 +1383,11 @@ struct block_mxfp4
|
||||
uint8_t qs[QUANT_K_MXFP4/2];
|
||||
};
|
||||
|
||||
//struct block_mxfp4_packed16
|
||||
//{
|
||||
// uint8_t e;
|
||||
// uint16_t qs[QUANT_K_MXFP4/2/2];
|
||||
//};
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
#define QUANT_K QUANT_K_MXFP4
|
||||
#define QUANT_R QUANT_R_MXFP4
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_mxfp4
|
||||
//#define A_TYPE_PACKED16 block_mxfp4_packed16
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
|
||||
@@ -1397,12 +1410,12 @@ void init_iq_shmem(uvec3 wgsize)
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
const FLOAT_TYPE kvalues_mxfp4_const[16] = {
|
||||
FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
|
||||
FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
|
||||
const int8_t kvalues_mxfp4_const[16] = {
|
||||
int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12),
|
||||
int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12),
|
||||
};
|
||||
|
||||
shared FLOAT_TYPE kvalues_mxfp4[16];
|
||||
shared int8_t kvalues_mxfp4[16];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
|
||||