mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-23 23:18:26 +00:00
Compare commits
117 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cbb2f09129 | ||
|
|
8852220f59 | ||
|
|
7325791599 | ||
|
|
522c11a763 | ||
|
|
0fadeffaee | ||
|
|
49a9c9ba6a | ||
|
|
1c094038bc | ||
|
|
a013693f80 | ||
|
|
f6a016f49d | ||
|
|
45c4739374 | ||
|
|
2dd029de12 | ||
|
|
903b1fc97f | ||
|
|
89eb795293 | ||
|
|
7e3ea813c1 | ||
|
|
7b95087b9d | ||
|
|
971d62595a | ||
|
|
ffbe8e076d | ||
|
|
2c639431b1 | ||
|
|
aacd1cb394 | ||
|
|
e3731fb160 | ||
|
|
8dbc9e7b68 | ||
|
|
abe67acf8a | ||
|
|
ff2011376d | ||
|
|
4ff8a691bc | ||
|
|
1b308e1d2a | ||
|
|
bd6c1d6b49 | ||
|
|
3af5d3b738 | ||
|
|
7730895158 | ||
|
|
de9ecfd01c | ||
|
|
95fdd8d619 | ||
|
|
9f7822851c | ||
|
|
9b2035d194 | ||
|
|
93d45d7a04 | ||
|
|
709f842457 | ||
|
|
2dfb74410d | ||
|
|
1eb5e75972 | ||
|
|
3475d915cb | ||
|
|
48e78e9be1 | ||
|
|
a838421ea3 | ||
|
|
1c4e85b4df | ||
|
|
dac4f17fea | ||
|
|
56b8fb024c | ||
|
|
b95693056c | ||
|
|
c34fc64688 | ||
|
|
7cf6f18c1f | ||
|
|
bbbb6b2a01 | ||
|
|
76f88caf43 | ||
|
|
2bccf8c624 | ||
|
|
0c5e5f6630 | ||
|
|
d475d1f081 | ||
|
|
d2f334c1f7 | ||
|
|
603ceefaa6 | ||
|
|
e082d60a24 | ||
|
|
5dae738067 | ||
|
|
0c78723174 | ||
|
|
5a41d69b2a | ||
|
|
c146a138e3 | ||
|
|
31b8c6a214 | ||
|
|
2dd3f3c67c | ||
|
|
9191dfaf05 | ||
|
|
1108d8b34e | ||
|
|
7837a5bc7e | ||
|
|
0a844f8e96 | ||
|
|
a03223b86f | ||
|
|
0cf7794b16 | ||
|
|
854d40edc5 | ||
|
|
84a2cedf18 | ||
|
|
3f30836734 | ||
|
|
cc9555aff0 | ||
|
|
20aee96706 | ||
|
|
18b5958d46 | ||
|
|
5317202c38 | ||
|
|
d771043e88 | ||
|
|
f8f1071818 | ||
|
|
d3e0a0dee4 | ||
|
|
554172759c | ||
|
|
5b6a8e6001 | ||
|
|
467bbc0dd5 | ||
|
|
6d9f9323c5 | ||
|
|
0c2489605d | ||
|
|
8b1b89a984 | ||
|
|
58a46a6e73 | ||
|
|
47e272c35a | ||
|
|
417a81fda3 | ||
|
|
dba62ff3a5 | ||
|
|
d70e935526 | ||
|
|
5c1063df7f | ||
|
|
cb485b2019 | ||
|
|
b2af50960f | ||
|
|
eac5b8bfbd | ||
|
|
604e43b28d | ||
|
|
53985b3c4d | ||
|
|
b6e02cbbd2 | ||
|
|
91935631ac | ||
|
|
8de30b568a | ||
|
|
485da9fd35 | ||
|
|
0796d79d19 | ||
|
|
92981ae3f2 | ||
|
|
8ed1adf3db | ||
|
|
440a3823a6 | ||
|
|
718961de68 | ||
|
|
330f62a7fa | ||
|
|
584e2d646f | ||
|
|
1fd4cb87b2 | ||
|
|
4aba2e8b72 | ||
|
|
2f36d769aa | ||
|
|
399eacf486 | ||
|
|
231cc878cb | ||
|
|
aa676b313f | ||
|
|
dd0ed0ef17 | ||
|
|
d5649821ae | ||
|
|
4cea757e70 | ||
|
|
a751bc159c | ||
|
|
5d31242fbf | ||
|
|
d7fd72193f | ||
|
|
72ff5b9d8c | ||
|
|
ce29f695b4 |
4
.gitattributes
vendored
4
.gitattributes
vendored
@@ -15,8 +15,12 @@ ml/backend/**/*.cu linguist-vendored
|
|||||||
ml/backend/**/*.cuh linguist-vendored
|
ml/backend/**/*.cuh linguist-vendored
|
||||||
ml/backend/**/*.m linguist-vendored
|
ml/backend/**/*.m linguist-vendored
|
||||||
ml/backend/**/*.metal linguist-vendored
|
ml/backend/**/*.metal linguist-vendored
|
||||||
|
ml/backend/**/*.comp linguist-vendored
|
||||||
|
ml/backend/**/*.glsl linguist-vendored
|
||||||
ml/backend/**/CMakeLists.txt linguist-vendored
|
ml/backend/**/CMakeLists.txt linguist-vendored
|
||||||
|
|
||||||
|
app/webview linguist-vendored
|
||||||
|
|
||||||
llama/build-info.cpp linguist-generated
|
llama/build-info.cpp linguist-generated
|
||||||
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated
|
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated
|
||||||
|
|
||||||
|
|||||||
18
.github/workflows/release.yaml
vendored
18
.github/workflows/release.yaml
vendored
@@ -16,13 +16,15 @@ jobs:
|
|||||||
outputs:
|
outputs:
|
||||||
GOFLAGS: ${{ steps.goflags.outputs.GOFLAGS }}
|
GOFLAGS: ${{ steps.goflags.outputs.GOFLAGS }}
|
||||||
VERSION: ${{ steps.goflags.outputs.VERSION }}
|
VERSION: ${{ steps.goflags.outputs.VERSION }}
|
||||||
|
vendorsha: ${{ steps.changes.outputs.vendorsha }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set environment
|
- name: Set environment
|
||||||
id: goflags
|
id: goflags
|
||||||
run: |
|
run: |
|
||||||
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
|
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" | tee -a $GITHUB_OUTPUT
|
||||||
echo VERSION="${GITHUB_REF_NAME#v}" >>$GITHUB_OUTPUT
|
echo VERSION="${GITHUB_REF_NAME#v}" | tee -a $GITHUB_OUTPUT
|
||||||
|
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||||
|
|
||||||
darwin-build:
|
darwin-build:
|
||||||
runs-on: macos-14-xlarge
|
runs-on: macos-14-xlarge
|
||||||
@@ -53,6 +55,9 @@ jobs:
|
|||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version-file: go.mod
|
||||||
|
cache-dependency-path: |
|
||||||
|
go.sum
|
||||||
|
Makefile.sync
|
||||||
- run: |
|
- run: |
|
||||||
./scripts/build_darwin.sh
|
./scripts/build_darwin.sh
|
||||||
- name: Log build results
|
- name: Log build results
|
||||||
@@ -185,7 +190,7 @@ jobs:
|
|||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ${{ github.workspace }}\.ccache
|
path: ${{ github.workspace }}\.ccache
|
||||||
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
|
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}-${{ needs.setup-environment.outputs.vendorsha }}
|
||||||
- name: Build target "${{ matrix.preset }}"
|
- name: Build target "${{ matrix.preset }}"
|
||||||
run: |
|
run: |
|
||||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||||
@@ -249,6 +254,9 @@ jobs:
|
|||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version-file: go.mod
|
||||||
|
cache-dependency-path: |
|
||||||
|
go.sum
|
||||||
|
Makefile.sync
|
||||||
- name: Verify gcc is actually clang
|
- name: Verify gcc is actually clang
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference='Continue'
|
$ErrorActionPreference='Continue'
|
||||||
@@ -302,6 +310,9 @@ jobs:
|
|||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version-file: go.mod
|
||||||
|
cache-dependency-path: |
|
||||||
|
go.sum
|
||||||
|
Makefile.sync
|
||||||
- uses: actions/download-artifact@v4
|
- uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
pattern: depends-windows*
|
pattern: depends-windows*
|
||||||
@@ -366,6 +377,7 @@ jobs:
|
|||||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
|
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||||
|
|||||||
16
.github/workflows/test.yaml
vendored
16
.github/workflows/test.yaml
vendored
@@ -22,6 +22,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
changed: ${{ steps.changes.outputs.changed }}
|
changed: ${{ steps.changes.outputs.changed }}
|
||||||
|
vendorsha: ${{ steps.changes.outputs.vendorsha }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
@@ -37,6 +38,7 @@ jobs:
|
|||||||
}
|
}
|
||||||
|
|
||||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
||||||
|
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
|
||||||
|
|
||||||
linux:
|
linux:
|
||||||
needs: [changes]
|
needs: [changes]
|
||||||
@@ -83,7 +85,7 @@ jobs:
|
|||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: /github/home/.cache/ccache
|
path: /github/home/.cache/ccache
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||||
- run: |
|
- run: |
|
||||||
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
||||||
cmake --build --preset ${{ matrix.preset }} --parallel
|
cmake --build --preset ${{ matrix.preset }} --parallel
|
||||||
@@ -178,7 +180,7 @@ jobs:
|
|||||||
- uses: actions/cache@v4
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ${{ github.workspace }}\.ccache
|
path: ${{ github.workspace }}\.ccache
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
|
||||||
- run: |
|
- run: |
|
||||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||||
@@ -206,6 +208,9 @@ jobs:
|
|||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: 'go.mod'
|
go-version-file: 'go.mod'
|
||||||
|
cache-dependency-path: |
|
||||||
|
go.sum
|
||||||
|
Makefile.sync
|
||||||
- uses: actions/setup-node@v4
|
- uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: '20'
|
node-version: '20'
|
||||||
@@ -226,12 +231,9 @@ jobs:
|
|||||||
if: always()
|
if: always()
|
||||||
run: go test -count=1 -benchtime=1x ./...
|
run: go test -count=1 -benchtime=1x ./...
|
||||||
|
|
||||||
# TODO(bmizerany): replace this heavy tool with just the
|
- uses: golangci/golangci-lint-action@v9
|
||||||
# tools/checks/binaries we want and then make them all run in parallel
|
|
||||||
# across jobs, not on a single tiny vm on Github Actions.
|
|
||||||
- uses: golangci/golangci-lint-action@v6
|
|
||||||
with:
|
with:
|
||||||
args: --timeout 10m0s -v
|
only-new-issues: true
|
||||||
|
|
||||||
patches:
|
patches:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
run:
|
version: "2"
|
||||||
timeout: 5m
|
|
||||||
linters:
|
linters:
|
||||||
enable:
|
enable:
|
||||||
- asasalint
|
- asasalint
|
||||||
@@ -7,35 +6,46 @@ linters:
|
|||||||
- bodyclose
|
- bodyclose
|
||||||
- containedctx
|
- containedctx
|
||||||
- gocheckcompilerdirectives
|
- gocheckcompilerdirectives
|
||||||
- gofmt
|
|
||||||
- gofumpt
|
|
||||||
- gosimple
|
|
||||||
- govet
|
|
||||||
- ineffassign
|
|
||||||
- intrange
|
- intrange
|
||||||
- makezero
|
- makezero
|
||||||
- misspell
|
- misspell
|
||||||
- nilerr
|
- nilerr
|
||||||
- nolintlint
|
- nolintlint
|
||||||
- nosprintfhostport
|
- nosprintfhostport
|
||||||
- staticcheck
|
|
||||||
- unconvert
|
- unconvert
|
||||||
- usetesting
|
- usetesting
|
||||||
- wastedassign
|
- wastedassign
|
||||||
- whitespace
|
- whitespace
|
||||||
disable:
|
disable:
|
||||||
- usestdlibvars
|
|
||||||
- errcheck
|
- errcheck
|
||||||
linters-settings:
|
- usestdlibvars
|
||||||
staticcheck:
|
settings:
|
||||||
checks:
|
govet:
|
||||||
- all
|
disable:
|
||||||
- -SA1019 # omit Deprecated check
|
- unusedresult
|
||||||
|
staticcheck:
|
||||||
|
checks:
|
||||||
|
- all
|
||||||
|
- -QF* # disable quick fix suggestions
|
||||||
|
- -SA1019
|
||||||
|
- -ST1000 # package comment format
|
||||||
|
- -ST1003 # underscores in package names
|
||||||
|
- -ST1005 # error strings should not be capitalized
|
||||||
|
- -ST1012 # error var naming (ErrFoo)
|
||||||
|
- -ST1016 # receiver name consistency
|
||||||
|
- -ST1020 # comment on exported function format
|
||||||
|
- -ST1021 # comment on exported type format
|
||||||
|
- -ST1022 # comment on exported var format
|
||||||
|
- -ST1023 # omit type from declaration
|
||||||
severity:
|
severity:
|
||||||
default-severity: error
|
default: error
|
||||||
rules:
|
rules:
|
||||||
- linters:
|
- linters:
|
||||||
- gofmt
|
- gofmt
|
||||||
- goimports
|
- goimports
|
||||||
- intrange
|
- intrange
|
||||||
severity: info
|
severity: info
|
||||||
|
formatters:
|
||||||
|
enable:
|
||||||
|
- gofmt
|
||||||
|
- gofumpt
|
||||||
|
|||||||
@@ -54,6 +54,13 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cp
|
|||||||
|
|
||||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||||
|
|
||||||
|
# Define GGML version variables for shared library SOVERSION
|
||||||
|
# These are required by ggml/src/CMakeLists.txt for proper library versioning
|
||||||
|
set(GGML_VERSION_MAJOR 0)
|
||||||
|
set(GGML_VERSION_MINOR 0)
|
||||||
|
set(GGML_VERSION_PATCH 0)
|
||||||
|
set(GGML_VERSION "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||||
|
|
||||||
set(GGML_CPU ON)
|
set(GGML_CPU ON)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||||
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ See the [development documentation](./docs/development.md) for instructions on h
|
|||||||
|
|
||||||
* New features: new features (e.g. API fields, environment variables) add surface area to Ollama and make it harder to maintain in the long run as they cannot be removed without potentially breaking users in the future.
|
* New features: new features (e.g. API fields, environment variables) add surface area to Ollama and make it harder to maintain in the long run as they cannot be removed without potentially breaking users in the future.
|
||||||
* Refactoring: large code improvements are important, but can be harder or take longer to review and merge.
|
* Refactoring: large code improvements are important, but can be harder or take longer to review and merge.
|
||||||
* Documentation: small updates to fill in or correct missing documentation is helpful, however large documentation additions can be hard to maintain over time.
|
* Documentation: small updates to fill in or correct missing documentation are helpful, however large documentation additions can be hard to maintain over time.
|
||||||
|
|
||||||
### Issues that may not be accepted
|
### Issues that may not be accepted
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ Tips for proposals:
|
|||||||
* Explain how the change will be tested.
|
* Explain how the change will be tested.
|
||||||
|
|
||||||
Additionally, for bonus points: Provide draft documentation you would expect to
|
Additionally, for bonus points: Provide draft documentation you would expect to
|
||||||
see if the change were accepted.
|
see if the changes were accepted.
|
||||||
|
|
||||||
## Pull requests
|
## Pull requests
|
||||||
|
|
||||||
@@ -66,7 +66,6 @@ Examples:
|
|||||||
|
|
||||||
llm/backend/mlx: support the llama architecture
|
llm/backend/mlx: support the llama architecture
|
||||||
CONTRIBUTING: provide clarity on good commit messages, and bad
|
CONTRIBUTING: provide clarity on good commit messages, and bad
|
||||||
docs: simplify manual installation with shorter curl commands
|
|
||||||
|
|
||||||
Bad Examples:
|
Bad Examples:
|
||||||
|
|
||||||
|
|||||||
14
Dockerfile
14
Dockerfile
@@ -39,14 +39,14 @@ ENV CC=clang CXX=clang++
|
|||||||
FROM base-${TARGETARCH} AS base
|
FROM base-${TARGETARCH} AS base
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
|
||||||
ENV LDFLAGS=-s
|
ENV LDFLAGS=-s
|
||||||
|
|
||||||
FROM base AS cpu
|
FROM base AS cpu
|
||||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
ARG PARALLEL
|
ARG PARALLEL
|
||||||
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CPU' \
|
cmake --preset 'CPU' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
||||||
@@ -57,6 +57,8 @@ ARG CUDA11VERSION=11.8
|
|||||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||||
ARG PARALLEL
|
ARG PARALLEL
|
||||||
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 11' \
|
cmake --preset 'CUDA 11' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
||||||
@@ -67,6 +69,8 @@ ARG CUDA12VERSION=12.8
|
|||||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||||
ARG PARALLEL
|
ARG PARALLEL
|
||||||
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 12' \
|
cmake --preset 'CUDA 12' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
||||||
@@ -78,6 +82,8 @@ ARG CUDA13VERSION=13.0
|
|||||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||||
ARG PARALLEL
|
ARG PARALLEL
|
||||||
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 13' \
|
cmake --preset 'CUDA 13' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
||||||
@@ -87,6 +93,8 @@ RUN --mount=type=cache,target=/root/.ccache \
|
|||||||
FROM base AS rocm-6
|
FROM base AS rocm-6
|
||||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||||
ARG PARALLEL
|
ARG PARALLEL
|
||||||
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'ROCm 6' \
|
cmake --preset 'ROCm 6' \
|
||||||
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
||||||
@@ -118,6 +126,8 @@ RUN --mount=type=cache,target=/root/.ccache \
|
|||||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
FROM base AS vulkan
|
FROM base AS vulkan
|
||||||
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'Vulkan' \
|
cmake --preset 'Vulkan' \
|
||||||
&& cmake --build --parallel --preset 'Vulkan' \
|
&& cmake --build --parallel --preset 'Vulkan' \
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
|
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
|
||||||
WORKDIR=llama/vendor
|
WORKDIR=llama/vendor
|
||||||
FETCH_HEAD=3cfa9c3f125763305b4226bc032f1954f08990dc
|
FETCH_HEAD=ec98e2002
|
||||||
|
|
||||||
.PHONY: help
|
.PHONY: help
|
||||||
help:
|
help:
|
||||||
@@ -57,7 +57,7 @@ checkout: $(WORKDIR)
|
|||||||
$(WORKDIR):
|
$(WORKDIR):
|
||||||
git clone $(UPSTREAM) $(WORKDIR)
|
git clone $(UPSTREAM) $(WORKDIR)
|
||||||
|
|
||||||
.PHONE: format-patches
|
.PHONY: format-patches
|
||||||
format-patches: llama/patches
|
format-patches: llama/patches
|
||||||
git -C $(WORKDIR) format-patch \
|
git -C $(WORKDIR) format-patch \
|
||||||
--no-signature \
|
--no-signature \
|
||||||
@@ -66,7 +66,11 @@ format-patches: llama/patches
|
|||||||
-o $(realpath $<) \
|
-o $(realpath $<) \
|
||||||
$(FETCH_HEAD)
|
$(FETCH_HEAD)
|
||||||
|
|
||||||
.PHONE: clean
|
.PHONY: clean
|
||||||
clean: checkout
|
clean: checkout
|
||||||
@git -C $(WORKDIR) am --abort || true
|
@git -C $(WORKDIR) am --abort || true
|
||||||
$(RM) llama/patches/.*.patched
|
$(RM) llama/patches/.*.patched
|
||||||
|
|
||||||
|
.PHONY: print-base
|
||||||
|
print-base:
|
||||||
|
@echo $(FETCH_HEAD)
|
||||||
10
README.md
10
README.md
@@ -389,6 +389,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
|
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
|
||||||
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
||||||
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VS Code extension for multi-file/whole-repo coding
|
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VS Code extension for multi-file/whole-repo coding
|
||||||
|
- [Void](https://github.com/voideditor/void) (Open source AI code editor and Cursor alternative)
|
||||||
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
||||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
||||||
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
||||||
@@ -449,6 +450,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||||
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
||||||
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
||||||
|
- [KDeps](https://github.com/kdeps/kdeps) (Kdeps is an offline-first AI framework for building Dockerized full-stack AI applications declaratively using Apple PKL and integrates APIs with Ollama on the backend.)
|
||||||
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
|
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
|
||||||
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
|
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
|
||||||
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
|
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
|
||||||
@@ -575,7 +577,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
|
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
|
||||||
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
|
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
|
||||||
- [Ollama for Swift](https://github.com/mattt/ollama-swift)
|
- [Ollama for Swift](https://github.com/mattt/ollama-swift)
|
||||||
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
|
- [Swollama for Swift](https://github.com/guitaripod/Swollama) with [DocC](https://guitaripod.github.io/Swollama/documentation/swollama)
|
||||||
- [GoLamify](https://github.com/prasad89/golamify)
|
- [GoLamify](https://github.com/prasad89/golamify)
|
||||||
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
||||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
|
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
|
||||||
@@ -638,7 +640,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
|
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
|
||||||
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
|
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
|
||||||
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
|
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
|
||||||
- [AI Summmary Helper plugin](https://github.com/philffm/ai-summary-helper)
|
- [AI Summary Helper plugin](https://github.com/philffm/ai-summary-helper)
|
||||||
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
|
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
|
||||||
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
|
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
|
||||||
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
|
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
|
||||||
@@ -646,7 +648,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
||||||
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
||||||
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
|
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
|
||||||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Editor tool to analyze scripts via Ollama)
|
||||||
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
|
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
|
||||||
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
|
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
|
||||||
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
|
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
|
||||||
@@ -656,7 +658,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||||
|
|
||||||
### Observability
|
### Observability
|
||||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
|
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
|
||||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||||
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ Please include the following details in your report:
|
|||||||
|
|
||||||
## Security best practices
|
## Security best practices
|
||||||
|
|
||||||
While the maintainer team does their best to secure Ollama, users are encouraged to implement their own security best practices, such as:
|
While the maintainer team does its best to secure Ollama, users are encouraged to implement their own security best practices, such as:
|
||||||
|
|
||||||
- Regularly updating to the latest version of Ollama
|
- Regularly updating to the latest version of Ollama
|
||||||
- Securing access to hosted instances of Ollama
|
- Securing access to hosted instances of Ollama
|
||||||
|
|||||||
@@ -226,7 +226,14 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
|||||||
|
|
||||||
bts := scanner.Bytes()
|
bts := scanner.Bytes()
|
||||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
||||||
return fmt.Errorf("unmarshal: %w", err)
|
if response.StatusCode >= http.StatusBadRequest {
|
||||||
|
return StatusError{
|
||||||
|
StatusCode: response.StatusCode,
|
||||||
|
Status: response.Status,
|
||||||
|
ErrorMessage: string(bts),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors.New(string(bts))
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.StatusCode == http.StatusUnauthorized {
|
if response.StatusCode == http.StatusUnauthorized {
|
||||||
@@ -340,7 +347,7 @@ type CreateProgressFunc func(ProgressResponse) error
|
|||||||
// Create creates a model from a [Modelfile]. fn is a progress function that
|
// Create creates a model from a [Modelfile]. fn is a progress function that
|
||||||
// behaves similarly to other methods (see [Client.Pull]).
|
// behaves similarly to other methods (see [Client.Pull]).
|
||||||
//
|
//
|
||||||
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
|
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.mdx
|
||||||
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
|
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
|
||||||
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
|
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
|
||||||
var resp ProgressResponse
|
var resp ProgressResponse
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ func TestClientFromEnvironment(t *testing.T) {
|
|||||||
type testError struct {
|
type testError struct {
|
||||||
message string
|
message string
|
||||||
statusCode int
|
statusCode int
|
||||||
|
raw bool // if true, write message as-is instead of JSON encoding
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e testError) Error() string {
|
func (e testError) Error() string {
|
||||||
@@ -111,6 +112,20 @@ func TestClientStream(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "plain text error response",
|
||||||
|
responses: []any{
|
||||||
|
"internal server error",
|
||||||
|
},
|
||||||
|
wantErr: "internal server error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTML error page",
|
||||||
|
responses: []any{
|
||||||
|
"<html><body>404 Not Found</body></html>",
|
||||||
|
},
|
||||||
|
wantErr: "404 Not Found",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
@@ -135,6 +150,12 @@ func TestClientStream(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if str, ok := resp.(string); ok {
|
||||||
|
fmt.Fprintln(w, str)
|
||||||
|
flusher.Flush()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
t.Fatalf("failed to encode response: %v", err)
|
t.Fatalf("failed to encode response: %v", err)
|
||||||
}
|
}
|
||||||
@@ -173,9 +194,10 @@ func TestClientStream(t *testing.T) {
|
|||||||
|
|
||||||
func TestClientDo(t *testing.T) {
|
func TestClientDo(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
response any
|
response any
|
||||||
wantErr string
|
wantErr string
|
||||||
|
wantStatusCode int
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "immediate error response",
|
name: "immediate error response",
|
||||||
@@ -183,7 +205,8 @@ func TestClientDo(t *testing.T) {
|
|||||||
message: "test error message",
|
message: "test error message",
|
||||||
statusCode: http.StatusBadRequest,
|
statusCode: http.StatusBadRequest,
|
||||||
},
|
},
|
||||||
wantErr: "test error message",
|
wantErr: "test error message",
|
||||||
|
wantStatusCode: http.StatusBadRequest,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "server error response",
|
name: "server error response",
|
||||||
@@ -191,7 +214,8 @@ func TestClientDo(t *testing.T) {
|
|||||||
message: "internal error",
|
message: "internal error",
|
||||||
statusCode: http.StatusInternalServerError,
|
statusCode: http.StatusInternalServerError,
|
||||||
},
|
},
|
||||||
wantErr: "internal error",
|
wantErr: "internal error",
|
||||||
|
wantStatusCode: http.StatusInternalServerError,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "successful response",
|
name: "successful response",
|
||||||
@@ -203,6 +227,26 @@ func TestClientDo(t *testing.T) {
|
|||||||
Success: true,
|
Success: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "plain text error response",
|
||||||
|
response: testError{
|
||||||
|
message: "internal server error",
|
||||||
|
statusCode: http.StatusInternalServerError,
|
||||||
|
raw: true,
|
||||||
|
},
|
||||||
|
wantErr: "internal server error",
|
||||||
|
wantStatusCode: http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTML error page",
|
||||||
|
response: testError{
|
||||||
|
message: "<html><body>404 Not Found</body></html>",
|
||||||
|
statusCode: http.StatusNotFound,
|
||||||
|
raw: true,
|
||||||
|
},
|
||||||
|
wantErr: "<html><body>404 Not Found</body></html>",
|
||||||
|
wantStatusCode: http.StatusNotFound,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
@@ -210,11 +254,16 @@ func TestClientDo(t *testing.T) {
|
|||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if errResp, ok := tc.response.(testError); ok {
|
if errResp, ok := tc.response.(testError); ok {
|
||||||
w.WriteHeader(errResp.statusCode)
|
w.WriteHeader(errResp.statusCode)
|
||||||
err := json.NewEncoder(w).Encode(map[string]string{
|
if !errResp.raw {
|
||||||
"error": errResp.message,
|
err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
})
|
"error": errResp.message,
|
||||||
if err != nil {
|
})
|
||||||
t.Fatal("failed to encode error response:", err)
|
if err != nil {
|
||||||
|
t.Fatal("failed to encode error response:", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Write raw message (simulates non-JSON error responses)
|
||||||
|
fmt.Fprint(w, errResp.message)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -241,6 +290,15 @@ func TestClientDo(t *testing.T) {
|
|||||||
if err.Error() != tc.wantErr {
|
if err.Error() != tc.wantErr {
|
||||||
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
|
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
|
||||||
}
|
}
|
||||||
|
if tc.wantStatusCode != 0 {
|
||||||
|
if statusErr, ok := err.(StatusError); ok {
|
||||||
|
if statusErr.StatusCode != tc.wantStatusCode {
|
||||||
|
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Errorf("expected StatusError, got %T", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,19 +15,19 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
messages := []api.Message{
|
messages := []api.Message{
|
||||||
api.Message{
|
{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
Content: "Provide very brief, concise responses",
|
Content: "Provide very brief, concise responses",
|
||||||
},
|
},
|
||||||
api.Message{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: "Name some unusual animals",
|
Content: "Name some unusual animals",
|
||||||
},
|
},
|
||||||
api.Message{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: "Monotreme, platypus, echidna",
|
Content: "Monotreme, platypus, echidna",
|
||||||
},
|
},
|
||||||
api.Message{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: "which of these is the most dangerous?",
|
Content: "which of these is the most dangerous?",
|
||||||
},
|
},
|
||||||
|
|||||||
15
api/types.go
15
api/types.go
@@ -283,11 +283,12 @@ func (pt PropertyType) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ToolProperty struct {
|
type ToolProperty struct {
|
||||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||||
Type PropertyType `json:"type,omitempty"`
|
Type PropertyType `json:"type,omitempty"`
|
||||||
Items any `json:"items,omitempty"`
|
Items any `json:"items,omitempty"`
|
||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
Enum []any `json:"enum,omitempty"`
|
Enum []any `json:"enum,omitempty"`
|
||||||
|
Properties map[string]ToolProperty `json:"properties,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||||
@@ -553,6 +554,9 @@ type CreateRequest struct {
|
|||||||
Renderer string `json:"renderer,omitempty"`
|
Renderer string `json:"renderer,omitempty"`
|
||||||
Parser string `json:"parser,omitempty"`
|
Parser string `json:"parser,omitempty"`
|
||||||
|
|
||||||
|
// Requires is the minimum version of Ollama required by the model.
|
||||||
|
Requires string `json:"requires,omitempty"`
|
||||||
|
|
||||||
// Info is a map of additional information for the model
|
// Info is a map of additional information for the model
|
||||||
Info map[string]any `json:"info,omitempty"`
|
Info map[string]any `json:"info,omitempty"`
|
||||||
|
|
||||||
@@ -603,6 +607,7 @@ type ShowResponse struct {
|
|||||||
Tensors []Tensor `json:"tensors,omitempty"`
|
Tensors []Tensor `json:"tensors,omitempty"`
|
||||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||||
|
Requires string `json:"requires,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CopyRequest is the request passed to [Client.Copy].
|
// CopyRequest is the request passed to [Client.Copy].
|
||||||
|
|||||||
@@ -504,6 +504,107 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToolPropertyNestedProperties(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected ToolProperty
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nested object properties",
|
||||||
|
input: `{
|
||||||
|
"type": "object",
|
||||||
|
"description": "Location details",
|
||||||
|
"properties": {
|
||||||
|
"address": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Street address"
|
||||||
|
},
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City name"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
expected: ToolProperty{
|
||||||
|
Type: PropertyType{"object"},
|
||||||
|
Description: "Location details",
|
||||||
|
Properties: map[string]ToolProperty{
|
||||||
|
"address": {
|
||||||
|
Type: PropertyType{"string"},
|
||||||
|
Description: "Street address",
|
||||||
|
},
|
||||||
|
"city": {
|
||||||
|
Type: PropertyType{"string"},
|
||||||
|
Description: "City name",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deeply nested properties",
|
||||||
|
input: `{
|
||||||
|
"type": "object",
|
||||||
|
"description": "Event",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Location",
|
||||||
|
"properties": {
|
||||||
|
"coordinates": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "GPS coordinates",
|
||||||
|
"properties": {
|
||||||
|
"lat": {"type": "number", "description": "Latitude"},
|
||||||
|
"lng": {"type": "number", "description": "Longitude"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
expected: ToolProperty{
|
||||||
|
Type: PropertyType{"object"},
|
||||||
|
Description: "Event",
|
||||||
|
Properties: map[string]ToolProperty{
|
||||||
|
"location": {
|
||||||
|
Type: PropertyType{"object"},
|
||||||
|
Description: "Location",
|
||||||
|
Properties: map[string]ToolProperty{
|
||||||
|
"coordinates": {
|
||||||
|
Type: PropertyType{"object"},
|
||||||
|
Description: "GPS coordinates",
|
||||||
|
Properties: map[string]ToolProperty{
|
||||||
|
"lat": {Type: PropertyType{"number"}, Description: "Latitude"},
|
||||||
|
"lng": {Type: PropertyType{"number"}, Description: "Longitude"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var prop ToolProperty
|
||||||
|
err := json.Unmarshal([]byte(tt.input), &prop)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expected, prop)
|
||||||
|
|
||||||
|
// Round-trip test: marshal and unmarshal again
|
||||||
|
data, err := json.Marshal(prop)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var prop2 ToolProperty
|
||||||
|
err = json.Unmarshal(data, &prop2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expected, prop2)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestToolFunctionParameters_String(t *testing.T) {
|
func TestToolFunctionParameters_String(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -273,10 +273,6 @@ func main() {
|
|||||||
Handler: uiServer.Handler(),
|
Handler: uiServer.Handler(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := uiServer.UserData(ctx); err != nil {
|
|
||||||
slog.Warn("failed to load user data", "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the UI server
|
// Start the UI server
|
||||||
slog.Info("starting ui server", "port", port)
|
slog.Info("starting ui server", "port", port)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -320,6 +316,17 @@ func main() {
|
|||||||
slog.Debug("no URL scheme request to handle")
|
slog.Debug("no URL scheme request to handle")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
slog.Debug("waiting for ollama server to be ready")
|
||||||
|
if err := ui.WaitForServer(ctx, 10*time.Second); err != nil {
|
||||||
|
slog.Warn("ollama server not ready, continuing anyway", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := uiServer.UserData(ctx); err != nil {
|
||||||
|
slog.Warn("failed to load user data", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
osRun(cancel, hasCompletedFirstRun, startHidden)
|
osRun(cancel, hasCompletedFirstRun, startHidden)
|
||||||
|
|
||||||
slog.Info("shutting down desktop server")
|
slog.Info("shutting down desktop server")
|
||||||
@@ -361,7 +368,7 @@ func checkUserLoggedIn(uiServerPort int) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/v1/me", uiServerPort))
|
resp, err := http.Post(fmt.Sprintf("http://127.0.0.1:%d/api/me", uiServerPort), "application/json", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug("failed to call local auth endpoint", "error", err)
|
slog.Debug("failed to call local auth endpoint", "error", err)
|
||||||
return false
|
return false
|
||||||
@@ -397,8 +404,8 @@ func checkUserLoggedIn(uiServerPort int) bool {
|
|||||||
// handleConnectURLScheme fetches the connect URL and opens it in the browser
|
// handleConnectURLScheme fetches the connect URL and opens it in the browser
|
||||||
func handleConnectURLScheme() {
|
func handleConnectURLScheme() {
|
||||||
if checkUserLoggedIn(uiServerPort) {
|
if checkUserLoggedIn(uiServerPort) {
|
||||||
slog.Info("user is already logged in, opening settings instead")
|
slog.Info("user is already logged in, opening app instead")
|
||||||
sendUIRequestMessage("/")
|
showWindow(wv.webview.Window())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -434,37 +441,30 @@ func openInBrowser(url string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseURLScheme parses an ollama:// URL and returns whether it's a connect URL and the UI path
|
// parseURLScheme parses an ollama:// URL and validates it
|
||||||
func parseURLScheme(urlSchemeRequest string) (isConnect bool, uiPath string, err error) {
|
// Supports: ollama:// (open app) and ollama://connect (OAuth)
|
||||||
|
func parseURLScheme(urlSchemeRequest string) (isConnect bool, err error) {
|
||||||
parsedURL, err := url.Parse(urlSchemeRequest)
|
parsedURL, err := url.Parse(urlSchemeRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", err
|
return false, fmt.Errorf("invalid URL: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this is a connect URL
|
// Check if this is a connect URL
|
||||||
if parsedURL.Host == "connect" || strings.TrimPrefix(parsedURL.Path, "/") == "connect" {
|
if parsedURL.Host == "connect" || strings.TrimPrefix(parsedURL.Path, "/") == "connect" {
|
||||||
return true, "", nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract the UI path
|
// Allow bare ollama:// or ollama:/// to open the app
|
||||||
path := "/"
|
if (parsedURL.Host == "" && parsedURL.Path == "") || parsedURL.Path == "/" {
|
||||||
if parsedURL.Path != "" && parsedURL.Path != "/" {
|
return false, nil
|
||||||
// For URLs like ollama:///settings, use the path directly
|
|
||||||
path = parsedURL.Path
|
|
||||||
} else if parsedURL.Host != "" {
|
|
||||||
// For URLs like ollama://settings (without triple slash),
|
|
||||||
// the "settings" part is parsed as the host, not the path.
|
|
||||||
// We need to convert it to a path by prepending "/"
|
|
||||||
// This also handles ollama://settings/ where Windows adds a trailing slash
|
|
||||||
path = "/" + parsedURL.Host
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, path, nil
|
return false, fmt.Errorf("unsupported ollama:// URL path: %s", urlSchemeRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleURLSchemeInCurrentInstance processes URL scheme requests in the current instance
|
// handleURLSchemeInCurrentInstance processes URL scheme requests in the current instance
|
||||||
func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
|
func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
|
||||||
isConnect, uiPath, err := parseURLScheme(urlSchemeRequest)
|
isConnect, err := parseURLScheme(urlSchemeRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to parse URL scheme request", "url", urlSchemeRequest, "error", err)
|
slog.Error("failed to parse URL scheme request", "url", urlSchemeRequest, "error", err)
|
||||||
return
|
return
|
||||||
@@ -473,6 +473,8 @@ func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
|
|||||||
if isConnect {
|
if isConnect {
|
||||||
handleConnectURLScheme()
|
handleConnectURLScheme()
|
||||||
} else {
|
} else {
|
||||||
sendUIRequestMessage(uiPath)
|
if wv.webview != nil {
|
||||||
|
showWindow(wv.webview.Window())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -191,13 +191,6 @@ func LaunchNewApp() {
|
|||||||
C.launchApp(appName)
|
C.launchApp(appName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a request to the main app thread to load a UI page
|
|
||||||
func sendUIRequestMessage(path string) {
|
|
||||||
p := C.CString(path)
|
|
||||||
defer C.free(unsafe.Pointer(p))
|
|
||||||
C.uiRequest(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func registerLaunchAgent(hasCompletedFirstRun bool) {
|
func registerLaunchAgent(hasCompletedFirstRun bool) {
|
||||||
// Remove any stale Login Item registrations
|
// Remove any stale Login Item registrations
|
||||||
C.unregisterSelfFromLoginItem()
|
C.unregisterSelfFromLoginItem()
|
||||||
|
|||||||
@@ -24,27 +24,14 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
|||||||
for (NSURL *url in urls) {
|
for (NSURL *url in urls) {
|
||||||
if ([url.scheme isEqualToString:@"ollama"]) {
|
if ([url.scheme isEqualToString:@"ollama"]) {
|
||||||
NSString *path = url.path;
|
NSString *path = url.path;
|
||||||
if (!path || [path isEqualToString:@""]) {
|
|
||||||
// For URLs like ollama://settings (without triple slash),
|
|
||||||
// the "settings" part is parsed as the host, not the path.
|
|
||||||
// We need to convert it to a path by prepending "/"
|
|
||||||
if (url.host && ![url.host isEqualToString:@""]) {
|
|
||||||
path = [@"/" stringByAppendingString:url.host];
|
|
||||||
} else {
|
|
||||||
path = @"/";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"]) {
|
if (path && ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"])) {
|
||||||
// Special case: handle connect by opening browser instead of app
|
// Special case: handle connect by opening browser instead of app
|
||||||
handleConnectURL();
|
handleConnectURL();
|
||||||
} else {
|
} else {
|
||||||
// Set app to be active and visible
|
// Set app to be active and visible
|
||||||
[NSApp setActivationPolicy:NSApplicationActivationPolicyRegular];
|
[NSApp setActivationPolicy:NSApplicationActivationPolicyRegular];
|
||||||
[NSApp activateIgnoringOtherApps:YES];
|
[NSApp activateIgnoringOtherApps:YES];
|
||||||
|
|
||||||
// Open the path with the UI
|
|
||||||
[self uiRequest:path];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
@@ -260,7 +247,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
|||||||
}
|
}
|
||||||
|
|
||||||
- (void)openHelp:(id)sender {
|
- (void)openHelp:(id)sender {
|
||||||
NSURL *url = [NSURL URLWithString:@"https://github.com/ollama/ollama/tree/main/docs"];
|
NSURL *url = [NSURL URLWithString:@"https://docs.ollama.com/"];
|
||||||
[[NSWorkspace sharedWorkspace] openURL:url];
|
[[NSWorkspace sharedWorkspace] openURL:url];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ func (app *appCallbacks) HandleURLScheme(urlScheme string) {
|
|||||||
|
|
||||||
// handleURLSchemeRequest processes URL scheme requests from other instances
|
// handleURLSchemeRequest processes URL scheme requests from other instances
|
||||||
func handleURLSchemeRequest(urlScheme string) {
|
func handleURLSchemeRequest(urlScheme string) {
|
||||||
isConnect, uiPath, err := parseURLScheme(urlScheme)
|
isConnect, err := parseURLScheme(urlScheme)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to parse URL scheme request", "url", urlScheme, "error", err)
|
slog.Error("failed to parse URL scheme request", "url", urlScheme, "error", err)
|
||||||
return
|
return
|
||||||
@@ -147,7 +147,9 @@ func handleURLSchemeRequest(urlScheme string) {
|
|||||||
if isConnect {
|
if isConnect {
|
||||||
handleConnectURLScheme()
|
handleConnectURLScheme()
|
||||||
} else {
|
} else {
|
||||||
sendUIRequestMessage(uiPath)
|
if wv.webview != nil {
|
||||||
|
showWindow(wv.webview.Window())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,11 +263,6 @@ func createLoginShortcut() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a request to the main app thread to load a UI page
|
|
||||||
func sendUIRequestMessage(path string) {
|
|
||||||
wintray.SendUIRequestMessage(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LaunchNewApp() {
|
func LaunchNewApp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -169,37 +169,47 @@ DlgResult fileDlg(FileDlgParams* params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
NSArray* urls = [panel URLs];
|
NSArray* urls = [panel URLs];
|
||||||
if(self->params->allowMultiple && [urls count] >= 1) {
|
if([urls count] == 0) {
|
||||||
|
return DLG_CANCEL;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(self->params->allowMultiple) {
|
||||||
// For multiple files, we need to return all paths separated by null bytes
|
// For multiple files, we need to return all paths separated by null bytes
|
||||||
char* bufPtr = self->params->buf;
|
char* bufPtr = self->params->buf;
|
||||||
int remainingBuf = self->params->nbuf;
|
int remainingBuf = self->params->nbuf;
|
||||||
|
|
||||||
// Calculate total required buffer size first
|
// Calculate total required buffer size first
|
||||||
int totalSize = 0;
|
int totalSize = 0;
|
||||||
for(NSURL* url in urls) {
|
for(NSURL* url in urls) {
|
||||||
char tempBuf[PATH_MAX];
|
char tempBuf[PATH_MAX];
|
||||||
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
|
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
|
||||||
return DLG_URLFAIL;
|
return DLG_URLFAIL;
|
||||||
}
|
}
|
||||||
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
|
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
|
||||||
}
|
}
|
||||||
totalSize += 1; // Final null terminator
|
totalSize += 1; // Final null terminator
|
||||||
|
|
||||||
if(totalSize > self->params->nbuf) {
|
if(totalSize > self->params->nbuf) {
|
||||||
// Not enough buffer space
|
// Not enough buffer space
|
||||||
return DLG_URLFAIL;
|
return DLG_URLFAIL;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now actually copy the paths (we know we have space)
|
// Now actually copy the paths (we know we have space)
|
||||||
bufPtr = self->params->buf;
|
bufPtr = self->params->buf;
|
||||||
for(NSURL* url in urls) {
|
for(NSURL* url in urls) {
|
||||||
char tempBuf[PATH_MAX];
|
char tempBuf[PATH_MAX];
|
||||||
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
|
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
|
||||||
int pathLen = strlen(tempBuf);
|
int pathLen = strlen(tempBuf);
|
||||||
strcpy(bufPtr, tempBuf);
|
strcpy(bufPtr, tempBuf);
|
||||||
bufPtr += pathLen + 1;
|
bufPtr += pathLen + 1;
|
||||||
}
|
}
|
||||||
*bufPtr = '\0'; // Final null terminator
|
*bufPtr = '\0'; // Final null terminator
|
||||||
|
} else {
|
||||||
|
// Single file/directory selection - write path to buffer
|
||||||
|
NSURL* url = [urls firstObject];
|
||||||
|
if(![url getFileSystemRepresentation:self->params->buf maxLength:self->params->nbuf]) {
|
||||||
|
return DLG_URLFAIL;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return DLG_OK;
|
return DLG_OK;
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ const multiFileBufferSize = w32.MAX_PATH * 10
|
|||||||
type WinDlgError int
|
type WinDlgError int
|
||||||
|
|
||||||
func (e WinDlgError) Error() string {
|
func (e WinDlgError) Error() string {
|
||||||
return fmt.Sprintf("CommDlgExtendedError: %#x", e)
|
return fmt.Sprintf("CommDlgExtendedError: %#x", int(e))
|
||||||
}
|
}
|
||||||
|
|
||||||
func err() error {
|
func err() error {
|
||||||
|
|||||||
@@ -224,9 +224,7 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
|
|||||||
if _, err := os.Stat(settings.Models); err == nil {
|
if _, err := os.Stat(settings.Models); err == nil {
|
||||||
env["OLLAMA_MODELS"] = settings.Models
|
env["OLLAMA_MODELS"] = settings.Models
|
||||||
} else {
|
} else {
|
||||||
slog.Warn("models path not accessible, clearing models setting", "path", settings.Models, "err", err)
|
slog.Warn("models path not accessible, using default", "path", settings.Models, "err", err)
|
||||||
settings.Models = ""
|
|
||||||
s.store.SetSettings(settings)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if settings.ContextLength > 0 {
|
if settings.ContextLength > 0 {
|
||||||
|
|||||||
@@ -469,26 +469,24 @@ export class HealthResponse {
|
|||||||
}
|
}
|
||||||
export class User {
|
export class User {
|
||||||
id: string;
|
id: string;
|
||||||
name: string;
|
|
||||||
email: string;
|
email: string;
|
||||||
avatarURL: string;
|
name: string;
|
||||||
plan: string;
|
bio?: string;
|
||||||
bio: string;
|
avatarurl?: string;
|
||||||
firstName: string;
|
firstname?: string;
|
||||||
lastName: string;
|
lastname?: string;
|
||||||
overThreshold: boolean;
|
plan?: string;
|
||||||
|
|
||||||
constructor(source: any = {}) {
|
constructor(source: any = {}) {
|
||||||
if ('string' === typeof source) source = JSON.parse(source);
|
if ('string' === typeof source) source = JSON.parse(source);
|
||||||
this.id = source["id"];
|
this.id = source["id"];
|
||||||
this.name = source["name"];
|
|
||||||
this.email = source["email"];
|
this.email = source["email"];
|
||||||
this.avatarURL = source["avatarURL"];
|
this.name = source["name"];
|
||||||
this.plan = source["plan"];
|
|
||||||
this.bio = source["bio"];
|
this.bio = source["bio"];
|
||||||
this.firstName = source["firstName"];
|
this.avatarurl = source["avatarurl"];
|
||||||
this.lastName = source["lastName"];
|
this.firstname = source["firstname"];
|
||||||
this.overThreshold = source["overThreshold"];
|
this.lastname = source["lastname"];
|
||||||
|
this.plan = source["plan"];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
export class Attachment {
|
export class Attachment {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import {
|
|||||||
import { parseJsonlFromResponse } from "./util/jsonl-parsing";
|
import { parseJsonlFromResponse } from "./util/jsonl-parsing";
|
||||||
import { ollamaClient as ollama } from "./lib/ollama-client";
|
import { ollamaClient as ollama } from "./lib/ollama-client";
|
||||||
import type { ModelResponse } from "ollama/browser";
|
import type { ModelResponse } from "ollama/browser";
|
||||||
|
import { API_BASE, OLLAMA_DOT_COM } from "./lib/config";
|
||||||
|
|
||||||
// Extend Model class with utility methods
|
// Extend Model class with utility methods
|
||||||
declare module "@/gotypes" {
|
declare module "@/gotypes" {
|
||||||
@@ -26,9 +27,6 @@ declare module "@/gotypes" {
|
|||||||
Model.prototype.isCloud = function (): boolean {
|
Model.prototype.isCloud = function (): boolean {
|
||||||
return this.model.endsWith("cloud");
|
return this.model.endsWith("cloud");
|
||||||
};
|
};
|
||||||
|
|
||||||
const API_BASE = import.meta.env.DEV ? "http://127.0.0.1:3001" : "";
|
|
||||||
|
|
||||||
// Helper function to convert Uint8Array to base64
|
// Helper function to convert Uint8Array to base64
|
||||||
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
|
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
|
||||||
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
|
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
|
||||||
@@ -43,44 +41,50 @@ function uint8ArrayToBase64(uint8Array: Uint8Array): string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export async function fetchUser(): Promise<User | null> {
|
export async function fetchUser(): Promise<User | null> {
|
||||||
try {
|
const response = await fetch(`${API_BASE}/api/me`, {
|
||||||
const response = await fetch(`${API_BASE}/api/v1/me`, {
|
method: "POST",
|
||||||
method: "GET",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (response.ok) {
|
|
||||||
const userData: User = await response.json();
|
|
||||||
return userData;
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Error fetching user:", error);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function fetchConnectUrl(): Promise<string> {
|
|
||||||
const response = await fetch(`${API_BASE}/api/v1/connect`, {
|
|
||||||
method: "GET",
|
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!response.ok) {
|
if (response.ok) {
|
||||||
throw new Error("Failed to fetch connect URL");
|
const userData: User = await response.json();
|
||||||
|
|
||||||
|
if (userData.avatarurl && !userData.avatarurl.startsWith("http")) {
|
||||||
|
userData.avatarurl = `${OLLAMA_DOT_COM}${userData.avatarurl}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
return userData;
|
||||||
}
|
}
|
||||||
|
|
||||||
const data = await response.json();
|
if (response.status === 401 || response.status === 403) {
|
||||||
return data.connect_url;
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error(`Failed to fetch user: ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function fetchConnectUrl(): Promise<string> {
|
||||||
|
const response = await fetch(`${API_BASE}/api/me`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (response.status === 401) {
|
||||||
|
const data = await response.json();
|
||||||
|
if (data.signin_url) {
|
||||||
|
return data.signin_url;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error("Failed to fetch connect URL");
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function disconnectUser(): Promise<void> {
|
export async function disconnectUser(): Promise<void> {
|
||||||
const response = await fetch(`${API_BASE}/api/v1/disconnect`, {
|
const response = await fetch(`${API_BASE}/api/signout`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@@ -205,12 +209,10 @@ export async function* sendMessage(
|
|||||||
data: uint8ArrayToBase64(att.data),
|
data: uint8ArrayToBase64(att.data),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Only send think parameter when actually requesting thinking
|
// Send think parameter when it's explicitly set (true, false, or a non-empty string).
|
||||||
// Don't send false as it causes issues with some providers
|
|
||||||
const shouldSendThink =
|
const shouldSendThink =
|
||||||
think !== undefined &&
|
think !== undefined &&
|
||||||
((typeof think === "boolean" && think) ||
|
(typeof think === "boolean" || (typeof think === "string" && think !== ""));
|
||||||
(typeof think === "string" && think !== ""));
|
|
||||||
|
|
||||||
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
|
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
@@ -392,7 +394,8 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
|
|||||||
|
|
||||||
export async function fetchHealth(): Promise<boolean> {
|
export async function fetchHealth(): Promise<boolean> {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`${API_BASE}/api/v1/health`, {
|
// Use the /api/version endpoint as a health check
|
||||||
|
const response = await fetch(`${API_BASE}/api/version`, {
|
||||||
method: "GET",
|
method: "GET",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@@ -401,7 +404,8 @@ export async function fetchHealth(): Promise<boolean> {
|
|||||||
|
|
||||||
if (response.ok) {
|
if (response.ok) {
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
return data.healthy || false;
|
// If we get a version back, the server is healthy
|
||||||
|
return !!data.version;
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@@ -299,9 +299,9 @@ export default function Settings() {
|
|||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{user?.avatarURL && (
|
{user?.avatarurl && (
|
||||||
<img
|
<img
|
||||||
src={user.avatarURL}
|
src={user.avatarurl}
|
||||||
alt={user?.name}
|
alt={user?.name}
|
||||||
className="h-10 w-10 rounded-full bg-neutral-200 dark:bg-neutral-700 flex-shrink-0"
|
className="h-10 w-10 rounded-full bg-neutral-200 dark:bg-neutral-700 flex-shrink-0"
|
||||||
onError={(e) => {
|
onError={(e) => {
|
||||||
|
|||||||
@@ -50,21 +50,33 @@ export default function Thinking({
|
|||||||
// Position content to show bottom when collapsed
|
// Position content to show bottom when collapsed
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isCollapsed && contentRef.current && wrapperRef.current) {
|
if (isCollapsed && contentRef.current && wrapperRef.current) {
|
||||||
const contentHeight = contentRef.current.scrollHeight;
|
requestAnimationFrame(() => {
|
||||||
const wrapperHeight = wrapperRef.current.clientHeight;
|
if (!contentRef.current || !wrapperRef.current) return;
|
||||||
if (contentHeight > wrapperHeight) {
|
|
||||||
const translateY = -(contentHeight - wrapperHeight);
|
const contentHeight = contentRef.current.scrollHeight;
|
||||||
contentRef.current.style.transform = `translateY(${translateY}px)`;
|
const wrapperHeight = wrapperRef.current.clientHeight;
|
||||||
setHasOverflow(true);
|
if (contentHeight > wrapperHeight) {
|
||||||
} else {
|
const translateY = -(contentHeight - wrapperHeight);
|
||||||
setHasOverflow(false);
|
contentRef.current.style.transform = `translateY(${translateY}px)`;
|
||||||
}
|
setHasOverflow(true);
|
||||||
|
} else {
|
||||||
|
contentRef.current.style.transform = "translateY(0)";
|
||||||
|
setHasOverflow(false);
|
||||||
|
}
|
||||||
|
});
|
||||||
} else if (contentRef.current) {
|
} else if (contentRef.current) {
|
||||||
contentRef.current.style.transform = "translateY(0)";
|
contentRef.current.style.transform = "translateY(0)";
|
||||||
setHasOverflow(false);
|
setHasOverflow(false);
|
||||||
}
|
}
|
||||||
}, [thinking, isCollapsed]);
|
}, [thinking, isCollapsed]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (activelyThinking && wrapperRef.current && !isCollapsed) {
|
||||||
|
// When expanded and actively thinking, scroll to bottom
|
||||||
|
wrapperRef.current.scrollTop = wrapperRef.current.scrollHeight;
|
||||||
|
}
|
||||||
|
}, [thinking, activelyThinking, isCollapsed]);
|
||||||
|
|
||||||
const handleToggle = () => {
|
const handleToggle = () => {
|
||||||
setIsCollapsed(!isCollapsed);
|
setIsCollapsed(!isCollapsed);
|
||||||
setHasUserInteracted(true);
|
setHasUserInteracted(true);
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import { createQueryBatcher } from "./useQueryBatcher";
|
|||||||
import { useRefetchModels } from "./useModels";
|
import { useRefetchModels } from "./useModels";
|
||||||
import { useStreamingContext } from "@/contexts/StreamingContext";
|
import { useStreamingContext } from "@/contexts/StreamingContext";
|
||||||
import { useSettings } from "./useSettings";
|
import { useSettings } from "./useSettings";
|
||||||
|
import { getModelCapabilities } from "@/api";
|
||||||
|
|
||||||
export const useChats = () => {
|
export const useChats = () => {
|
||||||
return useQuery({
|
return useQuery({
|
||||||
@@ -606,6 +607,24 @@ export const useSendMessage = (chatId: string) => {
|
|||||||
queryClient.setQueryData(["staleModels"], newStaleMap);
|
queryClient.setQueryData(["staleModels"], newStaleMap);
|
||||||
|
|
||||||
queryClient.invalidateQueries({ queryKey: ["models"] });
|
queryClient.invalidateQueries({ queryKey: ["models"] });
|
||||||
|
|
||||||
|
// Fetch fresh capabilities for the downloaded model
|
||||||
|
getModelCapabilities(selectedModel.model)
|
||||||
|
.then((capabilities) => {
|
||||||
|
queryClient.setQueryData(
|
||||||
|
["modelCapabilities", selectedModel.model],
|
||||||
|
capabilities,
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error(
|
||||||
|
"Failed to fetch capabilities after download:",
|
||||||
|
error,
|
||||||
|
);
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: ["modelCapabilities", selectedModel.model],
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,114 +0,0 @@
|
|||||||
import { useMutation, useQueryClient } from "@tanstack/react-query";
|
|
||||||
import { useState } from "react";
|
|
||||||
import { pullModel } from "@/api";
|
|
||||||
import { useSelectedModel } from "./useSelectedModel";
|
|
||||||
import { useSettings } from "./useSettings";
|
|
||||||
|
|
||||||
interface DownloadProgress {
|
|
||||||
status: string;
|
|
||||||
digest?: string;
|
|
||||||
total?: number;
|
|
||||||
completed?: number;
|
|
||||||
done?: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useDownloadModel(chatId?: string) {
|
|
||||||
const queryClient = useQueryClient();
|
|
||||||
const { selectedModel } = useSelectedModel(chatId);
|
|
||||||
const { setSettings } = useSettings();
|
|
||||||
const [downloadProgress, setDownloadProgress] =
|
|
||||||
useState<DownloadProgress | null>(null);
|
|
||||||
const [abortController, setAbortController] =
|
|
||||||
useState<AbortController | null>(null);
|
|
||||||
const [downloadingChatIds, setDownloadingChatIds] = useState<Set<string>>(
|
|
||||||
new Set(),
|
|
||||||
);
|
|
||||||
|
|
||||||
const mutation = useMutation({
|
|
||||||
mutationFn: async (modelName: string) => {
|
|
||||||
const controller = new AbortController();
|
|
||||||
setAbortController(controller);
|
|
||||||
setDownloadProgress({ status: "Starting download..." });
|
|
||||||
if (chatId) {
|
|
||||||
setDownloadingChatIds((prev) => new Set(prev).add(chatId));
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
for await (const progress of pullModel(modelName, controller.signal)) {
|
|
||||||
setDownloadProgress(progress);
|
|
||||||
|
|
||||||
if (progress.status === "success") {
|
|
||||||
// Update selected model to indicate it's now available locally
|
|
||||||
if (selectedModel && selectedModel.model === modelName) {
|
|
||||||
setSettings({ SelectedModel: modelName });
|
|
||||||
}
|
|
||||||
// Invalidate models query to refresh the list
|
|
||||||
await queryClient.invalidateQueries({ queryKey: ["models"] });
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
setAbortController(null);
|
|
||||||
if (chatId) {
|
|
||||||
setDownloadingChatIds((prev) => {
|
|
||||||
const newSet = new Set(prev);
|
|
||||||
newSet.delete(chatId);
|
|
||||||
return newSet;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
onSuccess: () => {
|
|
||||||
setDownloadProgress(null);
|
|
||||||
if (chatId) {
|
|
||||||
setDownloadingChatIds((prev) => {
|
|
||||||
const newSet = new Set(prev);
|
|
||||||
newSet.delete(chatId);
|
|
||||||
return newSet;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
},
|
|
||||||
onError: (error: Error) => {
|
|
||||||
const status =
|
|
||||||
error.name === "AbortError" ? "Download cancelled" : "Download failed";
|
|
||||||
setDownloadProgress({ status, done: true });
|
|
||||||
|
|
||||||
// Clear error message after delay
|
|
||||||
const delay = error.name === "AbortError" ? 1500 : 3000;
|
|
||||||
setTimeout(() => {
|
|
||||||
setDownloadProgress(null);
|
|
||||||
if (chatId) {
|
|
||||||
setDownloadingChatIds((prev) => {
|
|
||||||
const newSet = new Set(prev);
|
|
||||||
newSet.delete(chatId);
|
|
||||||
return newSet;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, delay);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const cancelDownload = () => {
|
|
||||||
if (abortController) {
|
|
||||||
abortController.abort();
|
|
||||||
setAbortController(null);
|
|
||||||
if (chatId) {
|
|
||||||
setDownloadingChatIds((prev) => {
|
|
||||||
const newSet = new Set(prev);
|
|
||||||
newSet.delete(chatId);
|
|
||||||
return newSet;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return {
|
|
||||||
downloadModel: mutation.mutate,
|
|
||||||
isDownloading:
|
|
||||||
mutation.isPending && chatId ? downloadingChatIds.has(chatId) : false,
|
|
||||||
downloadProgress:
|
|
||||||
chatId && downloadingChatIds.has(chatId) ? downloadProgress : null,
|
|
||||||
error: mutation.error,
|
|
||||||
cancelDownload,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,29 +1,20 @@
|
|||||||
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
||||||
import { useEffect, useState } from "react";
|
|
||||||
import { fetchUser, fetchConnectUrl, disconnectUser } from "@/api";
|
import { fetchUser, fetchConnectUrl, disconnectUser } from "@/api";
|
||||||
|
|
||||||
export function useUser() {
|
export function useUser() {
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
const [initialDataLoaded, setInitialDataLoaded] = useState(false);
|
|
||||||
|
|
||||||
// Wait for initial data to be loaded
|
|
||||||
useEffect(() => {
|
|
||||||
const initialPromise = window.__initialUserDataPromise;
|
|
||||||
if (initialPromise) {
|
|
||||||
initialPromise.finally(() => {
|
|
||||||
setInitialDataLoaded(true);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
setInitialDataLoaded(true);
|
|
||||||
}
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const userQuery = useQuery({
|
const userQuery = useQuery({
|
||||||
queryKey: ["user"],
|
queryKey: ["user"],
|
||||||
queryFn: () => fetchUser(),
|
queryFn: async () => {
|
||||||
|
const result = await fetchUser();
|
||||||
|
return result;
|
||||||
|
},
|
||||||
staleTime: 5 * 60 * 1000, // Consider data stale after 5 minutes
|
staleTime: 5 * 60 * 1000, // Consider data stale after 5 minutes
|
||||||
gcTime: 10 * 60 * 1000, // Keep in cache for 10 minutes
|
gcTime: 10 * 60 * 1000, // Keep in cache for 10 minutes
|
||||||
initialData: null, // Start with null to prevent flashing
|
retry: 10,
|
||||||
|
retryDelay: (attemptIndex) => Math.min(500 * attemptIndex, 2000),
|
||||||
|
refetchOnMount: true, // Always fetch when component mounts
|
||||||
});
|
});
|
||||||
|
|
||||||
// Mutation to refresh user data
|
// Mutation to refresh user data
|
||||||
@@ -49,14 +40,15 @@ export function useUser() {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const isLoading = userQuery.isLoading || userQuery.isFetching;
|
||||||
|
const isAuthenticated = Boolean(userQuery.data?.name);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
user: userQuery.data,
|
user: userQuery.data,
|
||||||
isLoading:
|
isLoading,
|
||||||
!initialDataLoaded ||
|
|
||||||
(userQuery.isLoading && userQuery.data === undefined), // Show loading until initial data is loaded
|
|
||||||
isError: userQuery.isError,
|
isError: userQuery.isError,
|
||||||
error: userQuery.error,
|
error: userQuery.error,
|
||||||
isAuthenticated: Boolean(userQuery.data?.name),
|
isAuthenticated,
|
||||||
refreshUser: refreshUser.mutate,
|
refreshUser: refreshUser.mutate,
|
||||||
isRefreshing: refreshUser.isPending,
|
isRefreshing: refreshUser.isPending,
|
||||||
refetchUser: userQuery.refetch,
|
refetchUser: userQuery.refetch,
|
||||||
|
|||||||
13
app/ui/app/src/lib/config.ts
Normal file
13
app/ui/app/src/lib/config.ts
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
// API configuration
|
||||||
|
const DEV_API_URL = "http://127.0.0.1:3001";
|
||||||
|
|
||||||
|
// Base URL for fetch API calls (can be relative in production)
|
||||||
|
export const API_BASE = import.meta.env.DEV ? DEV_API_URL : "";
|
||||||
|
|
||||||
|
// Full host URL for Ollama client (needs full origin in production)
|
||||||
|
export const OLLAMA_HOST = import.meta.env.DEV
|
||||||
|
? DEV_API_URL
|
||||||
|
: window.location.origin;
|
||||||
|
|
||||||
|
export const OLLAMA_DOT_COM =
|
||||||
|
import.meta.env.VITE_OLLAMA_DOT_COM_URL || "https://ollama.com";
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import { Ollama } from "ollama/browser";
|
import { Ollama } from "ollama/browser";
|
||||||
|
import { OLLAMA_HOST } from "./config";
|
||||||
|
|
||||||
let _ollamaClient: Ollama | null = null;
|
let _ollamaClient: Ollama | null = null;
|
||||||
|
|
||||||
@@ -6,7 +7,7 @@ export const ollamaClient = new Proxy({} as Ollama, {
|
|||||||
get(_target, prop) {
|
get(_target, prop) {
|
||||||
if (!_ollamaClient) {
|
if (!_ollamaClient) {
|
||||||
_ollamaClient = new Ollama({
|
_ollamaClient = new Ollama({
|
||||||
host: window.location.origin,
|
host: OLLAMA_HOST,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
const value = _ollamaClient[prop as keyof Ollama];
|
const value = _ollamaClient[prop as keyof Ollama];
|
||||||
|
|||||||
@@ -5,13 +5,6 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
|||||||
import { routeTree } from "./routeTree.gen";
|
import { routeTree } from "./routeTree.gen";
|
||||||
import { fetchUser } from "./api";
|
import { fetchUser } from "./api";
|
||||||
import { StreamingProvider } from "./contexts/StreamingContext";
|
import { StreamingProvider } from "./contexts/StreamingContext";
|
||||||
import { User } from "@/gotypes";
|
|
||||||
|
|
||||||
declare global {
|
|
||||||
interface Window {
|
|
||||||
__initialUserDataPromise?: Promise<User | null>;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const queryClient = new QueryClient({
|
const queryClient = new QueryClient({
|
||||||
defaultOptions: {
|
defaultOptions: {
|
||||||
@@ -24,27 +17,11 @@ const queryClient = new QueryClient({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// Track initial user data fetch
|
fetchUser().then((userData) => {
|
||||||
let initialUserDataPromise: Promise<User | null> | null = null;
|
if (userData) {
|
||||||
|
|
||||||
// Initialize user data on app startup
|
|
||||||
const initializeUserData = async () => {
|
|
||||||
try {
|
|
||||||
const userData = await fetchUser();
|
|
||||||
queryClient.setQueryData(["user"], userData);
|
queryClient.setQueryData(["user"], userData);
|
||||||
return userData;
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Error initializing user data:", error);
|
|
||||||
queryClient.setQueryData(["user"], null);
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
};
|
});
|
||||||
|
|
||||||
// Start initialization immediately and track the promise
|
|
||||||
initialUserDataPromise = initializeUserData();
|
|
||||||
|
|
||||||
// Export the promise so hooks can await it
|
|
||||||
window.__initialUserDataPromise = initialUserDataPromise;
|
|
||||||
|
|
||||||
const router = createRouter({
|
const router = createRouter({
|
||||||
routeTree,
|
routeTree,
|
||||||
|
|||||||
@@ -101,15 +101,14 @@ type HealthResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Name string `json:"name"`
|
Email string `json:"email"`
|
||||||
Email string `json:"email"`
|
Name string `json:"name"`
|
||||||
AvatarURL string `json:"avatarURL"`
|
Bio string `json:"bio,omitempty"`
|
||||||
Plan string `json:"plan"`
|
AvatarURL string `json:"avatarurl,omitempty"`
|
||||||
Bio string `json:"bio"`
|
FirstName string `json:"firstname,omitempty"`
|
||||||
FirstName string `json:"firstName"`
|
LastName string `json:"lastname,omitempty"`
|
||||||
LastName string `json:"lastName"`
|
Plan string `json:"plan,omitempty"`
|
||||||
OverThreshold bool `json:"overThreshold"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Attachment struct {
|
type Attachment struct {
|
||||||
|
|||||||
241
app/ui/ui.go
241
app/ui/ui.go
@@ -12,18 +12,17 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/app/auth"
|
|
||||||
"github.com/ollama/ollama/app/server"
|
"github.com/ollama/ollama/app/server"
|
||||||
"github.com/ollama/ollama/app/store"
|
"github.com/ollama/ollama/app/store"
|
||||||
"github.com/ollama/ollama/app/tools"
|
"github.com/ollama/ollama/app/tools"
|
||||||
@@ -118,40 +117,66 @@ func (s *Server) log() *slog.Logger {
|
|||||||
|
|
||||||
// ollamaProxy creates a reverse proxy handler to the Ollama server
|
// ollamaProxy creates a reverse proxy handler to the Ollama server
|
||||||
func (s *Server) ollamaProxy() http.Handler {
|
func (s *Server) ollamaProxy() http.Handler {
|
||||||
ollamaHost := os.Getenv("OLLAMA_HOST")
|
var (
|
||||||
if ollamaHost == "" {
|
proxy http.Handler
|
||||||
ollamaHost = "http://127.0.0.1:11434"
|
proxyMu sync.Mutex
|
||||||
}
|
)
|
||||||
|
|
||||||
if !strings.HasPrefix(ollamaHost, "http://") && !strings.HasPrefix(ollamaHost, "https://") {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ollamaHost = "http://" + ollamaHost
|
proxyMu.Lock()
|
||||||
}
|
p := proxy
|
||||||
|
proxyMu.Unlock()
|
||||||
|
|
||||||
target, err := url.Parse(ollamaHost)
|
if p == nil {
|
||||||
if err != nil {
|
proxyMu.Lock()
|
||||||
s.log().Error("failed to parse OLLAMA_HOST", "error", err, "host", ollamaHost)
|
if proxy == nil {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
var err error
|
||||||
http.Error(w, "failed to configure proxy", http.StatusInternalServerError)
|
for i := range 2 {
|
||||||
})
|
if i > 0 {
|
||||||
}
|
s.log().Warn("ollama server not ready, retrying", "attempt", i+1)
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
err = WaitForServer(context.Background(), 10*time.Second)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
if err != nil {
|
||||||
|
proxyMu.Unlock()
|
||||||
|
s.log().Error("ollama server not ready after retries", "error", err)
|
||||||
|
http.Error(w, "Ollama server is not ready", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
originalDirector := proxy.Director
|
target := envconfig.Host()
|
||||||
proxy.Director = func(req *http.Request) {
|
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||||
originalDirector(req)
|
|
||||||
req.Host = target.Host
|
|
||||||
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
newProxy := httputil.NewSingleHostReverseProxy(target)
|
||||||
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
|
|
||||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
|
||||||
}
|
|
||||||
|
|
||||||
return proxy
|
originalDirector := newProxy.Director
|
||||||
|
newProxy.Director = func(req *http.Request) {
|
||||||
|
originalDirector(req)
|
||||||
|
req.Host = target.Host
|
||||||
|
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
newProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
|
||||||
|
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy = newProxy
|
||||||
|
p = newProxy
|
||||||
|
} else {
|
||||||
|
p = proxy
|
||||||
|
}
|
||||||
|
proxyMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
p.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type errHandlerFunc func(http.ResponseWriter, *http.Request) error
|
type errHandlerFunc func(http.ResponseWriter, *http.Request) error
|
||||||
@@ -264,11 +289,10 @@ func (s *Server) Handler() http.Handler {
|
|||||||
ollamaProxy := s.ollamaProxy()
|
ollamaProxy := s.ollamaProxy()
|
||||||
mux.Handle("GET /api/tags", ollamaProxy)
|
mux.Handle("GET /api/tags", ollamaProxy)
|
||||||
mux.Handle("POST /api/show", ollamaProxy)
|
mux.Handle("POST /api/show", ollamaProxy)
|
||||||
|
mux.Handle("GET /api/version", ollamaProxy)
|
||||||
mux.Handle("GET /api/v1/me", handle(s.me))
|
mux.Handle("HEAD /api/version", ollamaProxy)
|
||||||
mux.Handle("POST /api/v1/disconnect", handle(s.disconnect))
|
mux.Handle("POST /api/me", ollamaProxy)
|
||||||
mux.Handle("GET /api/v1/connect", handle(s.connectURL))
|
mux.Handle("POST /api/signout", ollamaProxy)
|
||||||
mux.Handle("GET /api/v1/health", handle(s.health))
|
|
||||||
|
|
||||||
// React app - catch all non-API routes and serve the React app
|
// React app - catch all non-API routes and serve the React app
|
||||||
mux.Handle("GET /", s.appHandler())
|
mux.Handle("GET /", s.appHandler())
|
||||||
@@ -338,7 +362,7 @@ func (s *Server) doSelfSigned(ctx context.Context, method, path string) (*http.R
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UserData fetches user data from ollama.com API for the current ollama key
|
// UserData fetches user data from ollama.com API for the current ollama key
|
||||||
func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
|
func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
|
||||||
resp, err := s.doSelfSigned(ctx, http.MethodPost, "/api/me")
|
resp, err := s.doSelfSigned(ctx, http.MethodPost, "/api/me")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to call ollama.com/api/me: %w", err)
|
return nil, fmt.Errorf("failed to call ollama.com/api/me: %w", err)
|
||||||
@@ -349,7 +373,7 @@ func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
|
|||||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var user responses.User
|
var user api.UserResponse
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse user response: %w", err)
|
return nil, fmt.Errorf("failed to parse user response: %w", err)
|
||||||
}
|
}
|
||||||
@@ -368,29 +392,27 @@ func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func waitForServer(ctx context.Context) error {
|
// WaitForServer waits for the Ollama server to be ready
|
||||||
timeout := time.Now().Add(10 * time.Second)
|
func WaitForServer(ctx context.Context, timeout time.Duration) error {
|
||||||
// TODO: this avoids an error on first load of the app
|
deadline := time.Now().Add(timeout)
|
||||||
// however we should either show a loading state or
|
for time.Now().Before(deadline) {
|
||||||
// wait for the Ollama server to be ready before redirecting
|
|
||||||
for {
|
|
||||||
c, err := api.ClientFromEnvironment()
|
c, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := c.Version(ctx); err == nil {
|
if _, err := c.Version(ctx); err == nil {
|
||||||
break
|
slog.Debug("ollama server is ready")
|
||||||
}
|
return nil
|
||||||
if time.Now().After(timeout) {
|
|
||||||
return fmt.Errorf("timeout waiting for Ollama server to be ready")
|
|
||||||
}
|
}
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
return nil
|
return errors.New("timeout waiting for Ollama server to be ready")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) createChat(w http.ResponseWriter, r *http.Request) error {
|
func (s *Server) createChat(w http.ResponseWriter, r *http.Request) error {
|
||||||
waitForServer(r.Context())
|
if err := WaitForServer(r.Context(), 10*time.Second); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
id, err := uuid.NewV7()
|
id, err := uuid.NewV7()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1438,129 +1460,6 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) me(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
if r.Method != http.MethodGet {
|
|
||||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := s.UserData(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
// If fetching from API fails, try to return cached user data if available
|
|
||||||
if cachedUser, cacheErr := s.Store.User(); cacheErr == nil && cachedUser != nil {
|
|
||||||
s.log().Info("API request failed, returning cached user data", "error", err)
|
|
||||||
responseUser := &responses.User{
|
|
||||||
Name: cachedUser.Name,
|
|
||||||
Email: cachedUser.Email,
|
|
||||||
Plan: cachedUser.Plan,
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
return json.NewEncoder(w).Encode(responseUser)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.log().Error("failed to get user data", "error", err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return json.NewEncoder(w).Encode(responses.Error{
|
|
||||||
Error: "failed to get user data",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
return json.NewEncoder(w).Encode(user)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) disconnect(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
if r.Method != http.MethodPost {
|
|
||||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.Store.ClearUser(); err != nil {
|
|
||||||
s.log().Warn("failed to clear cached user data", "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the SSH public key to encode for the delete request
|
|
||||||
pubKey, err := ollamaAuth.GetPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
s.log().Error("failed to get public key", "error", err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return json.NewEncoder(w).Encode(responses.Error{
|
|
||||||
Error: "failed to get public key",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode the key using base64 URL encoding
|
|
||||||
encodedKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
|
||||||
|
|
||||||
// Call the /api/user/keys/{encodedKey} endpoint with DELETE
|
|
||||||
resp, err := s.doSelfSigned(r.Context(), http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey))
|
|
||||||
if err != nil {
|
|
||||||
s.log().Error("failed to call ollama.com/api/user/keys", "error", err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return json.NewEncoder(w).Encode(responses.Error{
|
|
||||||
Error: "failed to disconnect from ollama.com",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
s.log().Error("disconnect request failed", "status", resp.StatusCode)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return json.NewEncoder(w).Encode(responses.Error{
|
|
||||||
Error: "failed to disconnect from ollama.com",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
return json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) connectURL(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
if r.Method != http.MethodGet {
|
|
||||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
connectURL, err := auth.BuildConnectURL(OllamaDotCom)
|
|
||||||
if err != nil {
|
|
||||||
s.log().Error("failed to build connect URL", "error", err)
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
return json.NewEncoder(w).Encode(responses.Error{
|
|
||||||
Error: "failed to build connect URL",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
return json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"connect_url": connectURL,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
if r.Method != http.MethodGet {
|
|
||||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
healthy := false
|
|
||||||
c, err := api.ClientFromEnvironment()
|
|
||||||
if err == nil {
|
|
||||||
if _, err := c.Version(r.Context()); err == nil {
|
|
||||||
healthy = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
return json.NewEncoder(w).Encode(responses.HealthResponse{
|
|
||||||
Healthy: healthy,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
|
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
|
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -158,16 +158,16 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
|
|||||||
case uint32(UI_REQUEST_MSG_ID):
|
case uint32(UI_REQUEST_MSG_ID):
|
||||||
// Requests for the UI must always come from the main event thread
|
// Requests for the UI must always come from the main event thread
|
||||||
l := int(wParam)
|
l := int(wParam)
|
||||||
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l)
|
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l) //nolint:govet,gosec
|
||||||
t.app.UIRun(path)
|
t.app.UIRun(path)
|
||||||
case WM_COPYDATA:
|
case WM_COPYDATA:
|
||||||
// Handle URL scheme requests from other instances
|
// Handle URL scheme requests from other instances
|
||||||
if lParam != 0 {
|
if lParam != 0 {
|
||||||
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam))
|
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam)) //nolint:govet,gosec
|
||||||
if cds.DwData == 1 { // Our identifier for URL scheme messages
|
if cds.DwData == 1 { // Our identifier for URL scheme messages
|
||||||
// Convert the data back to string
|
// Convert the data back to string
|
||||||
data := make([]byte, cds.CbData)
|
data := make([]byte, cds.CbData)
|
||||||
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData])
|
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData]) //nolint:govet,gosec
|
||||||
urlScheme := string(data)
|
urlScheme := string(data)
|
||||||
handleURLSchemeRequest(urlScheme)
|
handleURLSchemeRequest(urlScheme)
|
||||||
lResult = 1 // Return non-zero to indicate success
|
lResult = 1 // Return non-zero to indicate success
|
||||||
|
|||||||
115
cmd/bench/README.md
Normal file
115
cmd/bench/README.md
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
Ollama Benchmark Tool
|
||||||
|
---------------------
|
||||||
|
|
||||||
|
A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
* Benchmark multiple models in a single run
|
||||||
|
* Support for both text and image prompts
|
||||||
|
* Configurable generation parameters (temperature, max tokens, seed, etc.)
|
||||||
|
* Supports benchstat and CSV output formats
|
||||||
|
* Detailed performance metrics (prefill, generate, load, total durations)
|
||||||
|
|
||||||
|
## Building from Source
|
||||||
|
|
||||||
|
```
|
||||||
|
go build -o ollama-bench bench.go
|
||||||
|
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
|
||||||
|
```
|
||||||
|
|
||||||
|
Using Go Run (without building)
|
||||||
|
|
||||||
|
```
|
||||||
|
go run bench.go -model gpt-oss:20b -epochs 3
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Basic Example
|
||||||
|
|
||||||
|
```
|
||||||
|
./ollama-bench -model gemma3 -epochs 6
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark Multiple Models
|
||||||
|
|
||||||
|
```
|
||||||
|
./ollama-bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
|
||||||
|
benchstat -col /name gemma.bench
|
||||||
|
```
|
||||||
|
|
||||||
|
### With Image Prompt
|
||||||
|
|
||||||
|
```
|
||||||
|
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Advanced Example
|
||||||
|
|
||||||
|
```
|
||||||
|
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
|
||||||
|
```
|
||||||
|
|
||||||
|
## Command Line Options
|
||||||
|
|
||||||
|
| Option | Description | Default |
|
||||||
|
|----------|-------------|---------|
|
||||||
|
| -model | Comma-separated list of models to benchmark | (required) |
|
||||||
|
| -epochs | Number of iterations per model | 1 |
|
||||||
|
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
|
||||||
|
| -temperature | Temperature parameter | 0.0 |
|
||||||
|
| -seed | Random seed | 0 (random) |
|
||||||
|
| -timeout | Timeout in seconds | 300 |
|
||||||
|
| -p | Prompt text | "Write a long story." |
|
||||||
|
| -image | Image file to include in prompt | |
|
||||||
|
| -k | Keep-alive duration in seconds | 0 |
|
||||||
|
| -format | Output format (benchstat, csv) | benchstat |
|
||||||
|
| -output | Output file for results | "" (stdout) |
|
||||||
|
| -v | Verbose mode | false |
|
||||||
|
| -debug | Show debug information | false |
|
||||||
|
|
||||||
|
## Output Formats
|
||||||
|
|
||||||
|
### Markdown Format
|
||||||
|
|
||||||
|
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like:
|
||||||
|
```
|
||||||
|
Model | Step | Count | Duration | nsPerToken | tokensPerSec |
|
||||||
|
|-------|------|-------|----------|------------|--------------|
|
||||||
|
| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
|
||||||
|
| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
|
||||||
|
| gpt-oss:20b | load | 1 | 121.674208ms | - | - |
|
||||||
|
| gpt-oss:20b | total | 1 | 2.861047625s | - | - |
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchstat Format
|
||||||
|
|
||||||
|
Compatible with Go's benchstat tool for statistical analysis:
|
||||||
|
|
||||||
|
```
|
||||||
|
BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec
|
||||||
|
BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec
|
||||||
|
BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request
|
||||||
|
```
|
||||||
|
|
||||||
|
### CSV Format
|
||||||
|
|
||||||
|
Machine-readable comma-separated values:
|
||||||
|
|
||||||
|
```
|
||||||
|
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
|
||||||
|
gpt-oss:20b,prefill,128,78125.00,12800.00
|
||||||
|
gpt-oss:20b,generate,512,19531.25,51200.00
|
||||||
|
gpt-oss:20b,load,1,1500000000,0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Metrics Explained
|
||||||
|
|
||||||
|
The tool reports four types of metrics for each model:
|
||||||
|
|
||||||
|
* prefill: Time spent processing the prompt
|
||||||
|
* generate: Time spent generating the response
|
||||||
|
* load: Model loading time (one-time cost)
|
||||||
|
* total: Total request duration
|
||||||
|
|
||||||
321
cmd/bench/bench.go
Normal file
321
cmd/bench/bench.go
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type flagOptions struct {
|
||||||
|
models *string
|
||||||
|
epochs *int
|
||||||
|
maxTokens *int
|
||||||
|
temperature *float64
|
||||||
|
seed *int
|
||||||
|
timeout *int
|
||||||
|
prompt *string
|
||||||
|
imageFile *string
|
||||||
|
keepAlive *float64
|
||||||
|
format *string
|
||||||
|
outputFile *string
|
||||||
|
debug *bool
|
||||||
|
verbose *bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Metrics struct {
|
||||||
|
Model string
|
||||||
|
Step string
|
||||||
|
Count int
|
||||||
|
Duration time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
var once sync.Once
|
||||||
|
|
||||||
|
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
|
||||||
|
|
||||||
|
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
||||||
|
switch format {
|
||||||
|
case "benchstat":
|
||||||
|
if verbose {
|
||||||
|
printHeader := func() {
|
||||||
|
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
|
||||||
|
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
|
||||||
|
}
|
||||||
|
once.Do(printHeader)
|
||||||
|
}
|
||||||
|
for _, m := range metrics {
|
||||||
|
if m.Step == "generate" || m.Step == "prefill" {
|
||||||
|
if m.Count > 0 {
|
||||||
|
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||||
|
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
|
||||||
|
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
|
||||||
|
m.Model, m.Step, m.Count)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var suffix string
|
||||||
|
if m.Step == "load" {
|
||||||
|
suffix = "/step=load"
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
|
||||||
|
m.Model, suffix, m.Duration.Nanoseconds())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "csv":
|
||||||
|
printHeader := func() {
|
||||||
|
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
||||||
|
fmt.Fprintln(w, strings.Join(headings, ","))
|
||||||
|
}
|
||||||
|
once.Do(printHeader)
|
||||||
|
|
||||||
|
for _, m := range metrics {
|
||||||
|
if m.Step == "generate" || m.Step == "prefill" {
|
||||||
|
var nsPerToken float64
|
||||||
|
var tokensPerSec float64
|
||||||
|
if m.Count > 0 {
|
||||||
|
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||||
|
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "%s,%s,%d,%.2f,%.2f\n", m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "markdown":
|
||||||
|
printHeader := func() {
|
||||||
|
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
|
||||||
|
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
|
||||||
|
}
|
||||||
|
once.Do(printHeader)
|
||||||
|
|
||||||
|
for _, m := range metrics {
|
||||||
|
var nsPerToken, tokensPerSec float64
|
||||||
|
var nsPerTokenStr, tokensPerSecStr string
|
||||||
|
|
||||||
|
if m.Step == "generate" || m.Step == "prefill" {
|
||||||
|
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
||||||
|
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
||||||
|
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
|
||||||
|
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
|
||||||
|
} else {
|
||||||
|
nsPerTokenStr = "-"
|
||||||
|
tokensPerSecStr = "-"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
|
||||||
|
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkChat(fOpt flagOptions) error {
|
||||||
|
models := strings.Split(*fOpt.models, ",")
|
||||||
|
|
||||||
|
// todo - add multi-image support
|
||||||
|
var imgData api.ImageData
|
||||||
|
var err error
|
||||||
|
if *fOpt.imageFile != "" {
|
||||||
|
imgData, err = readImage(*fOpt.imageFile)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Couldn't read image '%s': %v\n", *fOpt.imageFile, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if *fOpt.debug && imgData != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Read file '%s'\n", *fOpt.imageFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Couldn't create ollama client: %v\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var out io.Writer = os.Stdout
|
||||||
|
if fOpt.outputFile != nil && *fOpt.outputFile != "" {
|
||||||
|
f, err := os.OpenFile(*fOpt.outputFile, os.O_CREATE|os.O_WRONLY, 0o644)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: cannot open output file %s: %v\n", *fOpt.outputFile, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
out = f
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
for range *fOpt.epochs {
|
||||||
|
options := make(map[string]interface{})
|
||||||
|
if *fOpt.maxTokens > 0 {
|
||||||
|
options["num_predict"] = *fOpt.maxTokens
|
||||||
|
}
|
||||||
|
options["temperature"] = *fOpt.temperature
|
||||||
|
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||||
|
options["seed"] = *fOpt.seed
|
||||||
|
}
|
||||||
|
|
||||||
|
var keepAliveDuration *api.Duration
|
||||||
|
if *fOpt.keepAlive > 0 {
|
||||||
|
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
||||||
|
keepAliveDuration = &duration
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &api.ChatRequest{
|
||||||
|
Model: model,
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: *fOpt.prompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: options,
|
||||||
|
KeepAlive: keepAliveDuration,
|
||||||
|
}
|
||||||
|
|
||||||
|
if imgData != nil {
|
||||||
|
req.Messages[0].Images = []api.ImageData{imgData}
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseMetrics *api.Metrics
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
||||||
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Done {
|
||||||
|
responseMetrics = &resp.Metrics
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if *fOpt.debug {
|
||||||
|
fmt.Fprintln(os.Stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if responseMetrics == nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
metrics := []Metrics{
|
||||||
|
{
|
||||||
|
Model: model,
|
||||||
|
Step: "prefill",
|
||||||
|
Count: responseMetrics.PromptEvalCount,
|
||||||
|
Duration: responseMetrics.PromptEvalDuration,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Model: model,
|
||||||
|
Step: "generate",
|
||||||
|
Count: responseMetrics.EvalCount,
|
||||||
|
Duration: responseMetrics.EvalDuration,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Model: model,
|
||||||
|
Step: "load",
|
||||||
|
Count: 1,
|
||||||
|
Duration: responseMetrics.LoadDuration,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Model: model,
|
||||||
|
Step: "total",
|
||||||
|
Count: 1,
|
||||||
|
Duration: responseMetrics.TotalDuration,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
|
||||||
|
|
||||||
|
if *fOpt.keepAlive > 0 {
|
||||||
|
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readImage(filePath string) (api.ImageData, error) {
|
||||||
|
file, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.ImageData(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fOpt := flagOptions{
|
||||||
|
models: flag.String("model", "", "Model to benchmark"),
|
||||||
|
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
|
||||||
|
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
|
||||||
|
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
|
||||||
|
seed: flag.Int("seed", 0, "Random seed"),
|
||||||
|
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
|
||||||
|
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
|
||||||
|
imageFile: flag.String("image", "", "Filename for an image to include"),
|
||||||
|
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
||||||
|
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
|
||||||
|
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
||||||
|
verbose: flag.Bool("v", false, "Show system information"),
|
||||||
|
debug: flag.Bool("debug", false, "Show debug information"),
|
||||||
|
}
|
||||||
|
|
||||||
|
flag.Usage = func() {
|
||||||
|
fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0])
|
||||||
|
fmt.Fprintf(os.Stderr, "Description:\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " Model benchmarking tool with configurable parameters\n\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "Options:\n")
|
||||||
|
flag.PrintDefaults()
|
||||||
|
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
|
||||||
|
}
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(*fOpt.models) == 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "ERROR: No model(s) specified to benchmark.\n")
|
||||||
|
flag.Usage()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
BenchmarkChat(fOpt)
|
||||||
|
}
|
||||||
463
cmd/bench/bench_test.go
Normal file
463
cmd/bench/bench_test.go
Normal file
@@ -0,0 +1,463 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTestFlagOptions() flagOptions {
|
||||||
|
models := "test-model"
|
||||||
|
format := "benchstat"
|
||||||
|
epochs := 1
|
||||||
|
maxTokens := 100
|
||||||
|
temperature := 0.7
|
||||||
|
seed := 42
|
||||||
|
timeout := 30
|
||||||
|
prompt := "test prompt"
|
||||||
|
imageFile := ""
|
||||||
|
keepAlive := 5.0
|
||||||
|
verbose := false
|
||||||
|
debug := false
|
||||||
|
|
||||||
|
return flagOptions{
|
||||||
|
models: &models,
|
||||||
|
format: &format,
|
||||||
|
epochs: &epochs,
|
||||||
|
maxTokens: &maxTokens,
|
||||||
|
temperature: &temperature,
|
||||||
|
seed: &seed,
|
||||||
|
timeout: &timeout,
|
||||||
|
prompt: &prompt,
|
||||||
|
imageFile: &imageFile,
|
||||||
|
keepAlive: &keepAlive,
|
||||||
|
verbose: &verbose,
|
||||||
|
debug: &debug,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func captureOutput(f func()) string {
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
oldStderr := os.Stderr
|
||||||
|
defer func() {
|
||||||
|
os.Stdout = oldStdout
|
||||||
|
os.Stderr = oldStderr
|
||||||
|
}()
|
||||||
|
|
||||||
|
r, w, _ := os.Pipe()
|
||||||
|
os.Stdout = w
|
||||||
|
os.Stderr = w
|
||||||
|
|
||||||
|
f()
|
||||||
|
|
||||||
|
w.Close()
|
||||||
|
var buf bytes.Buffer
|
||||||
|
io.Copy(&buf, r)
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func createMockOllamaServer(t *testing.T, responses []api.ChatResponse) *httptest.Server {
|
||||||
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/api/chat" {
|
||||||
|
t.Errorf("Expected path /api/chat, got %s", r.URL.Path)
|
||||||
|
http.Error(w, "Not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Method != "POST" {
|
||||||
|
t.Errorf("Expected POST method, got %s", r.Method)
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
for _, resp := range responses {
|
||||||
|
jsonData, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to marshal response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write(jsonData)
|
||||||
|
w.Write([]byte("\n"))
|
||||||
|
if f, ok := w.(http.Flusher); ok {
|
||||||
|
f.Flush()
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond) // Simulate some delay
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBenchmarkChat_Success(t *testing.T) {
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
|
||||||
|
mockResponses := []api.ChatResponse{
|
||||||
|
{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "test response part 1",
|
||||||
|
},
|
||||||
|
Done: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "test response part 2",
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
Metrics: api.Metrics{
|
||||||
|
PromptEvalCount: 10,
|
||||||
|
PromptEvalDuration: 100 * time.Millisecond,
|
||||||
|
EvalCount: 50,
|
||||||
|
EvalDuration: 500 * time.Millisecond,
|
||||||
|
TotalDuration: 600 * time.Millisecond,
|
||||||
|
LoadDuration: 50 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server := createMockOllamaServer(t, mockResponses)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", server.URL)
|
||||||
|
|
||||||
|
output := captureOutput(func() {
|
||||||
|
err := BenchmarkChat(fOpt)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if !strings.Contains(output, "BenchmarkModel/name=test-model/step=prefill") {
|
||||||
|
t.Errorf("Expected output to contain prefill metrics, got: %s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "BenchmarkModel/name=test-model/step=generate") {
|
||||||
|
t.Errorf("Expected output to contain generate metrics, got: %s", output)
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "ns/token") {
|
||||||
|
t.Errorf("Expected output to contain ns/token metric, got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBenchmarkChat_ServerError(t *testing.T) {
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", server.URL)
|
||||||
|
|
||||||
|
output := captureOutput(func() {
|
||||||
|
err := BenchmarkChat(fOpt)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected error to be handled internally, got returned error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if !strings.Contains(output, "ERROR: Couldn't chat with model") {
|
||||||
|
t.Errorf("Expected error message about chat failure, got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBenchmarkChat_Timeout(t *testing.T) {
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
shortTimeout := 1 // Very short timeout
|
||||||
|
fOpt.timeout = &shortTimeout
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Simulate a long delay that will cause timeout
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
response := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "test response",
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
Metrics: api.Metrics{
|
||||||
|
PromptEvalCount: 10,
|
||||||
|
PromptEvalDuration: 100 * time.Millisecond,
|
||||||
|
EvalCount: 50,
|
||||||
|
EvalDuration: 500 * time.Millisecond,
|
||||||
|
TotalDuration: 600 * time.Millisecond,
|
||||||
|
LoadDuration: 50 * time.Millisecond,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
jsonData, _ := json.Marshal(response)
|
||||||
|
w.Write(jsonData)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", server.URL)
|
||||||
|
|
||||||
|
output := captureOutput(func() {
|
||||||
|
err := BenchmarkChat(fOpt)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected timeout to be handled internally, got returned error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if !strings.Contains(output, "ERROR: Chat request timed out") {
|
||||||
|
t.Errorf("Expected timeout error message, got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBenchmarkChat_NoMetrics(t *testing.T) {
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
|
||||||
|
mockResponses := []api.ChatResponse{
|
||||||
|
{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "test response",
|
||||||
|
},
|
||||||
|
Done: false, // Never sends Done=true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server := createMockOllamaServer(t, mockResponses)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", server.URL)
|
||||||
|
|
||||||
|
output := captureOutput(func() {
|
||||||
|
err := BenchmarkChat(fOpt)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if !strings.Contains(output, "ERROR: No metrics received") {
|
||||||
|
t.Errorf("Expected no metrics error message, got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBenchmarkChat_MultipleModels(t *testing.T) {
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
models := "model1,model2"
|
||||||
|
epochs := 2
|
||||||
|
fOpt.models = &models
|
||||||
|
fOpt.epochs = &epochs
|
||||||
|
|
||||||
|
callCount := 0
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
callCount++
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
var req api.ChatRequest
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
json.Unmarshal(body, &req)
|
||||||
|
|
||||||
|
response := api.ChatResponse{
|
||||||
|
Model: req.Model,
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "test response for " + req.Model,
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
Metrics: api.Metrics{
|
||||||
|
PromptEvalCount: 10,
|
||||||
|
PromptEvalDuration: 100 * time.Millisecond,
|
||||||
|
EvalCount: 50,
|
||||||
|
EvalDuration: 500 * time.Millisecond,
|
||||||
|
TotalDuration: 600 * time.Millisecond,
|
||||||
|
LoadDuration: 50 * time.Millisecond,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
jsonData, _ := json.Marshal(response)
|
||||||
|
w.Write(jsonData)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", server.URL)
|
||||||
|
|
||||||
|
output := captureOutput(func() {
|
||||||
|
err := BenchmarkChat(fOpt)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Should be called 4 times (2 models × 2 epochs)
|
||||||
|
if callCount != 4 {
|
||||||
|
t.Errorf("Expected 4 API calls, got %d", callCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(output, "BenchmarkModel/name=model1") || !strings.Contains(output, "BenchmarkModel/name=model2") {
|
||||||
|
t.Errorf("Expected output for both models, got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBenchmarkChat_WithImage(t *testing.T) {
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
|
||||||
|
tmpfile, err := os.CreateTemp(t.TempDir(), "testimage")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp file: %v", err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tmpfile.Name())
|
||||||
|
|
||||||
|
content := []byte("fake image data")
|
||||||
|
if _, err := tmpfile.Write(content); err != nil {
|
||||||
|
t.Fatalf("Failed to write to temp file: %v", err)
|
||||||
|
}
|
||||||
|
tmpfile.Close()
|
||||||
|
|
||||||
|
tmpfileName := tmpfile.Name()
|
||||||
|
fOpt.imageFile = &tmpfileName
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Verify the request contains image data
|
||||||
|
var req api.ChatRequest
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
json.Unmarshal(body, &req)
|
||||||
|
|
||||||
|
if len(req.Messages) == 0 || len(req.Messages[0].Images) == 0 {
|
||||||
|
t.Error("Expected request to contain images")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
response := api.ChatResponse{
|
||||||
|
Model: "test-model",
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "test response with image",
|
||||||
|
},
|
||||||
|
Done: true,
|
||||||
|
Metrics: api.Metrics{
|
||||||
|
PromptEvalCount: 10,
|
||||||
|
PromptEvalDuration: 100 * time.Millisecond,
|
||||||
|
EvalCount: 50,
|
||||||
|
EvalDuration: 500 * time.Millisecond,
|
||||||
|
TotalDuration: 600 * time.Millisecond,
|
||||||
|
LoadDuration: 50 * time.Millisecond,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
jsonData, _ := json.Marshal(response)
|
||||||
|
w.Write(jsonData)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", server.URL)
|
||||||
|
|
||||||
|
output := captureOutput(func() {
|
||||||
|
err := BenchmarkChat(fOpt)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if !strings.Contains(output, "BenchmarkModel/name=test-model") {
|
||||||
|
t.Errorf("Expected benchmark output, got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBenchmarkChat_ImageError(t *testing.T) {
|
||||||
|
randFileName := func() string {
|
||||||
|
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
const length = 8
|
||||||
|
|
||||||
|
result := make([]byte, length)
|
||||||
|
rand.Read(result) // Fill with random bytes
|
||||||
|
|
||||||
|
for i := range result {
|
||||||
|
result[i] = charset[result[i]%byte(len(charset))]
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(result) + ".txt"
|
||||||
|
}
|
||||||
|
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
imageFile := randFileName()
|
||||||
|
fOpt.imageFile = &imageFile
|
||||||
|
|
||||||
|
output := captureOutput(func() {
|
||||||
|
err := BenchmarkChat(fOpt)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error from image reading, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if !strings.Contains(output, "ERROR: Couldn't read image") {
|
||||||
|
t.Errorf("Expected image read error message, got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadImage_Success(t *testing.T) {
|
||||||
|
tmpfile, err := os.CreateTemp(t.TempDir(), "testimage")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp file: %v", err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tmpfile.Name())
|
||||||
|
|
||||||
|
content := []byte("fake image data")
|
||||||
|
if _, err := tmpfile.Write(content); err != nil {
|
||||||
|
t.Fatalf("Failed to write to temp file: %v", err)
|
||||||
|
}
|
||||||
|
tmpfile.Close()
|
||||||
|
|
||||||
|
imgData, err := readImage(tmpfile.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if imgData == nil {
|
||||||
|
t.Error("Expected image data, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := api.ImageData(content)
|
||||||
|
if string(imgData) != string(expected) {
|
||||||
|
t.Errorf("Expected image data %v, got %v", expected, imgData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadImage_FileNotFound(t *testing.T) {
|
||||||
|
imgData, err := readImage("nonexistentfile.jpg")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for non-existent file, got nil")
|
||||||
|
}
|
||||||
|
if imgData != nil {
|
||||||
|
t.Error("Expected nil image data for non-existent file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptionsMapCreation(t *testing.T) {
|
||||||
|
fOpt := createTestFlagOptions()
|
||||||
|
|
||||||
|
options := make(map[string]interface{})
|
||||||
|
if *fOpt.maxTokens > 0 {
|
||||||
|
options["num_predict"] = *fOpt.maxTokens
|
||||||
|
}
|
||||||
|
options["temperature"] = *fOpt.temperature
|
||||||
|
if fOpt.seed != nil && *fOpt.seed > 0 {
|
||||||
|
options["seed"] = *fOpt.seed
|
||||||
|
}
|
||||||
|
|
||||||
|
if options["num_predict"] != *fOpt.maxTokens {
|
||||||
|
t.Errorf("Expected num_predict %d, got %v", *fOpt.maxTokens, options["num_predict"])
|
||||||
|
}
|
||||||
|
if options["temperature"] != *fOpt.temperature {
|
||||||
|
t.Errorf("Expected temperature %f, got %v", *fOpt.temperature, options["temperature"])
|
||||||
|
}
|
||||||
|
if options["seed"] != *fOpt.seed {
|
||||||
|
t.Errorf("Expected seed %d, got %v", *fOpt.seed, options["seed"])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -943,6 +943,9 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
|||||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||||
}
|
}
|
||||||
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
|
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
|
||||||
|
if resp.Requires != "" {
|
||||||
|
rows = append(rows, []string{"", "requires", resp.Requires})
|
||||||
|
}
|
||||||
return
|
return
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -1430,7 +1433,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
|||||||
latest.Summary()
|
latest.Summary()
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generate(cmd *cobra.Command, opts runOptions) error {
|
func generate(cmd *cobra.Command, opts runOptions) error {
|
||||||
|
|||||||
@@ -291,6 +291,31 @@ Weigh anchor!
|
|||||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("min version", func(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := showInfo(&api.ShowResponse{
|
||||||
|
Details: api.ModelDetails{
|
||||||
|
Family: "test",
|
||||||
|
ParameterSize: "7B",
|
||||||
|
QuantizationLevel: "FP16",
|
||||||
|
},
|
||||||
|
Requires: "0.14.0",
|
||||||
|
}, false, &b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect := ` Model
|
||||||
|
architecture test
|
||||||
|
parameters 7B
|
||||||
|
quantization FP16
|
||||||
|
requires 0.14.0
|
||||||
|
|
||||||
|
`
|
||||||
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteHandler(t *testing.T) {
|
func TestDeleteHandler(t *testing.T) {
|
||||||
|
|||||||
@@ -182,6 +182,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
|||||||
conv = &llama4Model{}
|
conv = &llama4Model{}
|
||||||
case "Mistral3ForConditionalGeneration":
|
case "Mistral3ForConditionalGeneration":
|
||||||
conv = &mistral3Model{}
|
conv = &mistral3Model{}
|
||||||
|
case "Ministral3ForCausalLM":
|
||||||
|
conv = &mistral3CausalModel{}
|
||||||
case "MixtralForCausalLM":
|
case "MixtralForCausalLM":
|
||||||
conv = &mixtralModel{}
|
conv = &mixtralModel{}
|
||||||
case "GemmaForCausalLM":
|
case "GemmaForCausalLM":
|
||||||
@@ -200,12 +202,20 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
|||||||
conv = &qwen25VLModel{}
|
conv = &qwen25VLModel{}
|
||||||
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
|
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
|
||||||
conv = &qwen3VLModel{}
|
conv = &qwen3VLModel{}
|
||||||
|
case "Olmo3ForCausalLM":
|
||||||
|
conv = &olmoModel{}
|
||||||
case "BertModel":
|
case "BertModel":
|
||||||
conv = &bertModel{}
|
conv = &bertModel{}
|
||||||
|
case "NomicBertModel", "NomicBertMoEModel":
|
||||||
|
conv = &nomicbertModel{}
|
||||||
case "CohereForCausalLM":
|
case "CohereForCausalLM":
|
||||||
conv = &commandrModel{}
|
conv = &commandrModel{}
|
||||||
case "GptOssForCausalLM":
|
case "GptOssForCausalLM":
|
||||||
conv = &gptossModel{}
|
conv = &gptossModel{}
|
||||||
|
case "DeepseekOCRForCausalLM":
|
||||||
|
conv = &deepseekocr{}
|
||||||
|
case "DeepseekV3ForCausalLM":
|
||||||
|
conv = &deepseek2Model{}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||||
}
|
}
|
||||||
|
|||||||
173
convert/convert_deepseek2.go
Normal file
173
convert/convert_deepseek2.go
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type deepseek2Model struct {
|
||||||
|
ModelParameters // architectures, vocab_size
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
|
||||||
|
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
|
||||||
|
KVLoraRank uint32 `json:"kv_lora_rank"`
|
||||||
|
QLoraRank uint32 `json:"q_lora_rank"`
|
||||||
|
VHeadDim uint32 `json:"v_head_dim"`
|
||||||
|
|
||||||
|
ExpertCount uint32 `json:"n_routed_experts"`
|
||||||
|
ExpertSharedCount uint32 `json:"n_shared_experts"`
|
||||||
|
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
|
||||||
|
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
|
||||||
|
ExpertWeightsNorm bool `json:"norm_topk_prob"`
|
||||||
|
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
|
||||||
|
|
||||||
|
ScoringFunc string `json:"scoring_func"`
|
||||||
|
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
|
||||||
|
|
||||||
|
RopeScaling struct {
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
MScaleAllDim float32 `json:"mscale_all_dim"`
|
||||||
|
} `json:"rope_scaling"`
|
||||||
|
|
||||||
|
Architecture string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "deepseek2"
|
||||||
|
kv["general.type"] = "model"
|
||||||
|
kv["deepseek2.block_count"] = p.HiddenLayers
|
||||||
|
|
||||||
|
numHeads := p.NumAttentionHeads
|
||||||
|
numKVHeads := p.NumKeyValueHeads
|
||||||
|
|
||||||
|
kv["deepseek2.attention.head_count"] = numHeads
|
||||||
|
kv["deepseek2.attention.head_count_kv"] = numKVHeads
|
||||||
|
kv["deepseek2.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
|
||||||
|
kv["deepseek2.attention.kv_lora_rank"] = p.KVLoraRank
|
||||||
|
kv["deepseek2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||||
|
kv["deepseek2.attention.q_lora_rank"] = p.QLoraRank
|
||||||
|
kv["deepseek2.attention.value_length"] = p.VHeadDim
|
||||||
|
kv["deepseek2.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
kv["deepseek2.embedding_length"] = p.HiddenSize
|
||||||
|
kv["deepseek2.expert_count"] = p.ExpertCount
|
||||||
|
kv["deepseek2.expert_feed_forward_length"] = p.ExpertIntermediateSize
|
||||||
|
kv["deepseek2.expert_shared_count"] = p.ExpertSharedCount
|
||||||
|
|
||||||
|
var scoringFunc uint32
|
||||||
|
switch p.ScoringFunc {
|
||||||
|
case "softmax":
|
||||||
|
// not currently supported in the model, but needed for Deepseek-OCR
|
||||||
|
scoringFunc = 1
|
||||||
|
case "sigmoid":
|
||||||
|
scoringFunc = 2
|
||||||
|
}
|
||||||
|
kv["deepseek2.expert_gating_func"] = scoringFunc
|
||||||
|
kv["deepseek2.expert_used_count"] = p.ExpertUsedCount
|
||||||
|
kv["deepseek2.expert_weights_norm"] = p.ExpertWeightsNorm
|
||||||
|
kv["deepseek2.expert_weights_scale"] = p.ExpertWeightsScale
|
||||||
|
kv["deepseek2.feed_forward_length"] = p.IntermediateSize
|
||||||
|
kv["deepseek2.leading_dense_block_count"] = p.LeadingDenseBlockCount
|
||||||
|
|
||||||
|
kv["deepseek2.rope.dimension_count"] = p.QKRopeHeadDim
|
||||||
|
kv["deepseek2.rope.freq_base"] = cmp.Or(p.RopeTheta, 10000.0)
|
||||||
|
kv["deepseek2.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||||
|
kv["deepseek2.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
|
||||||
|
kv["deepseek2.rope.scaling.type"] = p.RopeScaling.Type
|
||||||
|
kv["deepseek2.rope.scaling.yarn_log_multiplier"] = 0.1 * p.RopeScaling.MScaleAllDim
|
||||||
|
|
||||||
|
kv["tokenizer.ggml.pre"] = "deepseek-v3"
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *deepseek2Model) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"lm_head", "output",
|
||||||
|
"model.embed_tokens", "token_embd",
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
"language_model.", "",
|
||||||
|
"model.layers", "blk",
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
|
||||||
|
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
|
||||||
|
"self_attn.kv_b_proj", "attn_kv_b",
|
||||||
|
"self_attn.q_a_proj", "attn_q_a",
|
||||||
|
"self_attn.q_a_layernorm", "attn_q_a_norm",
|
||||||
|
"self_attn.q_b_proj", "attn_q_b",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"post_attention_layernorm", "ffn_norm",
|
||||||
|
"mlp.shared_experts.down_proj", "ffn_down_shexp",
|
||||||
|
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
|
||||||
|
"mlp.shared_experts.up_proj", "ffn_up_shexp",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
|
||||||
|
"mlp.gate", "ffn_gate_inp",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *deepseek2Model) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||||
|
merges := make([]merge, p.HiddenLayers*3)
|
||||||
|
for i := range p.HiddenLayers {
|
||||||
|
merges[i*3+0] = merge{
|
||||||
|
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||||
|
}
|
||||||
|
merges[i*3+1] = merge{
|
||||||
|
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||||
|
}
|
||||||
|
merges[i*3+2] = merge{
|
||||||
|
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
skipLayer := func(n string, minValue uint32) bool {
|
||||||
|
re := regexp.MustCompile(`^blk\.(\d+)`)
|
||||||
|
matches := re.FindStringSubmatch(n)
|
||||||
|
if matches == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
blkNum, err := strconv.Atoi(matches[1])
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint32(blkNum) >= minValue
|
||||||
|
}
|
||||||
|
|
||||||
|
out, s = mergeTensors(s, merges...)
|
||||||
|
for _, t := range s {
|
||||||
|
// skip any additional layers (such as the Multi-Token Prediction layer)
|
||||||
|
if skipLayer(t.Name(), p.HiddenLayers) {
|
||||||
|
slog.Debug("skipping layer", "name", t.Name())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
136
convert/convert_deepseekocr.go
Normal file
136
convert/convert_deepseekocr.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type deepseekocr struct {
|
||||||
|
ModelParameters
|
||||||
|
LanguageConfig struct {
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
NumRoutedExperts uint32 `json:"n_routed_experts"`
|
||||||
|
NumSharedExperts uint32 `json:"n_shared_experts"`
|
||||||
|
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||||
|
FirstKDenseReplace uint32 `json:"first_k_dense_replace"`
|
||||||
|
} `json:"language_config"`
|
||||||
|
|
||||||
|
VisionConfig struct {
|
||||||
|
ImageSize uint32 `json:"image_size"`
|
||||||
|
Width struct {
|
||||||
|
Vision struct {
|
||||||
|
Heads uint32 `json:"heads"`
|
||||||
|
ImageSize uint32 `json:"image_size"`
|
||||||
|
Layers uint32 `json:"layers"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
Width uint32 `json:"width"`
|
||||||
|
} `json:"clip-l-14-224"`
|
||||||
|
Sam struct {
|
||||||
|
GlobalAttentionIndexes []int32 `json:"global_attn_indexes"`
|
||||||
|
Heads uint32 `json:"heads"`
|
||||||
|
Layers uint32 `json:"layers"`
|
||||||
|
Width uint32 `json:"width"`
|
||||||
|
} `json:"sam_vit_b"`
|
||||||
|
}
|
||||||
|
} `json:"vision_config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := m.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "deepseekocr"
|
||||||
|
kv["block_count"] = m.LanguageConfig.HiddenLayers
|
||||||
|
kv["context_length"] = m.LanguageConfig.MaxPositionEmbeddings
|
||||||
|
kv["embedding_length"] = m.LanguageConfig.HiddenSize
|
||||||
|
kv["feed_forward_length"] = m.LanguageConfig.IntermediateSize
|
||||||
|
kv["attention.head_count"] = m.LanguageConfig.NumAttentionHeads
|
||||||
|
kv["attention.head_count_kv"] = m.LanguageConfig.NumKeyValueHeads
|
||||||
|
kv["expert_count"] = m.LanguageConfig.NumRoutedExperts
|
||||||
|
kv["expert_used_count"] = m.LanguageConfig.NumExpertsPerToken
|
||||||
|
kv["leading_dense_block_count"] = m.LanguageConfig.FirstKDenseReplace
|
||||||
|
|
||||||
|
kv["vision.block_count"] = m.VisionConfig.Width.Vision.Layers
|
||||||
|
kv["vision.embedding_length"] = m.VisionConfig.Width.Vision.Width
|
||||||
|
kv["vision.head_count"] = m.VisionConfig.Width.Vision.Heads
|
||||||
|
kv["vision.image_size"] = m.VisionConfig.Width.Vision.ImageSize
|
||||||
|
kv["vision.patch_size"] = m.VisionConfig.Width.Vision.PatchSize
|
||||||
|
|
||||||
|
kv["sam.block_count"] = m.VisionConfig.Width.Sam.Layers
|
||||||
|
kv["sam.embedding_length"] = m.VisionConfig.Width.Sam.Width
|
||||||
|
kv["sam.head_count"] = m.VisionConfig.Width.Sam.Heads
|
||||||
|
kv["sam.global_attention_indexes"] = m.VisionConfig.Width.Sam.GlobalAttentionIndexes
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *deepseekocr) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||||
|
merges := make([]merge, m.LanguageConfig.HiddenLayers*3)
|
||||||
|
for i := range m.LanguageConfig.HiddenLayers {
|
||||||
|
merges[i*3+0] = merge{
|
||||||
|
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||||
|
}
|
||||||
|
merges[i*3+1] = merge{
|
||||||
|
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||||
|
}
|
||||||
|
merges[i*3+2] = merge{
|
||||||
|
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out, s = mergeTensors(s, merges...)
|
||||||
|
for _, t := range s {
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *deepseekocr) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"model.embed_tokens", "token_embd",
|
||||||
|
"model.layers", "blk",
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"post_attention_layernorm", "ffn_norm",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.gate", "ffn_gate_inp",
|
||||||
|
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
|
||||||
|
"mlp.shared_experts.up_proj", "ffn_up_shexp",
|
||||||
|
"mlp.shared_experts.down_proj", "ffn_down_shexp",
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
"lm_head", "output",
|
||||||
|
|
||||||
|
"model.vision_model", "v",
|
||||||
|
"embeddings.patch_embedding", "patch_embd",
|
||||||
|
"embeddings.class_embedding", "class_embd",
|
||||||
|
"embeddings.position_embedding", "position_embd",
|
||||||
|
"transformer.layers", "blk",
|
||||||
|
|
||||||
|
"model.projector", "mm",
|
||||||
|
"model.image_newline", "mm.image_newline",
|
||||||
|
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
|
||||||
|
"model.view_seperator", "mm.view_seperator",
|
||||||
|
|
||||||
|
"model.sam_model.patch_embed.proj", "s.patch_embd",
|
||||||
|
"model.sam_model.pos_embed", "s.position_embd",
|
||||||
|
"model.sam_model.blocks", "s.blk",
|
||||||
|
"model.sam_model.neck", "s.neck",
|
||||||
|
"model.sam_model.net_", "s.net_",
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package convert
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
@@ -26,16 +27,26 @@ type gemma3Model struct {
|
|||||||
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
||||||
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
||||||
} `json:"vision_config"`
|
} `json:"vision_config"`
|
||||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
HeadDim uint32 `json:"head_dim"`
|
HeadDim uint32 `json:"head_dim"`
|
||||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||||
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||||
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
SlidingWindow uint32 `json:"sliding_window"`
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
SlidingWindowPattern *uint32 `json:"sliding_window_pattern"`
|
||||||
|
LayerTypes []string `json:"layer_types"`
|
||||||
|
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
||||||
|
RopeScaling *struct {
|
||||||
|
Type string `json:"rope_type"`
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||||
|
ExtrapolationFactor float32 `json:"extrapolation_factor"`
|
||||||
|
BetaFast float32 `json:"beta_fast"`
|
||||||
|
BetaSlow float32 `json:"beta_slow"`
|
||||||
|
} `json:"rope_scaling"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -81,9 +92,38 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv["gemma3.attention.key_length"] = p.HeadDim
|
kv["gemma3.attention.key_length"] = p.HeadDim
|
||||||
kv["gemma3.attention.value_length"] = p.HeadDim
|
kv["gemma3.attention.value_length"] = p.HeadDim
|
||||||
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
|
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
|
||||||
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
|
|
||||||
|
// The sliding window pattern is either provided as the sliding_window_pattern
|
||||||
|
// key (an int) or as the layer_types key (a list of strings).
|
||||||
|
if p.SlidingWindowPattern != nil || len(p.LayerTypes) > 0 {
|
||||||
|
kv["gemma3.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||||
|
for i := range numBlocks {
|
||||||
|
var isLocal bool
|
||||||
|
if len(p.LayerTypes) > 0 && int(i) < len(p.LayerTypes) {
|
||||||
|
isLocal = p.LayerTypes[i] == "sliding_attention"
|
||||||
|
} else if p.SlidingWindowPattern != nil && *p.SlidingWindowPattern > 0 {
|
||||||
|
isLocal = (i+1)%*p.SlidingWindowPattern != 0
|
||||||
|
}
|
||||||
|
if !yield(isLocal) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if p.FinalLogitSoftcap > 0 {
|
||||||
|
kv["gemma3.final_logit_softcapping"] = p.FinalLogitSoftcap
|
||||||
|
}
|
||||||
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
|
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
|
||||||
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
|
kv["gemma3.rope.freq_base"] = cmp.Or(p.RopeTheta, 1000000.0)
|
||||||
|
if p.RopeScaling != nil && p.RopeScaling.Type == "yarn" && p.RopeScaling.Factor > 0 {
|
||||||
|
kv["gemma3.rope.scaling.type"] = "yarn"
|
||||||
|
kv["gemma3.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||||
|
kv["gemma3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
|
||||||
|
kv["gemma3.rope.scaling.extrapolation_factor"] = cmp.Or(p.RopeScaling.ExtrapolationFactor, float32(1.0))
|
||||||
|
kv["gemma3.rope.scaling.beta_fast"] = cmp.Or(p.RopeScaling.BetaFast, float32(64.0))
|
||||||
|
kv["gemma3.rope.scaling.beta_slow"] = cmp.Or(p.RopeScaling.BetaSlow, float32(1.0))
|
||||||
|
}
|
||||||
|
|
||||||
kv["gemma3.embedding_length"] = p.HiddenSize
|
kv["gemma3.embedding_length"] = p.HiddenSize
|
||||||
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -29,6 +29,17 @@ type mistral3Model struct {
|
|||||||
SlidingWindow *uint32 `json:"sliding_window"`
|
SlidingWindow *uint32 `json:"sliding_window"`
|
||||||
HiddenAct string `json:"hidden_act"`
|
HiddenAct string `json:"hidden_act"`
|
||||||
VocabSize uint32 `json:"vocab_size"`
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
|
RopeParameters struct {
|
||||||
|
BetaFast float32 `json:"beta_fast"`
|
||||||
|
BetaSlow float32 `json:"beta_slow"`
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
|
||||||
|
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||||
|
RopeType string `json:"rope_type"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
Mscale *float32 `json:"mscale"`
|
||||||
|
MscaleAllDim *float32 `json:"mscale_all_dim"`
|
||||||
|
} `json:"rope_parameters"`
|
||||||
} `json:"text_config"`
|
} `json:"text_config"`
|
||||||
VisionModel struct {
|
VisionModel struct {
|
||||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
@@ -41,6 +52,9 @@ type mistral3Model struct {
|
|||||||
HeadDim uint32 `json:"head_dim"`
|
HeadDim uint32 `json:"head_dim"`
|
||||||
HiddenAct string `json:"hidden_act"`
|
HiddenAct string `json:"hidden_act"`
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RopeParameters struct {
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
} `json:"rope_parameters"`
|
||||||
} `json:"vision_config"`
|
} `json:"vision_config"`
|
||||||
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
|
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
|
||||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||||
@@ -61,8 +75,25 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
||||||
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
||||||
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||||
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
|
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
|
||||||
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
|
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
|
||||||
|
kv["mistral3.rope.scaling.factor"] = p.TextModel.RopeParameters.Factor
|
||||||
|
kv["mistral3.rope.scaling.type"] = p.TextModel.RopeParameters.RopeType
|
||||||
|
kv["mistral3.rope.scaling.beta_fast"] = p.TextModel.RopeParameters.BetaFast
|
||||||
|
kv["mistral3.rope.scaling.beta_slow"] = p.TextModel.RopeParameters.BetaSlow
|
||||||
|
|
||||||
|
if p.TextModel.RopeParameters.Mscale != nil {
|
||||||
|
kv["mistral3.rope.scaling.mscale"] = *p.TextModel.RopeParameters.Mscale
|
||||||
|
}
|
||||||
|
if p.TextModel.RopeParameters.MscaleAllDim != nil {
|
||||||
|
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.TextModel.RopeParameters.MscaleAllDim
|
||||||
|
}
|
||||||
|
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
||||||
|
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
|
||||||
|
}
|
||||||
|
if p.TextModel.RopeParameters.Llama4ScalingBeta != nil {
|
||||||
|
kv["mistral3.rope.scaling_beta"] = *p.TextModel.RopeParameters.Llama4ScalingBeta
|
||||||
|
}
|
||||||
|
|
||||||
// Vision configuration
|
// Vision configuration
|
||||||
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||||
@@ -74,7 +105,7 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||||
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
||||||
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
|
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
|
||||||
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
kv["mistral3.vision.rope.freq_base"] = cmp.Or(p.VisionModel.RopeTheta, p.VisionModel.RopeParameters.RopeTheta)
|
||||||
|
|
||||||
// Multimodal configuration
|
// Multimodal configuration
|
||||||
kv["mistral3.image_token_index"] = p.ImageTokenIndex
|
kv["mistral3.image_token_index"] = p.ImageTokenIndex
|
||||||
|
|||||||
181
convert/convert_mistral_causal.go
Normal file
181
convert/convert_mistral_causal.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pdevine/tensor"
|
||||||
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mistral3CausalModel struct {
|
||||||
|
ModelParameters
|
||||||
|
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
SlidingWindow *uint32 `json:"sliding_window"`
|
||||||
|
HiddenAct string `json:"hidden_act"`
|
||||||
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
|
RopeParameters struct {
|
||||||
|
BetaFast float32 `json:"beta_fast"`
|
||||||
|
BetaSlow float32 `json:"beta_slow"`
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
|
||||||
|
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||||
|
RopeType string `json:"rope_type"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
Mscale *float32 `json:"mscale"`
|
||||||
|
MscaleAllDim *float32 `json:"mscale_all_dim"`
|
||||||
|
} `json:"rope_parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "mistral3"
|
||||||
|
kv["mistral3.vocab_size"] = p.VocabSize
|
||||||
|
|
||||||
|
// Text configuration
|
||||||
|
kv["mistral3.block_count"] = p.NumHiddenLayers
|
||||||
|
kv["mistral3.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
kv["mistral3.embedding_length"] = p.HiddenSize
|
||||||
|
kv["mistral3.feed_forward_length"] = p.IntermediateSize
|
||||||
|
kv["mistral3.attention.head_count"] = p.NumAttentionHeads
|
||||||
|
kv["mistral3.attention.head_count_kv"] = p.NumKeyValueHeads
|
||||||
|
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||||
|
kv["mistral3.attention.key_length"] = p.HeadDim
|
||||||
|
kv["mistral3.attention.value_length"] = p.HeadDim
|
||||||
|
kv["mistral3.rope.dimension_count"] = cmp.Or(p.HeadDim, p.HiddenSize/p.NumAttentionHeads)
|
||||||
|
kv["mistral3.rope.freq_base"] = cmp.Or(p.RopeTheta, p.RopeParameters.RopeTheta)
|
||||||
|
kv["mistral3.rope.scaling.factor"] = p.RopeParameters.Factor
|
||||||
|
kv["mistral3.rope.scaling.type"] = p.RopeParameters.RopeType
|
||||||
|
kv["mistral3.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast
|
||||||
|
kv["mistral3.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow
|
||||||
|
|
||||||
|
if p.RopeParameters.Mscale != nil {
|
||||||
|
kv["mistral3.rope.scaling.mscale"] = *p.RopeParameters.Mscale
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RopeParameters.MscaleAllDim != nil {
|
||||||
|
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.RopeParameters.MscaleAllDim
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
||||||
|
kv["mistral3.rope.scaling.original_context_length"] = p.RopeParameters.OrigMaxPositionEmbeddings
|
||||||
|
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RopeParameters.Llama4ScalingBeta != nil {
|
||||||
|
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3CausalModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
var out []*ggml.Tensor
|
||||||
|
|
||||||
|
for _, t := range ts {
|
||||||
|
if !strings.HasPrefix(t.Name(), "v.") {
|
||||||
|
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
|
||||||
|
strings.HasSuffix(t.Name(), ".attn_k.weight") {
|
||||||
|
t.SetRepacker(p.repack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3CausalModel) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
"model.", "",
|
||||||
|
"layers", "blk",
|
||||||
|
"transformer.layers", "blk",
|
||||||
|
"vision_tower", "v",
|
||||||
|
"ln_pre", "encoder_norm",
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"post_attention_layernorm", "ffn_norm",
|
||||||
|
"embed_tokens", "token_embd",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
"attention.q_proj", "attn_q",
|
||||||
|
"attention.k_proj", "attn_k",
|
||||||
|
"attention.v_proj", "attn_v",
|
||||||
|
"attention.o_proj", "attn_output",
|
||||||
|
"attention_norm", "attn_norm",
|
||||||
|
"feed_forward.gate_proj", "ffn_gate",
|
||||||
|
"feed_forward.down_proj", "ffn_down",
|
||||||
|
"feed_forward.up_proj", "ffn_up",
|
||||||
|
"multi_modal_projector", "mm",
|
||||||
|
"ffn_norm", "ffn_norm",
|
||||||
|
"lm_head", "output",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3CausalModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
var dims []int
|
||||||
|
for _, dim := range shape {
|
||||||
|
dims = append(dims, int(dim))
|
||||||
|
}
|
||||||
|
|
||||||
|
var heads uint32
|
||||||
|
if strings.HasSuffix(name, ".attn_q.weight") {
|
||||||
|
heads = p.NumAttentionHeads
|
||||||
|
} else if strings.HasSuffix(name, ".attn_k.weight") {
|
||||||
|
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
|
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.T(0, 2, 1, 3); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.Reshape(dims...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.Transpose(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, err := native.SelectF32(n, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var f32s []float32
|
||||||
|
for _, t := range ts {
|
||||||
|
f32s = append(f32s, t...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f32s, nil
|
||||||
|
}
|
||||||
213
convert/convert_nomicbert.go
Normal file
213
convert/convert_nomicbert.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"encoding/json"
|
||||||
|
"io/fs"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type nomicbertModel struct {
|
||||||
|
ModelParameters
|
||||||
|
NLayers uint32 `json:"n_layers"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
LayerNormEPS float32 `json:"layer_norm_eps"`
|
||||||
|
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||||
|
RopeFreqBase float32 `json:"rope_theta"`
|
||||||
|
normalizeEmbeddings bool
|
||||||
|
PoolingType uint32
|
||||||
|
|
||||||
|
// MoE parameters (only present in v2 models)
|
||||||
|
NumExperts uint32 `json:"num_local_experts"`
|
||||||
|
NumExpertsUsed uint32 `json:"num_experts_per_tok"`
|
||||||
|
MoEEveryNLayers uint32 `json:"moe_every_n_layers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ ModelConverter = (*nomicbertModel)(nil)
|
||||||
|
_ moreParser = (*nomicbertModel)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *nomicbertModel) parseMore(fsys fs.FS) error {
|
||||||
|
bts, err := fs.ReadFile(fsys, "modules.json")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var modules []struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(bts, &modules); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var pooling string
|
||||||
|
for _, m := range modules {
|
||||||
|
switch m.Type {
|
||||||
|
case "sentence_transformers.models.Pooling":
|
||||||
|
pooling = m.Path
|
||||||
|
case "sentence_transformers.models.Normalize":
|
||||||
|
p.normalizeEmbeddings = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pooling != "" {
|
||||||
|
bts, err := fs.ReadFile(fsys, filepath.Join(pooling, "config.json"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var pc struct {
|
||||||
|
PoolingModeCLSToken bool `json:"pooling_mode_cls_token"`
|
||||||
|
PoolingModeMeanTokens bool `json:"pooling_mode_mean_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(bts, &pc); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if pc.PoolingModeMeanTokens {
|
||||||
|
p.PoolingType = 1
|
||||||
|
} else if pc.PoolingModeCLSToken {
|
||||||
|
p.PoolingType = 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
|
||||||
|
// Determine architecture based on MoE parameters (following qwen3 pattern)
|
||||||
|
arch := "nomic-bert"
|
||||||
|
if p.MoEEveryNLayers > 0 {
|
||||||
|
arch += "-moe"
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["general.architecture"] = arch
|
||||||
|
kv["attention.causal"] = false
|
||||||
|
kv["pooling_type"] = p.PoolingType
|
||||||
|
kv["normalize_embeddings"] = p.normalizeEmbeddings
|
||||||
|
|
||||||
|
kv["block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers)
|
||||||
|
|
||||||
|
if contextLength := p.MaxPositionEmbeddings; contextLength > 0 {
|
||||||
|
kv["context_length"] = contextLength
|
||||||
|
}
|
||||||
|
|
||||||
|
if embeddingLength := p.HiddenSize; embeddingLength > 0 {
|
||||||
|
kv["embedding_length"] = p.HiddenSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if feedForwardLength := p.IntermediateSize; feedForwardLength > 0 {
|
||||||
|
kv["feed_forward_length"] = p.IntermediateSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if headCount := p.NumAttentionHeads; headCount > 0 {
|
||||||
|
kv["attention.head_count"] = p.NumAttentionHeads
|
||||||
|
}
|
||||||
|
|
||||||
|
if kvHeadCount := p.NumKeyValueHeads; kvHeadCount > 0 {
|
||||||
|
kv["attention.head_count_kv"] = p.NumKeyValueHeads
|
||||||
|
}
|
||||||
|
|
||||||
|
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon); layerNormEpsilon > 0 {
|
||||||
|
kv["attention.layer_norm_epsilon"] = layerNormEpsilon
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RopeFreqBase > 0 {
|
||||||
|
kv["rope.freq_base"] = p.RopeFreqBase
|
||||||
|
}
|
||||||
|
|
||||||
|
// MoE specific parameters (only if MoE is enabled)
|
||||||
|
if p.NumExperts > 0 {
|
||||||
|
kv["expert_count"] = p.NumExperts
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.NumExpertsUsed > 0 {
|
||||||
|
kv["expert_used_count"] = p.NumExpertsUsed
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.MoEEveryNLayers > 0 {
|
||||||
|
kv["moe_every_n_layers"] = p.MoEEveryNLayers
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["tokenizer.ggml.model"] = "bert"
|
||||||
|
kv["tokenizer.ggml.token_type_count"] = uint32(2)
|
||||||
|
|
||||||
|
// convert to phantom space tokens
|
||||||
|
for i, e := range t.Tokens {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(e, "[") && strings.HasSuffix(e, "]"):
|
||||||
|
// noop - keep special tokens as-is
|
||||||
|
case strings.HasPrefix(e, "##"):
|
||||||
|
t.Tokens[i] = e[2:]
|
||||||
|
default:
|
||||||
|
t.Tokens[i] = "\u2581" + e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["tokenizer.ggml.tokens"] = t.Tokens
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *nomicbertModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
out := make([]*ggml.Tensor, 0, len(ts))
|
||||||
|
for _, t := range ts {
|
||||||
|
if slices.Contains([]string{
|
||||||
|
"embeddings.position_ids",
|
||||||
|
"pooler.dense.weight",
|
||||||
|
"pooler.dense.bias",
|
||||||
|
}, t.Name()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nomicbertModel) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"encoder.layer", "blk",
|
||||||
|
"encoder.layers", "blk",
|
||||||
|
"embeddings.word_embeddings", "token_embd",
|
||||||
|
"embeddings.token_type_embeddings", "token_types",
|
||||||
|
"embeddings.LayerNorm", "token_embd_norm",
|
||||||
|
|
||||||
|
"attention.self.qkv", "attn_qkv",
|
||||||
|
|
||||||
|
"attention.output.dense", "attn_output",
|
||||||
|
"attention.output.LayerNorm", "attn_output_norm",
|
||||||
|
|
||||||
|
"mlp.up", "ffn_up",
|
||||||
|
"mlp.down", "ffn_down",
|
||||||
|
|
||||||
|
"mlp.router", "ffn_gate_inp",
|
||||||
|
"mlp.experts.up", "ffn_up_exps",
|
||||||
|
"mlp.experts.down", "ffn_down_exps",
|
||||||
|
|
||||||
|
"intermediate.dense", "ffn_up",
|
||||||
|
"output.dense", "ffn_down",
|
||||||
|
"output.LayerNorm", "layer_output_norm",
|
||||||
|
}
|
||||||
|
}
|
||||||
117
convert/convert_olmo.go
Normal file
117
convert/convert_olmo.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ropeScaling struct {
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"`
|
||||||
|
AttentionFactor float32 `json:"attention_factor"`
|
||||||
|
BetaFast float32 `json:"beta_fast"`
|
||||||
|
BetaSlow float32 `json:"beta_slow"`
|
||||||
|
RopeType string `json:"rope_type"`
|
||||||
|
ExtrapolationFactor float32 `json:"extrapolation_factor"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type olmoModel struct {
|
||||||
|
ModelParameters
|
||||||
|
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RopeScaling *ropeScaling `json:"rope_scaling"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
|
LayerTypes []string `json:"layer_types"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ModelConverter = (*olmoModel)(nil)
|
||||||
|
|
||||||
|
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "olmo3"
|
||||||
|
kv["olmo3.block_count"] = p.NumHiddenLayers
|
||||||
|
kv["olmo3.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
kv["olmo3.embedding_length"] = p.HiddenSize
|
||||||
|
kv["olmo3.feed_forward_length"] = p.IntermediateSize
|
||||||
|
kv["olmo3.attention.head_count"] = p.NumAttentionHeads
|
||||||
|
kv["olmo3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||||
|
|
||||||
|
if p.RopeTheta > 0 {
|
||||||
|
kv["olmo3.rope.freq_base"] = p.RopeTheta
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RopeScaling != nil {
|
||||||
|
if p.RopeScaling.Factor > 0 {
|
||||||
|
kv["olmo3.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||||
|
}
|
||||||
|
if p.RopeScaling.OriginalMaxPositionEmbeds > 0 {
|
||||||
|
kv["olmo3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds
|
||||||
|
}
|
||||||
|
if p.RopeScaling.AttentionFactor > 0 {
|
||||||
|
kv["olmo3.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor
|
||||||
|
}
|
||||||
|
if p.RopeScaling.RopeType != "" {
|
||||||
|
kv["olmo3.rope.scaling.type"] = p.RopeScaling.RopeType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RMSNormEPS > 0 {
|
||||||
|
kv["olmo3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.SlidingWindow > 0 {
|
||||||
|
kv["olmo3.attention.sliding_window"] = p.SlidingWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p.LayerTypes) > 0 {
|
||||||
|
slidingPattern := make([]bool, len(p.LayerTypes))
|
||||||
|
for i, layerType := range p.LayerTypes {
|
||||||
|
slidingPattern[i] = (layerType == "sliding_attention")
|
||||||
|
}
|
||||||
|
kv["olmo3.attention.sliding_window_pattern"] = slidingPattern
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
out := make([]*ggml.Tensor, 0, len(ts))
|
||||||
|
for _, t := range ts {
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *olmoModel) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"lm_head", "output",
|
||||||
|
"model.embed_tokens", "token_embd",
|
||||||
|
"model.layers", "blk",
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"self_attn.q_norm", "attn_q_norm",
|
||||||
|
"self_attn.k_norm", "attn_k_norm",
|
||||||
|
"post_attention_layernorm", "post_attention_norm",
|
||||||
|
"post_feedforward_layernorm", "post_ffw_norm",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -44,7 +44,10 @@ func (t tensorBase) Kind() uint32 {
|
|||||||
t.name == "v.positional_embedding_vlm" ||
|
t.name == "v.positional_embedding_vlm" ||
|
||||||
t.name == "v.tile_position_embd.weight" ||
|
t.name == "v.tile_position_embd.weight" ||
|
||||||
t.name == "v.pre_tile_position_embd.weight" ||
|
t.name == "v.pre_tile_position_embd.weight" ||
|
||||||
t.name == "v.post_tile_position_embd.weight" {
|
t.name == "v.post_tile_position_embd.weight" ||
|
||||||
|
t.name == "s.position_embd" ||
|
||||||
|
strings.HasSuffix(t.name, "rel_pos_h") ||
|
||||||
|
strings.HasSuffix(t.name, "rel_pos_w") {
|
||||||
// these tensors are always F32
|
// these tensors are always F32
|
||||||
return tensorKindFP32
|
return tensorKindFP32
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -96,7 +96,10 @@ type safetensor struct {
|
|||||||
|
|
||||||
func (st safetensor) Kind() uint32 {
|
func (st safetensor) Kind() uint32 {
|
||||||
kind := st.tensorBase.Kind()
|
kind := st.tensorBase.Kind()
|
||||||
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
|
if st.dtype == "BF16" &&
|
||||||
|
!strings.HasPrefix(st.name, "v.") &&
|
||||||
|
!strings.HasPrefix(st.name, "s.") &&
|
||||||
|
kind != tensorKindFP32 {
|
||||||
kind = tensorKindBF16
|
kind = tensorKindBF16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,8 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||||
|
|
||||||
// temporary fix to handle gemma3 broken configs
|
// temporary fix to handle gemma3 broken configs
|
||||||
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
|
// TODO(parthsareen): allow reading of tokenizer.json to allow managing special tokens when using spm
|
||||||
|
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>", "<start_function_declaration>", "<end_function_declaration>", "<start_function_call>", "<end_function_call>", "<start_function_response>", "<end_function_response>", "<escape>"}, piece.GetPiece()) {
|
||||||
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package discover
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -10,12 +11,21 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetCPUMem() (memInfo, error) {
|
func GetCPUMem() (memInfo, error) {
|
||||||
|
mem, err := getCPUMem()
|
||||||
|
if err != nil {
|
||||||
|
return memInfo{}, err
|
||||||
|
}
|
||||||
|
return getCPUMemByCgroups(mem), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCPUMem() (memInfo, error) {
|
||||||
var mem memInfo
|
var mem memInfo
|
||||||
var total, available, free, buffers, cached, freeSwap uint64
|
var total, available, free, buffers, cached, freeSwap uint64
|
||||||
f, err := os.Open("/proc/meminfo")
|
f, err := os.Open("/proc/meminfo")
|
||||||
@@ -56,6 +66,32 @@ func GetCPUMem() (memInfo, error) {
|
|||||||
return mem, nil
|
return mem, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getCPUMemByCgroups(mem memInfo) memInfo {
|
||||||
|
total, err := getUint64ValueFromFile("/sys/fs/cgroup/memory.max")
|
||||||
|
if err == nil {
|
||||||
|
mem.TotalMemory = total
|
||||||
|
}
|
||||||
|
used, err := getUint64ValueFromFile("/sys/fs/cgroup/memory.current")
|
||||||
|
if err == nil {
|
||||||
|
mem.FreeMemory = mem.TotalMemory - used
|
||||||
|
}
|
||||||
|
return mem
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUint64ValueFromFile(path string) (uint64, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
s := bufio.NewScanner(f)
|
||||||
|
for s.Scan() {
|
||||||
|
line := s.Text()
|
||||||
|
return strconv.ParseUint(line, 10, 64)
|
||||||
|
}
|
||||||
|
return 0, errors.New("empty file content")
|
||||||
|
}
|
||||||
|
|
||||||
const CpuInfoFilename = "/proc/cpuinfo"
|
const CpuInfoFilename = "/proc/cpuinfo"
|
||||||
|
|
||||||
type linuxCpuInfo struct {
|
type linuxCpuInfo struct {
|
||||||
@@ -74,7 +110,41 @@ func GetCPUDetails() []CPU {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
return linuxCPUDetails(file)
|
cpus := linuxCPUDetails(file)
|
||||||
|
return overwriteThreadCountByLinuxCgroups(cpus)
|
||||||
|
}
|
||||||
|
|
||||||
|
func overwriteThreadCountByLinuxCgroups(cpus []CPU) []CPU {
|
||||||
|
file, err := os.Open("/sys/fs/cgroup/cpu.max")
|
||||||
|
if err != nil {
|
||||||
|
return cpus
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(file)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if sl := strings.Split(line, " "); len(sl) == 2 {
|
||||||
|
allowdUs, err := strconv.ParseInt(sl[0], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to parse CPU allowed micro secs", "error", err)
|
||||||
|
return cpus
|
||||||
|
}
|
||||||
|
unitUs, err := strconv.ParseInt(sl[1], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to parse CPU unit micro secs", "error", err)
|
||||||
|
return cpus
|
||||||
|
}
|
||||||
|
|
||||||
|
threads := int(max(allowdUs/unitUs, 1))
|
||||||
|
|
||||||
|
cpu := cpus[0]
|
||||||
|
cpu.CoreCount = threads
|
||||||
|
cpu.ThreadCount = threads
|
||||||
|
return []CPU{cpu}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cpus
|
||||||
}
|
}
|
||||||
|
|
||||||
func linuxCPUDetails(file io.Reader) []CPU {
|
func linuxCPUDetails(file io.Reader) []CPU {
|
||||||
|
|||||||
@@ -65,6 +65,11 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
|||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("discovering available GPUs...")
|
slog.Info("discovering available GPUs...")
|
||||||
|
detectIncompatibleLibraries()
|
||||||
|
|
||||||
|
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
|
||||||
|
overrideWarnings()
|
||||||
|
|
||||||
requested := envconfig.LLMLibrary()
|
requested := envconfig.LLMLibrary()
|
||||||
jetpack := cudaJetpack()
|
jetpack := cudaJetpack()
|
||||||
|
|
||||||
@@ -90,10 +95,13 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
|||||||
var dirs []string
|
var dirs []string
|
||||||
if dir != "" {
|
if dir != "" {
|
||||||
if requested != "" && filepath.Base(dir) != requested {
|
if requested != "" && filepath.Base(dir) != requested {
|
||||||
slog.Debug("skipping available library at users request", "requested", requested, "libDir", dir)
|
slog.Debug("skipping available library at user's request", "requested", requested, "libDir", dir)
|
||||||
continue
|
continue
|
||||||
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
|
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
|
||||||
continue
|
continue
|
||||||
|
} else if jetpack == "" && strings.Contains(filepath.Base(dir), "cuda_jetpack") {
|
||||||
|
slog.Debug("jetpack not detected (set JETSON_JETPACK or OLLAMA_LLM_LIBRARY to override), skipping", "libDir", dir)
|
||||||
|
continue
|
||||||
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
|
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
|
||||||
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
|
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
|
||||||
continue
|
continue
|
||||||
@@ -113,7 +121,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
|||||||
// In the second pass, we more deeply initialize the GPUs to weed out devices that
|
// In the second pass, we more deeply initialize the GPUs to weed out devices that
|
||||||
// aren't supported by a given library. We run this phase in parallel to speed up discovery.
|
// aren't supported by a given library. We run this phase in parallel to speed up discovery.
|
||||||
// Only devices that need verification are included in this pass
|
// Only devices that need verification are included in this pass
|
||||||
slog.Debug("evluating which if any devices to filter out", "initial_count", len(devices))
|
slog.Debug("evaluating which, if any, devices to filter out", "initial_count", len(devices))
|
||||||
ctx2ndPass, cancel := context.WithTimeout(ctx, 30*time.Second)
|
ctx2ndPass, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
@@ -121,15 +129,25 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
|||||||
supportedMu := sync.Mutex{}
|
supportedMu := sync.Mutex{}
|
||||||
supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index
|
supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index
|
||||||
for i := range devices {
|
for i := range devices {
|
||||||
|
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
|
||||||
if !devices[i].NeedsInitValidation() {
|
if !devices[i].NeedsInitValidation() {
|
||||||
|
// No need to validate, add to the supported map
|
||||||
|
supportedMu.Lock()
|
||||||
|
if _, ok := supported[devices[i].Library]; !ok {
|
||||||
|
supported[devices[i].Library] = make(map[string]map[string]int)
|
||||||
|
}
|
||||||
|
if _, ok := supported[devices[i].Library][libDir]; !ok {
|
||||||
|
supported[devices[i].Library][libDir] = make(map[string]int)
|
||||||
|
}
|
||||||
|
supported[devices[i].Library][libDir][devices[i].ID] = i
|
||||||
|
supportedMu.Unlock()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
|
slog.Debug("verifying if device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
|
||||||
slog.Debug("verifying device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
extraEnvs := ml.GetVisibleDevicesEnv(devices[i : i+1])
|
extraEnvs := ml.GetVisibleDevicesEnv(devices[i:i+1], true)
|
||||||
devices[i].AddInitValidation(extraEnvs)
|
devices[i].AddInitValidation(extraEnvs)
|
||||||
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
|
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
|
||||||
slog.Debug("filtering device which didn't fully initialize",
|
slog.Debug("filtering device which didn't fully initialize",
|
||||||
@@ -315,7 +333,8 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Apply any dev filters to avoid re-discovering unsupported devices, and get IDs correct
|
// Apply any dev filters to avoid re-discovering unsupported devices, and get IDs correct
|
||||||
devFilter := ml.GetVisibleDevicesEnv(devices)
|
// We avoid CUDA filters here to keep ROCm from failing to discover GPUs in a mixed environment
|
||||||
|
devFilter := ml.GetVisibleDevicesEnv(devices, false)
|
||||||
|
|
||||||
for dir := range libDirs {
|
for dir := range libDirs {
|
||||||
updatedDevices := bootstrapDevices(ctx, []string{ml.LibOllamaPath, dir}, devFilter)
|
updatedDevices := bootstrapDevices(ctx, []string{ml.LibOllamaPath, dir}, devFilter)
|
||||||
@@ -449,3 +468,37 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map
|
|||||||
|
|
||||||
return devices
|
return devices
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func overrideWarnings() {
|
||||||
|
anyFound := false
|
||||||
|
m := envconfig.AsMap()
|
||||||
|
for _, k := range []string{
|
||||||
|
"CUDA_VISIBLE_DEVICES",
|
||||||
|
"HIP_VISIBLE_DEVICES",
|
||||||
|
"ROCR_VISIBLE_DEVICES",
|
||||||
|
"GGML_VK_VISIBLE_DEVICES",
|
||||||
|
"GPU_DEVICE_ORDINAL",
|
||||||
|
"HSA_OVERRIDE_GFX_VERSION",
|
||||||
|
} {
|
||||||
|
if e, found := m[k]; found && e.Value != "" {
|
||||||
|
anyFound = true
|
||||||
|
slog.Warn("user overrode visible devices", k, e.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if anyFound {
|
||||||
|
slog.Warn("if GPUs are not correctly discovered, unset and try again")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func detectIncompatibleLibraries() {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
basePath, err := exec.LookPath("ggml-base.dll")
|
||||||
|
if err != nil || basePath == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(basePath, ml.LibOllamaPath) {
|
||||||
|
slog.Warn("potentially incompatible library detected in PATH", "location", basePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
10
docs/api.md
10
docs/api.md
@@ -50,7 +50,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
|
|||||||
Advanced parameters (optional):
|
Advanced parameters (optional):
|
||||||
|
|
||||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema
|
- `format`: the format to return a response in. Format can be `json` or a JSON schema
|
||||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||||
- `system`: system message to (overrides what is defined in the `Modelfile`)
|
- `system`: system message to (overrides what is defined in the `Modelfile`)
|
||||||
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
||||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
@@ -507,7 +507,7 @@ The `message` object has the following fields:
|
|||||||
Advanced parameters (optional):
|
Advanced parameters (optional):
|
||||||
|
|
||||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
||||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||||
|
|
||||||
@@ -1189,7 +1189,7 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
|
|||||||
- `template`: (optional) the prompt template for the model
|
- `template`: (optional) the prompt template for the model
|
||||||
- `license`: (optional) a string or list of strings containing the license or licenses for the model
|
- `license`: (optional) a string or list of strings containing the license or licenses for the model
|
||||||
- `system`: (optional) a string containing the system prompt for the model
|
- `system`: (optional) a string containing the system prompt for the model
|
||||||
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.md#valid-parameters-and-values) for a list of parameters)
|
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.mdx#valid-parameters-and-values) for a list of parameters)
|
||||||
- `messages`: (optional) a list of message objects used to create a conversation
|
- `messages`: (optional) a list of message objects used to create a conversation
|
||||||
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
|
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
- `quantize` (optional): quantize a non-quantized (e.g. float16) model
|
- `quantize` (optional): quantize a non-quantized (e.g. float16) model
|
||||||
@@ -1698,7 +1698,7 @@ Generate embeddings from a model
|
|||||||
Advanced parameters:
|
Advanced parameters:
|
||||||
|
|
||||||
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
|
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
|
||||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||||
- `dimensions`: number of dimensions for the embedding
|
- `dimensions`: number of dimensions for the embedding
|
||||||
|
|
||||||
@@ -1817,7 +1817,7 @@ Generate embeddings from a model
|
|||||||
|
|
||||||
Advanced parameters:
|
Advanced parameters:
|
||||||
|
|
||||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
|
||||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -15,7 +15,7 @@ Also known as "single-shot" tool calling.
|
|||||||
```shell
|
```shell
|
||||||
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
|
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
|
||||||
"model": "qwen3",
|
"model": "qwen3",
|
||||||
"messages": [{"role": "user", "content": "What's the temperature in New York?"}],
|
"messages": [{"role": "user", "content": "What is the temperature in New York?"}],
|
||||||
"stream": false,
|
"stream": false,
|
||||||
"tools": [
|
"tools": [
|
||||||
{
|
{
|
||||||
@@ -41,7 +41,7 @@ Also known as "single-shot" tool calling.
|
|||||||
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
|
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
|
||||||
"model": "qwen3",
|
"model": "qwen3",
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "What's the temperature in New York?"},
|
{"role": "user", "content": "What is the temperature in New York?"},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
@@ -90,7 +90,7 @@ Also known as "single-shot" tool calling.
|
|||||||
}
|
}
|
||||||
return temperatures.get(city, "Unknown")
|
return temperatures.get(city, "Unknown")
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "What's the temperature in New York?"}]
|
messages = [{"role": "user", "content": "What is the temperature in New York?"}]
|
||||||
|
|
||||||
# pass functions directly as tools in the tools list or as a JSON schema
|
# pass functions directly as tools in the tools list or as a JSON schema
|
||||||
response = chat(model="qwen3", messages=messages, tools=[get_temperature], think=True)
|
response = chat(model="qwen3", messages=messages, tools=[get_temperature], think=True)
|
||||||
@@ -146,7 +146,7 @@ Also known as "single-shot" tool calling.
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const messages = [{ role: 'user', content: "What's the temperature in New York?" }]
|
const messages = [{ role: 'user', content: "What is the temperature in New York?" }]
|
||||||
|
|
||||||
const response = await ollama.chat({
|
const response = await ollama.chat({
|
||||||
model: 'qwen3',
|
model: 'qwen3',
|
||||||
@@ -609,7 +609,7 @@ def get_temperature(city: str) -> str:
|
|||||||
return temperatures.get(city, 'Unknown')
|
return temperatures.get(city, 'Unknown')
|
||||||
|
|
||||||
|
|
||||||
messages = [{'role': 'user', 'content': "What's the temperature in New York?"}]
|
messages = [{'role': 'user', 'content': "What is the temperature in New York?"}]
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
stream = chat(
|
stream = chat(
|
||||||
@@ -684,7 +684,7 @@ const getTemperatureTool = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function agentLoop() {
|
async function agentLoop() {
|
||||||
const messages = [{ role: 'user', content: "What's the temperature in New York?" }]
|
const messages = [{ role: 'user', content: "What is the temperature in New York?" }]
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
const stream = await ollama.chat({
|
const stream = await ollama.chat({
|
||||||
|
|||||||
@@ -9,15 +9,9 @@ sidebarTitle: Cloud
|
|||||||
|
|
||||||
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.
|
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.
|
||||||
|
|
||||||
Ollama currently supports the following cloud models, with more coming soon:
|
### Supported models
|
||||||
|
|
||||||
- `deepseek-v3.1:671b-cloud`
|
For a list of supported models, see Ollama's [model library](https://ollama.com/search?c=cloud).
|
||||||
- `gpt-oss:20b-cloud`
|
|
||||||
- `gpt-oss:120b-cloud`
|
|
||||||
- `kimi-k2:1t-cloud`
|
|
||||||
- `qwen3-coder:480b-cloud`
|
|
||||||
- `glm-4.6:cloud`
|
|
||||||
- `minimax-m2:cloud`
|
|
||||||
|
|
||||||
### Running Cloud models
|
### Running Cloud models
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ Install prerequisites:
|
|||||||
- [Ninja](https://github.com/ninja-build/ninja/releases)
|
- [Ninja](https://github.com/ninja-build/ninja/releases)
|
||||||
- (Optional) NVIDIA GPU support
|
- (Optional) NVIDIA GPU support
|
||||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
||||||
|
- (Optional) VULKAN GPU support
|
||||||
|
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||||
|
|
||||||
Then, configure and build the project:
|
Then, configure and build the project:
|
||||||
|
|
||||||
@@ -57,6 +59,17 @@ cmake -B build
|
|||||||
cmake --build build --config Release
|
cmake --build build --config Release
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> Building for Vulkan requires VULKAN_SDK environment variable:
|
||||||
|
>
|
||||||
|
> PowerShell
|
||||||
|
> ```powershell
|
||||||
|
> $env:VULKAN_SDK="C:\VulkanSDK\<version>"
|
||||||
|
> ```
|
||||||
|
> CMD
|
||||||
|
> ```cmd
|
||||||
|
> set VULKAN_SDK=C:\VulkanSDK\<version>
|
||||||
|
> ```
|
||||||
|
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> Building for ROCm requires additional flags:
|
> Building for ROCm requires additional flags:
|
||||||
> ```
|
> ```
|
||||||
@@ -65,6 +78,7 @@ cmake --build build --config Release
|
|||||||
> ```
|
> ```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Lastly, run Ollama:
|
Lastly, run Ollama:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@@ -84,7 +98,9 @@ Install prerequisites:
|
|||||||
- [ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html)
|
- [ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html)
|
||||||
- (Optional) NVIDIA GPU support
|
- (Optional) NVIDIA GPU support
|
||||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads)
|
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads)
|
||||||
|
- (Optional) VULKAN GPU support
|
||||||
|
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
|
||||||
|
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> Ensure prerequisites are in `PATH` before running CMake.
|
> Ensure prerequisites are in `PATH` before running CMake.
|
||||||
|
|
||||||
|
|||||||
@@ -57,8 +57,13 @@ ollama ps
|
|||||||
```
|
```
|
||||||
|
|
||||||
<Info>
|
<Info>
|
||||||
**Output**: ``` NAME ID SIZE PROCESSOR UNTIL llama3:70b bcfb190ca3a7 42 GB
|
|
||||||
100% GPU 4 minutes from now ```
|
**Output**:
|
||||||
|
|
||||||
|
```
|
||||||
|
NAME ID SIZE PROCESSOR UNTIL
|
||||||
|
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
||||||
|
```
|
||||||
</Info>
|
</Info>
|
||||||
|
|
||||||
The `Processor` column will show which memory the model was loaded in to:
|
The `Processor` column will show which memory the model was loaded in to:
|
||||||
|
|||||||
@@ -9,26 +9,26 @@ Install [VS Code](https://code.visualstudio.com/download).
|
|||||||
## Usage with Ollama
|
## Usage with Ollama
|
||||||
|
|
||||||
1. Open Copilot side bar found in top right window
|
1. Open Copilot side bar found in top right window
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||||
<img
|
<img
|
||||||
src="/images/vscode-sidebar.png"
|
src="/images/vscode-sidebar.png"
|
||||||
alt="VS Code chat Sidebar"
|
alt="VS Code chat Sidebar"
|
||||||
width="75%"
|
width="75%"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
2. Select the model drowpdown > **Manage models**
|
2. Select the model dropdown > **Manage models**
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||||
<img
|
<img
|
||||||
src="/images/vscode-models.png"
|
src="/images/vscode-models.png"
|
||||||
alt="VS Code model picker"
|
alt="VS Code model picker"
|
||||||
width="75%"
|
width="75%"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
3. Enter **Ollama** under **Provider Dropdown** and select desired models (e.g `qwen3, qwen3-coder:480b-cloud`)
|
3. Enter **Ollama** under **Provider Dropdown** and select desired models (e.g `qwen3, qwen3-coder:480b-cloud`)
|
||||||
<div style={{ display: 'flex', justifyContent: 'center' }}>
|
<div style={{ display: "flex", justifyContent: "center" }}>
|
||||||
<img
|
<img
|
||||||
src="/images/vscode-model-options.png"
|
src="/images/vscode-model-options.png"
|
||||||
alt="VS Code model options dropdown"
|
alt="VS Code model options dropdown"
|
||||||
width="75%"
|
width="75%"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ INSTRUCTION arguments
|
|||||||
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
|
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
|
||||||
| [`LICENSE`](#license) | Specifies the legal license. |
|
| [`LICENSE`](#license) | Specifies the legal license. |
|
||||||
| [`MESSAGE`](#message) | Specify message history. |
|
| [`MESSAGE`](#message) | Specify message history. |
|
||||||
|
| [`REQUIRES`](#requires) | Specify the minimum version of Ollama required by the model. |
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
@@ -149,9 +150,6 @@ PARAMETER <parameter> <parametervalue>
|
|||||||
|
|
||||||
| Parameter | Description | Value Type | Example Usage |
|
| Parameter | Description | Value Type | Example Usage |
|
||||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||||
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
|
|
||||||
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
|
|
||||||
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
|
|
||||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||||
@@ -251,6 +249,16 @@ MESSAGE user Is Ontario in Canada?
|
|||||||
MESSAGE assistant yes
|
MESSAGE assistant yes
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### REQUIRES
|
||||||
|
|
||||||
|
The `REQUIRES` instruction allows you to specify the minimum version of Ollama required by the model.
|
||||||
|
|
||||||
|
```
|
||||||
|
REQUIRES <version>
|
||||||
|
```
|
||||||
|
|
||||||
|
The version should be a valid Ollama version (e.g. 0.14.0).
|
||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.
|
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.
|
||||||
|
|||||||
@@ -111,6 +111,12 @@ components:
|
|||||||
description: Model keep-alive duration (for example `5m` or `0` to unload immediately)
|
description: Model keep-alive duration (for example `5m` or `0` to unload immediately)
|
||||||
options:
|
options:
|
||||||
$ref: "#/components/schemas/ModelOptions"
|
$ref: "#/components/schemas/ModelOptions"
|
||||||
|
logprobs:
|
||||||
|
type: boolean
|
||||||
|
description: Whether to return log probabilities of the output tokens
|
||||||
|
top_logprobs:
|
||||||
|
type: integer
|
||||||
|
description: Number of most likely tokens to return at each token position when logprobs are enabled
|
||||||
GenerateResponse:
|
GenerateResponse:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -150,6 +156,11 @@ components:
|
|||||||
eval_duration:
|
eval_duration:
|
||||||
type: integer
|
type: integer
|
||||||
description: Time spent generating tokens in nanoseconds
|
description: Time spent generating tokens in nanoseconds
|
||||||
|
logprobs:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: "#/components/schemas/Logprob"
|
||||||
|
description: Log probability information for the generated tokens when logprobs are enabled
|
||||||
GenerateStreamEvent:
|
GenerateStreamEvent:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -287,6 +298,12 @@ components:
|
|||||||
- type: string
|
- type: string
|
||||||
- type: number
|
- type: number
|
||||||
description: Model keep-alive duration (for example `5m` or `0` to unload immediately)
|
description: Model keep-alive duration (for example `5m` or `0` to unload immediately)
|
||||||
|
logprobs:
|
||||||
|
type: boolean
|
||||||
|
description: Whether to return log probabilities of the output tokens
|
||||||
|
top_logprobs:
|
||||||
|
type: integer
|
||||||
|
description: Number of most likely tokens to return at each token position when logprobs are enabled
|
||||||
ChatResponse:
|
ChatResponse:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -344,6 +361,11 @@ components:
|
|||||||
eval_duration:
|
eval_duration:
|
||||||
type: integer
|
type: integer
|
||||||
description: Time spent generating tokens in nanoseconds
|
description: Time spent generating tokens in nanoseconds
|
||||||
|
logprobs:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: "#/components/schemas/Logprob"
|
||||||
|
description: Log probability information for the generated tokens when logprobs are enabled
|
||||||
ChatStreamEvent:
|
ChatStreamEvent:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -706,6 +728,41 @@ components:
|
|||||||
version:
|
version:
|
||||||
type: string
|
type: string
|
||||||
description: Version of Ollama
|
description: Version of Ollama
|
||||||
|
TokenLogprob:
|
||||||
|
type: object
|
||||||
|
description: Log probability information for a single token alternative
|
||||||
|
properties:
|
||||||
|
token:
|
||||||
|
type: string
|
||||||
|
description: The text representation of the token
|
||||||
|
logprob:
|
||||||
|
type: number
|
||||||
|
description: The log probability of this token
|
||||||
|
bytes:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: integer
|
||||||
|
description: The raw byte representation of the token
|
||||||
|
Logprob:
|
||||||
|
type: object
|
||||||
|
description: Log probability information for a generated token
|
||||||
|
properties:
|
||||||
|
token:
|
||||||
|
type: string
|
||||||
|
description: The text representation of the token
|
||||||
|
logprob:
|
||||||
|
type: number
|
||||||
|
description: The log probability of this token
|
||||||
|
bytes:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: integer
|
||||||
|
description: The raw byte representation of the token
|
||||||
|
top_logprobs:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: "#/components/schemas/TokenLogprob"
|
||||||
|
description: Most likely tokens and their log probabilities at this position
|
||||||
ErrorResponse:
|
ErrorResponse:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|||||||
46
docs/tools/extract-examples/README.md
Normal file
46
docs/tools/extract-examples/README.md
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# extract-examples
|
||||||
|
|
||||||
|
Extracts code examples from MDX files to a temp directory so you can run them.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```shell
|
||||||
|
go run docs/tools/extract-examples/main.go <mdx-file>
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```shell
|
||||||
|
go run docs/tools/extract-examples/main.go docs/api/openai-compatibility.mdx
|
||||||
|
```
|
||||||
|
|
||||||
|
Output:
|
||||||
|
|
||||||
|
```
|
||||||
|
Extracting code examples to: /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||||
|
|
||||||
|
- 01_basic.py
|
||||||
|
- 01_basic.js
|
||||||
|
- 01_basic.sh
|
||||||
|
- 02_responses.py
|
||||||
|
- 02_responses.js
|
||||||
|
- 02_responses.sh
|
||||||
|
- 03_vision.py
|
||||||
|
- 03_vision.js
|
||||||
|
- 03_vision.sh
|
||||||
|
|
||||||
|
Extracted 9 file(s) to /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||||
|
|
||||||
|
To run examples:
|
||||||
|
|
||||||
|
cd /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
|
||||||
|
npm install # for JS examples
|
||||||
|
|
||||||
|
then run individual files with `node file.js`, `python file.py`, `bash file.sh`
|
||||||
|
```
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
|
||||||
|
- Parses MDX files looking for fenced code blocks with filenames (e.g., ` ```python basic.py `)
|
||||||
|
- Groups examples by their `<CodeGroup>` and prefixes filenames with `01_`, `02_`, etc.
|
||||||
|
- Writes all extracted files to a temp directory
|
||||||
137
docs/tools/extract-examples/main.go
Normal file
137
docs/tools/extract-examples/main.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if len(os.Args) < 2 {
|
||||||
|
fmt.Fprintln(os.Stderr, "Usage: go run extract-examples.go <mdx-file>")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
mdxFile := os.Args[1]
|
||||||
|
|
||||||
|
f, err := os.Open(mdxFile)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// Create temp directory
|
||||||
|
tempDir, err := os.MkdirTemp("", "mdx-examples-*")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error creating temp dir: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Extracting code examples to: %s\n\n", tempDir)
|
||||||
|
|
||||||
|
// Patterns
|
||||||
|
codeBlockStart := regexp.MustCompile("^```([a-zA-Z0-9_-]+)\\s+([^\\s]+)$")
|
||||||
|
codeGroupStart := regexp.MustCompile("^<CodeGroup")
|
||||||
|
codeGroupEnd := regexp.MustCompile("^</CodeGroup>")
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
inCodeBlock := false
|
||||||
|
inCodeGroup := false
|
||||||
|
var currentFile string
|
||||||
|
var content strings.Builder
|
||||||
|
count := 0
|
||||||
|
codeGroupNum := 0
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
|
||||||
|
// Track CodeGroup boundaries
|
||||||
|
if codeGroupStart.MatchString(line) {
|
||||||
|
inCodeGroup = true
|
||||||
|
codeGroupNum++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if codeGroupEnd.MatchString(line) {
|
||||||
|
inCodeGroup = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if inCodeBlock {
|
||||||
|
if line == "```" {
|
||||||
|
// End of code block - write file
|
||||||
|
if currentFile != "" {
|
||||||
|
outPath := filepath.Join(tempDir, currentFile)
|
||||||
|
if err := os.WriteFile(outPath, []byte(content.String()), 0o644); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error writing %s: %v\n", currentFile, err)
|
||||||
|
} else {
|
||||||
|
fmt.Printf(" - %s\n", currentFile)
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inCodeBlock = false
|
||||||
|
currentFile = ""
|
||||||
|
content.Reset()
|
||||||
|
} else {
|
||||||
|
content.WriteString(line)
|
||||||
|
content.WriteString("\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if matches := codeBlockStart.FindStringSubmatch(line); matches != nil {
|
||||||
|
inCodeBlock = true
|
||||||
|
filename := matches[2]
|
||||||
|
// Prefix with CodeGroup number if inside a CodeGroup
|
||||||
|
if inCodeGroup {
|
||||||
|
currentFile = fmt.Sprintf("%02d_%s", codeGroupNum, filename)
|
||||||
|
} else {
|
||||||
|
currentFile = filename
|
||||||
|
}
|
||||||
|
content.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error reading file: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write package.json for JavaScript dependencies
|
||||||
|
packageJSON := `{
|
||||||
|
"name": "mdx-examples",
|
||||||
|
"type": "module",
|
||||||
|
"dependencies": {
|
||||||
|
"openai": "^4",
|
||||||
|
"ollama": "^0.5"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(filepath.Join(tempDir, "package.json"), []byte(packageJSON), 0o644); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error writing package.json: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write pyproject.toml for Python dependencies
|
||||||
|
pyprojectTOML := `[project]
|
||||||
|
name = "mdx-examples"
|
||||||
|
version = "0.0.0"
|
||||||
|
dependencies = [
|
||||||
|
"openai",
|
||||||
|
"ollama",
|
||||||
|
]
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(filepath.Join(tempDir, "pyproject.toml"), []byte(pyprojectTOML), 0o644); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error writing pyproject.toml: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("\n")
|
||||||
|
fmt.Printf("Extracted %d file(s) to %s\n", count, tempDir)
|
||||||
|
fmt.Printf("\n")
|
||||||
|
fmt.Printf("To run examples:\n")
|
||||||
|
fmt.Printf("\n")
|
||||||
|
fmt.Printf(" cd %s\n npm install # for JS examples\n", tempDir)
|
||||||
|
fmt.Printf("\n")
|
||||||
|
fmt.Printf("then run individual files with `node file.js`, `python file.py`, `bash file.sh`\n")
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GGML struct {
|
type GGML struct {
|
||||||
@@ -240,12 +241,17 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
|||||||
|
|
||||||
func (kv KV) OllamaEngineRequired() bool {
|
func (kv KV) OllamaEngineRequired() bool {
|
||||||
return slices.Contains([]string{
|
return slices.Contains([]string{
|
||||||
|
"bert",
|
||||||
|
"deepseek2",
|
||||||
|
"deepseekocr",
|
||||||
"gemma3",
|
"gemma3",
|
||||||
"gemma3n",
|
"gemma3n",
|
||||||
"gptoss", "gpt-oss",
|
"gptoss", "gpt-oss",
|
||||||
"llama4",
|
"llama4",
|
||||||
"mistral3",
|
"mistral3",
|
||||||
"mllama",
|
"mllama",
|
||||||
|
"nomic-bert",
|
||||||
|
"olmo3",
|
||||||
"qwen25vl",
|
"qwen25vl",
|
||||||
"qwen3", "qwen3moe",
|
"qwen3", "qwen3moe",
|
||||||
"qwen3vl", "qwen3vlmoe",
|
"qwen3vl", "qwen3vlmoe",
|
||||||
@@ -547,7 +553,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
|
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention ml.FlashAttentionType) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||||
context *= uint64(numParallel)
|
context *= uint64(numParallel)
|
||||||
|
|
||||||
embedding := f.KV().EmbeddingLength()
|
embedding := f.KV().EmbeddingLength()
|
||||||
@@ -788,7 +794,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||||
if useFlashAttention {
|
if useFlashAttention == ml.FlashAttentionEnabled {
|
||||||
// rough estimate of graph size with flash attention on
|
// rough estimate of graph size with flash attention on
|
||||||
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
||||||
}
|
}
|
||||||
@@ -806,6 +812,14 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
|||||||
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
|
||||||
|
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
|
||||||
|
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// SupportsFlashAttention checks if the model supports flash attention
|
// SupportsFlashAttention checks if the model supports flash attention
|
||||||
func (f GGML) SupportsFlashAttention() bool {
|
func (f GGML) SupportsFlashAttention() bool {
|
||||||
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
||||||
@@ -826,8 +840,11 @@ func (f GGML) SupportsFlashAttention() bool {
|
|||||||
// FlashAttention checks if the model should enable flash attention
|
// FlashAttention checks if the model should enable flash attention
|
||||||
func (f GGML) FlashAttention() bool {
|
func (f GGML) FlashAttention() bool {
|
||||||
return slices.Contains([]string{
|
return slices.Contains([]string{
|
||||||
|
"bert",
|
||||||
"gemma3",
|
"gemma3",
|
||||||
"gptoss", "gpt-oss",
|
"gptoss", "gpt-oss",
|
||||||
|
"mistral3",
|
||||||
|
"olmo3",
|
||||||
"qwen3", "qwen3moe",
|
"qwen3", "qwen3moe",
|
||||||
"qwen3vl", "qwen3vlmoe",
|
"qwen3vl", "qwen3vlmoe",
|
||||||
}, f.KV().String("general.architecture"))
|
}, f.KV().String("general.architecture"))
|
||||||
|
|||||||
@@ -305,7 +305,7 @@ func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error
|
|||||||
|
|
||||||
a.values[i] = e
|
a.values[i] = e
|
||||||
} else {
|
} else {
|
||||||
discardGGUFString(llm, r)
|
_ = discardGGUFString(llm, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -568,7 +568,6 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
|||||||
g.SetLimit(runtime.GOMAXPROCS(0))
|
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||||
// TODO consider reducing if tensors size * gomaxprocs is larger than free memory
|
// TODO consider reducing if tensors size * gomaxprocs is larger than free memory
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
t := t
|
|
||||||
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
|
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
_, err := t.WriteTo(w)
|
_, err := t.WriteTo(w)
|
||||||
@@ -598,6 +597,10 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
switch v := v.(type) {
|
switch v := v.(type) {
|
||||||
|
case int32:
|
||||||
|
err = writeGGUF(ws, ggufTypeInt32, v)
|
||||||
|
case int64:
|
||||||
|
err = writeGGUF(ws, ggufTypeInt64, v)
|
||||||
case uint32, FileType:
|
case uint32, FileType:
|
||||||
err = writeGGUF(ws, ggufTypeUint32, v)
|
err = writeGGUF(ws, ggufTypeUint32, v)
|
||||||
case uint64:
|
case uint64:
|
||||||
@@ -612,6 +615,10 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
|
|||||||
err = writeGGUFArray(ws, ggufTypeInt32, v)
|
err = writeGGUFArray(ws, ggufTypeInt32, v)
|
||||||
case *array[int32]:
|
case *array[int32]:
|
||||||
err = writeGGUFArray(ws, ggufTypeInt32, v.values)
|
err = writeGGUFArray(ws, ggufTypeInt32, v.values)
|
||||||
|
case []int64:
|
||||||
|
err = writeGGUFArray(ws, ggufTypeInt64, v)
|
||||||
|
case *array[int64]:
|
||||||
|
err = writeGGUFArray(ws, ggufTypeInt64, v.values)
|
||||||
case []uint32:
|
case []uint32:
|
||||||
err = writeGGUFArray(ws, ggufTypeUint32, v)
|
err = writeGGUFArray(ws, ggufTypeUint32, v)
|
||||||
case *array[uint32]:
|
case *array[uint32]:
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ func TestWriteGGUF(t *testing.T) {
|
|||||||
"general.architecture": "test",
|
"general.architecture": "test",
|
||||||
"general.alignment": uint32(16),
|
"general.alignment": uint32(16),
|
||||||
"test.key": "value",
|
"test.key": "value",
|
||||||
|
"test.int32_key": int32(-42),
|
||||||
|
"test.int64_key": int64(-9223372036854775808),
|
||||||
|
"test.int32_array": []int32{-1, 0, 1, 2147483647, -2147483648},
|
||||||
|
"test.int64_array": []int64{-1, 0, 1, 9223372036854775807, -9223372036854775808},
|
||||||
"attention.key": "value2",
|
"attention.key": "value2",
|
||||||
"tokenizer.key": "value3",
|
"tokenizer.key": "value3",
|
||||||
"adapter.key": "value4",
|
"adapter.key": "value4",
|
||||||
@@ -55,7 +59,7 @@ func TestWriteGGUF(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
ff, err := Decode(r, 0)
|
ff, err := Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -65,15 +69,19 @@ func TestWriteGGUF(t *testing.T) {
|
|||||||
"general.alignment": uint32(16),
|
"general.alignment": uint32(16),
|
||||||
"general.parameter_count": uint64(54),
|
"general.parameter_count": uint64(54),
|
||||||
"test.key": "value",
|
"test.key": "value",
|
||||||
|
"test.int32_key": int32(-42),
|
||||||
|
"test.int64_key": int64(-9223372036854775808),
|
||||||
|
"test.int32_array": &array[int32]{size: 5, values: []int32{-1, 0, 1, 2147483647, -2147483648}},
|
||||||
|
"test.int64_array": &array[int64]{size: 5, values: []int64{-1, 0, 1, 9223372036854775807, -9223372036854775808}},
|
||||||
"test.attention.key": "value2",
|
"test.attention.key": "value2",
|
||||||
"tokenizer.key": "value3",
|
"tokenizer.key": "value3",
|
||||||
"adapter.key": "value4",
|
"adapter.key": "value4",
|
||||||
}, ff.KV()); diff != "" {
|
}, ff.KV(), cmp.AllowUnexported(array[int32]{}, array[int64]{})); diff != "" {
|
||||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(Tensors{
|
if diff := cmp.Diff(Tensors{
|
||||||
Offset: 800,
|
Offset: 992,
|
||||||
items: []*Tensor{
|
items: []*Tensor{
|
||||||
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||||
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||||
|
|||||||
16
go.mod
16
go.mod
@@ -15,9 +15,8 @@ require (
|
|||||||
github.com/spf13/cobra v1.7.0
|
github.com/spf13/cobra v1.7.0
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/x448/float16 v0.8.4
|
github.com/x448/float16 v0.8.4
|
||||||
golang.org/x/sync v0.12.0
|
golang.org/x/sync v0.17.0
|
||||||
golang.org/x/sys v0.36.0
|
golang.org/x/sys v0.37.0
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -30,7 +29,8 @@ require (
|
|||||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||||
golang.org/x/image v0.22.0
|
golang.org/x/image v0.22.0
|
||||||
golang.org/x/tools v0.30.0
|
golang.org/x/mod v0.30.0
|
||||||
|
golang.org/x/tools v0.38.0
|
||||||
gonum.org/v1/gonum v0.15.0
|
gonum.org/v1/gonum v0.15.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -77,11 +77,11 @@ require (
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.36.0
|
golang.org/x/crypto v0.43.0
|
||||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||||
golang.org/x/net v0.38.0 // indirect
|
golang.org/x/net v0.46.0 // indirect
|
||||||
golang.org/x/term v0.30.0
|
golang.org/x/term v0.36.0
|
||||||
golang.org/x/text v0.23.0
|
golang.org/x/text v0.30.0
|
||||||
google.golang.org/protobuf v1.34.1
|
google.golang.org/protobuf v1.34.1
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
30
go.sum
30
go.sum
@@ -224,8 +224,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
|||||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
@@ -255,6 +255,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
|||||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
|
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||||
|
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
@@ -267,8 +269,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
|||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -278,8 +280,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@@ -295,17 +297,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||||
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
@@ -319,8 +321,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
|||||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
|||||||
@@ -388,9 +388,9 @@ func NewFunctionNameMap() *FunctionNameMap {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes the handler with tools and optional last message
|
// Init initializes the handler with tools, optional last message, and think value
|
||||||
// Implements the Parser interface
|
// Implements the Parser interface
|
||||||
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
// Initialize the harmony parser
|
// Initialize the harmony parser
|
||||||
if h.HarmonyParser == nil {
|
if h.HarmonyParser == nil {
|
||||||
h.HarmonyParser = &HarmonyParser{
|
h.HarmonyParser = &HarmonyParser{
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ package integration
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"math"
|
"math"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -204,8 +206,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|||||||
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim)
|
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim)
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.PromptEvalCount != 6 {
|
if res.PromptEvalCount != 8 {
|
||||||
t.Fatalf("expected 6 prompt tokens, got %d", res.PromptEvalCount)
|
t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,8 +253,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|||||||
t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim)
|
t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim)
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.PromptEvalCount != 12 {
|
if res.PromptEvalCount != 16 {
|
||||||
t.Fatalf("expected 12 prompt tokens, got %d", res.PromptEvalCount)
|
t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -275,7 +277,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
cases := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
request api.EmbedRequest
|
request api.EmbedRequest
|
||||||
check func(*api.EmbedResponse, error)
|
check func(*testing.T, *api.EmbedResponse, error)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "target truncation",
|
name: "target truncation",
|
||||||
@@ -283,7 +285,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Input: "why",
|
Input: "why",
|
||||||
},
|
},
|
||||||
check: func(got *api.EmbedResponse, err error) {
|
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -300,10 +302,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
Input: "why is the sky blue?",
|
Input: "why is the sky blue?",
|
||||||
Options: map[string]any{"num_ctx": 3},
|
Options: map[string]any{"num_ctx": 3},
|
||||||
},
|
},
|
||||||
check: func(got *api.EmbedResponse, err error) {
|
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
|
||||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
@@ -317,10 +320,11 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
Truncate: &truncTrue,
|
Truncate: &truncTrue,
|
||||||
Options: map[string]any{"num_ctx": 3},
|
Options: map[string]any{"num_ctx": 3},
|
||||||
},
|
},
|
||||||
check: func(got *api.EmbedResponse, err error) {
|
check: func(t *testing.T, got *api.EmbedResponse, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
|
||||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
@@ -334,21 +338,21 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
Truncate: &truncFalse,
|
Truncate: &truncFalse,
|
||||||
Options: map[string]any{"num_ctx": 3},
|
Options: map[string]any{"num_ctx": 3},
|
||||||
},
|
},
|
||||||
check: func(res *api.EmbedResponse, err error) {
|
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||||
if err.Error() != "input exceeds maximum context length" {
|
if err.Error() != "the input length exceeds the context length" {
|
||||||
t.Fatalf("expected truncation error, got: %v", err)
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "input after truncate error",
|
name: "input after truncate error with context length of 1",
|
||||||
request: api.EmbedRequest{
|
request: api.EmbedRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Input: "why is the sky blue?",
|
Input: "why is the sky blue?",
|
||||||
Truncate: &truncTrue,
|
Truncate: &truncTrue,
|
||||||
Options: map[string]any{"num_ctx": 1},
|
Options: map[string]any{"num_ctx": 1},
|
||||||
},
|
},
|
||||||
check: func(res *api.EmbedResponse, err error) {
|
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||||
t.Fatalf("expected truncation error, got: %v", err)
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
}
|
}
|
||||||
@@ -362,7 +366,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
Truncate: &truncTrue,
|
Truncate: &truncTrue,
|
||||||
Options: map[string]any{"num_ctx": 0},
|
Options: map[string]any{"num_ctx": 0},
|
||||||
},
|
},
|
||||||
check: func(res *api.EmbedResponse, err error) {
|
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||||
t.Fatalf("expected truncation error, got: %v", err)
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
}
|
}
|
||||||
@@ -375,7 +379,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
Input: "why is the sky blue? Why is the sky blue? hi there my",
|
Input: "why is the sky blue? Why is the sky blue? hi there my",
|
||||||
Options: map[string]any{"num_ctx": 16},
|
Options: map[string]any{"num_ctx": 16},
|
||||||
},
|
},
|
||||||
check: func(res *api.EmbedResponse, err error) {
|
check: func(t *testing.T, res *api.EmbedResponse, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -385,7 +389,8 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
|
|
||||||
for _, req := range cases {
|
for _, req := range cases {
|
||||||
t.Run(req.name, func(t *testing.T) {
|
t.Run(req.name, func(t *testing.T) {
|
||||||
req.check(embedTestHelper(ctx, client, t, req.request))
|
resp, err := embedTestHelper(ctx, client, t, req.request)
|
||||||
|
req.check(t, resp, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -409,3 +414,230 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
|
|||||||
|
|
||||||
return client.Embed(ctx, &req)
|
return client.Embed(ctx, &req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbedTruncation(t *testing.T) {
|
||||||
|
// Use test deadline if set, otherwise default to 2 minutes
|
||||||
|
timeout := 2 * time.Minute
|
||||||
|
if deadline, ok := t.Deadline(); ok {
|
||||||
|
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
for _, model := range libraryEmbedModels {
|
||||||
|
model := model
|
||||||
|
t.Run(model, func(t *testing.T) {
|
||||||
|
// Check if we're running out of time (reserve 20s for current model)
|
||||||
|
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
|
||||||
|
t.Skip("skipping remaining tests to avoid timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give each model its own budget to account for first-time pulls/loads
|
||||||
|
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
|
||||||
|
defer mcancel()
|
||||||
|
|
||||||
|
t.Run("truncation batch", func(t *testing.T) {
|
||||||
|
truncTrue := true
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
|
||||||
|
Truncate: &truncTrue,
|
||||||
|
Options: map[string]any{"num_ctx": 30},
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := embedTestHelper(mctx, client, t, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings) != 3 {
|
||||||
|
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.PromptEvalCount > 90 {
|
||||||
|
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("runner token count accuracy", func(t *testing.T) {
|
||||||
|
baseline := api.EmbedRequest{Model: model, Input: "test"}
|
||||||
|
baseRes, err := embedTestHelper(mctx, client, t, baseline)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
batch := api.EmbedRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: []string{"test", "test", "test"},
|
||||||
|
}
|
||||||
|
batchRes, err := embedTestHelper(mctx, client, t, batch)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedCount := baseRes.PromptEvalCount * 3
|
||||||
|
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
|
||||||
|
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
|
||||||
|
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes.
|
||||||
|
func TestEmbedLargeInput(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
for _, model := range libraryEmbedModels {
|
||||||
|
model := model
|
||||||
|
t.Run(model, func(t *testing.T) {
|
||||||
|
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||||
|
defer mcancel()
|
||||||
|
|
||||||
|
// Test with progressively larger inputs
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
inputWords int
|
||||||
|
}{
|
||||||
|
{"medium_input_256_words", 256},
|
||||||
|
{"large_input_512_words", 512},
|
||||||
|
{"very_large_input_800_words", 800},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
words := make([]string, tc.inputWords)
|
||||||
|
for i := range words {
|
||||||
|
words[i] = "word"
|
||||||
|
}
|
||||||
|
input := strings.Join(words, " ")
|
||||||
|
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: input,
|
||||||
|
KeepAlive: &api.Duration{Duration: 30 * time.Second},
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := embedTestHelper(mctx, client, t, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("embedding failed for %d words: %v", tc.inputWords, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings) != 1 {
|
||||||
|
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings[0]) == 0 {
|
||||||
|
t.Fatal("expected non-empty embedding")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Successfully embedded %d words (%d tokens)", tc.inputWords, res.PromptEvalCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEmbedStatusCode tests that errors from the embedding endpoint
|
||||||
|
// properly preserve their HTTP status codes when returned to the client.
|
||||||
|
// This test specifically checks the error handling path in EmbedHandler
|
||||||
|
// where api.StatusError errors should maintain their original status code.
|
||||||
|
func TestEmbedStatusCode(t *testing.T) {
|
||||||
|
// Use test deadline if set, otherwise default to 2 minutes
|
||||||
|
timeout := 2 * time.Minute
|
||||||
|
if deadline, ok := t.Deadline(); ok {
|
||||||
|
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
for _, model := range libraryEmbedModels {
|
||||||
|
model := model
|
||||||
|
t.Run(model, func(t *testing.T) {
|
||||||
|
// Check if we're running out of time (reserve 20s for current model)
|
||||||
|
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
|
||||||
|
t.Skip("skipping remaining tests to avoid timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
|
||||||
|
defer mcancel()
|
||||||
|
|
||||||
|
// Pull the model if needed
|
||||||
|
if err := PullIfMissing(mctx, client, model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("truncation error status code", func(t *testing.T) {
|
||||||
|
truncFalse := false
|
||||||
|
longInput := strings.Repeat("word ", 100)
|
||||||
|
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: longInput,
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := embedTestHelper(mctx, client, t, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when truncate=false with long input")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that it's a StatusError with the correct status code
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) {
|
||||||
|
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The error should be a 4xx client error (likely 400 Bad Request)
|
||||||
|
// not a 500 Internal Server Error
|
||||||
|
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||||
|
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the error message is meaningful
|
||||||
|
if !strings.Contains(err.Error(), "context length") {
|
||||||
|
t.Errorf("expected error message to mention context length, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("batch truncation error status code", func(t *testing.T) {
|
||||||
|
truncFalse := false
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: []string{
|
||||||
|
"short input",
|
||||||
|
strings.Repeat("very long input ", 100),
|
||||||
|
"another short input",
|
||||||
|
},
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := embedTestHelper(mctx, client, t, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when one input exceeds context with truncate=false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that it's a StatusError with the correct status code
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) {
|
||||||
|
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The error should be a 4xx client error, not a 500 Internal Server Error
|
||||||
|
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||||
|
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -33,6 +33,9 @@ func TestVisionModels(t *testing.T) {
|
|||||||
// Qwen 3 VL mixture of experts
|
// Qwen 3 VL mixture of experts
|
||||||
model: "qwen3-vl:30b",
|
model: "qwen3-vl:30b",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
model: "ministral-3",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range testCases {
|
for _, v := range testCases {
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ func TestAPIToolCalling(t *testing.T) {
|
|||||||
"mistral": 6,
|
"mistral": 6,
|
||||||
"qwen2.5": 6,
|
"qwen2.5": 6,
|
||||||
"qwen2": 6,
|
"qwen2": 6,
|
||||||
|
"ministral-3": 20,
|
||||||
"mistral-nemo": 9,
|
"mistral-nemo": 9,
|
||||||
"mistral-small": 16,
|
"mistral-small": 16,
|
||||||
"mixtral:8x22b": 80,
|
"mixtral:8x22b": 80,
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ var (
|
|||||||
|
|
||||||
// Note: add newer models at the top of the list to test them first
|
// Note: add newer models at the top of the list to test them first
|
||||||
ollamaEngineChatModels = []string{
|
ollamaEngineChatModels = []string{
|
||||||
|
"ministral-3",
|
||||||
"qwen3-coder:30b",
|
"qwen3-coder:30b",
|
||||||
"gpt-oss:20b",
|
"gpt-oss:20b",
|
||||||
"gemma3n:e2b",
|
"gemma3n:e2b",
|
||||||
@@ -167,6 +168,7 @@ var (
|
|||||||
"medllama2",
|
"medllama2",
|
||||||
"megadolphin",
|
"megadolphin",
|
||||||
"minicpm-v",
|
"minicpm-v",
|
||||||
|
"ministral-3",
|
||||||
"mistral-large",
|
"mistral-large",
|
||||||
"mistral-nemo",
|
"mistral-nemo",
|
||||||
"mistral-openorca",
|
"mistral-openorca",
|
||||||
@@ -270,6 +272,7 @@ var (
|
|||||||
"mistral",
|
"mistral",
|
||||||
"qwen2.5",
|
"qwen2.5",
|
||||||
"qwen2",
|
"qwen2",
|
||||||
|
"ministral-3",
|
||||||
"mistral-nemo",
|
"mistral-nemo",
|
||||||
"mistral-small",
|
"mistral-small",
|
||||||
"mixtral:8x22b",
|
"mixtral:8x22b",
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package kvcache
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
@@ -40,18 +39,18 @@ type Causal struct {
|
|||||||
|
|
||||||
// ** current forward pass **
|
// ** current forward pass **
|
||||||
|
|
||||||
// the active layer for Get and Put
|
|
||||||
curLayer int
|
|
||||||
|
|
||||||
// starting location for data storage for this batch
|
|
||||||
curLoc int
|
|
||||||
|
|
||||||
// size of the current batch
|
// size of the current batch
|
||||||
curBatchSize int
|
curBatchSize int
|
||||||
|
|
||||||
|
// locations for data storage for this batch
|
||||||
|
curLoc ml.Tensor
|
||||||
|
|
||||||
// mask of the cache as used by this batch
|
// mask of the cache as used by this batch
|
||||||
curMask ml.Tensor
|
curMask ml.Tensor
|
||||||
|
|
||||||
|
// the active layer for Get and Put
|
||||||
|
curLayer int
|
||||||
|
|
||||||
// locations in the cache that are needed for this batch
|
// locations in the cache that are needed for this batch
|
||||||
curCellRange cellRange
|
curCellRange cellRange
|
||||||
|
|
||||||
@@ -141,10 +140,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
|||||||
c.config.CachePadding = 1
|
c.config.CachePadding = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.config.MaskBatchPadding == 0 {
|
|
||||||
c.config.MaskBatchPadding = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.MaskDType == ml.DTypeOther {
|
if c.config.MaskDType == ml.DTypeOther {
|
||||||
c.config.MaskDType = ml.DTypeF32
|
c.config.MaskDType = ml.DTypeF32
|
||||||
}
|
}
|
||||||
@@ -206,45 +201,47 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
|||||||
c.curPositions = batch.Positions
|
c.curPositions = batch.Positions
|
||||||
c.opts.Except = nil
|
c.opts.Except = nil
|
||||||
|
|
||||||
|
var locs []int32
|
||||||
if !reserve {
|
if !reserve {
|
||||||
c.updateSlidingWindow()
|
c.updateSlidingWindow()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
c.curLoc, err = c.findStartLoc()
|
locs, err = c.findLocs()
|
||||||
if errors.Is(err, ErrKvCacheFull) {
|
|
||||||
c.defrag()
|
|
||||||
c.curLoc, err = c.findStartLoc()
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, pos := range batch.Positions {
|
for i, pos := range batch.Positions {
|
||||||
seq := batch.Sequences[i]
|
seq := batch.Sequences[i]
|
||||||
|
loc := int(locs[i])
|
||||||
|
|
||||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||||
|
|
||||||
seqRange, ok := c.cellRanges[seq]
|
seqRange, ok := c.cellRanges[seq]
|
||||||
if !ok {
|
if !ok {
|
||||||
seqRange = newRange()
|
seqRange = newRange()
|
||||||
}
|
}
|
||||||
|
|
||||||
seqRange.min = min(seqRange.min, c.curLoc+i)
|
seqRange.min = min(seqRange.min, loc)
|
||||||
c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i)
|
c.curCellRange.min = min(c.curCellRange.min, loc)
|
||||||
|
|
||||||
seqRange.max = max(seqRange.max, c.curLoc+i)
|
seqRange.max = max(seqRange.max, loc)
|
||||||
c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i)
|
c.curCellRange.max = max(c.curCellRange.max, loc)
|
||||||
|
|
||||||
c.cellRanges[seq] = seqRange
|
c.cellRanges[seq] = seqRange
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// If we are reserving memory, don't update any of the cache metadata but set the size
|
// If we are reserving memory, don't update any of the cache metadata but set the size
|
||||||
// to the worst case.
|
// to the worst case.
|
||||||
c.curLoc = 0
|
locs = make([]int32, c.curBatchSize)
|
||||||
|
for i := range locs {
|
||||||
|
locs[i] = int32(i)
|
||||||
|
}
|
||||||
c.curCellRange.min = 0
|
c.curCellRange.min = 0
|
||||||
c.curCellRange.max = len(c.cells) - 1
|
c.curCellRange.max = len(c.cells) - 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.curLoc = ctx.Input().FromInts(locs, len(locs))
|
||||||
c.curMask = c.buildMask(ctx)
|
c.curMask = c.buildMask(ctx)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -257,22 +254,20 @@ func newRange() cellRange {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the first contiguous block of at least curBatchSize
|
// Returns a slice of locations where each token in the batch should be stored
|
||||||
func (c *Causal) findStartLoc() (int, error) {
|
func (c *Causal) findLocs() ([]int32, error) {
|
||||||
var start, count int
|
loc := make([]int32, 0, c.curBatchSize)
|
||||||
|
|
||||||
for i := range c.cells {
|
for i := range c.cells {
|
||||||
if len(c.cells[i].sequences) == 0 {
|
if len(c.cells[i].sequences) == 0 {
|
||||||
count++
|
loc = append(loc, int32(i))
|
||||||
if count >= c.curBatchSize {
|
if len(loc) >= c.curBatchSize {
|
||||||
return start, nil
|
return loc, nil
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
start = i + 1
|
|
||||||
count = 0
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) updateSlidingWindow() {
|
func (c *Causal) updateSlidingWindow() {
|
||||||
@@ -365,15 +360,12 @@ func roundUp(length, pad int) int {
|
|||||||
// token in the history should apply. This is based on both the sequence and causality (the
|
// token in the history should apply. This is based on both the sequence and causality (the
|
||||||
// position of the history is not ahead of the token in the batch).
|
// position of the history is not ahead of the token in the batch).
|
||||||
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||||
// Align and pad the two dimensions as required by the backend
|
|
||||||
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
|
||||||
|
|
||||||
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||||
|
|
||||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||||
|
|
||||||
mask := make([]float32, batchSize*length)
|
mask := make([]float32, c.curBatchSize*length)
|
||||||
|
|
||||||
for i := range c.curBatchSize {
|
for i := range c.curBatchSize {
|
||||||
enabled := !slices.Contains(c.opts.Except, i)
|
enabled := !slices.Contains(c.opts.Except, i)
|
||||||
@@ -387,13 +379,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mask out any padding tokens we added. For padding that we added to the cache history, this
|
maskTensor := ctx.Input().FromFloats(mask, length, c.curBatchSize)
|
||||||
// has already been masked out because the sequence doesn't match.
|
|
||||||
for i := c.curBatchSize * length; i < len(mask); i++ {
|
|
||||||
mask[i] = float32(math.Inf(-1))
|
|
||||||
}
|
|
||||||
|
|
||||||
maskTensor := ctx.Input().FromFloats(mask, length, batchSize)
|
|
||||||
|
|
||||||
if c.config.MaskDType != ml.DTypeF32 {
|
if c.config.MaskDType != ml.DTypeF32 {
|
||||||
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||||
@@ -402,145 +388,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
|||||||
return maskTensor
|
return maskTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
|
||||||
for i, key := range c.keys {
|
|
||||||
if key == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
kHeadDim := key.Dim(0)
|
|
||||||
numKVHeads := key.Dim(1)
|
|
||||||
rowSize := key.Stride(2)
|
|
||||||
|
|
||||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
|
|
||||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
|
|
||||||
|
|
||||||
value := c.values[i]
|
|
||||||
var vSrcView, vDstView ml.Tensor
|
|
||||||
if c.config.PermutedV {
|
|
||||||
vHeadDim := value.Dim(1)
|
|
||||||
elemSize := value.Stride(0)
|
|
||||||
|
|
||||||
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
|
||||||
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
|
||||||
} else {
|
|
||||||
vHeadDim := value.Dim(0)
|
|
||||||
rowSize := value.Stride(2)
|
|
||||||
|
|
||||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
|
|
||||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Forward(
|
|
||||||
kSrcView.Copy(ctx, kDstView),
|
|
||||||
vSrcView.Copy(ctx, vDstView),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Causal) defrag() {
|
|
||||||
slog.Debug("defragmenting kv cache")
|
|
||||||
|
|
||||||
// Defrag strategy:
|
|
||||||
// - Search for empty holes at the beginning of the cache,
|
|
||||||
// filling them with active data starting at the end
|
|
||||||
// - If there are contiguous elements that need to be moved,
|
|
||||||
// combine them into a single operation by holding new moves
|
|
||||||
// until we see that the next one is non-contiguous
|
|
||||||
// - Fill up the context with the maximum number of operations it
|
|
||||||
// can hold then compute that and continue with a new context
|
|
||||||
//
|
|
||||||
// We could try to optimize placement by grouping blocks from
|
|
||||||
// the same sequences together but most likely the next forward
|
|
||||||
// pass will disrupt this anyways, so the real world benefit
|
|
||||||
// seems limited as this time.
|
|
||||||
|
|
||||||
ctx := c.backend.NewContext()
|
|
||||||
|
|
||||||
// For every move, 6 tensors are required per layer (2 views and a
|
|
||||||
// copy for each of k and v). We also need to refer to the original
|
|
||||||
// k and v cache tensors - once per layer, not per move.
|
|
||||||
layers := 0
|
|
||||||
for _, key := range c.keys {
|
|
||||||
if key == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
layers++
|
|
||||||
}
|
|
||||||
|
|
||||||
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
|
|
||||||
moves := 0
|
|
||||||
|
|
||||||
var pendingSrc, pendingDst, pendingLen int
|
|
||||||
src := len(c.cells) - 1
|
|
||||||
|
|
||||||
for dst := 0; dst < src; dst++ {
|
|
||||||
if len(c.cells[dst].sequences) == 0 {
|
|
||||||
for ; src > dst; src-- {
|
|
||||||
if len(c.cells[src].sequences) != 0 {
|
|
||||||
c.cells[dst] = c.cells[src]
|
|
||||||
c.cells[src] = cacheCell{}
|
|
||||||
|
|
||||||
if pendingLen > 0 {
|
|
||||||
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
|
|
||||||
pendingSrc = src
|
|
||||||
pendingLen++
|
|
||||||
break
|
|
||||||
} else {
|
|
||||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
|
||||||
moves++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pendingSrc = src
|
|
||||||
pendingDst = dst
|
|
||||||
pendingLen = 1
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if moves >= maxMoves {
|
|
||||||
ctx.Compute()
|
|
||||||
ctx.Close()
|
|
||||||
ctx = c.backend.NewContext()
|
|
||||||
|
|
||||||
moves = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if pendingLen > 0 {
|
|
||||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
|
||||||
moves++
|
|
||||||
}
|
|
||||||
|
|
||||||
if moves > 0 {
|
|
||||||
ctx.Compute()
|
|
||||||
}
|
|
||||||
ctx.Close()
|
|
||||||
|
|
||||||
// Reset range metadata
|
|
||||||
for seq := range c.cellRanges {
|
|
||||||
seqRange := newRange()
|
|
||||||
|
|
||||||
for i, cell := range c.cells {
|
|
||||||
if slices.Contains(cell.sequences, seq) {
|
|
||||||
if i < seqRange.min {
|
|
||||||
seqRange.min = i
|
|
||||||
}
|
|
||||||
if i > seqRange.max {
|
|
||||||
seqRange.max = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.cellRanges[seq] = seqRange
|
|
||||||
}
|
|
||||||
|
|
||||||
c.updateSlidingWindow()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Causal) SetLayer(layer int) {
|
func (c *Causal) SetLayer(layer int) {
|
||||||
c.curLayer = layer
|
c.curLayer = layer
|
||||||
}
|
}
|
||||||
@@ -625,18 +472,25 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rowSize := c.keys[c.curLayer].Stride(2)
|
key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
|
||||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
|
keyCache := c.keys[c.curLayer]
|
||||||
|
keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
|
||||||
|
ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
|
||||||
|
|
||||||
if c.config.PermutedV {
|
if c.config.PermutedV {
|
||||||
elemSize := c.values[c.curLayer].Stride(0)
|
value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
|
||||||
|
value = value.Permute(ctx, 2, 0, 1, 3)
|
||||||
|
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
valueCache := c.values[c.curLayer]
|
||||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
|
valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
|
||||||
|
|
||||||
|
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
|
||||||
} else {
|
} else {
|
||||||
rowSize := c.values[c.curLayer].Stride(2)
|
value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
|
||||||
|
valueCache := c.values[c.curLayer]
|
||||||
|
valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
|
||||||
|
|
||||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
|
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
2
llama/build-info.cpp
generated
vendored
2
llama/build-info.cpp
generated
vendored
@@ -1,4 +1,4 @@
|
|||||||
int LLAMA_BUILD_NUMBER = 0;
|
int LLAMA_BUILD_NUMBER = 0;
|
||||||
char const *LLAMA_COMMIT = "3cfa9c3f125763305b4226bc032f1954f08990dc";
|
char const *LLAMA_COMMIT = "ec98e2002";
|
||||||
char const *LLAMA_COMPILER = "";
|
char const *LLAMA_COMPILER = "";
|
||||||
char const *LLAMA_BUILD_TARGET = "";
|
char const *LLAMA_BUILD_TARGET = "";
|
||||||
|
|||||||
@@ -17,11 +17,17 @@ include /tools/mtmd/clip.cpp
|
|||||||
include /tools/mtmd/mtmd.cpp
|
include /tools/mtmd/mtmd.cpp
|
||||||
include /tools/mtmd/mtmd-audio.cpp
|
include /tools/mtmd/mtmd-audio.cpp
|
||||||
include /tools/mtmd/mtmd-helper.cpp
|
include /tools/mtmd/mtmd-helper.cpp
|
||||||
|
include /tools/mtmd/models/
|
||||||
|
include /tools/mtmd/models/*.h
|
||||||
|
include /tools/mtmd/models/*.cpp
|
||||||
include /src/
|
include /src/
|
||||||
include /src/llama.*
|
include /src/llama.*
|
||||||
include /src/llama-*.*
|
include /src/llama-*.*
|
||||||
include /src/unicode-data.*
|
include /src/unicode-data.*
|
||||||
include /src/unicode.*
|
include /src/unicode.*
|
||||||
|
include /src/models/
|
||||||
|
include /src/models/*.h
|
||||||
|
include /src/models/*.cpp
|
||||||
include /vendor/
|
include /vendor/
|
||||||
include /vendor/miniaudio/
|
include /vendor/miniaudio/
|
||||||
include /vendor/miniaudio/*.h
|
include /vendor/miniaudio/*.h
|
||||||
|
|||||||
359
llama/llama.cpp/common/common.cpp
vendored
359
llama/llama.cpp/common/common.cpp
vendored
@@ -8,6 +8,7 @@
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "sampling.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
@@ -26,7 +27,6 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <unordered_map>
|
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -60,6 +60,14 @@
|
|||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||||
|
|
||||||
|
common_time_meas::~common_time_meas() {
|
||||||
|
if (t_start_us >= 0) {
|
||||||
|
t_acc += ggml_time_us() - t_start_us;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// CPU utils
|
// CPU utils
|
||||||
//
|
//
|
||||||
@@ -355,11 +363,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
|||||||
}
|
}
|
||||||
|
|
||||||
void common_init() {
|
void common_init() {
|
||||||
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
|
llama_log_set(common_log_default_callback, NULL);
|
||||||
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
|
||||||
common_log_add(common_log_main(), level, "%s", text);
|
|
||||||
}
|
|
||||||
}, NULL);
|
|
||||||
|
|
||||||
#ifdef NDEBUG
|
#ifdef NDEBUG
|
||||||
const char * build_type = "";
|
const char * build_type = "";
|
||||||
@@ -690,7 +694,7 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
|
|||||||
|
|
||||||
// Validate if a filename is safe to use
|
// Validate if a filename is safe to use
|
||||||
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
|
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
|
||||||
bool fs_validate_filename(const std::string & filename) {
|
bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||||
if (!filename.length()) {
|
if (!filename.length()) {
|
||||||
// Empty filename invalid
|
// Empty filename invalid
|
||||||
return false;
|
return false;
|
||||||
@@ -750,10 +754,14 @@ bool fs_validate_filename(const std::string & filename) {
|
|||||||
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|
||||||
|| c == 0xFFFD // Replacement Character (UTF-8)
|
|| c == 0xFFFD // Replacement Character (UTF-8)
|
||||||
|| c == 0xFEFF // Byte Order Mark (BOM)
|
|| c == 0xFEFF // Byte Order Mark (BOM)
|
||||||
|| c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
|
|| c == ':' || c == '*' // Illegal characters
|
||||||
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
|
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (!allow_subdirs && (c == '/' || c == '\\')) {
|
||||||
|
// Subdirectories not allowed, reject path separators
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
|
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
|
||||||
@@ -778,11 +786,29 @@ bool fs_validate_filename(const std::string & filename) {
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
static std::wstring utf8_to_wstring(const std::string & str) {
|
||||||
|
if (str.empty()) {
|
||||||
|
return std::wstring();
|
||||||
|
}
|
||||||
|
|
||||||
|
int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
|
||||||
|
|
||||||
|
if (size <= 0) {
|
||||||
|
return std::wstring();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::wstring wstr(size, 0);
|
||||||
|
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
|
||||||
|
|
||||||
|
return wstr;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// returns true if successful, false otherwise
|
// returns true if successful, false otherwise
|
||||||
bool fs_create_directory_with_parents(const std::string & path) {
|
bool fs_create_directory_with_parents(const std::string & path) {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
std::wstring wpath = utf8_to_wstring(path);
|
||||||
std::wstring wpath = converter.from_bytes(path);
|
|
||||||
|
|
||||||
// if the path already exists, check whether it's a directory
|
// if the path already exists, check whether it's a directory
|
||||||
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
||||||
@@ -855,6 +881,11 @@ bool fs_create_directory_with_parents(const std::string & path) {
|
|||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool fs_is_directory(const std::string & path) {
|
||||||
|
std::filesystem::path dir(path);
|
||||||
|
return std::filesystem::exists(dir) && std::filesystem::is_directory(dir);
|
||||||
|
}
|
||||||
|
|
||||||
std::string fs_get_cache_directory() {
|
std::string fs_get_cache_directory() {
|
||||||
std::string cache_directory = "";
|
std::string cache_directory = "";
|
||||||
auto ensure_trailing_slash = [](std::string p) {
|
auto ensure_trailing_slash = [](std::string p) {
|
||||||
@@ -889,6 +920,8 @@ std::string fs_get_cache_directory() {
|
|||||||
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
||||||
#elif defined(_WIN32)
|
#elif defined(_WIN32)
|
||||||
cache_directory = std::getenv("LOCALAPPDATA");
|
cache_directory = std::getenv("LOCALAPPDATA");
|
||||||
|
#elif defined(__EMSCRIPTEN__)
|
||||||
|
GGML_ABORT("not implemented on this platform");
|
||||||
#else
|
#else
|
||||||
# error Unknown architecture
|
# error Unknown architecture
|
||||||
#endif
|
#endif
|
||||||
@@ -908,34 +941,258 @@ std::string fs_get_cache_file(const std::string & filename) {
|
|||||||
return cache_directory + filename;
|
return cache_directory + filename;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories) {
|
||||||
|
std::vector<common_file_info> files;
|
||||||
|
if (path.empty()) return files;
|
||||||
|
|
||||||
|
std::filesystem::path dir(path);
|
||||||
|
if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
|
||||||
|
return files;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & entry : std::filesystem::directory_iterator(dir)) {
|
||||||
|
try {
|
||||||
|
// Only include regular files (skip directories)
|
||||||
|
const auto & p = entry.path();
|
||||||
|
if (std::filesystem::is_regular_file(p)) {
|
||||||
|
common_file_info info;
|
||||||
|
info.path = p.string();
|
||||||
|
info.name = p.filename().string();
|
||||||
|
info.is_dir = false;
|
||||||
|
try {
|
||||||
|
info.size = static_cast<size_t>(std::filesystem::file_size(p));
|
||||||
|
} catch (const std::filesystem::filesystem_error &) {
|
||||||
|
info.size = 0;
|
||||||
|
}
|
||||||
|
files.push_back(std::move(info));
|
||||||
|
} else if (include_directories && std::filesystem::is_directory(p)) {
|
||||||
|
common_file_info info;
|
||||||
|
info.path = p.string();
|
||||||
|
info.name = p.filename().string();
|
||||||
|
info.size = 0; // Directories have no size
|
||||||
|
info.is_dir = true;
|
||||||
|
files.push_back(std::move(info));
|
||||||
|
}
|
||||||
|
} catch (const std::filesystem::filesystem_error &) {
|
||||||
|
// skip entries we cannot inspect
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return files;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// TTY utils
|
||||||
|
//
|
||||||
|
|
||||||
|
bool tty_can_use_colors() {
|
||||||
|
// Check NO_COLOR environment variable (https://no-color.org/)
|
||||||
|
if (const char * no_color = std::getenv("NO_COLOR")) {
|
||||||
|
if (no_color[0] != '\0') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check TERM environment variable
|
||||||
|
if (const char * term = std::getenv("TERM")) {
|
||||||
|
if (std::strcmp(term, "dumb") == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if stdout and stderr are connected to a terminal
|
||||||
|
// We check both because log messages can go to either
|
||||||
|
bool stdout_is_tty = isatty(fileno(stdout));
|
||||||
|
bool stderr_is_tty = isatty(fileno(stderr));
|
||||||
|
|
||||||
|
return stdout_is_tty || stderr_is_tty;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model utils
|
// Model utils
|
||||||
//
|
//
|
||||||
|
|
||||||
struct common_init_result common_init_from_params(common_params & params) {
|
// TODO: move to common/sampling
|
||||||
common_init_result iparams;
|
static void common_init_sampler_from_model(
|
||||||
|
const llama_model * model,
|
||||||
|
common_params_sampling & sparams) {
|
||||||
|
|
||||||
|
const uint64_t config = sparams.user_sampling_config;
|
||||||
|
|
||||||
|
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
|
||||||
|
if (config & user_config) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
char buf[64] = {0};
|
||||||
|
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||||
|
char * end = nullptr;
|
||||||
|
int32_t v = strtol(buf, &end, 10);
|
||||||
|
if (end && end != buf) {
|
||||||
|
dst = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
|
||||||
|
if (config & user_config) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
char buf[128] = {0};
|
||||||
|
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||||
|
char * end = nullptr;
|
||||||
|
float v = strtof(buf, &end);
|
||||||
|
if (end && end != buf) {
|
||||||
|
dst = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Sampling sequence
|
||||||
|
if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) {
|
||||||
|
char buf[512] = {0};
|
||||||
|
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
|
||||||
|
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
|
||||||
|
if (!sampler_names.empty()) {
|
||||||
|
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP);
|
||||||
|
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT);
|
||||||
|
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct common_init_result::impl {
|
||||||
|
impl() = default;
|
||||||
|
~impl() = default;
|
||||||
|
|
||||||
|
llama_model_ptr model;
|
||||||
|
llama_context_ptr context;
|
||||||
|
|
||||||
|
std::vector<llama_adapter_lora_ptr> lora;
|
||||||
|
|
||||||
|
std::vector<common_sampler_ptr> samplers;
|
||||||
|
};
|
||||||
|
|
||||||
|
common_init_result::common_init_result(common_params & params) :
|
||||||
|
pimpl(new impl{}) {
|
||||||
auto mparams = common_model_params_to_llama(params);
|
auto mparams = common_model_params_to_llama(params);
|
||||||
|
auto cparams = common_context_params_to_llama(params);
|
||||||
|
|
||||||
|
if (params.fit_params) {
|
||||||
|
LOG_INF("%s: fitting params to device memory, to report bugs during this step use -fit off (or --verbose if you can't)\n", __func__);
|
||||||
|
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||||
|
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
|
||||||
|
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
||||||
if (model == NULL) {
|
if (model == NULL) {
|
||||||
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
return;
|
||||||
__func__, params.model.path.c_str());
|
|
||||||
return iparams;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pimpl->model.reset(model);
|
||||||
|
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
auto cparams = common_context_params_to_llama(params);
|
// updates params.sampling
|
||||||
|
// TODO: fix naming
|
||||||
|
common_init_sampler_from_model(model, params.sampling);
|
||||||
|
|
||||||
|
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||||
|
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||||
|
params.sampling.ignore_eos = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize once
|
||||||
|
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||||
|
if (llama_vocab_is_eog(vocab, i)) {
|
||||||
|
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
|
||||||
|
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.sampling.ignore_eos) {
|
||||||
|
// add EOG biases to the active set of logit biases
|
||||||
|
params.sampling.logit_bias.insert(
|
||||||
|
params.sampling.logit_bias.end(),
|
||||||
|
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
//if (params.sampling.penalty_last_n == -1) {
|
||||||
|
// LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||||
|
// params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||||
|
//}
|
||||||
|
|
||||||
|
//if (params.sampling.dry_penalty_last_n == -1) {
|
||||||
|
// LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||||
|
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||||
|
//}
|
||||||
|
|
||||||
|
pimpl->samplers.resize(cparams.n_seq_max);
|
||||||
|
|
||||||
|
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
|
||||||
|
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
|
||||||
|
}
|
||||||
|
|
||||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||||
if (lctx == NULL) {
|
if (lctx == NULL) {
|
||||||
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||||
__func__, params.model.path.c_str());
|
return;
|
||||||
llama_model_free(model);
|
|
||||||
return iparams;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pimpl->context.reset(lctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_model * common_init_result::model() {
|
||||||
|
return pimpl->model.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_context * common_init_result::context() {
|
||||||
|
return pimpl->context.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
|
||||||
|
return pimpl->samplers[seq_id].get();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
||||||
|
return pimpl->lora;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_init_result::free_context() {
|
||||||
|
pimpl->context.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
common_init_result_ptr common_init_from_params(common_params & params) {
|
||||||
|
common_init_result_ptr res(new common_init_result(params));
|
||||||
|
|
||||||
|
llama_model * model = res->model();
|
||||||
|
if (model == NULL) {
|
||||||
|
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_context * lctx = res->context();
|
||||||
|
if (lctx == NULL) {
|
||||||
|
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
|
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
|
||||||
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
|
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
|
||||||
params.ctx_shift = false;
|
params.ctx_shift = false;
|
||||||
@@ -947,10 +1204,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
|
|
||||||
const auto cvec = common_control_vector_load(params.control_vectors);
|
const auto cvec = common_control_vector_load(params.control_vectors);
|
||||||
if (cvec.n_embd == -1) {
|
if (cvec.n_embd == -1) {
|
||||||
llama_free(lctx);
|
return res;
|
||||||
llama_model_free(model);
|
|
||||||
|
|
||||||
return iparams;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int err = llama_apply_adapter_cvec(
|
int err = llama_apply_adapter_cvec(
|
||||||
@@ -961,10 +1215,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
params.control_vector_layer_start,
|
params.control_vector_layer_start,
|
||||||
params.control_vector_layer_end);
|
params.control_vector_layer_end);
|
||||||
if (err) {
|
if (err) {
|
||||||
llama_free(lctx);
|
return res;
|
||||||
llama_model_free(model);
|
|
||||||
|
|
||||||
return iparams;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -988,10 +1239,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
llama_free(lctx);
|
return res;
|
||||||
llama_model_free(model);
|
|
||||||
|
|
||||||
return iparams;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1001,9 +1249,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||||
if (lora == nullptr) {
|
if (lora == nullptr) {
|
||||||
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
|
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
|
||||||
llama_free(lctx);
|
return res;
|
||||||
llama_model_free(model);
|
|
||||||
return iparams;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
char buf[1024];
|
char buf[1024];
|
||||||
@@ -1012,43 +1258,13 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
la.task_name = buf;
|
la.task_name = buf;
|
||||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
|
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
|
||||||
la.prompt_prefix = buf;
|
la.prompt_prefix = buf;
|
||||||
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
|
res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!params.lora_init_without_apply) {
|
if (!params.lora_init_without_apply) {
|
||||||
common_set_adapter_lora(lctx, params.lora_adapters);
|
common_set_adapter_lora(lctx, params.lora_adapters);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
|
||||||
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
|
||||||
params.sampling.ignore_eos = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// initialize once
|
|
||||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
|
||||||
if (llama_vocab_is_eog(vocab, i)) {
|
|
||||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
|
||||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.sampling.ignore_eos) {
|
|
||||||
// add EOG biases to the active set of logit biases
|
|
||||||
params.sampling.logit_bias.insert(
|
|
||||||
params.sampling.logit_bias.end(),
|
|
||||||
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.sampling.penalty_last_n == -1) {
|
|
||||||
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
|
||||||
params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.sampling.dry_penalty_last_n == -1) {
|
|
||||||
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
|
||||||
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.warmup) {
|
if (params.warmup) {
|
||||||
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||||
|
|
||||||
@@ -1087,12 +1303,11 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
llama_set_warmup(lctx, false);
|
llama_set_warmup(lctx, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
iparams.model.reset(model);
|
return res;
|
||||||
iparams.context.reset(lctx);
|
|
||||||
|
|
||||||
return iparams;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
common_init_result::~common_init_result() = default;
|
||||||
|
|
||||||
std::string get_model_endpoint() {
|
std::string get_model_endpoint() {
|
||||||
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
|
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
|
||||||
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
|
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
|
||||||
@@ -1101,7 +1316,9 @@ std::string get_model_endpoint() {
|
|||||||
std::string model_endpoint = "https://huggingface.co/";
|
std::string model_endpoint = "https://huggingface.co/";
|
||||||
if (endpoint_env) {
|
if (endpoint_env) {
|
||||||
model_endpoint = endpoint_env;
|
model_endpoint = endpoint_env;
|
||||||
if (model_endpoint.back() != '/') model_endpoint += '/';
|
if (model_endpoint.back() != '/') {
|
||||||
|
model_endpoint += '/';
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return model_endpoint;
|
return model_endpoint;
|
||||||
}
|
}
|
||||||
|
|||||||
125
llama/llama.cpp/common/common.h
vendored
125
llama/llama.cpp/common/common.h
vendored
@@ -2,17 +2,19 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml-opt.h"
|
||||||
|
#include "llama-cpp.h"
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <sstream>
|
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
#include "ggml-opt.h"
|
#if defined(_WIN32) && !defined(_WIN32_WINNT)
|
||||||
#include "llama-cpp.h"
|
#define _WIN32_WINNT 0x0A00
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#define DIRECTORY_SEPARATOR '\\'
|
#define DIRECTORY_SEPARATOR '\\'
|
||||||
@@ -28,7 +30,14 @@
|
|||||||
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
|
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
|
||||||
} while(0)
|
} while(0)
|
||||||
|
|
||||||
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
struct common_time_meas {
|
||||||
|
common_time_meas(int64_t & t_acc, bool disable = false);
|
||||||
|
~common_time_meas();
|
||||||
|
|
||||||
|
const int64_t t_start_us;
|
||||||
|
|
||||||
|
int64_t & t_acc;
|
||||||
|
};
|
||||||
|
|
||||||
struct common_adapter_lora_info {
|
struct common_adapter_lora_info {
|
||||||
std::string path;
|
std::string path;
|
||||||
@@ -73,7 +82,8 @@ int32_t cpu_get_num_math();
|
|||||||
enum llama_example {
|
enum llama_example {
|
||||||
LLAMA_EXAMPLE_COMMON,
|
LLAMA_EXAMPLE_COMMON,
|
||||||
LLAMA_EXAMPLE_SPECULATIVE,
|
LLAMA_EXAMPLE_SPECULATIVE,
|
||||||
LLAMA_EXAMPLE_MAIN,
|
LLAMA_EXAMPLE_COMPLETION,
|
||||||
|
LLAMA_EXAMPLE_CLI,
|
||||||
LLAMA_EXAMPLE_EMBEDDING,
|
LLAMA_EXAMPLE_EMBEDDING,
|
||||||
LLAMA_EXAMPLE_PERPLEXITY,
|
LLAMA_EXAMPLE_PERPLEXITY,
|
||||||
LLAMA_EXAMPLE_RETRIEVAL,
|
LLAMA_EXAMPLE_RETRIEVAL,
|
||||||
@@ -89,6 +99,7 @@ enum llama_example {
|
|||||||
LLAMA_EXAMPLE_TTS,
|
LLAMA_EXAMPLE_TTS,
|
||||||
LLAMA_EXAMPLE_DIFFUSION,
|
LLAMA_EXAMPLE_DIFFUSION,
|
||||||
LLAMA_EXAMPLE_FINETUNE,
|
LLAMA_EXAMPLE_FINETUNE,
|
||||||
|
LLAMA_EXAMPLE_FIT_PARAMS,
|
||||||
|
|
||||||
LLAMA_EXAMPLE_COUNT,
|
LLAMA_EXAMPLE_COUNT,
|
||||||
};
|
};
|
||||||
@@ -133,6 +144,22 @@ struct common_grammar_trigger {
|
|||||||
llama_token token = LLAMA_TOKEN_NULL;
|
llama_token token = LLAMA_TOKEN_NULL;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum common_params_sampling_config : uint64_t {
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
struct common_params_sampling {
|
struct common_params_sampling {
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||||
@@ -165,8 +192,9 @@ struct common_params_sampling {
|
|||||||
bool no_perf = false; // disable performance metrics
|
bool no_perf = false; // disable performance metrics
|
||||||
bool timing_per_token = false;
|
bool timing_per_token = false;
|
||||||
|
|
||||||
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
|
||||||
|
|
||||||
|
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||||
|
|
||||||
std::vector<enum common_sampler_type> samplers = {
|
std::vector<enum common_sampler_type> samplers = {
|
||||||
COMMON_SAMPLER_TYPE_PENALTIES,
|
COMMON_SAMPLER_TYPE_PENALTIES,
|
||||||
@@ -188,6 +216,10 @@ struct common_params_sampling {
|
|||||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||||
|
|
||||||
|
bool has_logit_bias() const {
|
||||||
|
return !logit_bias.empty();
|
||||||
|
}
|
||||||
|
|
||||||
// print the parameters into a string
|
// print the parameters into a string
|
||||||
std::string print() const;
|
std::string print() const;
|
||||||
};
|
};
|
||||||
@@ -198,6 +230,7 @@ struct common_params_model {
|
|||||||
std::string hf_repo = ""; // HF repo // NOLINT
|
std::string hf_repo = ""; // HF repo // NOLINT
|
||||||
std::string hf_file = ""; // HF file // NOLINT
|
std::string hf_file = ""; // HF file // NOLINT
|
||||||
std::string docker_repo = ""; // Docker repo // NOLINT
|
std::string docker_repo = ""; // Docker repo // NOLINT
|
||||||
|
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_params_speculative {
|
struct common_params_speculative {
|
||||||
@@ -274,8 +307,8 @@ struct lr_opt {
|
|||||||
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
|
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
|
||||||
|
|
||||||
struct common_params {
|
struct common_params {
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // max. number of new tokens to predict, -1 == no limit
|
||||||
int32_t n_ctx = 4096; // context size
|
int32_t n_ctx = 0; // context size, 0 == context the model was trained with
|
||||||
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
@@ -296,9 +329,12 @@ struct common_params {
|
|||||||
// offload params
|
// offload params
|
||||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||||
|
|
||||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||||
|
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
|
||||||
|
size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory
|
||||||
|
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
|
||||||
|
|
||||||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||||
|
|
||||||
@@ -344,7 +380,7 @@ struct common_params {
|
|||||||
|
|
||||||
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
|
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
|
||||||
|
|
||||||
int32_t verbosity = 0;
|
int32_t verbosity = 3; // LOG_LEVEL_INFO
|
||||||
int32_t control_vector_layer_start = -1; // layer range for control vector
|
int32_t control_vector_layer_start = -1; // layer range for control vector
|
||||||
int32_t control_vector_layer_end = -1; // layer range for control vector
|
int32_t control_vector_layer_end = -1; // layer range for control vector
|
||||||
bool offline = false;
|
bool offline = false;
|
||||||
@@ -378,6 +414,7 @@ struct common_params {
|
|||||||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||||
bool no_perf = false; // disable performance metrics
|
bool no_perf = false; // disable performance metrics
|
||||||
|
bool show_timings = true; // show timing information on CLI
|
||||||
bool ctx_shift = false; // context shift on infinite text generation
|
bool ctx_shift = false; // context shift on infinite text generation
|
||||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||||
bool kv_unified = false; // enable unified KV cache
|
bool kv_unified = false; // enable unified KV cache
|
||||||
@@ -406,6 +443,8 @@ struct common_params {
|
|||||||
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
||||||
bool no_mmproj = false; // explicitly disable multimodal model
|
bool no_mmproj = false; // explicitly disable multimodal model
|
||||||
std::vector<std::string> image; // path to image file(s)
|
std::vector<std::string> image; // path to image file(s)
|
||||||
|
int image_min_tokens = -1;
|
||||||
|
int image_max_tokens = -1;
|
||||||
|
|
||||||
// finetune
|
// finetune
|
||||||
struct lr_opt lr;
|
struct lr_opt lr;
|
||||||
@@ -432,7 +471,7 @@ struct common_params {
|
|||||||
std::string public_path = ""; // NOLINT
|
std::string public_path = ""; // NOLINT
|
||||||
std::string api_prefix = ""; // NOLINT
|
std::string api_prefix = ""; // NOLINT
|
||||||
std::string chat_template = ""; // NOLINT
|
std::string chat_template = ""; // NOLINT
|
||||||
bool use_jinja = false; // NOLINT
|
bool use_jinja = true; // NOLINT
|
||||||
bool enable_chat_template = true;
|
bool enable_chat_template = true;
|
||||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||||
int reasoning_budget = -1;
|
int reasoning_budget = -1;
|
||||||
@@ -451,14 +490,22 @@ struct common_params {
|
|||||||
bool endpoint_props = false; // only control POST requests, not GET
|
bool endpoint_props = false; // only control POST requests, not GET
|
||||||
bool endpoint_metrics = false;
|
bool endpoint_metrics = false;
|
||||||
|
|
||||||
|
// router server configs
|
||||||
|
std::string models_dir = ""; // directory containing models for the router server
|
||||||
|
std::string models_preset = ""; // directory containing model presets for the router server
|
||||||
|
int models_max = 4; // maximum number of models to load simultaneously
|
||||||
|
bool models_autoload = true; // automatically load models when requested via the router server
|
||||||
|
|
||||||
bool log_json = false;
|
bool log_json = false;
|
||||||
|
|
||||||
std::string slot_save_path;
|
std::string slot_save_path;
|
||||||
|
std::string media_path; // path to directory for loading media files
|
||||||
|
|
||||||
float slot_prompt_similarity = 0.1f;
|
float slot_prompt_similarity = 0.1f;
|
||||||
|
|
||||||
// batched-bench params
|
// batched-bench params
|
||||||
bool is_pp_shared = false;
|
bool is_pp_shared = false;
|
||||||
|
bool is_tg_separate = false;
|
||||||
|
|
||||||
std::vector<int32_t> n_pp;
|
std::vector<int32_t> n_pp;
|
||||||
std::vector<int32_t> n_tg;
|
std::vector<int32_t> n_tg;
|
||||||
@@ -505,6 +552,10 @@ struct common_params {
|
|||||||
// return false from callback to abort model loading or true to continue
|
// return false from callback to abort model loading or true to continue
|
||||||
llama_progress_callback load_progress_callback = NULL;
|
llama_progress_callback load_progress_callback = NULL;
|
||||||
void * load_progress_callback_user_data = NULL;
|
void * load_progress_callback_user_data = NULL;
|
||||||
|
|
||||||
|
bool has_speculative() const {
|
||||||
|
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// call once at the start of a program if it uses libcommon
|
// call once at the start of a program if it uses libcommon
|
||||||
@@ -599,25 +650,55 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
|
|||||||
// Filesystem utils
|
// Filesystem utils
|
||||||
//
|
//
|
||||||
|
|
||||||
bool fs_validate_filename(const std::string & filename);
|
bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false);
|
||||||
bool fs_create_directory_with_parents(const std::string & path);
|
bool fs_create_directory_with_parents(const std::string & path);
|
||||||
|
bool fs_is_directory(const std::string & path);
|
||||||
|
|
||||||
std::string fs_get_cache_directory();
|
std::string fs_get_cache_directory();
|
||||||
std::string fs_get_cache_file(const std::string & filename);
|
std::string fs_get_cache_file(const std::string & filename);
|
||||||
|
|
||||||
|
struct common_file_info {
|
||||||
|
std::string path;
|
||||||
|
std::string name;
|
||||||
|
size_t size = 0; // in bytes
|
||||||
|
bool is_dir = false;
|
||||||
|
};
|
||||||
|
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
|
||||||
|
|
||||||
|
//
|
||||||
|
// TTY utils
|
||||||
|
//
|
||||||
|
|
||||||
|
// Auto-detect if colors can be enabled based on terminal and environment
|
||||||
|
bool tty_can_use_colors();
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model utils
|
// Model utils
|
||||||
//
|
//
|
||||||
|
|
||||||
// note: defines object's lifetime
|
struct common_sampler;
|
||||||
struct common_init_result {
|
|
||||||
llama_model_ptr model;
|
|
||||||
llama_context_ptr context;
|
|
||||||
|
|
||||||
std::vector<llama_adapter_lora_ptr> lora;
|
// note: defines the model, context, samplers, ets. lifetimes
|
||||||
|
struct common_init_result {
|
||||||
|
common_init_result(common_params & params);
|
||||||
|
~common_init_result();
|
||||||
|
|
||||||
|
llama_model * model();
|
||||||
|
llama_context * context();
|
||||||
|
common_sampler * sampler(llama_seq_id seq_id);
|
||||||
|
|
||||||
|
std::vector<llama_adapter_lora_ptr> & lora();
|
||||||
|
|
||||||
|
void free_context();
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct impl;
|
||||||
|
std::unique_ptr<impl> pimpl;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_init_result common_init_from_params(common_params & params);
|
using common_init_result_ptr = std::unique_ptr<common_init_result>;
|
||||||
|
|
||||||
|
common_init_result_ptr common_init_from_params(common_params & params);
|
||||||
|
|
||||||
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
||||||
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
||||||
|
|||||||
165
llama/llama.cpp/common/json-schema-to-grammar.cpp
vendored
165
llama/llama.cpp/common/json-schema-to-grammar.cpp
vendored
@@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
|
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
|
||||||
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
|
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
|
||||||
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
|
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
|
||||||
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
|
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
|
||||||
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
|
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
||||||
@@ -303,8 +303,11 @@ static std::string format_literal(const std::string & literal) {
|
|||||||
return "\"" + escaped + "\"";
|
return "\"" + escaped + "\"";
|
||||||
}
|
}
|
||||||
|
|
||||||
class SchemaConverter {
|
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
|
||||||
|
|
||||||
|
class common_schema_converter {
|
||||||
private:
|
private:
|
||||||
|
friend class common_schema_info;
|
||||||
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
||||||
std::function<json(const std::string &)> _fetch_json;
|
std::function<json(const std::string &)> _fetch_json;
|
||||||
bool _dotall;
|
bool _dotall;
|
||||||
@@ -601,7 +604,10 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string _resolve_ref(const std::string & ref) {
|
std::string _resolve_ref(const std::string & ref) {
|
||||||
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
|
auto it = ref.find('#');
|
||||||
|
std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref;
|
||||||
|
static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)");
|
||||||
|
std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-");
|
||||||
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
|
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
|
||||||
_refs_being_resolved.insert(ref);
|
_refs_being_resolved.insert(ref);
|
||||||
json resolved = _refs[ref];
|
json resolved = _refs[ref];
|
||||||
@@ -724,7 +730,7 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SchemaConverter(
|
common_schema_converter(
|
||||||
const std::function<json(const std::string &)> & fetch_json,
|
const std::function<json(const std::string &)> & fetch_json,
|
||||||
bool dotall)
|
bool dotall)
|
||||||
: _fetch_json(fetch_json), _dotall(dotall)
|
: _fetch_json(fetch_json), _dotall(dotall)
|
||||||
@@ -774,11 +780,24 @@ public:
|
|||||||
std::vector<std::string> tokens = string_split(pointer, "/");
|
std::vector<std::string> tokens = string_split(pointer, "/");
|
||||||
for (size_t i = 1; i < tokens.size(); ++i) {
|
for (size_t i = 1; i < tokens.size(); ++i) {
|
||||||
std::string sel = tokens[i];
|
std::string sel = tokens[i];
|
||||||
if (target.is_null() || !target.contains(sel)) {
|
if (target.is_object() && target.contains(sel)) {
|
||||||
|
target = target[sel];
|
||||||
|
} else if (target.is_array()) {
|
||||||
|
size_t sel_index;
|
||||||
|
try {
|
||||||
|
sel_index = std::stoul(sel);
|
||||||
|
} catch (const std::invalid_argument & e) {
|
||||||
|
sel_index = target.size();
|
||||||
|
}
|
||||||
|
if (sel_index >= target.size()) {
|
||||||
|
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
target = target[sel_index];
|
||||||
|
} else {
|
||||||
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
target = target[sel];
|
|
||||||
}
|
}
|
||||||
_refs[ref] = target;
|
_refs[ref] = target;
|
||||||
}
|
}
|
||||||
@@ -956,7 +975,7 @@ public:
|
|||||||
|
|
||||||
void check_errors() {
|
void check_errors() {
|
||||||
if (!_errors.empty()) {
|
if (!_errors.empty()) {
|
||||||
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
|
throw std::invalid_argument("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
|
||||||
}
|
}
|
||||||
if (!_warnings.empty()) {
|
if (!_warnings.empty()) {
|
||||||
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
|
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
|
||||||
@@ -972,6 +991,134 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// common_schema_info implementation (pimpl)
|
||||||
|
|
||||||
|
common_schema_info::common_schema_info()
|
||||||
|
: impl_(std::make_unique<common_schema_converter>(
|
||||||
|
[](const std::string &) { return json(); },
|
||||||
|
false)) {}
|
||||||
|
|
||||||
|
common_schema_info::~common_schema_info() = default;
|
||||||
|
|
||||||
|
common_schema_info::common_schema_info(common_schema_info &&) noexcept = default;
|
||||||
|
common_schema_info & common_schema_info::operator=(common_schema_info &&) noexcept = default;
|
||||||
|
|
||||||
|
void common_schema_info::resolve_refs(nlohmann::ordered_json & schema) {
|
||||||
|
impl_->resolve_refs(schema, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determines if a JSON schema can resolve to a string type through any path.
|
||||||
|
// Some models emit raw string values rather than JSON-encoded strings for string parameters.
|
||||||
|
// If any branch of the schema (via oneOf, anyOf, $ref, etc.) permits a string, this returns
|
||||||
|
// true, allowing callers to handle the value as a raw string for simplicity.
|
||||||
|
bool common_schema_info::resolves_to_string(const nlohmann::ordered_json & schema) {
|
||||||
|
std::unordered_set<std::string> visited_refs;
|
||||||
|
|
||||||
|
std::function<bool(const json &)> check = [&](const json & s) -> bool {
|
||||||
|
if (!s.is_object()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle $ref
|
||||||
|
if (s.contains("$ref")) {
|
||||||
|
const std::string & ref = s["$ref"];
|
||||||
|
if (visited_refs.find(ref) != visited_refs.end()) {
|
||||||
|
// Circular reference, assume not a string to be safe
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
visited_refs.insert(ref);
|
||||||
|
auto it = impl_->_refs.find(ref);
|
||||||
|
if (it != impl_->_refs.end()) {
|
||||||
|
return check(it->second);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check type field
|
||||||
|
if (s.contains("type")) {
|
||||||
|
const json & schema_type = s["type"];
|
||||||
|
if (schema_type.is_string()) {
|
||||||
|
if (schema_type == "string") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else if (schema_type.is_array()) {
|
||||||
|
// Type can be an array like ["string", "null"]
|
||||||
|
for (const auto & t : schema_type) {
|
||||||
|
if (t == "string") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check oneOf/anyOf - if any alternative can be a string
|
||||||
|
if (s.contains("oneOf")) {
|
||||||
|
for (const auto & alt : s["oneOf"]) {
|
||||||
|
if (check(alt)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (s.contains("anyOf")) {
|
||||||
|
for (const auto & alt : s["anyOf"]) {
|
||||||
|
if (check(alt)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check allOf - all components must be compatible with string type
|
||||||
|
if (s.contains("allOf")) {
|
||||||
|
bool all_string = true;
|
||||||
|
for (const auto & component : s["allOf"]) {
|
||||||
|
if (!check(component)) {
|
||||||
|
all_string = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (all_string) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check const - if the constant value is a string
|
||||||
|
if (s.contains("const")) {
|
||||||
|
if (s["const"].is_string()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check enum - if any enum value is a string
|
||||||
|
if (s.contains("enum")) {
|
||||||
|
for (const auto & val : s["enum"]) {
|
||||||
|
if (val.is_string()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String-specific keywords imply string type
|
||||||
|
if (s.contains("pattern") || s.contains("minLength") || s.contains("maxLength")) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check format - many formats imply string
|
||||||
|
if (s.contains("format")) {
|
||||||
|
const std::string & fmt = s["format"];
|
||||||
|
if (fmt == "date" || fmt == "time" || fmt == "date-time" ||
|
||||||
|
fmt == "uri" || fmt == "email" || fmt == "hostname" ||
|
||||||
|
fmt == "ipv4" || fmt == "ipv6" || fmt == "uuid" ||
|
||||||
|
fmt.find("uuid") == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
return check(schema);
|
||||||
|
}
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
||||||
#ifdef LLAMA_USE_LLGUIDANCE
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
if (!force_gbnf) {
|
if (!force_gbnf) {
|
||||||
@@ -988,7 +1135,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
|
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
|
||||||
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
|
common_schema_converter converter([&](const std::string &) { return json(); }, options.dotall);
|
||||||
common_grammar_builder builder {
|
common_grammar_builder builder {
|
||||||
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
||||||
return converter._add_rule(name, rule);
|
return converter._add_rule(name, rule);
|
||||||
|
|||||||
22
llama/llama.cpp/common/json-schema-to-grammar.h
vendored
22
llama/llama.cpp/common/json-schema-to-grammar.h
vendored
@@ -3,11 +3,31 @@
|
|||||||
#include <nlohmann/json_fwd.hpp>
|
#include <nlohmann/json_fwd.hpp>
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
||||||
bool force_gbnf = false);
|
bool force_gbnf = false);
|
||||||
|
|
||||||
|
class common_schema_converter;
|
||||||
|
|
||||||
|
// Probes a JSON schema to extract information about its structure and type constraints.
|
||||||
|
class common_schema_info {
|
||||||
|
std::unique_ptr<common_schema_converter> impl_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
common_schema_info();
|
||||||
|
~common_schema_info();
|
||||||
|
|
||||||
|
common_schema_info(const common_schema_info &) = delete;
|
||||||
|
common_schema_info & operator=(const common_schema_info &) = delete;
|
||||||
|
common_schema_info(common_schema_info &&) noexcept;
|
||||||
|
common_schema_info & operator=(common_schema_info &&) noexcept;
|
||||||
|
|
||||||
|
void resolve_refs(nlohmann::ordered_json & schema);
|
||||||
|
bool resolves_to_string(const nlohmann::ordered_json & schema);
|
||||||
|
};
|
||||||
|
|
||||||
struct common_grammar_builder {
|
struct common_grammar_builder {
|
||||||
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
||||||
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
||||||
@@ -18,4 +38,6 @@ struct common_grammar_options {
|
|||||||
bool dotall = false;
|
bool dotall = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::string gbnf_format_literal(const std::string & literal);
|
||||||
|
|
||||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||||
|
|||||||
54
llama/llama.cpp/common/log.cpp
vendored
54
llama/llama.cpp/common/log.cpp
vendored
@@ -1,3 +1,4 @@
|
|||||||
|
#include "common.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
@@ -26,30 +27,6 @@ void common_log_set_verbosity_thold(int verbosity) {
|
|||||||
common_log_verbosity_thold = verbosity;
|
common_log_verbosity_thold = verbosity;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auto-detect if colors should be enabled based on terminal and environment
|
|
||||||
static bool common_log_should_use_colors_auto() {
|
|
||||||
// Check NO_COLOR environment variable (https://no-color.org/)
|
|
||||||
if (const char * no_color = std::getenv("NO_COLOR")) {
|
|
||||||
if (no_color[0] != '\0') {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check TERM environment variable
|
|
||||||
if (const char * term = std::getenv("TERM")) {
|
|
||||||
if (std::strcmp(term, "dumb") == 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if stdout and stderr are connected to a terminal
|
|
||||||
// We check both because log messages can go to either
|
|
||||||
bool stdout_is_tty = isatty(fileno(stdout));
|
|
||||||
bool stderr_is_tty = isatty(fileno(stderr));
|
|
||||||
|
|
||||||
return stdout_is_tty || stderr_is_tty;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int64_t t_us() {
|
static int64_t t_us() {
|
||||||
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||||
}
|
}
|
||||||
@@ -391,7 +368,7 @@ struct common_log * common_log_main() {
|
|||||||
static std::once_flag init_flag;
|
static std::once_flag init_flag;
|
||||||
std::call_once(init_flag, [&]() {
|
std::call_once(init_flag, [&]() {
|
||||||
// Set default to auto-detect colors
|
// Set default to auto-detect colors
|
||||||
log.set_colors(common_log_should_use_colors_auto());
|
log.set_colors(tty_can_use_colors());
|
||||||
});
|
});
|
||||||
|
|
||||||
return &log;
|
return &log;
|
||||||
@@ -422,7 +399,7 @@ void common_log_set_file(struct common_log * log, const char * file) {
|
|||||||
|
|
||||||
void common_log_set_colors(struct common_log * log, log_colors colors) {
|
void common_log_set_colors(struct common_log * log, log_colors colors) {
|
||||||
if (colors == LOG_COLORS_AUTO) {
|
if (colors == LOG_COLORS_AUTO) {
|
||||||
log->set_colors(common_log_should_use_colors_auto());
|
log->set_colors(tty_can_use_colors());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -442,3 +419,28 @@ void common_log_set_prefix(struct common_log * log, bool prefix) {
|
|||||||
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
||||||
log->set_timestamps(timestamps);
|
log->set_timestamps(timestamps);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void common_log_flush(struct common_log * log) {
|
||||||
|
log->pause();
|
||||||
|
log->resume();
|
||||||
|
}
|
||||||
|
|
||||||
|
static int common_get_verbosity(enum ggml_log_level level) {
|
||||||
|
switch (level) {
|
||||||
|
case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG;
|
||||||
|
case GGML_LOG_LEVEL_INFO: return LOG_LEVEL_INFO;
|
||||||
|
case GGML_LOG_LEVEL_WARN: return LOG_LEVEL_WARN;
|
||||||
|
case GGML_LOG_LEVEL_ERROR: return LOG_LEVEL_ERROR;
|
||||||
|
case GGML_LOG_LEVEL_CONT: return LOG_LEVEL_INFO; // same as INFO
|
||||||
|
case GGML_LOG_LEVEL_NONE:
|
||||||
|
default:
|
||||||
|
return LOG_LEVEL_OUTPUT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||||
|
auto verbosity = common_get_verbosity(level);
|
||||||
|
if (verbosity <= common_log_verbosity_thold) {
|
||||||
|
common_log_add(common_log_main(), level, "%s", text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
34
llama/llama.cpp/common/log.h
vendored
34
llama/llama.cpp/common/log.h
vendored
@@ -21,8 +21,14 @@
|
|||||||
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define LOG_DEFAULT_DEBUG 1
|
#define LOG_LEVEL_DEBUG 4
|
||||||
#define LOG_DEFAULT_LLAMA 0
|
#define LOG_LEVEL_INFO 3
|
||||||
|
#define LOG_LEVEL_WARN 2
|
||||||
|
#define LOG_LEVEL_ERROR 1
|
||||||
|
#define LOG_LEVEL_OUTPUT 0 // output data from tools
|
||||||
|
|
||||||
|
#define LOG_DEFAULT_DEBUG LOG_LEVEL_DEBUG
|
||||||
|
#define LOG_DEFAULT_LLAMA LOG_LEVEL_INFO
|
||||||
|
|
||||||
enum log_colors {
|
enum log_colors {
|
||||||
LOG_COLORS_AUTO = -1,
|
LOG_COLORS_AUTO = -1,
|
||||||
@@ -36,6 +42,8 @@ extern int common_log_verbosity_thold;
|
|||||||
|
|
||||||
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
||||||
|
|
||||||
|
void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
|
||||||
|
|
||||||
// the common_log uses an internal worker thread to print/write log messages
|
// the common_log uses an internal worker thread to print/write log messages
|
||||||
// when the worker thread is paused, incoming log messages are discarded
|
// when the worker thread is paused, incoming log messages are discarded
|
||||||
struct common_log;
|
struct common_log;
|
||||||
@@ -65,16 +73,18 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch
|
|||||||
// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU
|
// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU
|
||||||
// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU
|
// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU
|
||||||
//
|
//
|
||||||
// I - info (stdout, V = 0)
|
|
||||||
// W - warning (stderr, V = 0)
|
|
||||||
// E - error (stderr, V = 0)
|
|
||||||
// D - debug (stderr, V = LOG_DEFAULT_DEBUG)
|
// D - debug (stderr, V = LOG_DEFAULT_DEBUG)
|
||||||
|
// I - info (stdout, V = LOG_DEFAULT_INFO)
|
||||||
|
// W - warning (stderr, V = LOG_DEFAULT_WARN)
|
||||||
|
// E - error (stderr, V = LOG_DEFAULT_ERROR)
|
||||||
|
// O - output (stdout, V = LOG_DEFAULT_OUTPUT)
|
||||||
//
|
//
|
||||||
|
|
||||||
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
|
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
|
||||||
void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe
|
void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe
|
||||||
void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
|
void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
|
||||||
void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
|
void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
|
||||||
|
void common_log_flush (struct common_log * log); // flush all pending log messages
|
||||||
|
|
||||||
// helper macros for logging
|
// helper macros for logging
|
||||||
// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
|
// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
|
||||||
@@ -93,14 +103,14 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps); // w
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, 0, __VA_ARGS__)
|
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, LOG_LEVEL_OUTPUT, __VA_ARGS__)
|
||||||
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
|
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
|
||||||
|
|
||||||
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, 0, __VA_ARGS__)
|
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||||
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__)
|
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, LOG_LEVEL_INFO, __VA_ARGS__)
|
||||||
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__)
|
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, LOG_LEVEL_WARN, __VA_ARGS__)
|
||||||
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__)
|
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||||
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__)
|
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, LOG_LEVEL_INFO, __VA_ARGS__) // same as INFO
|
||||||
|
|
||||||
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
|
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
|
||||||
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)
|
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)
|
||||||
|
|||||||
249
llama/llama.cpp/common/sampling.cpp
vendored
249
llama/llama.cpp/common/sampling.cpp
vendored
@@ -3,9 +3,10 @@
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstring>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||||
// TODO: deduplicate with llama-impl.h
|
// TODO: deduplicate with llama-impl.h
|
||||||
@@ -103,15 +104,22 @@ struct ring_buffer {
|
|||||||
struct common_sampler {
|
struct common_sampler {
|
||||||
common_params_sampling params;
|
common_params_sampling params;
|
||||||
|
|
||||||
struct llama_sampler * grmr;
|
|
||||||
struct llama_sampler * chain;
|
struct llama_sampler * chain;
|
||||||
|
|
||||||
|
bool grammar;
|
||||||
|
|
||||||
ring_buffer<llama_token> prev;
|
ring_buffer<llama_token> prev;
|
||||||
|
|
||||||
std::vector<llama_token_data> cur;
|
std::vector<llama_token_data> cur;
|
||||||
|
|
||||||
llama_token_data_array cur_p;
|
llama_token_data_array cur_p;
|
||||||
|
|
||||||
|
void reset() {
|
||||||
|
prev.clear();
|
||||||
|
|
||||||
|
llama_sampler_reset(chain);
|
||||||
|
}
|
||||||
|
|
||||||
void set_logits(struct llama_context * ctx, int idx) {
|
void set_logits(struct llama_context * ctx, int idx) {
|
||||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||||
|
|
||||||
@@ -128,6 +136,12 @@ struct common_sampler {
|
|||||||
|
|
||||||
cur_p = { cur.data(), cur.size(), -1, false };
|
cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
common_time_meas tm() {
|
||||||
|
return common_time_meas(t_total_us, params.no_perf);
|
||||||
|
}
|
||||||
|
|
||||||
|
mutable int64_t t_total_us = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string common_params_sampling::print() const {
|
std::string common_params_sampling::print() const {
|
||||||
@@ -153,10 +167,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
|
|
||||||
lparams.no_perf = params.no_perf;
|
lparams.no_perf = params.no_perf;
|
||||||
|
|
||||||
struct llama_sampler * grmr;
|
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
||||||
|
|
||||||
|
bool grammar = false;
|
||||||
|
std::vector<llama_sampler *> samplers;
|
||||||
|
|
||||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||||
#ifdef LLAMA_USE_LLGUIDANCE
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
|
||||||
|
grammar = true;
|
||||||
#else
|
#else
|
||||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||||
#endif // LLAMA_USE_LLGUIDANCE
|
#endif // LLAMA_USE_LLGUIDANCE
|
||||||
@@ -203,30 +222,23 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
trigger_patterns_c.push_back(regex.c_str());
|
trigger_patterns_c.push_back(regex.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
grmr = params.grammar_lazy
|
if (!params.grammar.empty()) {
|
||||||
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
if (params.grammar_lazy) {
|
||||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
samplers.push_back(
|
||||||
trigger_tokens.data(), trigger_tokens.size())
|
llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||||
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||||
if (!grmr) {
|
trigger_tokens.data(), trigger_tokens.size()));
|
||||||
return nullptr;
|
} else {
|
||||||
|
samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
|
||||||
|
}
|
||||||
|
|
||||||
|
grammar = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * result = new common_sampler {
|
if (params.has_logit_bias()) {
|
||||||
/* .params = */ params,
|
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
|
||||||
/* .grmr = */ grmr,
|
}
|
||||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
|
||||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
|
||||||
/* .cur = */ {},
|
|
||||||
/* .cur_p = */ {},
|
|
||||||
};
|
|
||||||
|
|
||||||
llama_sampler_chain_add(result->chain,
|
|
||||||
llama_sampler_init_logit_bias(
|
|
||||||
llama_vocab_n_tokens(vocab),
|
|
||||||
params.logit_bias.size(),
|
|
||||||
params.logit_bias.data()));
|
|
||||||
|
|
||||||
if (params.mirostat == 0) {
|
if (params.mirostat == 0) {
|
||||||
for (const auto & cnstr : params.samplers) {
|
for (const auto & cnstr : params.samplers) {
|
||||||
@@ -239,58 +251,70 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
c_breakers.push_back(str.c_str());
|
c_breakers.push_back(str.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
samplers.push_back(llama_sampler_init_top_k (params.top_k));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_XTC:
|
case COMMON_SAMPLER_TYPE_XTC:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_INFILL:
|
case COMMON_SAMPLER_TYPE_INFILL:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
samplers.push_back(llama_sampler_init_infill (vocab));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unknown sampler type");
|
GGML_ASSERT(false && "unknown sampler type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
|
||||||
|
samplers.push_back(llama_sampler_init_dist(params.seed));
|
||||||
} else if (params.mirostat == 1) {
|
} else if (params.mirostat == 1) {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||||
} else if (params.mirostat == 2) {
|
} else if (params.mirostat == 2) {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false && "unknown mirostat version");
|
GGML_ASSERT(false && "unknown mirostat version");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto * smpl : samplers) {
|
||||||
|
llama_sampler_chain_add(chain, smpl);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto * result = new common_sampler {
|
||||||
|
/* .params = */ params,
|
||||||
|
/* .chain = */ chain,
|
||||||
|
/* .grammar = */ grammar,
|
||||||
|
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||||
|
/* .cur = */ {},
|
||||||
|
/* .cur_p = */ {},
|
||||||
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void common_sampler_free(struct common_sampler * gsmpl) {
|
void common_sampler_free(struct common_sampler * gsmpl) {
|
||||||
if (gsmpl) {
|
if (gsmpl) {
|
||||||
llama_sampler_free(gsmpl->grmr);
|
|
||||||
|
|
||||||
llama_sampler_free(gsmpl->chain);
|
llama_sampler_free(gsmpl->chain);
|
||||||
|
|
||||||
delete gsmpl;
|
delete gsmpl;
|
||||||
@@ -298,91 +322,117 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||||
if (accept_grammar) {
|
const auto tm = gsmpl->tm();
|
||||||
llama_sampler_accept(gsmpl->grmr, token);
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_sampler_accept(gsmpl->chain, token);
|
if (gsmpl->grammar) {
|
||||||
|
const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_smpl; i++) {
|
||||||
|
auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||||
|
|
||||||
|
// the grammar sampler is always the first one
|
||||||
|
if (i == 0) {
|
||||||
|
if (accept_grammar) {
|
||||||
|
llama_sampler_accept(smpl, token);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
llama_sampler_accept(smpl, token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
llama_sampler_accept(gsmpl->chain, token);
|
||||||
|
}
|
||||||
|
|
||||||
gsmpl->prev.push_back(token);
|
gsmpl->prev.push_back(token);
|
||||||
}
|
}
|
||||||
|
|
||||||
void common_sampler_reset(struct common_sampler * gsmpl) {
|
void common_sampler_reset(struct common_sampler * gsmpl) {
|
||||||
llama_sampler_reset(gsmpl->grmr);
|
gsmpl->reset();
|
||||||
|
|
||||||
llama_sampler_reset(gsmpl->chain);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||||
return new common_sampler {
|
return new common_sampler {
|
||||||
/* .params = */ gsmpl->params,
|
/* .params = */ gsmpl->params,
|
||||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
/* .grammar = */ gsmpl->grammar,
|
||||||
/* .prev = */ gsmpl->prev,
|
/* .prev = */ gsmpl->prev,
|
||||||
/* .cur = */ gsmpl->cur,
|
/* .cur = */ gsmpl->cur,
|
||||||
/* .cur_p = */ gsmpl->cur_p,
|
/* .cur_p = */ gsmpl->cur_p,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
||||||
// TODO: measure grammar performance
|
// TODO: measure grammar performance
|
||||||
|
|
||||||
|
const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0;
|
||||||
|
|
||||||
|
llama_perf_sampler_data data_smpl;
|
||||||
|
llama_perf_context_data data_ctx;
|
||||||
|
|
||||||
|
memset(&data_smpl, 0, sizeof(data_smpl));
|
||||||
|
memset(&data_ctx, 0, sizeof(data_ctx));
|
||||||
|
|
||||||
if (gsmpl) {
|
if (gsmpl) {
|
||||||
llama_perf_sampler_print(gsmpl->chain);
|
auto & data = data_smpl;
|
||||||
|
|
||||||
|
data = llama_perf_sampler(gsmpl->chain);
|
||||||
|
|
||||||
|
// note: the sampling time includes the samplers time + extra time spent in common/sampling
|
||||||
|
LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms);
|
||||||
|
LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx) {
|
if (ctx) {
|
||||||
llama_perf_context_print(ctx);
|
auto & data = data_ctx;
|
||||||
|
|
||||||
|
data = llama_perf_context(ctx);
|
||||||
|
|
||||||
|
const double t_end_ms = 1e-3 * ggml_time_us();
|
||||||
|
|
||||||
|
const double t_total_ms = t_end_ms - data.t_start_ms;
|
||||||
|
const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms);
|
||||||
|
const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms;
|
||||||
|
|
||||||
|
LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
|
||||||
|
LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
|
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
|
||||||
|
LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
|
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
||||||
|
LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
||||||
|
LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc);
|
||||||
|
LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused);
|
||||||
|
|
||||||
llama_memory_breakdown_print(ctx);
|
llama_memory_breakdown_print(ctx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
|
||||||
gsmpl->set_logits(ctx, idx);
|
return gsmpl->chain;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
|
||||||
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
|
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
||||||
|
const auto tm = gsmpl->tm();
|
||||||
|
|
||||||
|
llama_token id = LLAMA_TOKEN_NULL;
|
||||||
|
|
||||||
auto & grmr = gsmpl->grmr;
|
|
||||||
auto & chain = gsmpl->chain;
|
auto & chain = gsmpl->chain;
|
||||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||||
|
|
||||||
if (grammar_first) {
|
gsmpl->set_logits(ctx, idx);
|
||||||
llama_sampler_apply(grmr, &cur_p);
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_sampler_apply(chain, &cur_p);
|
llama_sampler_apply(chain, &cur_p);
|
||||||
|
|
||||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||||
|
|
||||||
const llama_token id = cur_p.data[cur_p.selected].id;
|
id = cur_p.data[cur_p.selected].id;
|
||||||
|
|
||||||
if (grammar_first) {
|
return id;
|
||||||
return id;
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if it the sampled token fits the grammar
|
|
||||||
{
|
|
||||||
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
|
||||||
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
|
||||||
|
|
||||||
llama_sampler_apply(grmr, &single_token_data_array);
|
|
||||||
|
|
||||||
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
|
||||||
if (is_valid) {
|
|
||||||
return id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// resampling:
|
|
||||||
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
|
||||||
gsmpl->set_logits(ctx, idx);
|
|
||||||
|
|
||||||
llama_sampler_apply(grmr, &cur_p);
|
|
||||||
llama_sampler_apply(chain, &cur_p);
|
|
||||||
|
|
||||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
|
|
||||||
|
|
||||||
return cur_p.data[cur_p.selected].id;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
|
||||||
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
||||||
|
|
||||||
std::vector<llama_token> result;
|
std::vector<llama_token> result;
|
||||||
@@ -390,7 +440,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
|||||||
|
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
for (; i < draft.size(); i++) {
|
for (; i < draft.size(); i++) {
|
||||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
||||||
|
|
||||||
common_sampler_accept(gsmpl, id, true);
|
common_sampler_accept(gsmpl, id, true);
|
||||||
|
|
||||||
@@ -402,7 +452,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (i == draft.size()) {
|
if (i == draft.size()) {
|
||||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
||||||
|
|
||||||
common_sampler_accept(gsmpl, id, true);
|
common_sampler_accept(gsmpl, id, true);
|
||||||
|
|
||||||
@@ -412,13 +462,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
|
||||||
std::vector<int> idxs(draft.size() + 1);
|
std::vector<int> idxs(draft.size() + 1);
|
||||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||||
idxs[i] = i;
|
idxs[i] = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
|
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||||
@@ -428,6 +478,8 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
|||||||
// helpers
|
// helpers
|
||||||
|
|
||||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
||||||
|
const auto tm = gsmpl->tm();
|
||||||
|
|
||||||
auto * res = &gsmpl->cur_p;
|
auto * res = &gsmpl->cur_p;
|
||||||
|
|
||||||
if (do_sort && !res->sorted) {
|
if (do_sort && !res->sorted) {
|
||||||
@@ -461,7 +513,8 @@ std::string common_sampler_print(const struct common_sampler * gsmpl) {
|
|||||||
|
|
||||||
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
||||||
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||||
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
|
result += std::string("-> ");
|
||||||
|
result += std::string(llama_sampler_name(smpl)) + " ";
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|||||||
17
llama/llama.cpp/common/sampling.h
vendored
17
llama/llama.cpp/common/sampling.h
vendored
@@ -48,6 +48,8 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
|||||||
// arguments can be nullptr to skip printing
|
// arguments can be nullptr to skip printing
|
||||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
|
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
|
||||||
|
|
||||||
|
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
|
||||||
|
|
||||||
// extended sampling implementation:
|
// extended sampling implementation:
|
||||||
//
|
//
|
||||||
// - set logits
|
// - set logits
|
||||||
@@ -55,10 +57,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
|
|||||||
// - check if the token fits the grammar (if any)
|
// - check if the token fits the grammar (if any)
|
||||||
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
||||||
//
|
//
|
||||||
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
|
||||||
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
|
|
||||||
//
|
|
||||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
|
||||||
|
|
||||||
// generalized version of common_sampler_sample
|
// generalized version of common_sampler_sample
|
||||||
//
|
//
|
||||||
@@ -76,10 +75,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
|||||||
//
|
//
|
||||||
// returns at least 1 token, up to idxs.size()
|
// returns at least 1 token, up to idxs.size()
|
||||||
//
|
//
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
|
||||||
|
|
||||||
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
||||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
|
||||||
|
|
||||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
||||||
|
|
||||||
@@ -107,3 +106,9 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:
|
|||||||
|
|
||||||
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
|
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
|
||||||
const char * grammar_kind, const char * grammar_data);
|
const char * grammar_kind, const char * grammar_data);
|
||||||
|
|
||||||
|
struct common_sampler_deleter {
|
||||||
|
void operator()(common_sampler * s) { common_sampler_free(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;
|
||||||
|
|||||||
47
llama/llama.cpp/include/llama.h
vendored
47
llama/llama.cpp/include/llama.h
vendored
@@ -83,6 +83,7 @@ extern "C" {
|
|||||||
LLAMA_ROPE_TYPE_NORM = 0,
|
LLAMA_ROPE_TYPE_NORM = 0,
|
||||||
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
|
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
|
||||||
LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE,
|
LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE,
|
||||||
|
LLAMA_ROPE_TYPE_IMROPE = GGML_ROPE_TYPE_IMROPE,
|
||||||
LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION,
|
LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -245,6 +246,21 @@ extern "C" {
|
|||||||
LLAMA_KV_OVERRIDE_TYPE_STR,
|
LLAMA_KV_OVERRIDE_TYPE_STR,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum llama_model_meta_key {
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_TOP_K,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_TOP_P,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_MIN_P,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_TEMP,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA,
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_model_kv_override {
|
struct llama_model_kv_override {
|
||||||
enum llama_model_kv_override_type tag;
|
enum llama_model_kv_override_type tag;
|
||||||
|
|
||||||
@@ -297,6 +313,7 @@ extern "C" {
|
|||||||
bool check_tensors; // validate model tensor data
|
bool check_tensors; // validate model tensor data
|
||||||
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
|
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
|
||||||
bool no_host; // bypass host buffer allowing extra buffers to be used
|
bool no_host; // bypass host buffer allowing extra buffers to be used
|
||||||
|
bool no_alloc; // only load metadata and simulate memory allocations
|
||||||
};
|
};
|
||||||
|
|
||||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||||
@@ -450,17 +467,35 @@ extern "C" {
|
|||||||
// Frees all allocated memory
|
// Frees all allocated memory
|
||||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||||
|
|
||||||
|
// fits mparams and cparams to free device memory (assumes system memory is unlimited)
|
||||||
|
// returns true if the parameters could be successfully modified to fit device memory
|
||||||
|
// this function is NOT thread safe because it modifies the global llama logger state
|
||||||
|
LLAMA_API bool llama_params_fit(
|
||||||
|
const char * path_model,
|
||||||
|
struct llama_model_params * mparams,
|
||||||
|
struct llama_context_params * cparams,
|
||||||
|
float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements
|
||||||
|
struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements
|
||||||
|
size_t margin, // margin of memory to leave per device in bytes
|
||||||
|
uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use
|
||||||
|
enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log
|
||||||
|
|
||||||
LLAMA_API int64_t llama_time_us(void);
|
LLAMA_API int64_t llama_time_us(void);
|
||||||
|
|
||||||
LLAMA_API size_t llama_max_devices(void);
|
LLAMA_API size_t llama_max_devices(void);
|
||||||
LLAMA_API size_t llama_max_parallel_sequences(void);
|
LLAMA_API size_t llama_max_parallel_sequences(void);
|
||||||
|
LLAMA_API size_t llama_max_tensor_buft_overrides(void);
|
||||||
|
|
||||||
LLAMA_API bool llama_supports_mmap (void);
|
LLAMA_API bool llama_supports_mmap (void);
|
||||||
LLAMA_API bool llama_supports_mlock (void);
|
LLAMA_API bool llama_supports_mlock (void);
|
||||||
LLAMA_API bool llama_supports_gpu_offload(void);
|
LLAMA_API bool llama_supports_gpu_offload(void);
|
||||||
LLAMA_API bool llama_supports_rpc (void);
|
LLAMA_API bool llama_supports_rpc (void);
|
||||||
|
|
||||||
|
// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
|
||||||
|
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
|
||||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||||
|
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||||
@@ -481,6 +516,7 @@ extern "C" {
|
|||||||
|
|
||||||
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
|
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
||||||
|
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||||
@@ -512,6 +548,9 @@ extern "C" {
|
|||||||
// Get the number of metadata key/value pairs
|
// Get the number of metadata key/value pairs
|
||||||
LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model);
|
LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model);
|
||||||
|
|
||||||
|
// Get sampling metadata key name. Returns nullptr if the key is invalid
|
||||||
|
LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key);
|
||||||
|
|
||||||
// Get metadata key name by index
|
// Get metadata key name by index
|
||||||
LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
|
LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
|
||||||
|
|
||||||
@@ -584,7 +623,7 @@ extern "C" {
|
|||||||
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
|
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
|
||||||
|
|
||||||
// Manually free a LoRA adapter
|
// Manually free a LoRA adapter
|
||||||
// Note: loaded adapters will be free when the associated model is deleted
|
// NOTE: loaded adapters will be free when the associated model is deleted
|
||||||
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
|
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
|
||||||
|
|
||||||
// Get the invocation tokens if the current lora is an alora
|
// Get the invocation tokens if the current lora is an alora
|
||||||
@@ -1110,8 +1149,6 @@ extern "C" {
|
|||||||
// // sample from the logits of the last token in the batch
|
// // sample from the logits of the last token in the batch
|
||||||
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
|
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
|
||||||
//
|
//
|
||||||
// // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
|
|
||||||
// llama_sampler_accept(smpl, id);
|
|
||||||
// ...
|
// ...
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
@@ -1332,7 +1369,9 @@ extern "C" {
|
|||||||
|
|
||||||
// Set callback for all future logging events.
|
// Set callback for all future logging events.
|
||||||
// If this is not called, or NULL is supplied, everything is output on stderr.
|
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||||
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
// The logger state is global so these functions are NOT thread safe.
|
||||||
|
LLAMA_API void llama_log_get(ggml_log_callback * log_callback, void ** user_data);
|
||||||
|
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Performance utils
|
// Performance utils
|
||||||
|
|||||||
4039
llama/llama.cpp/src/llama-arch.cpp
vendored
4039
llama/llama.cpp/src/llama-arch.cpp
vendored
File diff suppressed because it is too large
Load Diff
42
llama/llama.cpp/src/llama-arch.h
vendored
42
llama/llama.cpp/src/llama-arch.h
vendored
@@ -3,6 +3,7 @@
|
|||||||
#include "ggml.h" // ggml_op
|
#include "ggml.h" // ggml_op
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
//
|
//
|
||||||
// gguf constants (sync with gguf.py)
|
// gguf constants (sync with gguf.py)
|
||||||
@@ -36,6 +37,9 @@ enum llm_arch {
|
|||||||
LLM_ARCH_QWEN2VL,
|
LLM_ARCH_QWEN2VL,
|
||||||
LLM_ARCH_QWEN3,
|
LLM_ARCH_QWEN3,
|
||||||
LLM_ARCH_QWEN3MOE,
|
LLM_ARCH_QWEN3MOE,
|
||||||
|
LLM_ARCH_QWEN3NEXT,
|
||||||
|
LLM_ARCH_QWEN3VL,
|
||||||
|
LLM_ARCH_QWEN3VLMOE,
|
||||||
LLM_ARCH_PHI2,
|
LLM_ARCH_PHI2,
|
||||||
LLM_ARCH_PHI3,
|
LLM_ARCH_PHI3,
|
||||||
LLM_ARCH_PHIMOE,
|
LLM_ARCH_PHIMOE,
|
||||||
@@ -76,6 +80,7 @@ enum llm_arch {
|
|||||||
LLM_ARCH_JAIS,
|
LLM_ARCH_JAIS,
|
||||||
LLM_ARCH_NEMOTRON,
|
LLM_ARCH_NEMOTRON,
|
||||||
LLM_ARCH_NEMOTRON_H,
|
LLM_ARCH_NEMOTRON_H,
|
||||||
|
LLM_ARCH_NEMOTRON_H_MOE,
|
||||||
LLM_ARCH_EXAONE,
|
LLM_ARCH_EXAONE,
|
||||||
LLM_ARCH_EXAONE4,
|
LLM_ARCH_EXAONE4,
|
||||||
LLM_ARCH_RWKV6,
|
LLM_ARCH_RWKV6,
|
||||||
@@ -93,6 +98,7 @@ enum llm_arch {
|
|||||||
LLM_ARCH_BAILINGMOE2,
|
LLM_ARCH_BAILINGMOE2,
|
||||||
LLM_ARCH_DOTS1,
|
LLM_ARCH_DOTS1,
|
||||||
LLM_ARCH_ARCEE,
|
LLM_ARCH_ARCEE,
|
||||||
|
LLM_ARCH_AFMOE,
|
||||||
LLM_ARCH_ERNIE4_5,
|
LLM_ARCH_ERNIE4_5,
|
||||||
LLM_ARCH_ERNIE4_5_MOE,
|
LLM_ARCH_ERNIE4_5_MOE,
|
||||||
LLM_ARCH_HUNYUAN_MOE,
|
LLM_ARCH_HUNYUAN_MOE,
|
||||||
@@ -108,6 +114,11 @@ enum llm_arch {
|
|||||||
LLM_ARCH_SEED_OSS,
|
LLM_ARCH_SEED_OSS,
|
||||||
LLM_ARCH_GROVEMOE,
|
LLM_ARCH_GROVEMOE,
|
||||||
LLM_ARCH_APERTUS,
|
LLM_ARCH_APERTUS,
|
||||||
|
LLM_ARCH_MINIMAX_M2,
|
||||||
|
LLM_ARCH_COGVLM,
|
||||||
|
LLM_ARCH_RND1,
|
||||||
|
LLM_ARCH_PANGU_EMBED,
|
||||||
|
LLM_ARCH_MISTRAL3,
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -117,6 +128,18 @@ enum llm_kv {
|
|||||||
LLM_KV_GENERAL_QUANTIZATION_VERSION,
|
LLM_KV_GENERAL_QUANTIZATION_VERSION,
|
||||||
LLM_KV_GENERAL_ALIGNMENT,
|
LLM_KV_GENERAL_ALIGNMENT,
|
||||||
LLM_KV_GENERAL_FILE_TYPE,
|
LLM_KV_GENERAL_FILE_TYPE,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_SEQUENCE,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_TOP_K,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_TOP_P,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_MIN_P,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_TEMP,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_MIROSTAT,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA,
|
||||||
LLM_KV_GENERAL_NAME,
|
LLM_KV_GENERAL_NAME,
|
||||||
LLM_KV_GENERAL_AUTHOR,
|
LLM_KV_GENERAL_AUTHOR,
|
||||||
LLM_KV_GENERAL_VERSION,
|
LLM_KV_GENERAL_VERSION,
|
||||||
@@ -150,6 +173,7 @@ enum llm_kv {
|
|||||||
LLM_KV_EXPERTS_PER_GROUP,
|
LLM_KV_EXPERTS_PER_GROUP,
|
||||||
LLM_KV_MOE_EVERY_N_LAYERS,
|
LLM_KV_MOE_EVERY_N_LAYERS,
|
||||||
LLM_KV_NEXTN_PREDICT_LAYERS,
|
LLM_KV_NEXTN_PREDICT_LAYERS,
|
||||||
|
LLM_KV_NUM_DEEPSTACK_LAYERS,
|
||||||
LLM_KV_POOLING_TYPE,
|
LLM_KV_POOLING_TYPE,
|
||||||
LLM_KV_LOGIT_SCALE,
|
LLM_KV_LOGIT_SCALE,
|
||||||
LLM_KV_DECODER_START_TOKEN_ID,
|
LLM_KV_DECODER_START_TOKEN_ID,
|
||||||
@@ -188,6 +212,7 @@ enum llm_kv {
|
|||||||
LLM_KV_ATTENTION_SCALE,
|
LLM_KV_ATTENTION_SCALE,
|
||||||
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
||||||
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
||||||
|
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
|
||||||
LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
|
LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
|
||||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||||
@@ -294,6 +319,7 @@ enum llm_tensor {
|
|||||||
LLM_TENSOR_DENSE_3_OUT,
|
LLM_TENSOR_DENSE_3_OUT,
|
||||||
LLM_TENSOR_OUTPUT,
|
LLM_TENSOR_OUTPUT,
|
||||||
LLM_TENSOR_OUTPUT_NORM,
|
LLM_TENSOR_OUTPUT_NORM,
|
||||||
|
LLM_TENSOR_OUTPUT_NORM_LFM2, // fix for wrong tensor name
|
||||||
LLM_TENSOR_ROPE_FREQS,
|
LLM_TENSOR_ROPE_FREQS,
|
||||||
LLM_TENSOR_ROPE_FACTORS_LONG,
|
LLM_TENSOR_ROPE_FACTORS_LONG,
|
||||||
LLM_TENSOR_ROPE_FACTORS_SHORT,
|
LLM_TENSOR_ROPE_FACTORS_SHORT,
|
||||||
@@ -308,6 +334,7 @@ enum llm_tensor {
|
|||||||
LLM_TENSOR_ATTN_POST_NORM,
|
LLM_TENSOR_ATTN_POST_NORM,
|
||||||
LLM_TENSOR_ATTN_ROT_EMBD,
|
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||||
LLM_TENSOR_ATTN_SINKS,
|
LLM_TENSOR_ATTN_SINKS,
|
||||||
|
LLM_TENSOR_ATTN_GATE,
|
||||||
LLM_TENSOR_FFN_GATE_INP,
|
LLM_TENSOR_FFN_GATE_INP,
|
||||||
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
||||||
LLM_TENSOR_FFN_NORM,
|
LLM_TENSOR_FFN_NORM,
|
||||||
@@ -357,11 +384,13 @@ enum llm_tensor {
|
|||||||
LLM_TENSOR_SSM_DT,
|
LLM_TENSOR_SSM_DT,
|
||||||
LLM_TENSOR_SSM_DT_NORM,
|
LLM_TENSOR_SSM_DT_NORM,
|
||||||
LLM_TENSOR_SSM_A,
|
LLM_TENSOR_SSM_A,
|
||||||
|
LLM_TENSOR_SSM_A_NOSCAN, // qwen3next special case with MUL instead of SSM_SCAN
|
||||||
LLM_TENSOR_SSM_B_NORM,
|
LLM_TENSOR_SSM_B_NORM,
|
||||||
LLM_TENSOR_SSM_C_NORM,
|
LLM_TENSOR_SSM_C_NORM,
|
||||||
LLM_TENSOR_SSM_D,
|
LLM_TENSOR_SSM_D,
|
||||||
LLM_TENSOR_SSM_NORM,
|
LLM_TENSOR_SSM_NORM,
|
||||||
LLM_TENSOR_SSM_OUT,
|
LLM_TENSOR_SSM_OUT,
|
||||||
|
LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next
|
||||||
LLM_TENSOR_TIME_MIX_W0,
|
LLM_TENSOR_TIME_MIX_W0,
|
||||||
LLM_TENSOR_TIME_MIX_W1,
|
LLM_TENSOR_TIME_MIX_W1,
|
||||||
LLM_TENSOR_TIME_MIX_W2,
|
LLM_TENSOR_TIME_MIX_W2,
|
||||||
@@ -458,6 +487,11 @@ enum llm_tensor {
|
|||||||
LLM_TENSOR_SHORTCONV_CONV,
|
LLM_TENSOR_SHORTCONV_CONV,
|
||||||
LLM_TENSOR_SHORTCONV_INPROJ,
|
LLM_TENSOR_SHORTCONV_INPROJ,
|
||||||
LLM_TENSOR_SHORTCONV_OUTPROJ,
|
LLM_TENSOR_SHORTCONV_OUTPROJ,
|
||||||
|
LLM_TENSOR_VISEXP_ATTN_QKV,
|
||||||
|
LLM_TENSOR_VISEXP_ATTN_OUT,
|
||||||
|
LLM_TENSOR_VISEXP_FFN_GATE,
|
||||||
|
LLM_TENSOR_VISEXP_FFN_DOWN,
|
||||||
|
LLM_TENSOR_VISEXP_FFN_UP,
|
||||||
LLM_TENSOR_NEXTN_EH_PROJ,
|
LLM_TENSOR_NEXTN_EH_PROJ,
|
||||||
LLM_TENSOR_NEXTN_EMBED_TOKENS,
|
LLM_TENSOR_NEXTN_EMBED_TOKENS,
|
||||||
LLM_TENSOR_NEXTN_ENORM,
|
LLM_TENSOR_NEXTN_ENORM,
|
||||||
@@ -497,6 +531,10 @@ struct LLM_TN_IMPL {
|
|||||||
const int bid;
|
const int bid;
|
||||||
const int xid;
|
const int xid;
|
||||||
|
|
||||||
|
const std::set<llm_tensor> model_tensors;
|
||||||
|
|
||||||
|
LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid);
|
||||||
|
|
||||||
std::string str() const;
|
std::string str() const;
|
||||||
|
|
||||||
operator std::string() const {
|
operator std::string() const {
|
||||||
@@ -518,11 +556,11 @@ struct LLM_TN {
|
|||||||
llm_arch arch;
|
llm_arch arch;
|
||||||
|
|
||||||
LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const {
|
LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const {
|
||||||
return { arch, tensor, suffix, bid, xid };
|
return LLM_TN_IMPL(arch, tensor, suffix, bid, xid);
|
||||||
}
|
}
|
||||||
|
|
||||||
LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const {
|
LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const {
|
||||||
return { arch, tensor, nullptr, bid, xid };
|
return LLM_TN_IMPL(arch, tensor, nullptr, bid, xid);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
110
llama/llama.cpp/src/llama-batch.cpp
vendored
110
llama/llama.cpp/src/llama-batch.cpp
vendored
@@ -215,6 +215,7 @@ bool llama_batch_allocr::init(
|
|||||||
/*.n_seq_tokens =*/ (uint32_t) 1,
|
/*.n_seq_tokens =*/ (uint32_t) 1,
|
||||||
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
|
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
|
||||||
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
|
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
|
||||||
|
/*.n_pos =*/ n_pos_per_embd,
|
||||||
/*.token =*/ batch.token,
|
/*.token =*/ batch.token,
|
||||||
/*.embd =*/ batch.embd,
|
/*.embd =*/ batch.embd,
|
||||||
/*.pos =*/ batch.pos,
|
/*.pos =*/ batch.pos,
|
||||||
@@ -251,46 +252,72 @@ bool llama_batch_allocr::init(
|
|||||||
// consistency checks
|
// consistency checks
|
||||||
//
|
//
|
||||||
|
|
||||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
if (n_pos_per_embd > 1) {
|
||||||
if (seq_pos[s].empty()) {
|
// M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
|
||||||
continue;
|
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||||
}
|
if (seq_pos[s].empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
||||||
|
|
||||||
if (p0 >= 0) {
|
|
||||||
bool ok = true;
|
|
||||||
|
|
||||||
if (batch.token) {
|
if (batch.token) {
|
||||||
|
if (p0 >= 0 && p0 >= seq_pos_min(s)) {
|
||||||
|
LLAMA_LOG_ERROR(
|
||||||
|
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||||
|
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||||
|
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||||
|
" for M-RoPE, it is required that the position satisfies: X < Y\n",
|
||||||
|
__func__, s, s, p0, s, seq_pos_min(s));
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// embedding inputs can have overlapping positions
|
||||||
|
if (p0 >= 0 && p0 > seq_pos_min(s)) {
|
||||||
|
LLAMA_LOG_ERROR(
|
||||||
|
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||||
|
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||||
|
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||||
|
" for M-RoPE, it is required that the position satisfies: X <= Y\n",
|
||||||
|
__func__, s, s, p0, s, seq_pos_min(s));
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||||
|
if (seq_pos[s].empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
||||||
|
|
||||||
|
if (p0 >= 0) {
|
||||||
|
bool ok = true;
|
||||||
|
|
||||||
if (seq_pos_min(s) != p0 + 1) {
|
if (seq_pos_min(s) != p0 + 1) {
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
assert(batch.embd);
|
|
||||||
|
|
||||||
// for embeddings (typically used as vision input), we allow them to have repeating positions
|
if (!ok) {
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
LLAMA_LOG_ERROR(
|
||||||
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
|
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||||
ok = false;
|
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||||
|
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||||
|
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
||||||
|
__func__, s, s, p0, s, seq_pos_min(s));
|
||||||
|
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ok) {
|
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
||||||
LLAMA_LOG_ERROR(
|
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
|
||||||
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
|
||||||
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
|
||||||
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
|
||||||
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
|
||||||
__func__, s, s, p0, s, seq_pos_min(s));
|
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
|
||||||
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (memory) {
|
if (memory) {
|
||||||
@@ -389,6 +416,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
|
|||||||
/*.n_seq_tokens =*/ n_seq_tokens,
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
||||||
/*.n_seqs =*/ n_seqs,
|
/*.n_seqs =*/ n_seqs,
|
||||||
/*.n_seqs_unq =*/ n_seqs,
|
/*.n_seqs_unq =*/ n_seqs,
|
||||||
|
/*.n_pos =*/ n_pos_per_embd,
|
||||||
|
|
||||||
/*.token =*/ udata->token.data(),
|
/*.token =*/ udata->token.data(),
|
||||||
/*.embd =*/ nullptr,
|
/*.embd =*/ nullptr,
|
||||||
@@ -655,10 +683,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|||||||
|
|
||||||
auto udata = std::make_shared<llama_ubatch::data_t>();
|
auto udata = std::make_shared<llama_ubatch::data_t>();
|
||||||
|
|
||||||
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
|
|
||||||
|
|
||||||
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
|
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
|
||||||
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
|
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_per_embd;
|
||||||
|
|
||||||
udata->token .resize(n_tokens);
|
udata->token .resize(n_tokens);
|
||||||
udata->embd .resize(n_embd_all);
|
udata->embd .resize(n_embd_all);
|
||||||
@@ -669,6 +695,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|||||||
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
||||||
udata->output .resize(n_tokens);
|
udata->output .resize(n_tokens);
|
||||||
|
|
||||||
|
udata->seq_id_data.reserve(n_tokens);
|
||||||
|
|
||||||
seq_set_t seq_set_unq;
|
seq_set_t seq_set_unq;
|
||||||
|
|
||||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||||
@@ -680,16 +708,23 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|||||||
memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
|
memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 0; j < n_pos_cur; ++j) {
|
for (size_t j = 0; j < (size_t)n_pos_per_embd; ++j) {
|
||||||
udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
|
// if we are using M-RoPE
|
||||||
|
// if the current batch is text, we need to broadcast the same position across all RoPE sections
|
||||||
|
// otherwise, the input batch is image embeddings, we copy the positions as-is
|
||||||
|
// if we are not using M-RoPE, there is only one position per token (this loop runs only once)
|
||||||
|
size_t src_off = batch.token ? 0 : j*batch.n_tokens;
|
||||||
|
udata->pos[j*n_tokens + i] = batch.pos[src_off + idxs[i]];
|
||||||
}
|
}
|
||||||
|
|
||||||
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
|
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
|
||||||
udata->seq_id[i] = batch.seq_id[idxs[i]];
|
|
||||||
udata->output[i] = batch.logits[idxs[i]];
|
udata->output[i] = batch.logits[idxs[i]];
|
||||||
|
|
||||||
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
|
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
|
||||||
seq_set_unq.set(udata->seq_id[i][s]);
|
const llama_seq_id seq_id = batch.seq_id[idxs[i]][s];
|
||||||
|
|
||||||
|
udata->seq_id_data.push_back(seq_id);
|
||||||
|
seq_set_unq.set(seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (udata->output[i]) {
|
if (udata->output[i]) {
|
||||||
@@ -697,6 +732,12 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_seq_id * seq_id_ptr = udata->seq_id_data.data();
|
||||||
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||||
|
udata->seq_id[i] = seq_id_ptr;
|
||||||
|
seq_id_ptr += udata->n_seq_id[i];
|
||||||
|
}
|
||||||
|
|
||||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||||
if (seq_set_unq.test(s)) {
|
if (seq_set_unq.test(s)) {
|
||||||
udata->seq_idx[s] = udata->seq_id_unq.size();
|
udata->seq_idx[s] = udata->seq_id_unq.size();
|
||||||
@@ -710,6 +751,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|||||||
/*.n_seq_tokens =*/ n_tokens/n_seqs,
|
/*.n_seq_tokens =*/ n_tokens/n_seqs,
|
||||||
/*.n_seqs =*/ n_seqs,
|
/*.n_seqs =*/ n_seqs,
|
||||||
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
|
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
|
||||||
|
/*.n_pos =*/ n_pos_per_embd,
|
||||||
|
|
||||||
/*.token =*/ batch.token ? udata->token.data() : nullptr,
|
/*.token =*/ batch.token ? udata->token.data() : nullptr,
|
||||||
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
|
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
|
||||||
|
|||||||
19
llama/llama.cpp/src/llama-batch.h
vendored
19
llama/llama.cpp/src/llama-batch.h
vendored
@@ -17,6 +17,16 @@ struct llama_ubatch {
|
|||||||
return b_equal_seqs != 0;
|
return b_equal_seqs != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// typical for M-RoPE cases:
|
||||||
|
// 0 - sequantial position of the tokens/embeddings in the sequence
|
||||||
|
// 1 - y position in the image
|
||||||
|
// 2 - x position in the image
|
||||||
|
// 3 - other
|
||||||
|
bool is_pos_2d() const {
|
||||||
|
// TODO @ngxson : we may need to check for model arch when more models use >1 positions
|
||||||
|
return n_pos >= 3;
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
|
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
|
||||||
// otherwise address sanitizer complains
|
// otherwise address sanitizer complains
|
||||||
// TODO: whole_seqs for embeddings?
|
// TODO: whole_seqs for embeddings?
|
||||||
@@ -25,6 +35,7 @@ struct llama_ubatch {
|
|||||||
uint32_t n_seq_tokens; // tokens per sequence set
|
uint32_t n_seq_tokens; // tokens per sequence set
|
||||||
uint32_t n_seqs; // sequence sets in the ubatch
|
uint32_t n_seqs; // sequence sets in the ubatch
|
||||||
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
|
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
|
||||||
|
uint32_t n_pos; // number of position inputs for each token/embedding
|
||||||
|
|
||||||
// seq_id_unq: unique sequence ids in the ubatch
|
// seq_id_unq: unique sequence ids in the ubatch
|
||||||
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
|
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
|
||||||
@@ -33,7 +44,7 @@ struct llama_ubatch {
|
|||||||
// // size | idx | val
|
// // size | idx | val
|
||||||
llama_token * token; // [n_tokens] | i | id, token
|
llama_token * token; // [n_tokens] | i | id, token
|
||||||
float * embd; // [n_embd, n_tokens] | i | embd
|
float * embd; // [n_embd, n_tokens] | i | embd
|
||||||
llama_pos * pos; // [n_tokens] | i | pos
|
llama_pos * pos; // [n_tokens*n_pos] | i | pos
|
||||||
int32_t * n_seq_id; // [n_tokens] | i | -
|
int32_t * n_seq_id; // [n_tokens] | i | -
|
||||||
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
|
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
|
||||||
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
|
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
|
||||||
@@ -45,13 +56,15 @@ struct llama_ubatch {
|
|||||||
std::vector<float> embd;
|
std::vector<float> embd;
|
||||||
std::vector<llama_pos> pos;
|
std::vector<llama_pos> pos;
|
||||||
std::vector<int32_t> n_seq_id;
|
std::vector<int32_t> n_seq_id;
|
||||||
std::vector<llama_seq_id *> seq_id;
|
std::vector<llama_seq_id *> seq_id; // these point into the seq_id_data below
|
||||||
std::vector<llama_seq_id> seq_id_unq;
|
std::vector<llama_seq_id> seq_id_unq;
|
||||||
std::vector<int32_t> seq_idx;
|
std::vector<int32_t> seq_idx;
|
||||||
std::vector<int8_t> output;
|
std::vector<int8_t> output;
|
||||||
|
|
||||||
|
std::vector<llama_seq_id> seq_id_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
|
// the llama_ubatch pointers above point to this data if set. otherwise - point to external non-owning data
|
||||||
std::shared_ptr<data_t> data;
|
std::shared_ptr<data_t> data;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
32
llama/llama.cpp/src/llama-chat.cpp
vendored
32
llama/llama.cpp/src/llama-chat.cpp
vendored
@@ -73,6 +73,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|||||||
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
||||||
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
|
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
|
||||||
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
|
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
|
||||||
|
{ "pangu-embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED },
|
||||||
};
|
};
|
||||||
|
|
||||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||||
@@ -213,6 +214,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|||||||
return LLM_CHAT_TEMPLATE_SEED_OSS;
|
return LLM_CHAT_TEMPLATE_SEED_OSS;
|
||||||
} else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) {
|
} else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) {
|
||||||
return LLM_CHAT_TEMPLATE_GROK_2;
|
return LLM_CHAT_TEMPLATE_GROK_2;
|
||||||
|
} else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) {
|
||||||
|
return LLM_CHAT_TEMPLATE_PANGU_EMBED;
|
||||||
}
|
}
|
||||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||||
}
|
}
|
||||||
@@ -813,6 +816,35 @@ int32_t llm_chat_apply_template(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "Assistant:";
|
ss << "Assistant:";
|
||||||
}
|
}
|
||||||
|
}else if (tmpl == LLM_CHAT_TEMPLATE_PANGU_EMBED) {
|
||||||
|
// [unused9]系统:xxx[unused10]
|
||||||
|
// [unused9]用户:xxx[unused10]
|
||||||
|
// [unused9]助手:xxx[unused10]
|
||||||
|
// ...
|
||||||
|
for (size_t i = 0; i < chat.size(); ++i) {
|
||||||
|
const auto & msg = chat[i];
|
||||||
|
const std::string & role = msg->role;
|
||||||
|
const std::string & content = msg->content;
|
||||||
|
|
||||||
|
if (i == 0 && role != "system") {
|
||||||
|
ss << "[unused9]系统:[unused10]";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (role == "system") {
|
||||||
|
ss << "[unused9]系统:" << content << "[unused10]";
|
||||||
|
} else if (role == "user") {
|
||||||
|
ss << "[unused9]用户:" << content << "[unused10]";
|
||||||
|
} else if (role == "assistant") {
|
||||||
|
ss << "[unused9]助手:" << content << "[unused10]";
|
||||||
|
} else if (role == "tool") {
|
||||||
|
ss << "[unused9]工具:" << content << "[unused10]";
|
||||||
|
} else if (role == "function") {
|
||||||
|
ss << "[unused9]方法:" << content << "[unused10]";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "[unused9]助手:";
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// template not supported
|
// template not supported
|
||||||
return -1;
|
return -1;
|
||||||
|
|||||||
1
llama/llama.cpp/src/llama-chat.h
vendored
1
llama/llama.cpp/src/llama-chat.h
vendored
@@ -53,6 +53,7 @@ enum llm_chat_template {
|
|||||||
LLM_CHAT_TEMPLATE_KIMI_K2,
|
LLM_CHAT_TEMPLATE_KIMI_K2,
|
||||||
LLM_CHAT_TEMPLATE_SEED_OSS,
|
LLM_CHAT_TEMPLATE_SEED_OSS,
|
||||||
LLM_CHAT_TEMPLATE_GROK_2,
|
LLM_CHAT_TEMPLATE_GROK_2,
|
||||||
|
LLM_CHAT_TEMPLATE_PANGU_EMBED,
|
||||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
194
llama/llama.cpp/src/llama-context.cpp
vendored
194
llama/llama.cpp/src/llama-context.cpp
vendored
@@ -1,5 +1,6 @@
|
|||||||
#include "llama-context.h"
|
#include "llama-context.h"
|
||||||
|
|
||||||
|
#include "llama-arch.h"
|
||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
#include "llama-batch.h"
|
#include "llama-batch.h"
|
||||||
#include "llama-io.h"
|
#include "llama-io.h"
|
||||||
@@ -8,6 +9,7 @@
|
|||||||
#include "llama-model.h"
|
#include "llama-model.h"
|
||||||
|
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
|
#include <cmath>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
@@ -21,6 +23,8 @@ llama_context::llama_context(
|
|||||||
llama_context_params params) :
|
llama_context_params params) :
|
||||||
model(model),
|
model(model),
|
||||||
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
|
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
|
||||||
|
// TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
|
||||||
|
// may need to be backend-dependent
|
||||||
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
||||||
|
|
||||||
t_start_us = model.t_start_us;
|
t_start_us = model.t_start_us;
|
||||||
@@ -69,6 +73,43 @@ llama_context::llama_context(
|
|||||||
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cparams.yarn_ext_factor != 0) {
|
||||||
|
static auto get_mscale = [](float scale, float mscale) {
|
||||||
|
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
|
||||||
|
};
|
||||||
|
|
||||||
|
const float factor = 1.0f / cparams.rope_freq_scale;
|
||||||
|
|
||||||
|
// ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
|
||||||
|
if (hparams.rope_yarn_log_mul != 0.0f) {
|
||||||
|
// note: here we assume `mscale == 1.0f`
|
||||||
|
// TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
|
||||||
|
float mscale = 1.0f;
|
||||||
|
const float mscale_all_dims = hparams.rope_yarn_log_mul;
|
||||||
|
|
||||||
|
// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
||||||
|
// special-case DEEPSEEK v2:
|
||||||
|
// https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
|
||||||
|
if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
|
||||||
|
mscale = mscale_all_dims;
|
||||||
|
}
|
||||||
|
|
||||||
|
cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
|
||||||
|
|
||||||
|
LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
|
||||||
|
__func__, cparams.yarn_attn_factor, mscale, mscale_all_dims);
|
||||||
|
} else {
|
||||||
|
cparams.yarn_attn_factor = get_mscale(factor, 1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
|
||||||
|
// https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
|
||||||
|
//
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/discussions/7416
|
||||||
|
// https://github.com/ggml-org/llama.cpp/pull/17945
|
||||||
|
cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor));
|
||||||
|
}
|
||||||
|
|
||||||
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
||||||
|
|
||||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||||
@@ -90,14 +131,6 @@ llama_context::llama_context(
|
|||||||
// with causal attention, the batch size is limited by the context size
|
// with causal attention, the batch size is limited by the context size
|
||||||
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
||||||
|
|
||||||
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
|
||||||
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
|
||||||
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
|
|
||||||
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
|
|
||||||
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
|
||||||
cparams.n_batch = GGML_KQ_MASK_PAD;
|
|
||||||
}
|
|
||||||
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
||||||
|
|
||||||
cparams.op_offload = params.op_offload;
|
cparams.op_offload = params.op_offload;
|
||||||
@@ -112,11 +145,28 @@ llama_context::llama_context(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
// ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
|
||||||
|
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
|
||||||
|
|
||||||
|
if (cparams.kv_unified) {
|
||||||
|
cparams.n_ctx_seq = cparams.n_ctx;
|
||||||
|
} else {
|
||||||
|
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||||
|
cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256);
|
||||||
|
|
||||||
|
if (cparams.n_ctx_seq == 0) {
|
||||||
|
throw std::runtime_error("n_ctx_seq == 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
|
||||||
|
cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
|
||||||
|
LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
||||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||||
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
|
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
|
||||||
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
||||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||||
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
||||||
@@ -125,14 +175,14 @@ llama_context::llama_context(
|
|||||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||||
|
|
||||||
if (n_ctx_per_seq < hparams.n_ctx_train) {
|
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
|
||||||
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
||||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_ctx_per_seq > hparams.n_ctx_train) {
|
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
|
||||||
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
||||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!hparams.vocab_only) {
|
if (!hparams.vocab_only) {
|
||||||
@@ -208,6 +258,7 @@ llama_context::llama_context(
|
|||||||
|
|
||||||
backend_buft.clear();
|
backend_buft.clear();
|
||||||
backend_ptrs.clear();
|
backend_ptrs.clear();
|
||||||
|
backend_buf_exp_size.clear();
|
||||||
|
|
||||||
for (auto & backend : backends) {
|
for (auto & backend : backends) {
|
||||||
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
|
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
|
||||||
@@ -224,11 +275,15 @@ llama_context::llama_context(
|
|||||||
|
|
||||||
backend_buft.push_back(buft);
|
backend_buft.push_back(buft);
|
||||||
backend_ptrs.push_back(backend.get());
|
backend_ptrs.push_back(backend.get());
|
||||||
|
backend_buf_exp_size.push_back(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
|
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
|
||||||
|
|
||||||
const size_t max_nodes = this->graph_max_nodes();
|
const uint32_t n_seqs = cparams.n_seq_max;
|
||||||
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
|
|
||||||
|
const size_t max_nodes = this->graph_max_nodes(n_tokens);
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
||||||
|
|
||||||
@@ -268,9 +323,7 @@ llama_context::llama_context(
|
|||||||
if (pipeline_parallel) {
|
if (pipeline_parallel) {
|
||||||
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
|
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (!hparams.vocab_only) {
|
|
||||||
llama_memory_context_ptr mctx;
|
llama_memory_context_ptr mctx;
|
||||||
if (memory) {
|
if (memory) {
|
||||||
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
|
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
|
||||||
@@ -282,9 +335,6 @@ llama_context::llama_context(
|
|||||||
|
|
||||||
cross.v_embd.clear();
|
cross.v_embd.clear();
|
||||||
|
|
||||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
|
||||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
||||||
|
|
||||||
// avoid reserving graphs with zero outputs - assume one output per sequence
|
// avoid reserving graphs with zero outputs - assume one output per sequence
|
||||||
n_outputs = n_seqs;
|
n_outputs = n_seqs;
|
||||||
|
|
||||||
@@ -341,9 +391,17 @@ llama_context::llama_context(
|
|||||||
|
|
||||||
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(),
|
||||||
|
model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr);
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
if (pipeline_parallel) {
|
||||||
|
LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
|
||||||
|
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload));
|
||||||
|
gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||||
|
}
|
||||||
|
if (!gf) {
|
||||||
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
||||||
@@ -352,7 +410,7 @@ llama_context::llama_context(
|
|||||||
|
|
||||||
// reserve with tg (token generation) graph to get the number of splits and nodes
|
// reserve with tg (token generation) graph to get the number of splits and nodes
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
|
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc);
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||||
}
|
}
|
||||||
@@ -367,7 +425,7 @@ llama_context::llama_context(
|
|||||||
//
|
//
|
||||||
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
|
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
|
||||||
//
|
//
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc);
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||||
}
|
}
|
||||||
@@ -376,11 +434,13 @@ llama_context::llama_context(
|
|||||||
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
||||||
ggml_backend_t backend = backend_ptrs[i];
|
ggml_backend_t backend = backend_ptrs[i];
|
||||||
ggml_backend_buffer_type_t buft = backend_buft[i];
|
ggml_backend_buffer_type_t buft = backend_buft[i];
|
||||||
size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
if (!model.hparams.no_alloc) {
|
||||||
if (size > 1) {
|
backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
||||||
|
}
|
||||||
|
if (backend_buf_exp_size[i] > 1) {
|
||||||
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
||||||
ggml_backend_buft_name(buft),
|
ggml_backend_buft_name(buft),
|
||||||
size / 1024.0 / 1024.0);
|
backend_buf_exp_size[i] / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -399,6 +459,23 @@ llama_context::llama_context(
|
|||||||
}
|
}
|
||||||
|
|
||||||
llama_context::~llama_context() {
|
llama_context::~llama_context() {
|
||||||
|
// FIXME this currently results in a use-after-free bug if the model is freed before the context
|
||||||
|
// if (!model.hparams.no_alloc) {
|
||||||
|
// for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
||||||
|
// ggml_backend_t backend = backend_ptrs[i];
|
||||||
|
// ggml_backend_buffer_type_t buft = backend_buft[i];
|
||||||
|
|
||||||
|
// const size_t size_exp = backend_buf_exp_size[i];
|
||||||
|
// const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
||||||
|
// if (size_exp == size_act) {
|
||||||
|
// LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n",
|
||||||
|
// __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
|
||||||
|
// } else {
|
||||||
|
// LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n",
|
||||||
|
// __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0));
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
ggml_opt_free(opt_ctx);
|
ggml_opt_free(opt_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -448,8 +525,8 @@ uint32_t llama_context::n_ctx() const {
|
|||||||
return cparams.n_ctx;
|
return cparams.n_ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_context::n_ctx_per_seq() const {
|
uint32_t llama_context::n_ctx_seq() const {
|
||||||
return cparams.n_ctx / cparams.n_seq_max;
|
return cparams.n_ctx_seq;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_context::n_batch() const {
|
uint32_t llama_context::n_batch() const {
|
||||||
@@ -518,7 +595,7 @@ bool llama_context::memory_update(bool optimize) {
|
|||||||
throw std::runtime_error("failed to initialize memory context");
|
throw std::runtime_error("failed to initialize memory context");
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
const uint32_t n_seqs = cparams.n_seq_max;
|
||||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
|
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||||
@@ -803,7 +880,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||||||
|
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd_inp();
|
||||||
const int64_t n_vocab = model.vocab.n_tokens();
|
const int64_t n_vocab = model.vocab.n_tokens();
|
||||||
|
|
||||||
// note: during encode, we always pass the full sequence starting from pos = 0
|
// note: during encode, we always pass the full sequence starting from pos = 0
|
||||||
@@ -972,7 +1049,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
const int64_t n_vocab = vocab.n_tokens();
|
const int64_t n_vocab = vocab.n_tokens();
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd_inp();
|
||||||
|
|
||||||
const bool output_all = false;
|
const bool output_all = false;
|
||||||
|
|
||||||
@@ -1223,7 +1300,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||||||
|
|
||||||
// make the outputs have the same order they had in the user-provided batch
|
// make the outputs have the same order they had in the user-provided batch
|
||||||
// note: this is mostly relevant for recurrent models atm
|
// note: this is mostly relevant for recurrent models atm
|
||||||
if (!sorted_output) {
|
if (!sorted_output && n_outputs > 1) {
|
||||||
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
||||||
|
|
||||||
// TODO: is there something more efficient which also minimizes swaps?
|
// TODO: is there something more efficient which also minimizes swaps?
|
||||||
@@ -1300,6 +1377,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||||||
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
|
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
|
||||||
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
||||||
#endif
|
#endif
|
||||||
|
synchronize();
|
||||||
buf_output = nullptr;
|
buf_output = nullptr;
|
||||||
logits = nullptr;
|
logits = nullptr;
|
||||||
embd = nullptr;
|
embd = nullptr;
|
||||||
@@ -1360,7 +1438,10 @@ void llama_context::output_reorder() {
|
|||||||
// graph
|
// graph
|
||||||
//
|
//
|
||||||
|
|
||||||
uint32_t llama_context::graph_max_nodes() const {
|
uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
|
||||||
|
if (model.arch == LLM_ARCH_QWEN3NEXT) {
|
||||||
|
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
|
||||||
|
}
|
||||||
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1368,7 +1449,8 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
|
|||||||
return static_cast<llm_graph_result *>(gf_res_reserve.get());
|
return static_cast<llm_graph_result *>(gf_res_reserve.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
|
ggml_cgraph * llama_context::graph_reserve(
|
||||||
|
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) {
|
||||||
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||||
GGML_ASSERT(n_outputs >= 1);
|
GGML_ASSERT(n_outputs >= 1);
|
||||||
|
|
||||||
@@ -1405,8 +1487,13 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|||||||
|
|
||||||
// initialize scheduler with the specified graph
|
// initialize scheduler with the specified graph
|
||||||
if (split_only) {
|
if (split_only) {
|
||||||
ggml_backend_sched_split_graph(sched.get(), gf);
|
if (sizes) {
|
||||||
|
ggml_backend_sched_reserve_size(sched.get(), gf, sizes);
|
||||||
|
} else {
|
||||||
|
ggml_backend_sched_split_graph(sched.get(), gf);
|
||||||
|
}
|
||||||
} else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
} else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||||
|
GGML_ASSERT(!sizes);
|
||||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@@ -2028,15 +2115,26 @@ void llama_context::perf_reset() {
|
|||||||
|
|
||||||
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
|
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
|
||||||
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
|
std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
|
||||||
for (const auto & buft_size : model.memory_breakdown()) {
|
for (const auto & [buft, size] : model.memory_breakdown()) {
|
||||||
ret[buft_size.first].model += buft_size.second;
|
ret[buft].model += size;
|
||||||
}
|
}
|
||||||
for (const auto & buft_size : memory->memory_breakdown()) {
|
if (memory) {
|
||||||
ret[buft_size.first].context += buft_size.second;
|
for (const auto & [buft, size] : memory->memory_breakdown()) {
|
||||||
|
ret[buft].context += size;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (const auto & backend_ptr : backends) {
|
if (model.hparams.no_alloc) {
|
||||||
ggml_backend_t backend = backend_ptr.get();
|
for (size_t i = 0; i < backends.size(); ++i) {
|
||||||
ret[ggml_backend_sched_get_buffer_type(sched.get(), backend)].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
ggml_backend_t backend = backends[i].get();
|
||||||
|
ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
|
||||||
|
ret[buft].compute += backend_buf_exp_size[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (const auto & backend_ptr : backends) {
|
||||||
|
ggml_backend_t backend = backend_ptr.get();
|
||||||
|
ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend);
|
||||||
|
ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
@@ -2129,7 +2227,7 @@ void llama_context::opt_epoch_iter(
|
|||||||
batch.logits [pos_batch] = true;
|
batch.logits [pos_batch] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -2377,6 +2475,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
|
|||||||
return ctx->n_ctx();
|
return ctx->n_ctx();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t llama_n_ctx_seq(const llama_context * ctx) {
|
||||||
|
return ctx->n_ctx_seq();
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t llama_n_batch(const llama_context * ctx) {
|
uint32_t llama_n_batch(const llama_context * ctx) {
|
||||||
return ctx->n_batch();
|
return ctx->n_batch();
|
||||||
}
|
}
|
||||||
|
|||||||
22
llama/llama.cpp/src/llama-context.h
vendored
22
llama/llama.cpp/src/llama-context.h
vendored
@@ -26,6 +26,10 @@ struct llama_memory_breakdown_data {
|
|||||||
size_t model = 0; // memory allocated for the model
|
size_t model = 0; // memory allocated for the model
|
||||||
size_t context = 0; // memory allocated for the context
|
size_t context = 0; // memory allocated for the context
|
||||||
size_t compute = 0; // memory allocated for temporary compute buffers
|
size_t compute = 0; // memory allocated for temporary compute buffers
|
||||||
|
|
||||||
|
size_t total() const {
|
||||||
|
return model + context + compute;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_context {
|
struct llama_context {
|
||||||
@@ -43,11 +47,11 @@ struct llama_context {
|
|||||||
|
|
||||||
ggml_backend_sched_t get_sched() const;
|
ggml_backend_sched_t get_sched() const;
|
||||||
|
|
||||||
uint32_t n_ctx() const;
|
uint32_t n_ctx() const;
|
||||||
uint32_t n_ctx_per_seq() const;
|
uint32_t n_ctx_seq() const;
|
||||||
uint32_t n_batch() const;
|
uint32_t n_batch() const;
|
||||||
uint32_t n_ubatch() const;
|
uint32_t n_ubatch() const;
|
||||||
uint32_t n_seq_max() const;
|
uint32_t n_seq_max() const;
|
||||||
|
|
||||||
uint32_t n_threads() const;
|
uint32_t n_threads() const;
|
||||||
uint32_t n_threads_batch() const;
|
uint32_t n_threads_batch() const;
|
||||||
@@ -197,7 +201,7 @@ private:
|
|||||||
//
|
//
|
||||||
|
|
||||||
public:
|
public:
|
||||||
uint32_t graph_max_nodes() const;
|
uint32_t graph_max_nodes(uint32_t n_tokens) const;
|
||||||
|
|
||||||
// can reuse the llm_graph_result instance of the context (for example to update a memory module)
|
// can reuse the llm_graph_result instance of the context (for example to update a memory module)
|
||||||
llm_graph_result * get_gf_res_reserve() const;
|
llm_graph_result * get_gf_res_reserve() const;
|
||||||
@@ -206,7 +210,8 @@ public:
|
|||||||
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
||||||
|
|
||||||
// reserve a graph with a dummy ubatch of the specified size
|
// reserve a graph with a dummy ubatch of the specified size
|
||||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
|
ggml_cgraph * graph_reserve(
|
||||||
|
uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llm_graph_params graph_params(
|
llm_graph_params graph_params(
|
||||||
@@ -281,9 +286,10 @@ private:
|
|||||||
|
|
||||||
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
||||||
|
|
||||||
// buffer types used for the compute buffer of each backend
|
// pointers and buffer types used for the compute buffer of each backend
|
||||||
std::vector<ggml_backend_t> backend_ptrs;
|
std::vector<ggml_backend_t> backend_ptrs;
|
||||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||||
|
std::vector<size_t> backend_buf_exp_size; // expected buffer sizes
|
||||||
|
|
||||||
llm_graph_result_ptr gf_res_prev;
|
llm_graph_result_ptr gf_res_prev;
|
||||||
llm_graph_result_ptr gf_res_reserve;
|
llm_graph_result_ptr gf_res_reserve;
|
||||||
|
|||||||
1
llama/llama.cpp/src/llama-cparams.h
vendored
1
llama/llama.cpp/src/llama-cparams.h
vendored
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
struct llama_cparams {
|
struct llama_cparams {
|
||||||
uint32_t n_ctx; // context size used during inference
|
uint32_t n_ctx; // context size used during inference
|
||||||
|
uint32_t n_ctx_seq; // context for a single sequence
|
||||||
uint32_t n_batch;
|
uint32_t n_batch;
|
||||||
uint32_t n_ubatch;
|
uint32_t n_ubatch;
|
||||||
uint32_t n_seq_max;
|
uint32_t n_seq_max;
|
||||||
|
|||||||
291
llama/llama.cpp/src/llama-grammar.cpp
vendored
291
llama/llama.cpp/src/llama-grammar.cpp
vendored
@@ -6,8 +6,10 @@
|
|||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#define MAX_REPETITION_THRESHOLD 2000
|
||||||
//
|
//
|
||||||
// helpers
|
// helpers
|
||||||
//
|
//
|
||||||
@@ -179,6 +181,52 @@ static std::pair<uint32_t, const char *> parse_char(const char * src) {
|
|||||||
throw std::runtime_error("unexpected end of input");
|
throw std::runtime_error("unexpected end of input");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::pair<uint32_t, const char *> parse_token(const llama_vocab * vocab, const char * src) {
|
||||||
|
const char * pos = src;
|
||||||
|
if (*pos != '<') {
|
||||||
|
throw std::runtime_error(std::string("expecting '<' at ") + pos);
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
|
||||||
|
// Parse <[id]>
|
||||||
|
if (*pos == '[') {
|
||||||
|
pos++;
|
||||||
|
const char * int_end = parse_int(pos);
|
||||||
|
uint32_t token_id = std::stoul(std::string(pos, int_end - pos));
|
||||||
|
pos = int_end;
|
||||||
|
if (*pos != ']') {
|
||||||
|
throw std::runtime_error(std::string("expecting ']' at ") + pos);
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
if (*pos != '>') {
|
||||||
|
throw std::runtime_error(std::string("expecting '>' at ") + pos);
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
return std::make_pair(token_id, pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vocab == nullptr) {
|
||||||
|
throw std::runtime_error(std::string("no vocab to parse token at ") + src);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse <token> and tokenize to obtain the token id
|
||||||
|
while (*pos != 0 && *pos != '>') {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
if (*pos != '>') {
|
||||||
|
throw std::runtime_error(std::string("expecting '>' at ") + pos);
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
|
||||||
|
llama_token tokens[2];
|
||||||
|
int32_t n_tokens = vocab->tokenize(src, static_cast<int32_t>(pos - src), tokens, 2, false, true);
|
||||||
|
if (n_tokens != 1) {
|
||||||
|
// must tokenize to exactly 1 token
|
||||||
|
throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'");
|
||||||
|
}
|
||||||
|
return std::make_pair(tokens[0], pos);
|
||||||
|
}
|
||||||
|
|
||||||
static void print_grammar_char(FILE * file, uint32_t c) {
|
static void print_grammar_char(FILE * file, uint32_t c) {
|
||||||
if (0x20 <= c && c <= 0x7f) {
|
if (0x20 <= c && c <= 0x7f) {
|
||||||
fprintf(file, "%c", static_cast<char>(c));
|
fprintf(file, "%c", static_cast<char>(c));
|
||||||
@@ -210,6 +258,8 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
|
|||||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
||||||
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
||||||
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
|
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
|
||||||
|
case LLAMA_GRETYPE_TOKEN: fprintf(file, "TOKEN"); break;
|
||||||
|
case LLAMA_GRETYPE_TOKEN_NOT: fprintf(file, "TOKEN_NOT"); break;
|
||||||
}
|
}
|
||||||
switch (elem.type) {
|
switch (elem.type) {
|
||||||
case LLAMA_GRETYPE_END:
|
case LLAMA_GRETYPE_END:
|
||||||
@@ -226,6 +276,17 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
|
|||||||
print_grammar_char(file, elem.value);
|
print_grammar_char(file, elem.value);
|
||||||
fprintf(file, "\") ");
|
fprintf(file, "\") ");
|
||||||
break;
|
break;
|
||||||
|
case LLAMA_GRETYPE_TOKEN:
|
||||||
|
fprintf(file, "<[");
|
||||||
|
fprintf(file, "%u", elem.value);
|
||||||
|
fprintf(file, "]> ");
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||||
|
fprintf(file, "!");
|
||||||
|
fprintf(file, "<[");
|
||||||
|
fprintf(file, "%u", elem.value);
|
||||||
|
fprintf(file, "]> ");
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fprintf(file, "\n");
|
fprintf(file, "\n");
|
||||||
@@ -282,6 +343,17 @@ static void print_rule(
|
|||||||
case LLAMA_GRETYPE_CHAR_ANY:
|
case LLAMA_GRETYPE_CHAR_ANY:
|
||||||
fprintf(file, ".");
|
fprintf(file, ".");
|
||||||
break;
|
break;
|
||||||
|
case LLAMA_GRETYPE_TOKEN:
|
||||||
|
fprintf(file, "<[");
|
||||||
|
fprintf(file, "%u", elem.value);
|
||||||
|
fprintf(file, "]> ");
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||||
|
fprintf(file, "!");
|
||||||
|
fprintf(file, "<[");
|
||||||
|
fprintf(file, "%u", elem.value);
|
||||||
|
fprintf(file, "]> ");
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
if (is_char_element(elem)) {
|
if (is_char_element(elem)) {
|
||||||
switch (rule[i + 1].type) {
|
switch (rule[i + 1].type) {
|
||||||
@@ -345,8 +417,10 @@ const char * llama_grammar_parser::parse_sequence(
|
|||||||
size_t last_sym_start = rule.size();
|
size_t last_sym_start = rule.size();
|
||||||
const char * pos = src;
|
const char * pos = src;
|
||||||
|
|
||||||
auto handle_repetitions = [&](int min_times, int max_times) {
|
// use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used
|
||||||
|
// (though it's technically the same as -1 now)
|
||||||
|
auto handle_repetitions = [&](uint64_t min_times, uint64_t max_times) {
|
||||||
|
bool no_max = max_times == UINT64_MAX;
|
||||||
if (last_sym_start == rule.size()) {
|
if (last_sym_start == rule.size()) {
|
||||||
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
||||||
}
|
}
|
||||||
@@ -373,20 +447,20 @@ const char * llama_grammar_parser::parse_sequence(
|
|||||||
rule.resize(last_sym_start);
|
rule.resize(last_sym_start);
|
||||||
} else {
|
} else {
|
||||||
// Repeat the previous elements (min_times - 1) times
|
// Repeat the previous elements (min_times - 1) times
|
||||||
for (int i = 1; i < min_times; i++) {
|
for (uint64_t i = 1; i < min_times; i++) {
|
||||||
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t last_rec_rule_id = 0;
|
uint32_t last_rec_rule_id = 0;
|
||||||
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
auto n_opt = no_max ? 1 : max_times - min_times;
|
||||||
|
|
||||||
llama_grammar_rule rec_rule(prev_rule);
|
llama_grammar_rule rec_rule(prev_rule);
|
||||||
for (int i = 0; i < n_opt; i++) {
|
for (uint64_t i = 0; i < n_opt; i++) {
|
||||||
rec_rule.resize(prev_rule.size());
|
rec_rule.resize(prev_rule.size());
|
||||||
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
||||||
if (i > 0 || max_times < 0) {
|
if (i > 0 || no_max) {
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
|
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, no_max ? rec_rule_id : last_rec_rule_id});
|
||||||
}
|
}
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||||
@@ -440,6 +514,17 @@ const char * llama_grammar_parser::parse_sequence(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
pos = parse_space(pos + 1, is_nested);
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == '<' || *pos == '!') { // token
|
||||||
|
auto type = LLAMA_GRETYPE_TOKEN;
|
||||||
|
if (*pos == '!') { // token inverse
|
||||||
|
type = LLAMA_GRETYPE_TOKEN_NOT;
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
auto token_pair = parse_token(vocab, pos);
|
||||||
|
const char * token_end = token_pair.second;
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
rule.push_back({type, token_pair.first});
|
||||||
|
pos = parse_space(token_end, is_nested);
|
||||||
} else if (is_word_char(*pos)) { // rule reference
|
} else if (is_word_char(*pos)) { // rule reference
|
||||||
const char * name_end = parse_name(pos);
|
const char * name_end = parse_name(pos);
|
||||||
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
||||||
@@ -478,10 +563,10 @@ const char * llama_grammar_parser::parse_sequence(
|
|||||||
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
||||||
}
|
}
|
||||||
const char * int_end = parse_int(pos);
|
const char * int_end = parse_int(pos);
|
||||||
int min_times = std::stoul(std::string(pos, int_end - pos));
|
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
|
||||||
pos = parse_space(int_end, is_nested);
|
pos = parse_space(int_end, is_nested);
|
||||||
|
|
||||||
int max_times = -1;
|
uint64_t max_times = UINT64_MAX; // default: no max limit
|
||||||
|
|
||||||
if (*pos == '}') {
|
if (*pos == '}') {
|
||||||
max_times = min_times;
|
max_times = min_times;
|
||||||
@@ -502,6 +587,10 @@ const char * llama_grammar_parser::parse_sequence(
|
|||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
||||||
}
|
}
|
||||||
|
bool has_max = max_times != UINT64_MAX;
|
||||||
|
if (min_times > MAX_REPETITION_THRESHOLD || (has_max && max_times > MAX_REPETITION_THRESHOLD)) {
|
||||||
|
throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
|
||||||
|
}
|
||||||
handle_repetitions(min_times, max_times);
|
handle_repetitions(min_times, max_times);
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
@@ -683,6 +772,21 @@ static bool llama_grammar_match_partial_char(
|
|||||||
return !is_positive_char;
|
return !is_positive_char;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// returns true iff token matches the rule at pos (regular or inverse)
|
||||||
|
// asserts that pos is pointing to a token element
|
||||||
|
static bool llama_grammar_match_token(
|
||||||
|
const llama_grammar_element * pos,
|
||||||
|
const llama_token token) {
|
||||||
|
GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT);
|
||||||
|
if (pos->type == LLAMA_GRETYPE_TOKEN) {
|
||||||
|
return pos->value == static_cast<uint32_t>(token);
|
||||||
|
}
|
||||||
|
if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||||
|
return pos->value != static_cast<uint32_t>(token);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// transforms a grammar pushdown stack into N possible stacks, all ending
|
// transforms a grammar pushdown stack into N possible stacks, all ending
|
||||||
// at a character range (terminal element)
|
// at a character range (terminal element)
|
||||||
static void llama_grammar_advance_stack(
|
static void llama_grammar_advance_stack(
|
||||||
@@ -730,6 +834,8 @@ static void llama_grammar_advance_stack(
|
|||||||
case LLAMA_GRETYPE_CHAR:
|
case LLAMA_GRETYPE_CHAR:
|
||||||
case LLAMA_GRETYPE_CHAR_NOT:
|
case LLAMA_GRETYPE_CHAR_NOT:
|
||||||
case LLAMA_GRETYPE_CHAR_ANY:
|
case LLAMA_GRETYPE_CHAR_ANY:
|
||||||
|
case LLAMA_GRETYPE_TOKEN:
|
||||||
|
case LLAMA_GRETYPE_TOKEN_NOT:
|
||||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||||
// only add the stack if it's not a duplicate of one we already have
|
// only add the stack if it's not a duplicate of one we already have
|
||||||
new_stacks.emplace_back(stack);
|
new_stacks.emplace_back(stack);
|
||||||
@@ -823,26 +929,38 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
|
|||||||
return grammar->stacks;
|
return grammar->stacks;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void llama_grammar_accept_chr(
|
||||||
|
struct llama_grammar & grammar,
|
||||||
|
const llama_grammar_stack & stack,
|
||||||
|
uint32_t chr,
|
||||||
|
llama_grammar_stacks & new_stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_grammar_element * pos = stack.back();
|
||||||
|
|
||||||
|
// ignore if this turns into a token
|
||||||
|
if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto match = llama_grammar_match_char(pos, chr);
|
||||||
|
if (match.first) {
|
||||||
|
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||||
|
if (!llama_grammar_is_end_of_sequence(match.second)) {
|
||||||
|
new_stack.push_back(match.second);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
|
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
|
||||||
llama_grammar_stacks stacks_new;
|
llama_grammar_stacks stacks_new;
|
||||||
stacks_new.reserve(grammar->stacks.size());
|
stacks_new.reserve(grammar->stacks.size());
|
||||||
|
|
||||||
for (const auto & stack : grammar->stacks) {
|
for (const auto & stack : grammar->stacks) {
|
||||||
if (stack.empty()) {
|
llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto match = llama_grammar_match_char(stack.back(), chr);
|
|
||||||
if (match.first) {
|
|
||||||
const llama_grammar_element * pos = match.second;
|
|
||||||
|
|
||||||
// update top of stack to next element, if any
|
|
||||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
|
||||||
if (!llama_grammar_is_end_of_sequence(pos)) {
|
|
||||||
new_stack.push_back(pos);
|
|
||||||
}
|
|
||||||
llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
grammar->stacks = std::move(stacks_new);
|
grammar->stacks = std::move(stacks_new);
|
||||||
@@ -867,6 +985,22 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
|||||||
|
|
||||||
const llama_grammar_element * stack_pos = stack.back();
|
const llama_grammar_element * stack_pos = stack.back();
|
||||||
|
|
||||||
|
// if the top of the stack is a token rule, then we only need to check the token id
|
||||||
|
if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||||
|
for (const auto & tok : candidates) {
|
||||||
|
if (*tok.code_points == 0) {
|
||||||
|
// reached the end of a token consumed by char rules, reject iff it ended
|
||||||
|
// in a partial response
|
||||||
|
if (tok.partial_utf8.n_remain != 0) {
|
||||||
|
rejects.push_back(tok);
|
||||||
|
}
|
||||||
|
} else if (!llama_grammar_match_token(stack_pos, tok.id)) {
|
||||||
|
rejects.push_back(tok);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rejects;
|
||||||
|
}
|
||||||
|
|
||||||
llama_grammar_candidates next_candidates;
|
llama_grammar_candidates next_candidates;
|
||||||
next_candidates.reserve(candidates.size());
|
next_candidates.reserve(candidates.size());
|
||||||
|
|
||||||
@@ -879,7 +1013,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
|||||||
rejects.push_back(tok);
|
rejects.push_back(tok);
|
||||||
}
|
}
|
||||||
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
|
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
|
||||||
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
|
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id });
|
||||||
} else {
|
} else {
|
||||||
rejects.push_back(tok);
|
rejects.push_back(tok);
|
||||||
}
|
}
|
||||||
@@ -897,7 +1031,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
|||||||
|
|
||||||
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
||||||
for (const auto & tok : next_rejects) {
|
for (const auto & tok : next_rejects) {
|
||||||
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
|
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id });
|
||||||
}
|
}
|
||||||
|
|
||||||
return rejects;
|
return rejects;
|
||||||
@@ -966,12 +1100,13 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
ollama_vocab,
|
ollama_vocab,
|
||||||
std::move(vec_rules),
|
std::move(vec_rules),
|
||||||
std::move(stacks),
|
std::move(stacks),
|
||||||
/* .partial_utf8 = */ {},
|
/* .partial_utf8 = */ {},
|
||||||
/* .lazy =*/ false,
|
/* .lazy = */ false,
|
||||||
/* .awaiting_trigger = */ false,
|
/* .awaiting_trigger = */ false,
|
||||||
/* .trigger_buffer = */ "",
|
/* .trigger_buffer = */ "",
|
||||||
/* .trigger_tokens = */ {},
|
/* .trigger_buffer_positions = */ {},
|
||||||
/* .trigger_patterns = */ {},
|
/* .trigger_tokens = */ {},
|
||||||
|
/* .trigger_patterns = */ {},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -985,7 +1120,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
size_t num_trigger_patterns,
|
size_t num_trigger_patterns,
|
||||||
const llama_token * trigger_tokens,
|
const llama_token * trigger_tokens,
|
||||||
size_t num_trigger_tokens) {
|
size_t num_trigger_tokens) {
|
||||||
llama_grammar_parser parser;
|
llama_grammar_parser parser(vocab);
|
||||||
|
|
||||||
// if there is a grammar, parse it
|
// if there is a grammar, parse it
|
||||||
// rules will be empty (default) if there are parse errors
|
// rules will be empty (default) if there are parse errors
|
||||||
@@ -1073,10 +1208,11 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
ollama_vocab,
|
ollama_vocab,
|
||||||
std::move(vec_rules),
|
std::move(vec_rules),
|
||||||
std::move(stacks),
|
std::move(stacks),
|
||||||
/* .partial_utf8 = */ {},
|
/* .partial_utf8 = */ {},
|
||||||
/* .lazy = */ lazy,
|
/* .lazy = */ lazy,
|
||||||
/* .awaiting_trigger = */ lazy,
|
/* .awaiting_trigger = */ lazy,
|
||||||
/* .trigger_buffer = */ "",
|
/* .trigger_buffer = */ "",
|
||||||
|
/* .trigger_buffer_positions = */ {},
|
||||||
std::move(vec_trigger_tokens),
|
std::move(vec_trigger_tokens),
|
||||||
std::move(vec_trigger_patterns),
|
std::move(vec_trigger_patterns),
|
||||||
};
|
};
|
||||||
@@ -1100,6 +1236,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
|||||||
grammar.lazy,
|
grammar.lazy,
|
||||||
grammar.awaiting_trigger,
|
grammar.awaiting_trigger,
|
||||||
grammar.trigger_buffer,
|
grammar.trigger_buffer,
|
||||||
|
grammar.trigger_buffer_positions,
|
||||||
grammar.trigger_tokens,
|
grammar.trigger_tokens,
|
||||||
grammar.trigger_patterns,
|
grammar.trigger_patterns,
|
||||||
};
|
};
|
||||||
@@ -1156,7 +1293,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
|||||||
cur_p->data[i].logit = -INFINITY;
|
cur_p->data[i].logit = -INFINITY;
|
||||||
} else {
|
} else {
|
||||||
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
|
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
|
||||||
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1176,10 +1313,12 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||||||
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
||||||
grammar.awaiting_trigger = false;
|
grammar.awaiting_trigger = false;
|
||||||
grammar.trigger_buffer.clear();
|
grammar.trigger_buffer.clear();
|
||||||
llama_grammar_accept_str(grammar, piece);
|
llama_grammar_accept_token(grammar, token, piece);
|
||||||
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
|
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
|
auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
|
||||||
|
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
|
||||||
grammar.trigger_buffer += piece;
|
grammar.trigger_buffer += piece;
|
||||||
|
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
@@ -1197,10 +1336,23 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||||||
if (start == std::string::npos) {
|
if (start == std::string::npos) {
|
||||||
start = match.position(0);
|
start = match.position(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// replay tokens that overlap with [start, end)
|
||||||
|
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
|
||||||
|
auto [tok_start, tok_end] = tok_pos;
|
||||||
|
if (tok_end <= start) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
|
||||||
|
size_t piece_len = tok_end - piece_start;
|
||||||
|
auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
|
||||||
|
llama_grammar_accept_token(grammar, tok, tok_piece);
|
||||||
|
}
|
||||||
|
|
||||||
auto constrained_str = grammar.trigger_buffer.substr(start);
|
auto constrained_str = grammar.trigger_buffer.substr(start);
|
||||||
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
|
||||||
grammar.trigger_buffer.clear();
|
grammar.trigger_buffer.clear();
|
||||||
llama_grammar_accept_str(grammar, constrained_str);
|
grammar.trigger_buffer_positions.clear();
|
||||||
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
|
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -1220,7 +1372,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||||||
GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty");
|
GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty");
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_grammar_accept_str(grammar, piece);
|
llama_grammar_accept_token(grammar, token, piece);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
|
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
|
||||||
@@ -1238,6 +1390,61 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
|
||||||
|
// Note terminating 0 in decoded string
|
||||||
|
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
||||||
|
const auto & code_points = decoded.first;
|
||||||
|
|
||||||
|
llama_grammar_stacks stacks_new;
|
||||||
|
stacks_new.reserve(grammar.stacks.size());
|
||||||
|
|
||||||
|
for (const auto & stack : grammar.stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_grammar_element * pos = stack.back();
|
||||||
|
|
||||||
|
if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
|
||||||
|
if (llama_grammar_match_token(pos, token)) {
|
||||||
|
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||||
|
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||||
|
new_stack.push_back(pos + 1);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
llama_grammar_stacks current_stacks = {stack};
|
||||||
|
|
||||||
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
|
llama_grammar_stacks next_stacks;
|
||||||
|
|
||||||
|
for (const auto & cur_stack : current_stacks) {
|
||||||
|
llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
|
||||||
|
}
|
||||||
|
|
||||||
|
current_stacks = std::move(next_stacks);
|
||||||
|
if (current_stacks.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto & surviving_stack : current_stacks) {
|
||||||
|
if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
|
||||||
|
stacks_new.emplace_back(surviving_stack);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
grammar.stacks = std::move(stacks_new);
|
||||||
|
grammar.partial_utf8 = decoded.second;
|
||||||
|
|
||||||
|
if (grammar.stacks.empty()) {
|
||||||
|
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
const std::string & ollama_vocab::token_to_piece(const uint32_t token) const {
|
const std::string & ollama_vocab::token_to_piece(const uint32_t token) const {
|
||||||
try {
|
try {
|
||||||
|
|||||||
21
llama/llama.cpp/src/llama-grammar.h
vendored
21
llama/llama.cpp/src/llama-grammar.h
vendored
@@ -47,11 +47,17 @@ enum llama_gretype {
|
|||||||
|
|
||||||
// any character (.)
|
// any character (.)
|
||||||
LLAMA_GRETYPE_CHAR_ANY = 7,
|
LLAMA_GRETYPE_CHAR_ANY = 7,
|
||||||
|
|
||||||
|
// terminal element: token (<[token-id]>)
|
||||||
|
LLAMA_GRETYPE_TOKEN = 8,
|
||||||
|
|
||||||
|
// inverse token (!<[token-id]>)
|
||||||
|
LLAMA_GRETYPE_TOKEN_NOT = 9,
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef struct llama_grammar_element {
|
typedef struct llama_grammar_element {
|
||||||
enum llama_gretype type;
|
enum llama_gretype type;
|
||||||
uint32_t value; // Unicode code point or rule ID
|
uint32_t value; // Unicode code point, rule ID, or token ID
|
||||||
} llama_grammar_element;
|
} llama_grammar_element;
|
||||||
|
|
||||||
struct llama_partial_utf8 {
|
struct llama_partial_utf8 {
|
||||||
@@ -63,6 +69,7 @@ struct llama_grammar_candidate {
|
|||||||
size_t index;
|
size_t index;
|
||||||
const uint32_t * code_points;
|
const uint32_t * code_points;
|
||||||
llama_partial_utf8 partial_utf8;
|
llama_partial_utf8 partial_utf8;
|
||||||
|
llama_token id;
|
||||||
};
|
};
|
||||||
|
|
||||||
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
||||||
@@ -88,10 +95,13 @@ std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
|||||||
const llama_grammar_candidates & candidates);
|
const llama_grammar_candidates & candidates);
|
||||||
|
|
||||||
struct llama_grammar_parser {
|
struct llama_grammar_parser {
|
||||||
|
const llama_vocab * vocab;
|
||||||
std::map<std::string, uint32_t> symbol_ids;
|
std::map<std::string, uint32_t> symbol_ids;
|
||||||
|
|
||||||
llama_grammar_rules rules;
|
llama_grammar_rules rules;
|
||||||
|
|
||||||
|
llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {}
|
||||||
|
|
||||||
llama_grammar_stack c_rules() const;
|
llama_grammar_stack c_rules() const;
|
||||||
|
|
||||||
uint32_t get_symbol_id(const char * src, size_t len);
|
uint32_t get_symbol_id(const char * src, size_t len);
|
||||||
@@ -123,6 +133,9 @@ struct llama_grammar_trigger_pattern {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct llama_grammar {
|
struct llama_grammar {
|
||||||
|
// maintain a list of llama_tokens and their positions in the trigger_buffer
|
||||||
|
using token_pos = std::pair<llama_token, std::pair<size_t, size_t>>;
|
||||||
|
|
||||||
// note: allow null vocab for testing (not great)
|
// note: allow null vocab for testing (not great)
|
||||||
const llama_vocab * vocab;
|
const llama_vocab * vocab;
|
||||||
const ollama_vocab * o_vocab;
|
const ollama_vocab * o_vocab;
|
||||||
@@ -139,6 +152,7 @@ struct llama_grammar {
|
|||||||
bool lazy = false;
|
bool lazy = false;
|
||||||
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
||||||
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
|
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
|
||||||
|
std::vector<token_pos> trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found.
|
||||||
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
|
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
|
||||||
std::vector<llama_grammar_trigger_pattern>
|
std::vector<llama_grammar_trigger_pattern>
|
||||||
trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
|
trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
|
||||||
@@ -185,3 +199,8 @@ void llama_grammar_accept_impl(
|
|||||||
void llama_grammar_accept_str(
|
void llama_grammar_accept_str(
|
||||||
struct llama_grammar & grammar,
|
struct llama_grammar & grammar,
|
||||||
const std::string & piece);
|
const std::string & piece);
|
||||||
|
|
||||||
|
void llama_grammar_accept_token(
|
||||||
|
struct llama_grammar & grammar,
|
||||||
|
llama_token token,
|
||||||
|
const std::string & piece);
|
||||||
|
|||||||
127
llama/llama.cpp/src/llama-graph.cpp
vendored
127
llama/llama.cpp/src/llama-graph.cpp
vendored
@@ -71,11 +71,14 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
|||||||
if (ubatch->pos && attn_scale) {
|
if (ubatch->pos && attn_scale) {
|
||||||
const int64_t n_tokens = ubatch->n_tokens;
|
const int64_t n_tokens = ubatch->n_tokens;
|
||||||
|
|
||||||
|
GGML_ASSERT(f_attn_temp_scale != 0.0f);
|
||||||
|
GGML_ASSERT(n_attn_temp_floor_scale != 0);
|
||||||
|
|
||||||
std::vector<float> attn_scale_data(n_tokens, 0.0f);
|
std::vector<float> attn_scale_data(n_tokens, 0.0f);
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
const float pos = ubatch->pos[i];
|
const float pos = ubatch->pos[i];
|
||||||
attn_scale_data[i] = std::log(
|
attn_scale_data[i] = std::log(
|
||||||
std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
|
std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
|
||||||
) * f_attn_temp_scale + 1.0;
|
) * f_attn_temp_scale + 1.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,6 +254,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
|
||||||
|
const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
|
||||||
|
|
||||||
|
this->mctx = mctx;
|
||||||
|
|
||||||
|
bool res = true;
|
||||||
|
|
||||||
|
res &= s_copy->ne[0] == mctx->get_n_rs();
|
||||||
|
|
||||||
|
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||||
|
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
|
||||||
|
|
||||||
|
res &= head == mctx->get_head();
|
||||||
|
res &= rs_z == mctx->get_rs_z();
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
||||||
GGML_UNUSED(ubatch);
|
GGML_UNUSED(ubatch);
|
||||||
|
|
||||||
@@ -382,7 +403,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
|||||||
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||||
|
|
||||||
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
|
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
|
||||||
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@@ -413,10 +434,10 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
|||||||
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||||
|
|
||||||
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
|
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
|
||||||
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||||
|
|
||||||
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
|
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
|
||||||
res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@@ -449,7 +470,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
for (int i = n_tokens; i < n_tokens; ++i) {
|
||||||
for (int j = 0; j < n_enc; ++j) {
|
for (int j = 0; j < n_enc; ++j) {
|
||||||
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
||||||
}
|
}
|
||||||
@@ -458,8 +479,46 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||||
inp_attn->set_input(ubatch);
|
mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
||||||
inp_rs->set_input(ubatch);
|
mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
|
||||||
|
|
||||||
|
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
|
|
||||||
|
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||||
|
|
||||||
|
if (inp_rs->s_copy) {
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
|
||||||
|
int32_t * data = (int32_t *) inp_rs->s_copy->data;
|
||||||
|
|
||||||
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||||
|
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||||
|
data[i] = mctx->get_recr()->s_copy(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
||||||
|
const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
|
||||||
|
|
||||||
|
this->mctx = mctx;
|
||||||
|
|
||||||
|
bool res = true;
|
||||||
|
|
||||||
|
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||||
|
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||||
|
|
||||||
|
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
|
||||||
|
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||||
|
|
||||||
|
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||||
|
|
||||||
|
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||||
|
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
||||||
|
|
||||||
|
res &= inp_rs->head == mctx->get_recr()->get_head();
|
||||||
|
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -958,25 +1017,25 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||||||
// organize experts into n_expert_groups
|
// organize experts into n_expert_groups
|
||||||
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
|
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
|
||||||
|
|
||||||
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
|
ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
|
||||||
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
|
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
|
||||||
|
|
||||||
// get top n_group_used expert groups
|
// get top n_group_used expert groups
|
||||||
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
|
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
|
||||||
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
|
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
|
||||||
|
|
||||||
ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
|
ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
|
||||||
cb(expert_groups, "ffn_moe_group_topk", il);
|
cb(expert_groups, "ffn_moe_group_topk", il);
|
||||||
|
|
||||||
// mask out the other groups
|
// mask out the other groups
|
||||||
selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
|
selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
|
||||||
selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
|
selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
|
||||||
selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
|
selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
|
||||||
cb(selection_probs, "ffn_moe_probs_masked", il);
|
cb(selection_probs, "ffn_moe_probs_masked", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
// select experts
|
// select experts
|
||||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||||
cb(selected_experts, "ffn_moe_topk", il);
|
cb(selected_experts, "ffn_moe_topk", il);
|
||||||
|
|
||||||
@@ -1006,10 +1065,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
|
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
|
||||||
cb(weights_sum, "ffn_moe_weights_sum", il);
|
cb(weights_sum, "ffn_moe_weights_sum", il);
|
||||||
|
|
||||||
if (arch == LLM_ARCH_BAILINGMOE2) {
|
// Avoid division by zero, clamp to smallest number representable by F16
|
||||||
weights_sum = ggml_scale_bias(ctx0, weights_sum, 1.0, 1e-20);
|
weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
|
||||||
cb(weights_sum, "ffn_moe_weights_sum_biased", il);
|
cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
|
||||||
}
|
|
||||||
|
|
||||||
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
|
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
|
||||||
cb(weights, "ffn_moe_weights_norm", il);
|
cb(weights, "ffn_moe_weights_norm", il);
|
||||||
@@ -1087,6 +1145,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||||||
cur = ggml_relu(ctx0, cur);
|
cur = ggml_relu(ctx0, cur);
|
||||||
cb(cur, "ffn_moe_relu", il);
|
cb(cur, "ffn_moe_relu", il);
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_FFN_RELU_SQR:
|
||||||
|
if (gate_exps) {
|
||||||
|
// TODO: add support for gated squared relu
|
||||||
|
GGML_ABORT("fatal error: gated squared relu not implemented");
|
||||||
|
} else {
|
||||||
|
cur = ggml_relu(ctx0, cur);
|
||||||
|
cur = ggml_sqr(ctx0, cur);
|
||||||
|
cb(cur, "ffn_moe_relu_sqr", il);
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
@@ -1137,7 +1204,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||||||
|
|
||||||
// input embeddings with optional lora
|
// input embeddings with optional lora
|
||||||
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd_inp();
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_embd>();
|
auto inp = std::make_unique<llm_graph_input_embd>();
|
||||||
|
|
||||||
@@ -1201,7 +1268,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
|
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
|
||||||
|
|
||||||
auto & cur = inp->attn_scale;
|
auto & cur = inp->attn_scale;
|
||||||
|
|
||||||
@@ -1274,7 +1341,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
|||||||
// return cur;
|
// return cur;
|
||||||
//}
|
//}
|
||||||
|
|
||||||
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
|
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
|
||||||
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
||||||
|
|
||||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
|
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
|
||||||
@@ -1468,13 +1535,13 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|||||||
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
||||||
|
|
||||||
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
||||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
|
|
||||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
|
||||||
ggml_set_input(inp->self_kq_mask_swa);
|
ggml_set_input(inp->self_kq_mask_swa);
|
||||||
|
|
||||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||||
@@ -1556,7 +1623,7 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
|||||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
@@ -1587,9 +1654,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||||||
int il) const {
|
int il) const {
|
||||||
// these nodes are added to the graph together so that they are not reordered
|
// these nodes are added to the graph together so that they are not reordered
|
||||||
// by doing so, the number of splits in the graph is reduced
|
// by doing so, the number of splits in the graph is reduced
|
||||||
|
// expand k later to enable rope fusion which directly writes into k-v cache
|
||||||
ggml_build_forward_expand(gf, q_cur);
|
ggml_build_forward_expand(gf, q_cur);
|
||||||
ggml_build_forward_expand(gf, k_cur);
|
|
||||||
ggml_build_forward_expand(gf, v_cur);
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
|
|
||||||
const auto * mctx_cur = inp->mctx;
|
const auto * mctx_cur = inp->mctx;
|
||||||
|
|
||||||
@@ -1698,7 +1766,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|||||||
|
|
||||||
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
||||||
|
|
||||||
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
|
||||||
ggml_set_input(inp->cross_kq_mask);
|
ggml_set_input(inp->cross_kq_mask);
|
||||||
|
|
||||||
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
||||||
@@ -1764,7 +1832,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||||||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
@@ -1778,7 +1846,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||||||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||||
ggml_set_input(inp->self_kq_mask_swa);
|
ggml_set_input(inp->self_kq_mask_swa);
|
||||||
|
|
||||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||||
@@ -1838,6 +1906,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
|
|||||||
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
|
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
|
||||||
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
|
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
|
||||||
|
|
||||||
|
inp->head = mctx_cur->get_head();
|
||||||
|
inp->rs_z = mctx_cur->get_rs_z();
|
||||||
|
|
||||||
return inp;
|
return inp;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1906,10 +1977,10 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|||||||
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
||||||
|
|
||||||
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
|
||||||
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||||
|
|
||||||
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
||||||
}
|
}
|
||||||
@@ -2030,7 +2101,7 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck
|
|||||||
|
|
||||||
if (bidirectional) {
|
if (bidirectional) {
|
||||||
relative_bucket += (relative_position > 0) * n_buckets;
|
relative_bucket += (relative_position > 0) * n_buckets;
|
||||||
relative_position = abs(relative_position);
|
relative_position = std::abs(relative_position);
|
||||||
} else {
|
} else {
|
||||||
relative_position = -std::min<int32_t>(relative_position, 0);
|
relative_position = -std::min<int32_t>(relative_position, 0);
|
||||||
}
|
}
|
||||||
|
|||||||
21
llama/llama.cpp/src/llama-graph.h
vendored
21
llama/llama.cpp/src/llama-graph.h
vendored
@@ -132,8 +132,8 @@ public:
|
|||||||
// temperature tuning, used by llama4
|
// temperature tuning, used by llama4
|
||||||
class llm_graph_input_attn_temp : public llm_graph_input_i {
|
class llm_graph_input_attn_temp : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
|
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
|
||||||
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
|
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
|
||||||
virtual ~llm_graph_input_attn_temp() = default;
|
virtual ~llm_graph_input_attn_temp() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
@@ -142,6 +142,7 @@ public:
|
|||||||
|
|
||||||
const uint32_t n_attn_temp_floor_scale;
|
const uint32_t n_attn_temp_floor_scale;
|
||||||
const float f_attn_temp_scale;
|
const float f_attn_temp_scale;
|
||||||
|
const float f_attn_temp_offset;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_pos_bucket : public llm_graph_input_i {
|
class llm_graph_input_pos_bucket : public llm_graph_input_i {
|
||||||
@@ -224,6 +225,8 @@ public:
|
|||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
bool can_reuse(const llm_graph_params & params) override;
|
||||||
|
|
||||||
ggml_tensor * s_copy; // I32 [n_rs]
|
ggml_tensor * s_copy; // I32 [n_rs]
|
||||||
|
|
||||||
// views of s_copy, computed once per graph
|
// views of s_copy, computed once per graph
|
||||||
@@ -232,6 +235,10 @@ public:
|
|||||||
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
|
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
|
||||||
|
|
||||||
const llama_memory_recurrent_context * mctx;
|
const llama_memory_recurrent_context * mctx;
|
||||||
|
|
||||||
|
// used in view offsets, need to match for valid graph reuse
|
||||||
|
uint32_t head;
|
||||||
|
int32_t rs_z;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||||
@@ -364,22 +371,28 @@ public:
|
|||||||
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_mem_hybrid(
|
llm_graph_input_mem_hybrid(
|
||||||
|
const llama_cparams & cparams,
|
||||||
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
|
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
|
||||||
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
||||||
const llama_memory_hybrid_context * mctx) :
|
const llama_memory_hybrid_context * mctx) :
|
||||||
inp_attn(std::move(inp_attn)),
|
inp_attn(std::move(inp_attn)),
|
||||||
inp_rs(std::move(inp_rs)),
|
inp_rs(std::move(inp_rs)),
|
||||||
|
cparams(cparams),
|
||||||
mctx(mctx) { }
|
mctx(mctx) { }
|
||||||
virtual ~llm_graph_input_mem_hybrid() = default;
|
virtual ~llm_graph_input_mem_hybrid() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
bool can_reuse(const llm_graph_params & params) override;
|
||||||
|
|
||||||
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
|
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
|
||||||
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
||||||
|
|
||||||
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
|
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
|
||||||
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
||||||
|
|
||||||
|
const llama_cparams cparams;
|
||||||
|
|
||||||
const llama_memory_hybrid_context * mctx;
|
const llama_memory_hybrid_context * mctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user