mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-22 06:43:57 +00:00
Compare commits
214 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
501cb38b8c | ||
|
|
5994e8e8fd | ||
|
|
59e3a35203 | ||
|
|
b3e6120736 | ||
|
|
fb92b61754 | ||
|
|
8149a3c86e | ||
|
|
0cc90a8186 | ||
|
|
e42300f25b | ||
|
|
66e73809a1 | ||
|
|
c632fdbad8 | ||
|
|
517807cdf2 | ||
|
|
ead4a9a1d0 | ||
|
|
4383a3ab7a | ||
|
|
9d97e6a9f1 | ||
|
|
1081532430 | ||
|
|
59412fbb43 | ||
|
|
86834a2797 | ||
|
|
85ccf7354d | ||
|
|
30fb7e19f8 | ||
|
|
d3450dd52e | ||
|
|
4bcb04ad88 | ||
|
|
e3d5708754 | ||
|
|
4be4dc8717 | ||
|
|
109d4fc3b4 | ||
|
|
2cb0a580f3 | ||
|
|
7cce5aac76 | ||
|
|
131c496340 | ||
|
|
4ae4f47b16 | ||
|
|
073fa31df5 | ||
|
|
91fc3c48e3 | ||
|
|
6de62664d9 | ||
|
|
463a6caad8 | ||
|
|
fc5fb09f51 | ||
|
|
05ccb17c6e | ||
|
|
f804e8a460 | ||
|
|
9cfbffafc5 | ||
|
|
470d580205 | ||
|
|
b517bb1c19 | ||
|
|
e3ade453a8 | ||
|
|
048bd4472a | ||
|
|
ec8bf5e6c5 | ||
|
|
709bbb0b6d | ||
|
|
abeec240f9 | ||
|
|
df335aac09 | ||
|
|
026bc29237 | ||
|
|
883d031268 | ||
|
|
5271ff8559 | ||
|
|
d6f7233a1c | ||
|
|
8de1da4767 | ||
|
|
d925b5350c | ||
|
|
6eaf194b85 | ||
|
|
d5a0d8d904 | ||
|
|
ef7d26ba2c | ||
|
|
1a19df1f3a | ||
|
|
7ccfd97a93 | ||
|
|
c385ca8672 | ||
|
|
837379a94c | ||
|
|
a24f90604f | ||
|
|
dc5a645434 | ||
|
|
bb71654ebe | ||
|
|
d4af9f04f9 | ||
|
|
a343ae53a4 | ||
|
|
d0cf6c8281 | ||
|
|
8f4ec9ab28 | ||
|
|
dbfd7bd027 | ||
|
|
ee04dbba51 | ||
|
|
ea7657b54a | ||
|
|
2c776f0780 | ||
|
|
79f6376f5b | ||
|
|
756c78cfc7 | ||
|
|
d7f4f788d1 | ||
|
|
114c3f2265 | ||
|
|
f2e9c9aff5 | ||
|
|
aa9d889522 | ||
|
|
735c41f9ca | ||
|
|
223a619468 | ||
|
|
759dd78dd6 | ||
|
|
44bc36d063 | ||
|
|
8f14e1f5f6 | ||
|
|
203c137810 | ||
|
|
fa8be9e35c | ||
|
|
8a75e9ee15 | ||
|
|
9231379bce | ||
|
|
c7ba6128b4 | ||
|
|
8970233a2b | ||
|
|
cde948f976 | ||
|
|
7c8aba0d83 | ||
|
|
4742e12c23 | ||
|
|
2d06977ade | ||
|
|
30f8a68c4c | ||
|
|
e378e33421 | ||
|
|
fcec04bf42 | ||
|
|
ee92ca3e1d | ||
|
|
8253ad4d2b | ||
|
|
fa7776fd24 | ||
|
|
0d38b66502 | ||
|
|
e5e077b4b7 | ||
|
|
4183bb0574 | ||
|
|
ff89ba90bc | ||
|
|
6dcc5dfb9c | ||
|
|
25911a6e6b | ||
|
|
8afa6e83f2 | ||
|
|
ea85e27bbd | ||
|
|
c116a7523d | ||
|
|
3515cc377c | ||
|
|
bbf66c0b96 | ||
|
|
764be7480f | ||
|
|
b72e5adb14 | ||
|
|
80b538e312 | ||
|
|
4f8a0166cc | ||
|
|
1e6eab5c33 | ||
|
|
6c733bf0a6 | ||
|
|
3bac5cba60 | ||
|
|
4151ef8cf7 | ||
|
|
e4ff6e6c0f | ||
|
|
82da19c634 | ||
|
|
bdd9d22dfd | ||
|
|
5fc38d042f | ||
|
|
475a11d08e | ||
|
|
191d94289d | ||
|
|
802ad16ce4 | ||
|
|
5e67f4f90e | ||
|
|
e840ccb523 | ||
|
|
b4fe3adc0a | ||
|
|
d73f8aa8c3 | ||
|
|
92c2e8a56c | ||
|
|
2e3fd86d48 | ||
|
|
4261a3b0b2 | ||
|
|
acef9b4c1b | ||
|
|
9a43994c45 | ||
|
|
f8a6e88819 | ||
|
|
35fda7b4af | ||
|
|
66fb8575ce | ||
|
|
20c3266e94 | ||
|
|
34088dbcfb | ||
|
|
e41dd73705 | ||
|
|
43107b15b9 | ||
|
|
1f91cb0c8c | ||
|
|
12d8ad0d38 | ||
|
|
592d21e7db | ||
|
|
5a08b01f5b | ||
|
|
4f473e224c | ||
|
|
9d60bb44cf | ||
|
|
f371260e75 | ||
|
|
c9e6d7719e | ||
|
|
2c4ce40334 | ||
|
|
5d8c173529 | ||
|
|
44b17d2bfa | ||
|
|
4ad87b58bb | ||
|
|
3b8b692218 | ||
|
|
4129af9205 | ||
|
|
45f216a9c7 | ||
|
|
d0b32def60 | ||
|
|
11ffc36157 | ||
|
|
ba04902670 | ||
|
|
3944602f51 | ||
|
|
73b642e6f3 | ||
|
|
ad118d8b13 | ||
|
|
f08534137b | ||
|
|
4b4a90f233 | ||
|
|
03274a6b2f | ||
|
|
cc6463ebca | ||
|
|
405d2f628f | ||
|
|
a3f7dd3e98 | ||
|
|
c85c0ebf89 | ||
|
|
10a8e04a8d | ||
|
|
1c6669e64c | ||
|
|
b2b270ad5d | ||
|
|
2bb69b40c7 | ||
|
|
65bff664cb | ||
|
|
c088ac0e79 | ||
|
|
0a066cfd91 | ||
|
|
87b7af6cee | ||
|
|
f2527b08fb | ||
|
|
71a4057fcf | ||
|
|
5ab7422508 | ||
|
|
8bcb3125c1 | ||
|
|
6baf1e31e2 | ||
|
|
ed567ef43b | ||
|
|
a6e64fbdf2 | ||
|
|
60cfa2a203 | ||
|
|
55bbf3b4a1 | ||
|
|
6bda1d2479 | ||
|
|
50f2219dd6 | ||
|
|
9e125d884c | ||
|
|
a6fbfc880c | ||
|
|
502028968d | ||
|
|
5a8eb0e151 | ||
|
|
9f8a18ec05 | ||
|
|
6b04cad7e8 | ||
|
|
45f56355d5 | ||
|
|
0dabb4ef6a | ||
|
|
2e77aa1ae7 | ||
|
|
deaabe292d | ||
|
|
af21a5ac39 | ||
|
|
f63d7f68eb | ||
|
|
82ad1dbc07 | ||
|
|
feeabdadd2 | ||
|
|
fc0309615e | ||
|
|
09d308d6b6 | ||
|
|
a8ed68bd93 | ||
|
|
2ae65ae471 | ||
|
|
a3b6886b7d | ||
|
|
c6a6d7294d | ||
|
|
2cf007c9d1 | ||
|
|
0683efa637 | ||
|
|
0943001193 | ||
|
|
5c42800fca | ||
|
|
65f10c2823 | ||
|
|
aaa7818000 | ||
|
|
f15ffc4320 | ||
|
|
20c5fd39c8 | ||
|
|
d2ee599dcf | ||
|
|
6ed8898590 |
188
.github/workflows/release.yaml
vendored
188
.github/workflows/release.yaml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
|
||||
|
||||
darwin-build:
|
||||
runs-on: macos-13
|
||||
runs-on: macos-13-xlarge
|
||||
environment: release
|
||||
needs: setup-environment
|
||||
strategy:
|
||||
@@ -54,48 +54,6 @@ jobs:
|
||||
name: build-${{ matrix.os }}-${{ matrix.arch }}
|
||||
path: dist/*
|
||||
|
||||
darwin-sign:
|
||||
runs-on: macos-13
|
||||
environment: release
|
||||
needs: darwin-build
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- run: |
|
||||
echo $MACOS_SIGNING_KEY | base64 --decode > certificate.p12
|
||||
security create-keychain -p password build.keychain
|
||||
security default-keychain -s build.keychain
|
||||
security unlock-keychain -p password build.keychain
|
||||
security import certificate.p12 -k build.keychain -P $MACOS_SIGNING_KEY_PASSWORD -T /usr/bin/codesign
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k password build.keychain
|
||||
security set-keychain-settings -lut 3600 build.keychain
|
||||
env:
|
||||
MACOS_SIGNING_KEY: ${{ secrets.MACOS_SIGNING_KEY }}
|
||||
MACOS_SIGNING_KEY_PASSWORD: ${{ secrets.MACOS_SIGNING_KEY_PASSWORD }}
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: build-darwin-amd64
|
||||
path: dist/darwin-amd64
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: build-darwin-arm64
|
||||
path: dist/darwin-arm64
|
||||
- run: |
|
||||
export VERSION=${GITHUB_REF_NAME#v}
|
||||
./scripts/build_darwin.sh sign macapp
|
||||
env:
|
||||
APPLE_IDENTITY: ${{ secrets.APPLE_IDENTITY }}
|
||||
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
|
||||
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
|
||||
APPLE_ID: ${{ vars.APPLE_ID }}
|
||||
SDKROOT: /Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
|
||||
DEVELOPER_DIR: /Applications/Xcode_14.1.0.app/Contents/Developer
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-darwin
|
||||
path: |
|
||||
dist/Ollama-darwin.zip
|
||||
dist/ollama-darwin.tgz
|
||||
|
||||
windows-depends:
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -103,21 +61,18 @@ jobs:
|
||||
arch: [amd64]
|
||||
preset: ['CPU']
|
||||
include:
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'CUDA 11'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||
cuda-version: '11.3'
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'CUDA 12'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||
cuda-version: '12.8'
|
||||
flags: ''
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'ROCm 6'
|
||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||
rocm-version: '6.2'
|
||||
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
env:
|
||||
@@ -160,6 +115,9 @@ jobs:
|
||||
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- if: matrix.preset == 'CPU'
|
||||
run: |
|
||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
@@ -178,9 +136,9 @@ jobs:
|
||||
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
|
||||
- name: Build target "${{ matrix.preset }}"
|
||||
run: |
|
||||
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}"
|
||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
|
||||
env:
|
||||
@@ -230,61 +188,11 @@ jobs:
|
||||
go-version-file: go.mod
|
||||
- run: |
|
||||
go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ .
|
||||
- if: matrix.arch == 'arm64'
|
||||
run: |
|
||||
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "dist\windows-arm64\vc_redist.arm64.exe"
|
||||
- run: |
|
||||
$env:VERSION='${{ github.ref_name }}' -Replace "v(.*)", '$1'
|
||||
& .\scripts\build_windows.ps1 buildApp
|
||||
env:
|
||||
VCToolsRedistDir: stub
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: build-${{ matrix.os }}-${{ matrix.arch }}
|
||||
path: |
|
||||
dist\${{ matrix.os }}-${{ matrix.arch }}\*.exe
|
||||
dist\${{ matrix.os }}-${{ matrix.arch }}-app.exe
|
||||
|
||||
windows-sign:
|
||||
runs-on: windows-2022
|
||||
environment: release
|
||||
needs: [windows-depends, windows-build]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: google-github-actions/auth@v2
|
||||
with:
|
||||
project_id: ollama
|
||||
credentials_json: ${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}
|
||||
- run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${{ runner.temp }}\sdksetup.exe"
|
||||
Start-Process "${{ runner.temp }}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait
|
||||
|
||||
Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${{ runner.temp }}\plugin.zip"
|
||||
Expand-Archive -Path "${{ runner.temp }}\plugin.zip" -DestinationPath "${{ runner.temp }}\plugin\"
|
||||
& "${{ runner.temp }}\plugin\*\kmscng.msi" /quiet
|
||||
|
||||
echo "${{ vars.OLLAMA_CERT }}" >ollama_inc.crt
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: build-windows-*
|
||||
path: dist\
|
||||
merge-multiple: true
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: depends-windows-amd64-*
|
||||
path: dist\windows-amd64\
|
||||
merge-multiple: true
|
||||
- run: |
|
||||
& .\scripts\build_windows.ps1 gatherDependencies sign buildInstaller distZip
|
||||
env:
|
||||
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-windows
|
||||
path: |
|
||||
dist\OllamaSetup.exe
|
||||
dist\ollama-windows-*.zip
|
||||
|
||||
linux-build:
|
||||
strategy:
|
||||
@@ -317,21 +225,26 @@ jobs:
|
||||
CGO_CFLAGS=${{ env.CGO_CFLAGS }}
|
||||
CGO_CXXFLAGS=${{ env.CGO_CXXFLAGS }}
|
||||
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
cache-from: type=registry,ref=ollama/ollama:latest
|
||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||
cache-to: type=inline
|
||||
- run: |
|
||||
for COMPONENT in bin/* lib/ollama/*; do
|
||||
case "$COMPONENT" in
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v11) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v12) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
esac
|
||||
done
|
||||
working-directory: dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
- run: |
|
||||
echo "Manifests"
|
||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in ; do
|
||||
echo $ARCHIVE
|
||||
cat $ARCHIVE
|
||||
done
|
||||
- run: |
|
||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
|
||||
@@ -385,8 +298,8 @@ jobs:
|
||||
context: .
|
||||
platforms: ${{ matrix.os }}/${{ matrix.arch }}
|
||||
build-args: ${{ matrix.build-args }}
|
||||
outputs: type=image,name=ollama/ollama,push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=registry,ref=ollama/ollama:latest
|
||||
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||
cache-to: type=inline
|
||||
- run: |
|
||||
mkdir -p ${{ matrix.os }}-${{ matrix.arch }}
|
||||
@@ -418,7 +331,7 @@ jobs:
|
||||
latest=false
|
||||
suffix=${{ matrix.suffix }}
|
||||
images: |
|
||||
ollama/ollama
|
||||
${{ vars.DOCKER_REPO }}
|
||||
tags: |
|
||||
type=ref,enable=true,priority=600,prefix=pr-,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
@@ -428,56 +341,24 @@ jobs:
|
||||
path: ${{ runner.temp }}
|
||||
merge-multiple: true
|
||||
- run: |
|
||||
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf 'ollama/ollama@%s ')
|
||||
docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
|
||||
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf '${{ vars.DOCKER_REPO }}@%s ')
|
||||
docker buildx imagetools inspect ${{ vars.DOCKER_REPO }}:${{ steps.metadata.outputs.version }}
|
||||
working-directory: ${{ runner.temp }}
|
||||
|
||||
# Trigger downstream release process
|
||||
trigger:
|
||||
runs-on: ubuntu-latest
|
||||
environment: release
|
||||
needs: [darwin-build, windows-build, windows-depends]
|
||||
steps:
|
||||
- name: Trigger downstream release process
|
||||
run: |
|
||||
curl -L \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
|
||||
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\"}}"
|
||||
|
||||
# Aggregate all the assets and ship a release
|
||||
release:
|
||||
needs: [darwin-sign, windows-sign, linux-build]
|
||||
runs-on: linux
|
||||
environment: release
|
||||
needs: [darwin-build, windows-build, windows-depends, linux-build]
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist-darwin
|
||||
path: dist
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist-windows
|
||||
path: dist
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: dist-linux-*
|
||||
path: dist
|
||||
merge-multiple: true
|
||||
- run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
|
||||
working-directory: dist
|
||||
- name: Create or update Release
|
||||
- name: Create or update Release for tag
|
||||
run: |
|
||||
RELEASE_VERSION="$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)"
|
||||
|
||||
echo "Looking for existing release for ${RELEASE_VERSION}"
|
||||
OLD_TAG=$(gh release ls --json name,tagName | jq -r ".[] | select(.name == \"${RELEASE_VERSION}\") | .tagName")
|
||||
if [ -n "$OLD_TAG" ]; then
|
||||
@@ -491,5 +372,12 @@ jobs:
|
||||
--generate-notes \
|
||||
--prerelease
|
||||
fi
|
||||
echo "Uploading artifacts for tag ${GITHUB_REF_NAME}"
|
||||
gh release upload ${GITHUB_REF_NAME} dist/* --clobber
|
||||
- name: Trigger downstream release process
|
||||
run: |
|
||||
curl -L \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
|
||||
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\", \"origin\": \"${GITHUB_REPOSITORY}\", \"publish\": \"1\"}}"
|
||||
|
||||
17
.github/workflows/test.yaml
vendored
17
.github/workflows/test.yaml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||
}
|
||||
|
||||
echo changed=$(changed 'llama/llama.cpp/**' 'ml/backend/ggml/ggml/**') | tee -a $GITHUB_OUTPUT
|
||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
||||
|
||||
linux:
|
||||
needs: [changes]
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
include:
|
||||
- preset: CPU
|
||||
- preset: CUDA
|
||||
container: nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||
container: nvidia/cuda:12.8.1-devel-ubuntu22.04
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||
- preset: ROCm
|
||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
||||
@@ -78,11 +78,11 @@ jobs:
|
||||
include:
|
||||
- preset: CPU
|
||||
- preset: CUDA
|
||||
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||
- preset: ROCm
|
||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010'
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
runs-on: windows
|
||||
steps:
|
||||
- run: |
|
||||
@@ -102,7 +102,7 @@ jobs:
|
||||
$ErrorActionPreference = "Stop"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait
|
||||
}
|
||||
|
||||
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
||||
@@ -120,6 +120,9 @@ jobs:
|
||||
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
@@ -133,8 +136,8 @@ jobs:
|
||||
path: ${{ github.workspace }}\.ccache
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
||||
- run: |
|
||||
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||
env:
|
||||
|
||||
@@ -3,6 +3,7 @@ cmake_minimum_required(VERSION 3.21)
|
||||
project(Ollama C CXX)
|
||||
|
||||
include(CheckLanguage)
|
||||
include(GNUInstallDirs)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
@@ -51,7 +52,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
||||
|
||||
add_compile_definitions(NDEBUG)
|
||||
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||
|
||||
set(GGML_CPU ON)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||
@@ -78,14 +79,13 @@ if(CMAKE_CUDA_COMPILER)
|
||||
|
||||
find_package(CUDAToolkit)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||
set(OLLAMA_CUDA_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/cuda_v${CUDAToolkit_VERSION_MAJOR})
|
||||
install(TARGETS ggml-cuda
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
|
||||
LIBRARY DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -101,7 +101,7 @@ if(CMAKE_HIP_COMPILER)
|
||||
|
||||
find_package(hip REQUIRED)
|
||||
if(NOT AMDGPU_TARGETS)
|
||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|900(:xnack-)|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011(:xnack-)|1012(:xnack-)|103[0-6]|110[0-3]|115[01]|1201)$")
|
||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011(:xnack-)|1012(:xnack-)|103[0-6]|110[0-3]|115[01]|120[01])$")
|
||||
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||
endif()
|
||||
@@ -117,7 +117,11 @@ if(CMAKE_HIP_COMPILER)
|
||||
|
||||
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
||||
install(TARGETS ggml-hip
|
||||
RUNTIME_DEPENDENCIES
|
||||
RUNTIME_DEPENDENCY_SET rocm
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||
)
|
||||
install(RUNTIME_DEPENDENCY_SET rocm
|
||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
"binaryDir": "${sourceDir}/build",
|
||||
"installDir": "${sourceDir}/dist",
|
||||
"cacheVariables": {
|
||||
"CMAKE_BUILD_TYPE": "Release"
|
||||
"CMAKE_BUILD_TYPE": "Release",
|
||||
"CMAKE_MSVC_RUNTIME_LIBRARY": "MultiThreaded"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -17,20 +18,12 @@
|
||||
"name": "CUDA",
|
||||
"inherits": [ "Default" ]
|
||||
},
|
||||
{
|
||||
"name": "CUDA 11",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -58,7 +51,8 @@
|
||||
"name": "ROCm 6",
|
||||
"inherits": [ "ROCm" ],
|
||||
"cacheVariables": {
|
||||
"AMDGPU_TARGETS": "gfx803;gfx902;gfx1030;gfx1031;gfx1032;gfx1034;gfx1035;gfx1036;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1201;gfx900:xnack-;gfx906:xnack-;gfx90c:xnack-;gfx1010:xnack-;gfx1011:xnack-;gfx1012:xnack-;"
|
||||
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
|
||||
"AMDGPU_TARGETS": "gfx803;gfx902;gfx1030;gfx1031;gfx1032;gfx1034;gfx1035;gfx1036;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1200;gfx1201;gfx900:xnack-;gfx906:xnack-;gfx90c:xnack-;gfx1010:xnack-;gfx1011:xnack-;gfx1012:xnack-;"
|
||||
}
|
||||
}
|
||||
],
|
||||
@@ -78,11 +72,6 @@
|
||||
"configurePreset": "CUDA",
|
||||
"targets": [ "ggml-cuda" ]
|
||||
},
|
||||
{
|
||||
"name": "CUDA 11",
|
||||
"inherits": [ "CUDA" ],
|
||||
"configurePreset": "CUDA 11"
|
||||
},
|
||||
{
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
|
||||
@@ -65,7 +65,8 @@ continuation of the sentence:
|
||||
Examples:
|
||||
|
||||
llm/backend/mlx: support the llama architecture
|
||||
CONTRIBUTING: provide clairity on good commit messages, and bad
|
||||
CONTRIBUTING: provide clarity on good commit messages, and bad
|
||||
docs: simplify manual installation with shorter curl commands
|
||||
|
||||
Bad Examples:
|
||||
|
||||
|
||||
28
Dockerfile
28
Dockerfile
@@ -7,12 +7,13 @@ ARG JETPACK5VERSION=r35.4.1
|
||||
ARG JETPACK6VERSION=r36.4.0
|
||||
ARG CMAKEVERSION=3.31.2
|
||||
|
||||
# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
||||
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
||||
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||
RUN yum install -y yum-utils \
|
||||
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
||||
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
||||
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
|
||||
&& dnf install -y ccache \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||
|
||||
@@ -38,15 +39,6 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
&& cmake --build --parallel --preset 'CPU' \
|
||||
&& cmake --install build --component CPU --strip --parallel 8
|
||||
|
||||
FROM base AS cuda-11
|
||||
ARG CUDA11VERSION=11.3
|
||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 11' \
|
||||
&& cmake --build --parallel --preset 'CUDA 11' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
|
||||
FROM base AS cuda-12
|
||||
ARG CUDA12VERSION=12.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||
@@ -94,27 +86,27 @@ RUN go mod download
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
FROM --platform=linux/amd64 scratch AS amd64
|
||||
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa
|
||||
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6
|
||||
|
||||
FROM scratch AS rocm
|
||||
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
|
||||
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
|
||||
|
||||
FROM ${FLAVOR} AS archive
|
||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||
COPY --from=build /bin/ollama /bin/ollama
|
||||
|
||||
FROM ubuntu:20.04
|
||||
FROM ubuntu:24.04
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y ca-certificates \
|
||||
&& apt-get clean \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
|
||||
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
|
||||
WORKDIR=llama/vendor
|
||||
FETCH_HEAD=de4c07f93783a1a96456a44dc16b9db538ee1618
|
||||
FETCH_HEAD=e54d41befcc1575f4c898c5ff4ef43970cead75f
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
@@ -12,7 +12,7 @@ help:
|
||||
@echo " clean Clean local repository"
|
||||
@echo
|
||||
@echo "Example:"
|
||||
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
|
||||
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean apply-patches sync"
|
||||
|
||||
.PHONY: sync
|
||||
sync: llama/build-info.cpp ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
|
||||
@@ -24,12 +24,12 @@ ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal: ml/backend/ggml/ggml
|
||||
go generate ./$(@D)
|
||||
|
||||
.PHONY: llama/llama.cpp
|
||||
llama/llama.cpp: llama/vendor/
|
||||
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
|
||||
llama/llama.cpp: llama/vendor
|
||||
rsync -arvzc --delete -f "include LICENSE" -f "merge $@/.rsync-filter" $(addprefix $<,/LICENSE /) $@
|
||||
|
||||
.PHONY: ml/backend/ggml/ggml
|
||||
ml/backend/ggml/ggml: llama/vendor/ggml/
|
||||
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
|
||||
ml/backend/ggml/ggml: llama/vendor
|
||||
rsync -arvzc --delete -f "include LICENSE" -f "merge $@/.rsync-filter" $(addprefix $<,/LICENSE /ggml/) $@
|
||||
|
||||
PATCHES=$(wildcard llama/patches/*.patch)
|
||||
PATCHED=$(join $(dir $(PATCHES)), $(addsuffix ed, $(addprefix ., $(notdir $(PATCHES)))))
|
||||
@@ -39,7 +39,15 @@ PATCHED=$(join $(dir $(PATCHES)), $(addsuffix ed, $(addprefix ., $(notdir $(PATC
|
||||
apply-patches: $(PATCHED)
|
||||
|
||||
llama/patches/.%.patched: llama/patches/%.patch
|
||||
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi
|
||||
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then \
|
||||
touch $@; \
|
||||
else \
|
||||
echo "Patch failed. Resolve any conflicts then continue."; \
|
||||
echo "1. Run 'git -C $(WORKDIR) am --continue'"; \
|
||||
echo "2. Run 'make -f $(lastword $(MAKEFILE_LIST)) format-patches'"; \
|
||||
echo "3. Run 'make -f $(lastword $(MAKEFILE_LIST)) clean apply-patches'"; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
.PHONY: checkout
|
||||
checkout: $(WORKDIR)
|
||||
@@ -60,4 +68,5 @@ format-patches: llama/patches
|
||||
|
||||
.PHONE: clean
|
||||
clean: checkout
|
||||
@git -C $(WORKDIR) am --abort || true
|
||||
$(RM) llama/patches/.*.patched
|
||||
|
||||
27
README.md
27
README.md
@@ -1,6 +1,6 @@
|
||||
<div align="center">
|
||||
<a href="https://ollama.com">
|
||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
<img alt="ollama" width="240" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
</a>
|
||||
</div>
|
||||
|
||||
@@ -10,7 +10,7 @@ Get up and running with large language models.
|
||||
|
||||
### macOS
|
||||
|
||||
[Download](https://ollama.com/download/Ollama-darwin.zip)
|
||||
[Download](https://ollama.com/download/Ollama.dmg)
|
||||
|
||||
### Windows
|
||||
|
||||
@@ -62,10 +62,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
||||
|
||||
## Quickstart
|
||||
|
||||
To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2):
|
||||
To run and chat with [Gemma 3](https://ollama.com/library/gemma3):
|
||||
|
||||
```shell
|
||||
ollama run llama3.2
|
||||
ollama run gemma3
|
||||
```
|
||||
|
||||
## Model library
|
||||
@@ -382,7 +382,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
|
||||
- [LLMChat](https://github.com/trendy-design/llmchat) (Privacy focused, 100% local, intuitive all-in-one chat interface)
|
||||
- [Local Multimodal AI Chat](https://github.com/Leon-Sander/Local-Multimodal-AI-Chat) (Ollama-based LLM Chat with support for multiple features, including PDF RAG, voice chat, image-based interactions, and integration with OpenAI.)
|
||||
- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG on Mac/Windows/Linux)
|
||||
- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG and deep research on Mac/Windows/Linux)
|
||||
- [OrionChat](https://github.com/EliasPereirah/OrionChat) - OrionChat is a web interface for chatting with different AI providers
|
||||
- [G1](https://github.com/bklieger-groq/g1) (Prototype of using prompting strategies to improve the LLM's reasoning through o1-like reasoning chains.)
|
||||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
||||
@@ -429,6 +429,12 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
||||
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -473,6 +479,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
|
||||
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
|
||||
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
|
||||
- [ollama-bash-toolshed](https://github.com/attogram/ollama-bash-toolshed) - Bash scripts to chat with tool using models. Add new tools to your shed with ease. Runs on Ollama.
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
@@ -553,6 +561,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
|
||||
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
|
||||
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
|
||||
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
|
||||
|
||||
### Mobile
|
||||
|
||||
@@ -609,11 +620,15 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
|
||||
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
||||
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
||||
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
|
||||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
||||
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
|
||||
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
|
||||
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
|
||||
|
||||
### Supported backends
|
||||
|
||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
### Observability
|
||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
|
||||
|
||||
@@ -222,10 +222,6 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
if errorResponse.Error != "" {
|
||||
return errors.New(errorResponse.Error)
|
||||
}
|
||||
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
@@ -234,6 +230,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
}
|
||||
}
|
||||
|
||||
if errorResponse.Error != "" {
|
||||
return errors.New(errorResponse.Error)
|
||||
}
|
||||
|
||||
if err := fn(bts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -89,6 +89,16 @@ func TestClientStream(t *testing.T) {
|
||||
},
|
||||
wantErr: "mid-stream error",
|
||||
},
|
||||
{
|
||||
name: "http status error takes precedence over general error",
|
||||
responses: []any{
|
||||
testError{
|
||||
message: "custom error message",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
},
|
||||
},
|
||||
wantErr: "500",
|
||||
},
|
||||
{
|
||||
name: "successful stream completion",
|
||||
responses: []any{
|
||||
|
||||
241
api/types.go
241
api/types.go
@@ -85,10 +85,15 @@ type GenerateRequest struct {
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Think controls whether thinking/reasoning models will think before
|
||||
// responding. Needs to be a pointer so we can distinguish between false
|
||||
// responding. Can be a boolean (true/false) or a string ("high", "medium", "low")
|
||||
// for supported models. Needs to be a pointer so we can distinguish between false
|
||||
// (request that thinking _not_ be used) and unset (use the old behavior
|
||||
// before this option was introduced)
|
||||
Think *bool `json:"think,omitempty"`
|
||||
Think *ThinkValue `json:"think,omitempty"`
|
||||
|
||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||
// template instead of calling the model.
|
||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -116,8 +121,13 @@ type ChatRequest struct {
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Think controls whether thinking/reasoning models will think before
|
||||
// responding
|
||||
Think *bool `json:"think,omitempty"`
|
||||
// responding. Can be a boolean (true/false) or a string ("high", "medium", "low")
|
||||
// for supported models.
|
||||
Think *ThinkValue `json:"think,omitempty"`
|
||||
|
||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||
// template instead of calling the model.
|
||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@@ -143,6 +153,7 @@ type Message struct {
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||
@@ -222,21 +233,76 @@ func (pt PropertyType) String() string {
|
||||
return fmt.Sprintf("%v", []string(pt))
|
||||
}
|
||||
|
||||
type ToolProperty struct {
|
||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||
Type PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||
func (tp ToolProperty) ToTypeScriptType() string {
|
||||
if len(tp.AnyOf) > 0 {
|
||||
var types []string
|
||||
for _, anyOf := range tp.AnyOf {
|
||||
types = append(types, anyOf.ToTypeScriptType())
|
||||
}
|
||||
return strings.Join(types, " | ")
|
||||
}
|
||||
|
||||
if len(tp.Type) == 0 {
|
||||
return "any"
|
||||
}
|
||||
|
||||
if len(tp.Type) == 1 {
|
||||
return mapToTypeScriptType(tp.Type[0])
|
||||
}
|
||||
|
||||
var types []string
|
||||
for _, t := range tp.Type {
|
||||
types = append(types, mapToTypeScriptType(t))
|
||||
}
|
||||
return strings.Join(types, " | ")
|
||||
}
|
||||
|
||||
// mapToTypeScriptType maps JSON Schema types to TypeScript types
|
||||
func mapToTypeScriptType(jsonType string) string {
|
||||
switch jsonType {
|
||||
case "string":
|
||||
return "string"
|
||||
case "number", "integer":
|
||||
return "number"
|
||||
case "boolean":
|
||||
return "boolean"
|
||||
case "array":
|
||||
return "any[]"
|
||||
case "object":
|
||||
return "Record<string, any>"
|
||||
case "null":
|
||||
return "null"
|
||||
default:
|
||||
return "any"
|
||||
}
|
||||
}
|
||||
|
||||
type ToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]ToolProperty `json:"properties"`
|
||||
}
|
||||
|
||||
func (t *ToolFunctionParameters) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
type ToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
} `json:"parameters"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters ToolFunctionParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
func (t *ToolFunction) String() string {
|
||||
@@ -257,6 +323,19 @@ type ChatResponse struct {
|
||||
Metrics
|
||||
}
|
||||
|
||||
// DebugInfo contains debug information for template rendering
|
||||
type DebugInfo struct {
|
||||
RenderedTemplate string `json:"rendered_template"`
|
||||
ImageCount int `json:"image_count,omitempty"`
|
||||
}
|
||||
|
||||
// DebugTemplateResponse is returned when _debug_render_only is set to true
|
||||
type DebugTemplateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DebugInfo DebugInfo `json:"_debug_info"`
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||
@@ -467,13 +546,14 @@ type ListModelResponse struct {
|
||||
|
||||
// ProcessModelResponse is a single model description in [ProcessResponse].
|
||||
type ProcessModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
SizeVRAM int64 `json:"size_vram"`
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
SizeVRAM int64 `json:"size_vram"`
|
||||
ContextLength int `json:"context_length"`
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
@@ -506,6 +586,8 @@ type GenerateResponse struct {
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
Metrics
|
||||
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
@@ -675,6 +757,113 @@ func DefaultOptions() Options {
|
||||
}
|
||||
}
|
||||
|
||||
// ThinkValue represents a value that can be a boolean or a string ("high", "medium", "low")
|
||||
type ThinkValue struct {
|
||||
// Value can be a bool or string
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// IsValid checks if the ThinkValue is valid
|
||||
func (t *ThinkValue) IsValid() bool {
|
||||
if t == nil || t.Value == nil {
|
||||
return true // nil is valid (means not set)
|
||||
}
|
||||
|
||||
switch v := t.Value.(type) {
|
||||
case bool:
|
||||
return true
|
||||
case string:
|
||||
return v == "high" || v == "medium" || v == "low"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsBool returns true if the value is a boolean
|
||||
func (t *ThinkValue) IsBool() bool {
|
||||
if t == nil || t.Value == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := t.Value.(bool)
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsString returns true if the value is a string
|
||||
func (t *ThinkValue) IsString() bool {
|
||||
if t == nil || t.Value == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := t.Value.(string)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Bool returns the value as a bool (true if enabled in any way)
|
||||
func (t *ThinkValue) Bool() bool {
|
||||
if t == nil || t.Value == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch v := t.Value.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
// Any string value ("high", "medium", "low") means thinking is enabled
|
||||
return v == "high" || v == "medium" || v == "low"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the value as a string
|
||||
func (t *ThinkValue) String() string {
|
||||
if t == nil || t.Value == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch v := t.Value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case bool:
|
||||
if v {
|
||||
return "medium" // Default level when just true
|
||||
}
|
||||
return ""
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler
|
||||
func (t *ThinkValue) UnmarshalJSON(data []byte) error {
|
||||
// Try to unmarshal as bool first
|
||||
var b bool
|
||||
if err := json.Unmarshal(data, &b); err == nil {
|
||||
t.Value = b
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as string
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err == nil {
|
||||
// Validate string values
|
||||
if s != "high" && s != "medium" && s != "low" {
|
||||
return fmt.Errorf("invalid think value: %q (must be \"high\", \"medium\", \"low\", true, or false)", s)
|
||||
}
|
||||
t.Value = s
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\")")
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (t *ThinkValue) MarshalJSON() ([]byte, error) {
|
||||
if t == nil || t.Value == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(t.Value)
|
||||
}
|
||||
|
||||
type Duration struct {
|
||||
time.Duration
|
||||
}
|
||||
@@ -699,7 +888,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
if t < 0 {
|
||||
d.Duration = time.Duration(math.MaxInt64)
|
||||
} else {
|
||||
d.Duration = time.Duration(int(t) * int(time.Second))
|
||||
d.Duration = time.Duration(t * float64(time.Second))
|
||||
}
|
||||
case string:
|
||||
d.Duration, err = time.ParseDuration(t)
|
||||
|
||||
@@ -17,6 +17,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||
req string
|
||||
exp *Duration
|
||||
}{
|
||||
{
|
||||
name: "Unset",
|
||||
req: `{ }`,
|
||||
exp: nil,
|
||||
},
|
||||
{
|
||||
name: "Positive Integer",
|
||||
req: `{ "keep_alive": 42 }`,
|
||||
@@ -25,7 +30,7 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||
{
|
||||
name: "Positive Float",
|
||||
req: `{ "keep_alive": 42.5 }`,
|
||||
exp: &Duration{42 * time.Second},
|
||||
exp: &Duration{42500 * time.Millisecond},
|
||||
},
|
||||
{
|
||||
name: "Positive Integer String",
|
||||
@@ -374,24 +379,21 @@ func TestPropertyType_MarshalJSON(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
trueVal := true
|
||||
falseVal := false
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedThinking *bool
|
||||
expectedThinking *ThinkValue
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "true",
|
||||
input: `{ "think": true }`,
|
||||
expectedThinking: &trueVal,
|
||||
expectedThinking: &ThinkValue{Value: true},
|
||||
},
|
||||
{
|
||||
name: "false",
|
||||
input: `{ "think": false }`,
|
||||
expectedThinking: &falseVal,
|
||||
expectedThinking: &ThinkValue{Value: false},
|
||||
},
|
||||
{
|
||||
name: "unset",
|
||||
@@ -399,8 +401,23 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
expectedThinking: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid",
|
||||
input: `{ "think": "true" }`,
|
||||
name: "string_high",
|
||||
input: `{ "think": "high" }`,
|
||||
expectedThinking: &ThinkValue{Value: "high"},
|
||||
},
|
||||
{
|
||||
name: "string_medium",
|
||||
input: `{ "think": "medium" }`,
|
||||
expectedThinking: &ThinkValue{Value: "medium"},
|
||||
},
|
||||
{
|
||||
name: "string_low",
|
||||
input: `{ "think": "low" }`,
|
||||
expectedThinking: &ThinkValue{Value: "low"},
|
||||
},
|
||||
{
|
||||
name: "invalid_string",
|
||||
input: `{ "think": "invalid" }`,
|
||||
expectedThinking: nil,
|
||||
expectedError: true,
|
||||
},
|
||||
@@ -414,8 +431,60 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expectedThinking, req.Think)
|
||||
if test.expectedThinking == nil {
|
||||
assert.Nil(t, req.Think)
|
||||
} else {
|
||||
require.NotNil(t, req.Think)
|
||||
assert.Equal(t, test.expectedThinking.Value, req.Think.Value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFunctionParameters_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
params ToolFunctionParameters
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple object with string property",
|
||||
params: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"name"},
|
||||
Properties: map[string]ToolProperty{
|
||||
"name": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "The name of the person",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
|
||||
},
|
||||
{
|
||||
name: "marshal failure returns empty string",
|
||||
params: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Defs: func() any {
|
||||
// Create a cycle that will cause json.Marshal to fail
|
||||
type selfRef struct {
|
||||
Self *selfRef
|
||||
}
|
||||
s := &selfRef{}
|
||||
s.Self = s
|
||||
return s
|
||||
}(),
|
||||
Properties: map[string]ToolProperty{},
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := test.params.String()
|
||||
assert.Equal(t, test.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
142
api/types_typescript_test.go
Normal file
142
api/types_typescript_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestToolParameterToTypeScriptType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
param ToolProperty
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single string type",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"string"},
|
||||
},
|
||||
expected: "string",
|
||||
},
|
||||
{
|
||||
name: "single number type",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"number"},
|
||||
},
|
||||
expected: "number",
|
||||
},
|
||||
{
|
||||
name: "integer maps to number",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"integer"},
|
||||
},
|
||||
expected: "number",
|
||||
},
|
||||
{
|
||||
name: "boolean type",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"boolean"},
|
||||
},
|
||||
expected: "boolean",
|
||||
},
|
||||
{
|
||||
name: "array type",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"array"},
|
||||
},
|
||||
expected: "any[]",
|
||||
},
|
||||
{
|
||||
name: "object type",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
},
|
||||
expected: "Record<string, any>",
|
||||
},
|
||||
{
|
||||
name: "null type",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"null"},
|
||||
},
|
||||
expected: "null",
|
||||
},
|
||||
{
|
||||
name: "multiple types as union",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"string", "number"},
|
||||
},
|
||||
expected: "string | number",
|
||||
},
|
||||
{
|
||||
name: "string or null union",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"string", "null"},
|
||||
},
|
||||
expected: "string | null",
|
||||
},
|
||||
{
|
||||
name: "anyOf with single types",
|
||||
param: ToolProperty{
|
||||
AnyOf: []ToolProperty{
|
||||
{Type: PropertyType{"string"}},
|
||||
{Type: PropertyType{"number"}},
|
||||
},
|
||||
},
|
||||
expected: "string | number",
|
||||
},
|
||||
{
|
||||
name: "anyOf with multiple types in each branch",
|
||||
param: ToolProperty{
|
||||
AnyOf: []ToolProperty{
|
||||
{Type: PropertyType{"string", "null"}},
|
||||
{Type: PropertyType{"number"}},
|
||||
},
|
||||
},
|
||||
expected: "string | null | number",
|
||||
},
|
||||
{
|
||||
name: "nested anyOf",
|
||||
param: ToolProperty{
|
||||
AnyOf: []ToolProperty{
|
||||
{Type: PropertyType{"boolean"}},
|
||||
{
|
||||
AnyOf: []ToolProperty{
|
||||
{Type: PropertyType{"string"}},
|
||||
{Type: PropertyType{"number"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "boolean | string | number",
|
||||
},
|
||||
{
|
||||
name: "empty type returns any",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{},
|
||||
},
|
||||
expected: "any",
|
||||
},
|
||||
{
|
||||
name: "unknown type maps to any",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"unknown_type"},
|
||||
},
|
||||
expected: "any",
|
||||
},
|
||||
{
|
||||
name: "multiple types including array",
|
||||
param: ToolProperty{
|
||||
Type: PropertyType{"string", "array", "null"},
|
||||
},
|
||||
expected: "string | any[] | null",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.param.ToTypeScriptType()
|
||||
if result != tt.expected {
|
||||
t.Errorf("ToTypeScriptType() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,178 +0,0 @@
|
||||
package benchmark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Command line flags
|
||||
var modelFlag string
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
||||
flag.Lookup("m").DefValue = "model"
|
||||
}
|
||||
|
||||
// modelName returns the model name from flags, failing the test if not set
|
||||
func modelName(b *testing.B) string {
|
||||
if modelFlag == "" {
|
||||
b.Fatal("Error: -m flag is required for benchmark tests")
|
||||
}
|
||||
return modelFlag
|
||||
}
|
||||
|
||||
type TestCase struct {
|
||||
name string
|
||||
prompt string
|
||||
maxTokens int
|
||||
}
|
||||
|
||||
// runGenerateBenchmark contains the common generate and metrics logic
|
||||
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
||||
start := time.Now()
|
||||
var ttft time.Duration
|
||||
var metrics api.Metrics
|
||||
|
||||
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||
if ttft == 0 && resp.Response != "" {
|
||||
ttft = time.Since(start)
|
||||
}
|
||||
if resp.Done {
|
||||
metrics = resp.Metrics
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Report custom metrics as part of the benchmark results
|
||||
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
||||
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
||||
|
||||
// Token throughput metrics
|
||||
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
||||
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
||||
b.ReportMetric(promptThroughput, "prompt_tok/s")
|
||||
b.ReportMetric(genThroughput, "gen_tok/s")
|
||||
|
||||
// Token counts
|
||||
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
||||
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
||||
func BenchmarkColdStart(b *testing.B) {
|
||||
client := setup(b)
|
||||
tests := []TestCase{
|
||||
{"short_prompt", "Write a long story", 100},
|
||||
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||
}
|
||||
m := modelName(b)
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
||||
ctx := b.Context()
|
||||
|
||||
// Set number of tokens as our throughput metric
|
||||
b.SetBytes(int64(tt.maxTokens))
|
||||
|
||||
for b.Loop() {
|
||||
b.StopTimer()
|
||||
// Ensure model is unloaded before each iteration
|
||||
unload(client, m, b)
|
||||
b.StartTimer()
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: m,
|
||||
Prompt: tt.prompt,
|
||||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||
}
|
||||
|
||||
runGenerateBenchmark(b, ctx, client, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
||||
func BenchmarkWarmStart(b *testing.B) {
|
||||
client := setup(b)
|
||||
tests := []TestCase{
|
||||
{"short_prompt", "Write a long story", 100},
|
||||
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||
}
|
||||
m := modelName(b)
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
||||
ctx := b.Context()
|
||||
|
||||
// Pre-warm the model
|
||||
warmup(client, m, tt.prompt, b)
|
||||
|
||||
// Set number of tokens as our throughput metric
|
||||
b.SetBytes(int64(tt.maxTokens))
|
||||
|
||||
for b.Loop() {
|
||||
req := &api.GenerateRequest{
|
||||
Model: m,
|
||||
Prompt: tt.prompt,
|
||||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||
}
|
||||
|
||||
runGenerateBenchmark(b, ctx, client, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setup verifies server and model availability
|
||||
func setup(b *testing.B) *api.Client {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := client.Show(b.Context(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||
b.Fatalf("Model unavailable: %v", err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// warmup ensures the model is loaded and warmed up
|
||||
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
||||
for range 3 {
|
||||
err := client.Generate(
|
||||
context.Background(),
|
||||
&api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
|
||||
},
|
||||
func(api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
b.Logf("Error during model warm-up: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// unload forces model unloading using KeepAlive: 0 parameter
|
||||
func unload(client *api.Client, model string, b *testing.B) {
|
||||
req := &api.GenerateRequest{
|
||||
Model: model,
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
}
|
||||
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
||||
b.Logf("Unload error: %v", err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
112
cmd/cmd.go
112
cmd/cmd.go
@@ -322,11 +322,23 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
thinkFlag := cmd.Flags().Lookup("think")
|
||||
if thinkFlag.Changed {
|
||||
think, err := cmd.Flags().GetBool("think")
|
||||
thinkStr, err := cmd.Flags().GetString("think")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Think = &think
|
||||
|
||||
// Handle different values for --think
|
||||
switch thinkStr {
|
||||
case "", "true":
|
||||
// --think or --think=true
|
||||
opts.Think = &api.ThinkValue{Value: true}
|
||||
case "false":
|
||||
opts.Think = &api.ThinkValue{Value: false}
|
||||
case "high", "medium", "low":
|
||||
opts.Think = &api.ThinkValue{Value: thinkStr}
|
||||
default:
|
||||
return fmt.Errorf("invalid value for --think: %q (must be true, false, high, medium, or low)", thinkStr)
|
||||
}
|
||||
} else {
|
||||
opts.Think = nil
|
||||
}
|
||||
@@ -583,12 +595,13 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error {
|
||||
} else {
|
||||
until = format.HumanTime(m.ExpiresAt, "Never")
|
||||
}
|
||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, until})
|
||||
ctxStr := strconv.Itoa(m.ContextLength)
|
||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, ctxStr, until})
|
||||
}
|
||||
}
|
||||
|
||||
table := tablewriter.NewWriter(os.Stdout)
|
||||
table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "UNTIL"})
|
||||
table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "CONTEXT", "UNTIL"})
|
||||
table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
|
||||
table.SetAlignment(tablewriter.ALIGN_LEFT)
|
||||
table.SetHeaderLine(false)
|
||||
@@ -976,7 +989,7 @@ type runOptions struct {
|
||||
Options map[string]any
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
Think *bool
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
}
|
||||
|
||||
@@ -1016,10 +1029,11 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case ' ':
|
||||
case ' ', '\t':
|
||||
state.wordBuffer = ""
|
||||
case '\n':
|
||||
case '\n', '\r':
|
||||
state.lineLength = 0
|
||||
state.wordBuffer = ""
|
||||
default:
|
||||
state.wordBuffer += string(ch)
|
||||
}
|
||||
@@ -1077,12 +1091,14 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var thinkingContent strings.Builder
|
||||
var latest api.ChatResponse
|
||||
var fullResponse strings.Builder
|
||||
var role string
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
|
||||
role := "assistant"
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
if response.Message.Content != "" || !opts.HideThinking {
|
||||
p.StopAndClear()
|
||||
@@ -1095,14 +1111,21 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(thinkingOutputOpeningText(false))
|
||||
thinkTagOpened = true
|
||||
thinkTagClosed = false
|
||||
}
|
||||
thinkingContent.WriteString(response.Message.Thinking)
|
||||
displayResponse(response.Message.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
content := response.Message.Content
|
||||
if thinkTagOpened && !thinkTagClosed && content != "" {
|
||||
if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.Message.ToolCalls) > 0) {
|
||||
if !strings.HasSuffix(thinkingContent.String(), "\n") {
|
||||
fmt.Println()
|
||||
}
|
||||
fmt.Print(thinkingOutputClosingText(false))
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = true
|
||||
state = &displayResponseState{}
|
||||
}
|
||||
// purposefully not putting thinking blocks in the response, which would
|
||||
// only be needed if we later added tool calling to the cli (they get
|
||||
@@ -1110,6 +1133,13 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
// about to finish some tool calls)
|
||||
fullResponse.WriteString(content)
|
||||
|
||||
if response.Message.ToolCalls != nil {
|
||||
toolCalls := response.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
fmt.Print(renderToolCalls(toolCalls, false))
|
||||
}
|
||||
}
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
|
||||
return nil
|
||||
@@ -1135,6 +1165,14 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// this error should ideally be wrapped properly by the client
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
fmt.Println()
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1186,6 +1224,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var thinkingContent strings.Builder
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
|
||||
@@ -1203,17 +1242,31 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(thinkingOutputOpeningText(plainText))
|
||||
thinkTagOpened = true
|
||||
thinkTagClosed = false
|
||||
}
|
||||
thinkingContent.WriteString(response.Thinking)
|
||||
displayResponse(response.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
if thinkTagOpened && !thinkTagClosed && content != "" {
|
||||
if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.ToolCalls) > 0) {
|
||||
if !strings.HasSuffix(thinkingContent.String(), "\n") {
|
||||
fmt.Println()
|
||||
}
|
||||
fmt.Print(thinkingOutputClosingText(plainText))
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = true
|
||||
state = &displayResponseState{}
|
||||
}
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
|
||||
if response.ToolCalls != nil {
|
||||
toolCalls := response.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
fmt.Print(renderToolCalls(toolCalls, plainText))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1416,13 +1469,13 @@ func NewCLI() *cobra.Command {
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model from a Modelfile",
|
||||
Short: "Create a model",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: CreateHandler,
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"")
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
@@ -1453,7 +1506,8 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||
runCmd.Flags().Bool("think", false, "Whether to use thinking mode for supported models")
|
||||
runCmd.Flags().String("think", "", "Enable thinking mode: true/false or high/medium/low for supported models")
|
||||
runCmd.Flags().Lookup("think").NoOptDefVal = "true"
|
||||
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
@@ -1558,6 +1612,7 @@ func NewCLI() *cobra.Command {
|
||||
appendEnvDocs(cmd, []envconfig.EnvVar{
|
||||
envVars["OLLAMA_DEBUG"],
|
||||
envVars["OLLAMA_HOST"],
|
||||
envVars["OLLAMA_CONTEXT_LENGTH"],
|
||||
envVars["OLLAMA_KEEP_ALIVE"],
|
||||
envVars["OLLAMA_MAX_LOADED_MODELS"],
|
||||
envVars["OLLAMA_MAX_QUEUE"],
|
||||
@@ -1603,7 +1658,7 @@ func NewCLI() *cobra.Command {
|
||||
// to false).
|
||||
//
|
||||
// If capabilities are not provided, we fetch them from the server.
|
||||
func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*bool, error) {
|
||||
func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*api.ThinkValue, error) {
|
||||
if explicitlySetByUser {
|
||||
return runOpts.Think, nil
|
||||
}
|
||||
@@ -1630,9 +1685,34 @@ func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicit
|
||||
}
|
||||
|
||||
if thinkingSupported {
|
||||
thinking := true
|
||||
return &thinking, nil
|
||||
return &api.ThinkValue{Value: true}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
||||
out := ""
|
||||
formatExplanation := ""
|
||||
formatValues := ""
|
||||
if !plainText {
|
||||
formatExplanation = readline.ColorGrey + readline.ColorBold
|
||||
formatValues = readline.ColorDefault
|
||||
out += formatExplanation
|
||||
}
|
||||
for i, toolCall := range toolCalls {
|
||||
argsAsJSON, err := json.Marshal(toolCall.Function.Arguments)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if i > 0 {
|
||||
out += "\n"
|
||||
}
|
||||
// all tool calls are unexpected since we don't currently support registering any in the CLI
|
||||
out += fmt.Sprintf(" Model called a non-existent function '%s()' with arguments: %s", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
|
||||
}
|
||||
if !plainText {
|
||||
out += readline.ColorDefault
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -272,16 +272,29 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "think":
|
||||
think := true
|
||||
opts.Think = &think
|
||||
thinkValue := api.ThinkValue{Value: true}
|
||||
var maybeLevel string
|
||||
if len(args) > 2 {
|
||||
maybeLevel = args[2]
|
||||
}
|
||||
if maybeLevel != "" {
|
||||
// TODO(drifkin): validate the level, could be model dependent
|
||||
// though... It will also be validated on the server once a call is
|
||||
// made.
|
||||
thinkValue.Value = maybeLevel
|
||||
}
|
||||
opts.Think = &thinkValue
|
||||
thinkExplicitlySet = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
fmt.Println("Set 'think' mode.")
|
||||
if maybeLevel != "" {
|
||||
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
|
||||
} else {
|
||||
fmt.Println("Set 'think' mode.")
|
||||
}
|
||||
case "nothink":
|
||||
think := false
|
||||
opts.Think = &think
|
||||
opts.Think = &api.ThinkValue{Value: false}
|
||||
thinkExplicitlySet = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||
@@ -385,18 +398,21 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
case "modelfile":
|
||||
fmt.Println(resp.Modelfile)
|
||||
case "parameters":
|
||||
fmt.Println("Model defined parameters:")
|
||||
if resp.Parameters == "" {
|
||||
fmt.Println("No parameters were specified for this model.")
|
||||
fmt.Println(" No additional parameters were specified for this model.")
|
||||
} else {
|
||||
if len(opts.Options) > 0 {
|
||||
fmt.Println("User defined parameters:")
|
||||
for k, v := range opts.Options {
|
||||
fmt.Printf("%-*s %v\n", 30, k, v)
|
||||
}
|
||||
fmt.Println()
|
||||
for _, l := range strings.Split(resp.Parameters, "\n") {
|
||||
fmt.Printf(" %s\n", l)
|
||||
}
|
||||
fmt.Println("Model defined parameters:")
|
||||
fmt.Println(resp.Parameters)
|
||||
}
|
||||
fmt.Println()
|
||||
if len(opts.Options) > 0 {
|
||||
fmt.Println("User defined parameters:")
|
||||
for k, v := range opts.Options {
|
||||
fmt.Printf(" %-*s %v\n", 30, k, v)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
case "system":
|
||||
switch {
|
||||
@@ -475,7 +491,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
assistant, err := chat(cmd, opts)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "does not support thinking") {
|
||||
if strings.Contains(err.Error(), "does not support thinking") ||
|
||||
strings.Contains(err.Error(), "invalid think value") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
sb.Reset()
|
||||
continue
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"errors"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"regexp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -19,11 +19,12 @@ func startApp(ctx context.Context, client *api.Client) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !strings.Contains(link, "Ollama.app") {
|
||||
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
|
||||
m := r.FindStringSubmatch(link)
|
||||
if len(m) != 1 {
|
||||
return errors.New("could not find ollama app")
|
||||
}
|
||||
path := strings.Split(link, "Ollama.app")
|
||||
if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil {
|
||||
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
return waitForServer(ctx, client)
|
||||
|
||||
@@ -45,14 +45,11 @@ func startApp(ctx context.Context, client *api.Client) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
// log.Printf("XXX attempting to start app %s", appExe)
|
||||
|
||||
cmd_path := "c:\\Windows\\system32\\cmd.exe"
|
||||
cmd := exec.Command(cmd_path, "/c", appExe)
|
||||
// TODO - these hide flags aren't working - still pops up a command window for some reason
|
||||
cmd := exec.Command(cmd_path, "/c", appExe, "--hide", "--fast-startup")
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true}
|
||||
|
||||
// TODO this didn't help either...
|
||||
cmd.Stdin = strings.NewReader("")
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
@@ -74,7 +71,16 @@ func isProcRunning(procName string) []uint32 {
|
||||
slog.Debug("failed to check for running installers", "error", err)
|
||||
return nil
|
||||
}
|
||||
pids = pids[:ret]
|
||||
if ret > uint32(len(pids)) {
|
||||
pids = make([]uint32, ret+10)
|
||||
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
|
||||
slog.Debug("failed to check for running installers", "error", err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if ret < uint32(len(pids)) {
|
||||
pids = pids[:ret]
|
||||
}
|
||||
var matches []uint32
|
||||
for _, pid := range pids {
|
||||
if pid == 0 {
|
||||
|
||||
@@ -190,6 +190,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &gemma2Model{}
|
||||
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
|
||||
conv = &gemma3Model{Architecture: p.Architectures[0]}
|
||||
case "Gemma3nForConditionalGeneration":
|
||||
conv = &gemma3nModel{}
|
||||
case "Phi3ForCausalLM":
|
||||
conv = &phi3Model{}
|
||||
case "Qwen2ForCausalLM":
|
||||
@@ -200,6 +202,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &bertModel{}
|
||||
case "CohereForCausalLM":
|
||||
conv = &commandrModel{}
|
||||
case "GptOssForCausalLM":
|
||||
conv = &gptossModel{}
|
||||
default:
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
165
convert/convert_gemma3n.go
Normal file
165
convert/convert_gemma3n.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
"gonum.org/v1/gonum/stat/distuv"
|
||||
)
|
||||
|
||||
type gemma3nModel struct {
|
||||
ModelParameters
|
||||
|
||||
TextModel struct {
|
||||
ActivationSparsityPattern []float32 `json:"activation_sparsity_pattern"`
|
||||
AltupActiveIdx uint32 `json:"altup_active_idx"`
|
||||
AltupCoefClip float32 `json:"altup_coef_clip"`
|
||||
AltupCorrectScale bool `json:"altup_correct_scale"`
|
||||
AltupLRMultiplier float32 `json:"altup_lr_multiplier"`
|
||||
AltupNumInputs uint32 `json:"altup_num_inputs"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenSizePerLayerInput uint32 `json:"hidden_size_per_layer_input"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct{} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3n"
|
||||
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
|
||||
norm := distuv.Normal{Mu: 0, Sigma: 1}
|
||||
for _, v := range m.TextModel.ActivationSparsityPattern {
|
||||
if !yield(float32(norm.Quantile(float64(v)))) {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
kv["gemma3n.altup.active_idx"] = m.TextModel.AltupActiveIdx
|
||||
kv["gemma3n.altup.correct_scale"] = m.TextModel.AltupCorrectScale
|
||||
kv["gemma3n.altup.lr_multiplier"] = m.TextModel.AltupLRMultiplier
|
||||
kv["gemma3n.altup.num_inputs"] = m.TextModel.AltupNumInputs
|
||||
kv["gemma3n.attention.head_count_kv"] = m.TextModel.NumKeyValueHeads
|
||||
kv["gemma3n.attention.head_count"] = m.TextModel.NumAttentionHeads
|
||||
kv["gemma3n.attention.layer_norm_rms_epsilon"] = m.TextModel.RMSNormEPS
|
||||
kv["gemma3n.attention.sliding_window"] = m.TextModel.SlidingWindow
|
||||
kv["gemma3n.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||
for _, t := range m.TextModel.LayerTypes {
|
||||
if !yield(t == "sliding_attention") {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
kv["gemma3n.attention.shared_kv_layers"] = m.TextModel.NumKVSharedLayers
|
||||
kv["gemma3n.block_count"] = m.TextModel.NumHiddenLayers
|
||||
kv["gemma3n.context_length"] = m.TextModel.MaxPositionEmbeddings
|
||||
kv["gemma3n.embedding_length_per_layer_input"] = m.TextModel.HiddenSizePerLayerInput
|
||||
kv["gemma3n.embedding_length"] = m.TextModel.HiddenSize
|
||||
kv["gemma3n.feed_forward_length"] = m.TextModel.IntermediateSize
|
||||
kv["gemma3n.head_dim"] = m.TextModel.HeadDim
|
||||
kv["gemma3n.rope.freq_base_local"] = m.TextModel.RopeLocalBaseFreq
|
||||
kv["gemma3n.rope.freq_base"] = m.TextModel.RopeTheta
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
out, ts := mergeTensors(ts,
|
||||
merge{"altup_proj.*.weight", "altup_proj.weight"},
|
||||
merge{"altup_unembd_proj.*.weight", "altup_unembd_proj.weight"},
|
||||
)
|
||||
|
||||
for _, t := range ts {
|
||||
switch {
|
||||
case strings.Contains(t.Name(), "audio_tower"),
|
||||
strings.Contains(t.Name(), "embed_audio"),
|
||||
strings.Contains(t.Name(), "vision_tower"),
|
||||
strings.Contains(t.Name(), "embed_vision"):
|
||||
// TODO: handle audio and vision towers
|
||||
continue
|
||||
case strings.Contains(t.Name(), "altup_predict_coef"),
|
||||
strings.Contains(t.Name(), "altup_correct_coef"):
|
||||
if m.TextModel.AltupCoefClip > 0 {
|
||||
t.SetRepacker(func(name string, data []float32, shape []uint64) (_ []float32, err error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
|
||||
t, err = tensor.Clamp(t, -m.TextModel.AltupCoefClip, m.TextModel.AltupCoefClip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) Replacements() []string {
|
||||
return []string{
|
||||
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
|
||||
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm", "model.language_model.altup_projections", "altup_proj",
|
||||
"model.language_model.altup_unembed_projections", "altup_unembd_proj",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"model.language_model.layers", "blk",
|
||||
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"pre_feedforward_layernorm", "ffn_norm",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
"per_layer_input_gate", "inp_gate",
|
||||
"per_layer_projection", "proj",
|
||||
"post_per_layer_input_norm", "post_norm",
|
||||
"altup.", "altup_",
|
||||
"modality_router", "router",
|
||||
"prediction_coefs", "predict_coef",
|
||||
"correction_coefs", "correct_coef",
|
||||
"correct_output_scale", "correct_scale.weight",
|
||||
"laurel.", "laurel_",
|
||||
"linear_left", "l",
|
||||
"linear_right", "r",
|
||||
"post_laurel_norm", "post_norm",
|
||||
}
|
||||
}
|
||||
223
convert/convert_gptoss.go
Normal file
223
convert/convert_gptoss.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
)
|
||||
|
||||
type gptossModel struct {
|
||||
ModelParameters
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
AttentionHeads uint32 `json:"num_attention_heads"`
|
||||
KeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
Experts uint32 `json:"num_experts"`
|
||||
LocalExperts uint32 `json:"num_local_experts"`
|
||||
ExpertsPerToken uint32 `json:"experts_per_token"`
|
||||
RMSNormEpsilon float32 `json:"rms_norm_eps"`
|
||||
InitialContextLength uint32 `json:"initial_context_length"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeScalingFactor float32 `json:"rope_scaling_factor"`
|
||||
RopeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
} `json:"rope_scaling"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*gptossModel)(nil)
|
||||
|
||||
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gptoss"
|
||||
kv["general.file_type"] = uint32(4)
|
||||
kv["gptoss.context_length"] = cmp.Or(m.MaxPositionEmbeddings, uint32(m.RopeScalingFactor*float32(m.InitialContextLength)))
|
||||
kv["gptoss.block_count"] = m.HiddenLayers
|
||||
kv["gptoss.embedding_length"] = m.HiddenSize
|
||||
kv["gptoss.feed_forward_length"] = m.IntermediateSize
|
||||
kv["gptoss.expert_count"] = cmp.Or(m.Experts, m.LocalExperts)
|
||||
kv["gptoss.expert_used_count"] = m.ExpertsPerToken
|
||||
kv["gptoss.attention.head_count"] = m.AttentionHeads
|
||||
kv["gptoss.attention.head_count_kv"] = m.KeyValueHeads
|
||||
kv["gptoss.attention.key_length"] = m.HeadDim
|
||||
kv["gptoss.attention.value_length"] = m.HeadDim
|
||||
kv["gptoss.attention.layer_norm_rms_epsilon"] = cmp.Or(m.RMSNormEpsilon, 1e-5)
|
||||
kv["gptoss.attention.sliding_window"] = m.SlidingWindow
|
||||
kv["gptoss.rope.freq_base"] = m.RopeTheta
|
||||
kv["gptoss.rope.scaling.factor"] = cmp.Or(m.RopeScalingFactor, m.RopeScaling.Factor)
|
||||
kv["gptoss.rope.scaling.original_context_length"] = m.InitialContextLength
|
||||
kv["tokenizer.ggml.bos_token_id"] = uint32(199998) // <|startoftext|>
|
||||
kv["tokenizer.ggml.add_bos_token"] = false
|
||||
kv["tokenizer.ggml.eos_token_id"] = uint32(199999) // <|endoftext|>
|
||||
kv["tokenizer.ggml.eos_token_ids"] = []int32{
|
||||
199999, /* <|endoftext|> */
|
||||
200002, /* <|return|> */
|
||||
200012, /* <|call|> */
|
||||
}
|
||||
kv["tokenizer.ggml.add_eos_token"] = false
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
mxfp4s := make(map[string]*mxfp4)
|
||||
for _, t := range ts {
|
||||
if strings.HasSuffix(t.Name(), ".blocks") || strings.HasSuffix(t.Name(), ".scales") {
|
||||
dot := strings.LastIndex(t.Name(), ".")
|
||||
name, suffix := t.Name()[:dot], t.Name()[dot+1:]
|
||||
if _, ok := mxfp4s[name]; !ok {
|
||||
mxfp4s[name] = &mxfp4{}
|
||||
}
|
||||
|
||||
switch suffix {
|
||||
case "blocks":
|
||||
mxfp4s[name].blocks = t
|
||||
case "scales":
|
||||
mxfp4s[name].scales = t
|
||||
}
|
||||
} else {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for name, mxfp4 := range mxfp4s {
|
||||
dims := mxfp4.blocks.Shape()
|
||||
|
||||
if !strings.HasSuffix(name, ".weight") {
|
||||
name += ".weight"
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *gptossModel) Replacements() []string {
|
||||
var replacements []string
|
||||
if m.MaxPositionEmbeddings > 0 {
|
||||
// hf flavored model
|
||||
replacements = []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_out",
|
||||
"self_attn.sinks", "attn_sinks",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"mlp.router", "ffn_gate_inp",
|
||||
"mlp.experts.gate_up_proj_", "ffn_gate_up_exps.",
|
||||
"mlp.experts.down_proj_", "ffn_down_exps.",
|
||||
"model.norm", "output_norm",
|
||||
}
|
||||
} else {
|
||||
replacements = []string{
|
||||
// noop replacements so other replacements will not be applied
|
||||
".blocks", ".blocks",
|
||||
".scales", ".scales",
|
||||
// real replacements
|
||||
"block", "blk",
|
||||
"attn.norm", "attn_norm",
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.sinks", "attn_sinks",
|
||||
"attn.out", "attn_out",
|
||||
"mlp.norm", "ffn_norm",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
"mlp.mlp1_", "ffn_gate_up_exps.",
|
||||
"mlp.mlp2_", "ffn_down_exps.",
|
||||
"embedding", "token_embd",
|
||||
"norm", "output_norm",
|
||||
"unembedding", "output",
|
||||
"scale", "weight",
|
||||
}
|
||||
}
|
||||
return replacements
|
||||
}
|
||||
|
||||
type mxfp4 struct {
|
||||
blocks, scales Tensor
|
||||
}
|
||||
|
||||
func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
||||
var b bytes.Buffer
|
||||
if _, err := m.blocks.WriteTo(&b); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
blocksDims := make([]int, len(m.blocks.Shape()))
|
||||
for i, d := range m.blocks.Shape() {
|
||||
blocksDims[i] = int(d)
|
||||
}
|
||||
|
||||
bts := b.Bytes()
|
||||
var tmp [16]byte
|
||||
for i := 0; i < b.Len(); i += 16 {
|
||||
for j := range 8 {
|
||||
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
|
||||
a, b := bts[i+j], bts[i+j+8]
|
||||
tmp[2*j+0] = (a & 0x0F) | (b << 4)
|
||||
tmp[2*j+1] = (a >> 4) | (b & 0xF0)
|
||||
}
|
||||
|
||||
copy(bts[i:i+16], tmp[:])
|
||||
}
|
||||
|
||||
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts))
|
||||
|
||||
var s bytes.Buffer
|
||||
if _, err := m.scales.WriteTo(&s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
scalesDims := slices.Repeat([]int{1}, len(m.blocks.Shape()))
|
||||
for i, d := range m.scales.Shape() {
|
||||
scalesDims[i] = int(d)
|
||||
}
|
||||
|
||||
var scales tensor.Tensor = tensor.New(tensor.WithShape(scalesDims...), tensor.WithBacking(s.Bytes()))
|
||||
|
||||
out, err := tensor.Concat(3, scales, blocks)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
out = tensor.Materialize(out)
|
||||
|
||||
if err := out.Reshape(out.Shape().TotalSize()); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
u8s, err := native.VectorU8(out.(*tensor.Dense))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, binary.LittleEndian, u8s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int64(len(u8s)), nil
|
||||
}
|
||||
@@ -2,9 +2,6 @@ package convert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
@@ -30,65 +27,38 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||
}
|
||||
|
||||
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
oldnew := []string{
|
||||
"model.layers", "blk",
|
||||
"w1", "ffn_gate_exps",
|
||||
"w2", "ffn_down_exps",
|
||||
"w3", "ffn_up_exps",
|
||||
}
|
||||
|
||||
for i := range p.NumLocalExperts {
|
||||
oldnew = append(oldnew, fmt.Sprintf(".block_sparse_moe.experts.%d.", i), ".")
|
||||
}
|
||||
|
||||
// group experts of the same layer (model.layers.%d) and type (w[123]) into a single tensor
|
||||
namer := strings.NewReplacer(oldnew...)
|
||||
experts := make(map[string]experts)
|
||||
|
||||
// merge experts into a single tensor while removing them from ts
|
||||
ts = slices.DeleteFunc(ts, func(t Tensor) bool {
|
||||
if !strings.Contains(t.Name(), ".block_sparse_moe.experts.") {
|
||||
return false
|
||||
}
|
||||
|
||||
name := namer.Replace(t.Name())
|
||||
experts[name] = append(experts[name], t)
|
||||
return true
|
||||
})
|
||||
|
||||
var out []*ggml.Tensor
|
||||
for n, e := range experts {
|
||||
// TODO(mxyng): sanity check experts
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: n,
|
||||
Kind: e[0].Kind(),
|
||||
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
|
||||
WriterTo: e,
|
||||
merges := make([]merge, 0, p.NumHiddenLayers*6)
|
||||
for i := range p.NumHiddenLayers {
|
||||
merges = append(merges, merge{
|
||||
fmt.Sprintf("blk.%d.*.w1.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w1.bias", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.bias", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w2.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w2.bias", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.bias", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w3.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w3.bias", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.bias", i),
|
||||
})
|
||||
}
|
||||
|
||||
out, ts := mergeTensors(ts, merges...)
|
||||
return append(out, p.llamaModel.Tensors(ts)...)
|
||||
}
|
||||
|
||||
func (p *mixtralModel) Replacements() []string {
|
||||
return append(
|
||||
p.llamaModel.Replacements(),
|
||||
"model.layers", "blk",
|
||||
"block_sparse_moe.gate", "ffn_gate_inp",
|
||||
"block_sparse_moe.experts.", ".",
|
||||
)
|
||||
}
|
||||
|
||||
type experts []Tensor
|
||||
|
||||
func (e experts) WriteTo(w io.Writer) (int64, error) {
|
||||
// TODO(mxyng): experts _should_ be numerically sorted by expert but this should check
|
||||
for _, t := range e {
|
||||
// the canonical merged experts tensor stacks all experts along a new, 0 axis,
|
||||
// e.g. `tensor.Stack(0, e[0], e[1:]...)`, which requires allocating temporary buffers
|
||||
// this accomplishes the same thing by writing each expert tensor in sequence
|
||||
if _, err := t.WriteTo(w); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -65,17 +65,17 @@ func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
for _, t := range ts {
|
||||
if strings.Contains(t.Name(), "patch_embed.proj") {
|
||||
for t := range splitDim(t, 2,
|
||||
strings.NewReplacer("patch_embed.proj", "patch_embd_0"),
|
||||
strings.NewReplacer("patch_embed.proj", "patch_embd_1"),
|
||||
split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_0")},
|
||||
split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_1")},
|
||||
) {
|
||||
t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 })
|
||||
out = append(out, t)
|
||||
}
|
||||
} else if strings.Contains(t.Name(), "attn.qkv") {
|
||||
out = append(out, slices.Collect(splitDim(t, 0,
|
||||
strings.NewReplacer("attn.qkv", "attn_q"),
|
||||
strings.NewReplacer("attn.qkv", "attn_k"),
|
||||
strings.NewReplacer("attn.qkv", "attn_v"),
|
||||
split{Replacer: strings.NewReplacer("attn.qkv", "attn_q")},
|
||||
split{Replacer: strings.NewReplacer("attn.qkv", "attn_k")},
|
||||
split{Replacer: strings.NewReplacer("attn.qkv", "attn_v")},
|
||||
))...)
|
||||
} else {
|
||||
out = append(out, &ggml.Tensor{
|
||||
|
||||
@@ -11,14 +11,13 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -137,9 +136,7 @@ func TestConvertModel(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
keys := maps.Keys(expect)
|
||||
slices.Sort(keys)
|
||||
for _, k := range keys {
|
||||
for _, k := range slices.Sorted(maps.Keys(expect)) {
|
||||
if v, ok := actual[k]; !ok {
|
||||
t.Errorf("missing %s", k)
|
||||
} else if v != expect[k] {
|
||||
@@ -343,9 +340,7 @@ func TestConvertAdapter(t *testing.T) {
|
||||
|
||||
actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
|
||||
|
||||
keys := maps.Keys(c.Expected)
|
||||
slices.Sort(keys)
|
||||
for _, k := range keys {
|
||||
for _, k := range slices.Sorted(maps.Keys(c.Expected)) {
|
||||
if v, ok := actual[k]; !ok {
|
||||
t.Errorf("missing %s", k)
|
||||
} else if v != c.Expected[k] {
|
||||
|
||||
@@ -31,28 +31,31 @@ func (t tensorBase) Shape() []uint64 {
|
||||
}
|
||||
|
||||
const (
|
||||
tensorKindF32 uint32 = iota
|
||||
tensorKindF16
|
||||
tensorKindFP32 uint32 = iota
|
||||
tensorKindFP16
|
||||
tensorKindBF16 = 30
|
||||
tensorKindMXFP4 = 39
|
||||
)
|
||||
|
||||
func (t tensorBase) Kind() uint32 {
|
||||
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
||||
strings.HasSuffix(t.name, ".bias") ||
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
t.name == "v.pre_tile_position_embd.weight" ||
|
||||
t.name == "v.post_tile_position_embd.weight" {
|
||||
// these tensors are always F32
|
||||
return 0
|
||||
return tensorKindFP32
|
||||
}
|
||||
|
||||
switch len(t.shape) {
|
||||
case 0:
|
||||
panic("invalid tensor shape")
|
||||
case 1:
|
||||
return tensorKindF32
|
||||
return tensorKindFP32
|
||||
default:
|
||||
return tensorKindF16
|
||||
return tensorKindFP16
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
@@ -8,12 +9,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/x448/float16"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
type safetensorMetadata struct {
|
||||
@@ -46,8 +47,7 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := maps.Keys(headers)
|
||||
slices.Sort(keys)
|
||||
keys := slices.Sorted(maps.Keys(headers))
|
||||
|
||||
names := make(map[string]struct{}, len(keys))
|
||||
|
||||
@@ -94,6 +94,15 @@ type safetensor struct {
|
||||
*tensorBase
|
||||
}
|
||||
|
||||
func (st safetensor) Kind() uint32 {
|
||||
kind := st.tensorBase.Kind()
|
||||
if st.dtype == "BF16" && kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
|
||||
return kind
|
||||
}
|
||||
|
||||
func (st safetensor) Clone() Tensor {
|
||||
return &safetensor{
|
||||
fs: st.fs,
|
||||
@@ -116,26 +125,41 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if seeker, ok := f.(io.Seeker); ok {
|
||||
if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
if _, err := io.CopyN(io.Discard, f, st.offset); err != nil {
|
||||
return 0, err
|
||||
r, err := func() (io.Reader, error) {
|
||||
if readerAt, ok := f.(io.ReaderAt); ok {
|
||||
return io.NewSectionReader(readerAt, st.offset, st.size), nil
|
||||
} else if seeker, ok := f.(io.Seeker); ok {
|
||||
_, err := seeker.Seek(st.offset, io.SeekStart)
|
||||
return f, err
|
||||
} else {
|
||||
_, err := io.CopyN(io.Discard, f, st.offset)
|
||||
return f, err
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
br := bufio.NewReaderSize(r, min(32<<10, int(st.size)))
|
||||
// special case when input and output are same type and the
|
||||
// tensor doesn't need repacking
|
||||
if (st.repacker == nil) &&
|
||||
((st.dtype == "F32" && st.Kind() == tensorKindFP32) ||
|
||||
(st.dtype == "F16" && st.Kind() == tensorKindFP16) ||
|
||||
(st.dtype == "U8")) {
|
||||
return io.CopyN(w, br, st.size)
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
switch st.dtype {
|
||||
case "F32":
|
||||
f32s = make([]float32, st.size/4)
|
||||
if err = binary.Read(f, binary.LittleEndian, f32s); err != nil {
|
||||
if err = binary.Read(br, binary.LittleEndian, f32s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case "F16":
|
||||
u16s := make([]uint16, st.size/2)
|
||||
if err = binary.Read(f, binary.LittleEndian, u16s); err != nil {
|
||||
if err = binary.Read(br, binary.LittleEndian, u16s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -146,7 +170,7 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
|
||||
case "BF16":
|
||||
u8s := make([]uint8, st.size)
|
||||
if err = binary.Read(f, binary.LittleEndian, u8s); err != nil {
|
||||
if err = binary.Read(br, binary.LittleEndian, u8s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -163,15 +187,18 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
}
|
||||
|
||||
switch st.Kind() {
|
||||
case tensorKindF32:
|
||||
return 0, binary.Write(w, binary.LittleEndian, f32s)
|
||||
case tensorKindF16:
|
||||
case tensorKindFP32:
|
||||
return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s)
|
||||
case tensorKindFP16:
|
||||
f16s := make([]uint16, len(f32s))
|
||||
for i := range f32s {
|
||||
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
|
||||
}
|
||||
|
||||
return 0, binary.Write(w, binary.LittleEndian, f16s)
|
||||
return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s)
|
||||
case tensorKindBF16:
|
||||
u8s := bfloat16.EncodeFloat32(f32s)
|
||||
return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s)
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
|
||||
}
|
||||
|
||||
232
convert/reader_test.go
Normal file
232
convert/reader_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/x448/float16"
|
||||
)
|
||||
|
||||
func TestSafetensors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
root, err := os.OpenRoot(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
cases := []struct {
|
||||
name,
|
||||
dtype string
|
||||
offset,
|
||||
size int64
|
||||
shape []uint64
|
||||
setup func(*testing.T, *os.File)
|
||||
want []byte
|
||||
}{
|
||||
{
|
||||
name: "fp32-fp32",
|
||||
dtype: "F32",
|
||||
size: 32 * 4, // 32 floats, each 4 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fp32-fp16",
|
||||
dtype: "F32",
|
||||
size: 32 * 4, // 32 floats, each 4 bytes
|
||||
shape: []uint64{16, 2},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47,
|
||||
0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b,
|
||||
0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d,
|
||||
0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fp16-fp16",
|
||||
dtype: "F16",
|
||||
size: 32 * 2, // 32 floats, each 2 bytes
|
||||
shape: []uint64{16, 2},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
u16s := make([]uint16, 32)
|
||||
for i := range u16s {
|
||||
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47,
|
||||
0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b,
|
||||
0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d,
|
||||
0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fp16-fp32",
|
||||
dtype: "F16",
|
||||
size: 32 * 2, // 32 floats, each 2 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
u16s := make([]uint16, 32)
|
||||
for i := range u16s {
|
||||
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bf16-bf16",
|
||||
dtype: "BF16",
|
||||
size: 32 * 2, // 32 brain floats, each 2 bytes
|
||||
shape: []uint64{16, 2},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x80, 0x3f, 0x00, 0x40, 0x40, 0x40, 0x80, 0x40, 0xa0, 0x40, 0xc0, 0x40, 0xe0, 0x40,
|
||||
0x00, 0x41, 0x10, 0x41, 0x20, 0x41, 0x30, 0x41, 0x40, 0x41, 0x50, 0x41, 0x60, 0x41, 0x70, 0x41,
|
||||
0x80, 0x41, 0x88, 0x41, 0x90, 0x41, 0x98, 0x41, 0xa0, 0x41, 0xa8, 0x41, 0xb0, 0x41, 0xb8, 0x41,
|
||||
0xc0, 0x41, 0xc8, 0x41, 0xd0, 0x41, 0xd8, 0x41, 0xe0, 0x41, 0xe8, 0x41, 0xf0, 0x41, 0xf8, 0x41,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bf16-fp32",
|
||||
dtype: "BF16",
|
||||
size: 32 * 2, // 32 brain floats, each 2 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "u8-u8",
|
||||
dtype: "U8",
|
||||
size: 32, // 32 brain floats, each 1 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
u8s := make([]uint8, 32)
|
||||
for i := range u8s {
|
||||
u8s[i] = uint8(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, u8s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
path := filepath.Base(t.Name())
|
||||
st := safetensor{
|
||||
fs: root.FS(),
|
||||
path: path,
|
||||
dtype: tt.dtype,
|
||||
offset: tt.offset,
|
||||
size: tt.size,
|
||||
tensorBase: &tensorBase{
|
||||
name: tt.name,
|
||||
shape: tt.shape,
|
||||
},
|
||||
}
|
||||
|
||||
f, err := root.Create(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
tt.setup(t, f)
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := st.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, b.Bytes()); diff != "" {
|
||||
t.Errorf("safetensor.WriteTo() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,56 +1,129 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"io"
|
||||
"iter"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type split struct {
|
||||
*strings.Replacer
|
||||
dim int
|
||||
|
||||
// fn is an optional function to apply to the tensor after slicing
|
||||
fn func(tensor.Tensor) (tensor.Tensor, error)
|
||||
}
|
||||
|
||||
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
|
||||
// is split evenly based on the number of replacers provided.
|
||||
func splitDim(t Tensor, dim int, replacers ...*strings.Replacer) iter.Seq[*ggml.Tensor] {
|
||||
// is split evenly based on the number of replacers provided unless a specific count is given.
|
||||
func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
|
||||
return func(yield func(*ggml.Tensor) bool) {
|
||||
for i, replacer := range replacers {
|
||||
var offset int
|
||||
for _, split := range splits {
|
||||
t := t.Clone()
|
||||
shape := slices.Clone(t.Shape())
|
||||
shape[dim] = shape[dim] / uint64(len(replacers))
|
||||
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
|
||||
|
||||
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
|
||||
slice[dim] = tensor.S(i*int(shape[dim]), (i+1)*int(shape[dim]))
|
||||
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
|
||||
offset += int(shape[dim])
|
||||
|
||||
tt := t.Clone()
|
||||
tt.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
t, err := t.Slice(slice...)
|
||||
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
tt, err := tt.Slice(slice...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t = tensor.Materialize(t)
|
||||
tt = tensor.Materialize(tt)
|
||||
|
||||
if split.fn != nil {
|
||||
tt, err = split.fn(tt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// flatten tensor so it can be written as a vector
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
return native.VectorF32(tt.(*tensor.Dense))
|
||||
})
|
||||
|
||||
if !yield(&ggml.Tensor{
|
||||
Name: replacer.Replace(t.Name()),
|
||||
Name: split.Replace(t.Name()),
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
WriterTo: tt,
|
||||
WriterTo: t,
|
||||
}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type merge struct {
|
||||
pattern, name string
|
||||
}
|
||||
|
||||
// mergeTensors merges tensors that match a given pattern into a single tensor.
|
||||
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
|
||||
var matched []Tensor
|
||||
for i := range merges {
|
||||
matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
|
||||
matched, _ := path.Match(merges[i].pattern, t.Name())
|
||||
return matched
|
||||
})
|
||||
|
||||
if len(matched) > 0 {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: merges[i].name,
|
||||
Kind: matched[0].Kind(),
|
||||
Shape: append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
|
||||
WriterTo: mergeGroup(matched),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out, unmatched
|
||||
}
|
||||
|
||||
// slicesSplitFunc splits a slice into two slices based on a predicate function.
|
||||
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
|
||||
for _, e := range s {
|
||||
if fn(e) {
|
||||
matched = append(matched, e)
|
||||
} else {
|
||||
unmatched = append(unmatched, e)
|
||||
}
|
||||
}
|
||||
|
||||
return matched, unmatched
|
||||
}
|
||||
|
||||
type mergeGroup []Tensor
|
||||
|
||||
func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
|
||||
for _, t := range g {
|
||||
if _, err := t.WriteTo(w); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
953
convert/tensor_test.go
Normal file
953
convert/tensor_test.go
Normal file
@@ -0,0 +1,953 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
)
|
||||
|
||||
type fakeTensor struct {
|
||||
name string
|
||||
shape []uint64
|
||||
data []float32
|
||||
|
||||
repacker Repacker
|
||||
}
|
||||
|
||||
func (f fakeTensor) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
func (f fakeTensor) Shape() []uint64 {
|
||||
return f.shape
|
||||
}
|
||||
|
||||
func (f fakeTensor) Kind() uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (f *fakeTensor) SetRepacker(fn Repacker) {
|
||||
f.repacker = fn
|
||||
}
|
||||
|
||||
func (f fakeTensor) Clone() Tensor {
|
||||
return &fakeTensor{
|
||||
name: f.name,
|
||||
shape: slices.Clone(f.shape),
|
||||
data: slices.Clone(f.data),
|
||||
repacker: f.repacker,
|
||||
}
|
||||
}
|
||||
|
||||
func (f fakeTensor) WriteTo(w io.Writer) (n int64, err error) {
|
||||
data := f.data
|
||||
if f.repacker != nil {
|
||||
data, err = f.repacker(f.name, data, f.shape)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := binary.Write(w, binary.LittleEndian, data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int64(len(data) * 4), nil
|
||||
}
|
||||
|
||||
func mul(shape []uint64) int {
|
||||
n := 1
|
||||
for _, dim := range shape {
|
||||
n *= int(dim)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func TestSplitDim(t *testing.T) {
|
||||
t.Run("2d", func(t *testing.T) {
|
||||
r := fakeTensor{
|
||||
name: "a.b",
|
||||
shape: []uint64{3, 4},
|
||||
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||
}
|
||||
|
||||
t.Run("no split", func(t *testing.T) {
|
||||
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatalf("expected name 'x', got '%s'", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 4}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("even split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x")},
|
||||
split{Replacer: strings.NewReplacer("b", "y")},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{2, 3, 6, 7, 10, 11}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uneven split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 0,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{2, 4}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("three way split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 0,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 1},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{4, 5, 6, 7}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.z" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uneven three way split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 1}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{2, 6, 10}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.z" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 1}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{3, 7, 11}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("split with transpose", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x")},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) {
|
||||
return tensor.Transpose(tt, 1, 0)
|
||||
}},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{2, 6, 10, 3, 7, 11}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
t.Run("3d", func(t *testing.T) {
|
||||
r := fakeTensor{
|
||||
name: "a.b",
|
||||
shape: []uint64{3, 4, 2},
|
||||
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||
}
|
||||
|
||||
t.Run("no split", func(t *testing.T) {
|
||||
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatalf("expected name 'x', got '%s'", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 4, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("even split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x")},
|
||||
split{Replacer: strings.NewReplacer("b", "y")},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uneven split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 0,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{2, 4, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{16, 17, 18, 19, 20, 21, 22, 23}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("three way split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 0,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 1},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11, 12, 13, 14, 15}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.z" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{16, 17, 18, 19, 20, 21, 22, 23}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uneven three way split", func(t *testing.T) {
|
||||
next, stop := iter.Pull(splitDim(&r, 1,
|
||||
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
|
||||
))
|
||||
defer stop()
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "x.b" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.y" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 1, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{4, 5, 12, 13, 20, 21}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
tt, ok := next()
|
||||
if !ok {
|
||||
t.Fatal("expected at least one split")
|
||||
}
|
||||
|
||||
if tt.Name != "a.z" {
|
||||
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.Shape, []uint64{3, 1, 2}); diff != "" {
|
||||
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := tt.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, mul(tt.Shape))
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f32s, []float32{6, 7, 14, 15, 22, 23}); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMerge(t *testing.T) {
|
||||
unmatched := []Tensor{
|
||||
&fakeTensor{
|
||||
name: "a.0.b",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "a.1.b",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{20, 21, 22, 23, 24, 25, 26, 27, 28, 29},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "c.0.d",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{30, 31, 32, 33, 34, 35, 36, 37, 38, 39},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "c.1.d",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{40, 41, 42, 43, 44, 45, 46, 47, 48, 49},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "e.0.f",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{50, 51, 52, 53, 54, 55, 56, 57, 58, 59},
|
||||
},
|
||||
}
|
||||
|
||||
checkMatched := func(t *testing.T, n int, matched []*ggml.Tensor) {
|
||||
for i := range n {
|
||||
got := matched[i]
|
||||
if diff := cmp.Diff([]uint64{2, 5, 2}, got.Shape); diff != "" {
|
||||
t.Errorf("unexpected (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := got.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, 20)
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
offset := 10 + (i * 20)
|
||||
want := make([]float32, 20)
|
||||
for j := range 20 {
|
||||
want[j] = float32(offset + j)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, f32s); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("single merge", func(t *testing.T) {
|
||||
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"})
|
||||
if len(unmatched) != 3 {
|
||||
t.Error("expected 3 remaining tensors, got", len(unmatched))
|
||||
}
|
||||
|
||||
if len(matched) != 1 {
|
||||
t.Error("expected 1 merged tensor, got", len(matched))
|
||||
}
|
||||
|
||||
checkMatched(t, 1, matched)
|
||||
})
|
||||
|
||||
t.Run("multiple merges", func(t *testing.T) {
|
||||
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"}, merge{"c.*.d", "c.d"})
|
||||
if len(unmatched) != 1 {
|
||||
t.Error("expected 1 remaining tensors, got", len(unmatched))
|
||||
}
|
||||
|
||||
if len(matched) != 2 {
|
||||
t.Error("expected 2 merged tensor, got", len(matched))
|
||||
}
|
||||
|
||||
checkMatched(t, 2, matched)
|
||||
})
|
||||
|
||||
t.Run("no match", func(t *testing.T) {
|
||||
matched, unmatched := mergeTensors(unmatched, merge{"x.*.y", "x.y"})
|
||||
if len(unmatched) != 5 {
|
||||
t.Error("expected 5 remaining tensors, got", len(unmatched))
|
||||
}
|
||||
|
||||
if len(matched) != 0 {
|
||||
t.Error("expected no merged tensors, got", len(matched))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -8,11 +8,10 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -260,11 +259,8 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
|
||||
tokens[token.ID] = token
|
||||
}
|
||||
|
||||
keys := maps.Keys(tokens)
|
||||
slices.Sort(keys)
|
||||
|
||||
v := Vocabulary{Model: "gpt2"}
|
||||
for _, k := range keys {
|
||||
for _, k := range slices.Sorted(maps.Keys(tokens)) {
|
||||
token := tokens[k]
|
||||
v.Tokens = append(v.Tokens, token.Content)
|
||||
v.Scores = append(v.Scores, float32(token.ID))
|
||||
|
||||
@@ -58,7 +58,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
driverMajor, driverMinor, err := AMDDriverVersion()
|
||||
if err != nil {
|
||||
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
|
||||
slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
|
||||
slog.Warn("ollama recommends running the https://www.amd.com/en/support/download/linux-drivers.html", "error", err)
|
||||
}
|
||||
|
||||
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
|
||||
@@ -97,6 +97,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
return a < b
|
||||
})
|
||||
gpuCount := 0
|
||||
gpuOrdinalID := 0
|
||||
for _, match := range matches {
|
||||
slog.Debug("evaluating amdgpu node " + match)
|
||||
fp, err := os.Open(match)
|
||||
@@ -187,10 +188,6 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Keep track of numeric IDs based on valid GPUs
|
||||
gpuID := gpuCount
|
||||
gpuCount += 1
|
||||
|
||||
// Look up the memory for the current node
|
||||
totalMemory := uint64(0)
|
||||
usedMemory := uint64(0)
|
||||
@@ -269,7 +266,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
if uniqueID != 0 {
|
||||
ID = fmt.Sprintf("GPU-%016x", uniqueID)
|
||||
} else {
|
||||
ID = strconv.Itoa(gpuID)
|
||||
ID = strconv.Itoa(gpuOrdinalID)
|
||||
}
|
||||
|
||||
gpuInfo := RocmGPUInfo{
|
||||
@@ -280,6 +277,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
FreeMemory: (totalMemory - usedMemory),
|
||||
},
|
||||
ID: ID,
|
||||
filterID: gpuOrdinalID,
|
||||
Name: name,
|
||||
Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
@@ -287,13 +285,40 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
DriverMinor: driverMinor,
|
||||
},
|
||||
usedFilepath: usedFile,
|
||||
index: gpuID,
|
||||
index: gpuCount,
|
||||
}
|
||||
|
||||
// Keep track of numeric IDs based on valid GPUs
|
||||
gpuCount += 1
|
||||
|
||||
// If the user wants to filter to a subset of devices, filter out if we aren't a match
|
||||
if len(visibleDevices) > 0 {
|
||||
include := false
|
||||
for _, visible := range visibleDevices {
|
||||
if (uniqueID != 0 && visible == gpuInfo.ID) || visible == strconv.Itoa(gpuInfo.index) {
|
||||
include = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !include {
|
||||
reason := "filtering out device per user request"
|
||||
slog.Info(reason, "id", gpuInfo.ID, "index", gpuInfo.index, "visible_devices", visibleDevices)
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: reason,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Ordinal IDs are based on the visible GPUs
|
||||
gpuOrdinalID += 1
|
||||
|
||||
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
||||
if totalMemory < IGPUMemLimit {
|
||||
reason := "unsupported Radeon iGPU detected skipping"
|
||||
slog.Info(reason, "id", gpuID, "total", format.HumanBytes2(totalMemory))
|
||||
slog.Info(reason, "id", gpuInfo.ID, "total", format.HumanBytes2(totalMemory))
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: reason,
|
||||
@@ -315,29 +340,8 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
// continue
|
||||
//}
|
||||
|
||||
slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
|
||||
slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
|
||||
|
||||
// If the user wants to filter to a subset of devices, filter out if we aren't a match
|
||||
if len(visibleDevices) > 0 {
|
||||
include := false
|
||||
for _, visible := range visibleDevices {
|
||||
if visible == gpuInfo.ID || visible == strconv.Itoa(gpuInfo.index) {
|
||||
include = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !include {
|
||||
reason := "filtering out device per user request"
|
||||
slog.Info(reason, "id", gpuInfo.ID, "visible_devices", visibleDevices)
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: reason,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
slog.Debug("amdgpu memory", "gpu", gpuInfo.ID, "total", format.HumanBytes2(totalMemory))
|
||||
slog.Debug("amdgpu memory", "gpu", gpuInfo.ID, "available", format.HumanBytes2(totalMemory-usedMemory))
|
||||
|
||||
// Final validation is gfx compatibility - load the library if we haven't already loaded it
|
||||
// even if the user overrides, we still need to validate the library
|
||||
@@ -391,7 +395,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
|
||||
// Check for env var workarounds
|
||||
if name == "1002:687f" { // Vega RX 56
|
||||
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, [2]string{"HSA_ENABLE_SDMA", "0"})
|
||||
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, "HSA_ENABLE_SDMA=0")
|
||||
}
|
||||
|
||||
// The GPU has passed all the verification steps and is supported
|
||||
@@ -520,19 +524,26 @@ func verifyKFDDriverAccess() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "rocm" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
|
||||
if _, err := strconv.Atoi(info.ID); err == nil {
|
||||
ids = append(ids, fmt.Sprintf("%d", info.filterID))
|
||||
} else {
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// There are 3 potential env vars to use to select GPUs.
|
||||
// ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux
|
||||
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
||||
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
||||
return "ROCR_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
return "ROCR_VISIBLE_DEVICES=" + strings.Join(ids, ",")
|
||||
}
|
||||
|
||||
@@ -111,6 +111,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
UnreliableFreeMemory: true,
|
||||
|
||||
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
|
||||
filterID: i,
|
||||
DependencyPath: []string{libDir},
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
Name: name,
|
||||
@@ -200,19 +201,26 @@ func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "rocm" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
|
||||
if _, err := strconv.Atoi(info.ID); err == nil {
|
||||
ids = append(ids, fmt.Sprintf("%d", info.filterID))
|
||||
} else {
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// There are 3 potential env vars to use to select GPUs.
|
||||
// ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows
|
||||
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
||||
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
||||
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
return "HIP_VISIBLE_DEVICES=" + strings.Join(ids, ",")
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"regexp"
|
||||
@@ -15,19 +16,6 @@ import (
|
||||
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||
|
||||
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "cuda" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
}
|
||||
|
||||
func cudaVariant(gpuInfo CudaGPUInfo) string {
|
||||
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
||||
if CudaTegra != "" {
|
||||
@@ -55,10 +43,13 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
|
||||
}
|
||||
}
|
||||
}
|
||||
return "sbsa"
|
||||
}
|
||||
|
||||
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
|
||||
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
||||
// The detected driver is older than Feb 2023
|
||||
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
|
||||
return "v11"
|
||||
}
|
||||
return "v12"
|
||||
|
||||
@@ -263,6 +263,8 @@ func GetGPUInfo() GpuInfoList {
|
||||
var driverMinor int
|
||||
if cHandles.cudart != nil {
|
||||
C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo)
|
||||
driverMajor = int(cHandles.cudart.driver_major)
|
||||
driverMinor = int(cHandles.cudart.driver_minor)
|
||||
} else {
|
||||
C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo)
|
||||
driverMajor = int(cHandles.nvcuda.driver_major)
|
||||
@@ -369,6 +371,15 @@ func GetGPUInfo() GpuInfoList {
|
||||
}
|
||||
|
||||
rocmGPUs, err = AMDGetGPUInfo()
|
||||
|
||||
// The ID field is used in context of the filtered set of GPUS
|
||||
// so we have to replace any of these numeric IDs with their
|
||||
// placement in this set of GPUs
|
||||
for i := range rocmGPUs {
|
||||
if _, err := strconv.Atoi(rocmGPUs[i].ID); err == nil {
|
||||
rocmGPUs[i].ID = strconv.Itoa(i)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
bootstrapErrors = append(bootstrapErrors, err)
|
||||
}
|
||||
@@ -678,23 +689,16 @@ func getVerboseState() C.uint16_t {
|
||||
|
||||
// Given the list of GPUs this instantiation is targeted for,
|
||||
// figure out the visible devices environment variable
|
||||
//
|
||||
// If different libraries are detected, the first one is what we use
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() []string {
|
||||
if len(l) == 0 {
|
||||
return "", ""
|
||||
return nil
|
||||
}
|
||||
switch l[0].Library {
|
||||
case "cuda":
|
||||
return cudaGetVisibleDevicesEnv(l)
|
||||
case "rocm":
|
||||
return rocmGetVisibleDevicesEnv(l)
|
||||
case "oneapi":
|
||||
return oneapiGetVisibleDevicesEnv(l)
|
||||
default:
|
||||
slog.Debug("no filter required for library " + l[0].Library)
|
||||
return "", ""
|
||||
vd := []string{}
|
||||
// Only filter the AMD GPUs at this level, let all NVIDIA devices through
|
||||
if tmp := rocmGetVisibleDevicesEnv(l); tmp != "" {
|
||||
vd = append(vd, tmp)
|
||||
}
|
||||
return vd
|
||||
}
|
||||
|
||||
func GetSystemInfo() SystemInfo {
|
||||
|
||||
@@ -62,9 +62,9 @@ func GetCPUMem() (memInfo, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() []string {
|
||||
// No-op on darwin
|
||||
return "", ""
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetSystemInfo() SystemInfo {
|
||||
|
||||
@@ -69,18 +69,15 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||
}
|
||||
|
||||
int version = 0;
|
||||
cudartDriverVersion_t driverVersion;
|
||||
driverVersion.major = 0;
|
||||
driverVersion.minor = 0;
|
||||
|
||||
// Report driver version if we're in verbose mode, ignore errors
|
||||
ret = (*resp->ch.cudaDriverGetVersion)(&version);
|
||||
if (ret != CUDART_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "cudaDriverGetVersion failed: %d\n", ret);
|
||||
} else {
|
||||
driverVersion.major = version / 1000;
|
||||
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
|
||||
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
|
||||
resp->ch.driver_major = version / 1000;
|
||||
resp->ch.driver_minor = (version - (resp->ch.driver_major * 1000)) / 10;
|
||||
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", resp->ch.driver_major, resp->ch.driver_minor);
|
||||
}
|
||||
|
||||
ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
|
||||
|
||||
@@ -29,11 +29,6 @@ typedef struct cudartMemory_st {
|
||||
size_t used;
|
||||
} cudartMemory_t;
|
||||
|
||||
typedef struct cudartDriverVersion {
|
||||
int major;
|
||||
int minor;
|
||||
} cudartDriverVersion_t;
|
||||
|
||||
typedef struct cudaUUID {
|
||||
unsigned char bytes[16];
|
||||
} cudaUUID_t;
|
||||
@@ -123,6 +118,8 @@ typedef struct cudaDeviceProp {
|
||||
typedef struct cudart_handle {
|
||||
void *handle;
|
||||
uint16_t verbose;
|
||||
int driver_major;
|
||||
int driver_minor;
|
||||
cudartReturn_t (*cudaSetDevice)(int device);
|
||||
cudartReturn_t (*cudaDeviceSynchronize)(void);
|
||||
cudartReturn_t (*cudaDeviceReset)(void);
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
//go:build linux || windows
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func oneapiGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "oneapi" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("oneapiGetVisibleDevicesEnv skipping over non-sycl device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
return "ONEAPI_DEVICE_SELECTOR", "level_zero:" + strings.Join(ids, ",")
|
||||
}
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
// '../lib/ollama' on Linux and the executable's directory on macOS
|
||||
// note: distribution builds, additional GPU-specific libraries are
|
||||
// found in subdirectories of the returned path, such as
|
||||
// 'cuda_v11', 'cuda_v12', 'rocm', etc.
|
||||
// 'cuda_v12', 'rocm', etc.
|
||||
var LibOllamaPath string = func() string {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
|
||||
@@ -27,8 +27,8 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
|
||||
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
|
||||
DependencyPath []string `json:"lib_path,omitempty"`
|
||||
|
||||
// Extra environment variables specific to the GPU as list of [key,value]
|
||||
EnvWorkarounds [][2]string `json:"envs,omitempty"`
|
||||
// Extra environment variables specific to the GPU as list of [key=value]
|
||||
EnvWorkarounds []string `json:"envs,omitempty"`
|
||||
|
||||
// Set to true if we can NOT reliably discover FreeMemory. A value of true indicates
|
||||
// the FreeMemory is best effort, and may over or under report actual memory usage
|
||||
@@ -36,9 +36,10 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
|
||||
UnreliableFreeMemory bool
|
||||
|
||||
// GPU information
|
||||
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
||||
Name string `json:"name"` // user friendly name if available
|
||||
Compute string `json:"compute"` // Compute Capability or gfx
|
||||
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
||||
filterID int //nolint:unused,nolintlint // AMD Workaround: The numeric ID of the device used to filter out other devices
|
||||
Name string `json:"name"` // user friendly name if available
|
||||
Compute string `json:"compute"` // Compute Capability or gfx
|
||||
|
||||
// Driver Information - TODO no need to put this on each GPU
|
||||
DriverMajor int `json:"driver_major,omitempty"`
|
||||
@@ -171,7 +172,8 @@ func (si SystemInfo) GetOptimalThreadCount() int {
|
||||
// For each GPU, check if it does NOT support flash attention
|
||||
func (l GpuInfoList) FlashAttentionSupported() bool {
|
||||
for _, gpu := range l {
|
||||
supportsFA := gpu.Library == "metal" ||
|
||||
supportsFA := gpu.Library == "cpu" ||
|
||||
gpu.Library == "metal" ||
|
||||
(gpu.Library == "cuda" && gpu.DriverMajor >= 7) ||
|
||||
gpu.Library == "rocm"
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
* [Quickstart](../README.md#quickstart)
|
||||
* [Examples](./examples.md)
|
||||
* [Importing models](./import.md)
|
||||
* [MacOS Documentation](./macos.md)
|
||||
* [Linux Documentation](./linux.md)
|
||||
* [Windows Documentation](./windows.md)
|
||||
* [Docker Documentation](./docker.md)
|
||||
|
||||
246
docs/api.md
246
docs/api.md
@@ -500,21 +500,30 @@ The `message` object has the following fields:
|
||||
- `thinking`: (for thinking models) the model's thinking process
|
||||
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
||||
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
|
||||
- `tool_name` (optional): add the name of the tool that was executed to inform the model of the result
|
||||
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
### Tool calling
|
||||
|
||||
Tool calling is supported by providing a list of tools in the `tools` parameter. The model will generate a response that includes a list of tool calls. See the [Chat request (Streaming with tools)](#chat-request-streaming-with-tools) example below.
|
||||
|
||||
Models can also explain the result of the tool call in the response. See the [Chat request (With history, with tools)](#chat-request-with-history-with-tools) example below.
|
||||
|
||||
[See models with tool calling capabilities](https://ollama.com/search?c=tool).
|
||||
|
||||
### Structured outputs
|
||||
|
||||
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below.
|
||||
|
||||
### Examples
|
||||
|
||||
#### Chat Request (Streaming)
|
||||
#### Chat request (Streaming)
|
||||
|
||||
##### Request
|
||||
|
||||
@@ -569,6 +578,88 @@ Final response:
|
||||
}
|
||||
```
|
||||
|
||||
#### Chat request (Streaming with tools)
|
||||
|
||||
##### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3.2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the weather in tokyo?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to get the weather for"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response
|
||||
|
||||
A stream of JSON objects is returned:
|
||||
```json
|
||||
{
|
||||
"model": "llama3.2",
|
||||
"created_at": "2025-07-07T20:22:19.184789Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"city": "Tokyo"
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
"done": false
|
||||
}
|
||||
```
|
||||
|
||||
Final response:
|
||||
|
||||
```json
|
||||
{
|
||||
"model":"llama3.2",
|
||||
"created_at":"2025-07-07T20:22:19.19314Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": ""
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": true,
|
||||
"total_duration": 182242375,
|
||||
"load_duration": 41295167,
|
||||
"prompt_eval_count": 169,
|
||||
"prompt_eval_duration": 24573166,
|
||||
"eval_count": 15,
|
||||
"eval_duration": 115959084
|
||||
}
|
||||
```
|
||||
|
||||
#### Chat request (No streaming)
|
||||
|
||||
##### Request
|
||||
@@ -606,6 +697,74 @@ curl http://localhost:11434/api/chat -d '{
|
||||
}
|
||||
```
|
||||
|
||||
#### Chat request (No streaming, with tools)
|
||||
|
||||
##### Request
|
||||
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3.2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the weather in tokyo?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to get the weather for"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama3.2",
|
||||
"created_at": "2025-07-07T20:32:53.844124Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"city": "Tokyo"
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": true,
|
||||
"total_duration": 3244883583,
|
||||
"load_duration": 2969184542,
|
||||
"prompt_eval_count": 169,
|
||||
"prompt_eval_duration": 141656333,
|
||||
"eval_count": 18,
|
||||
"eval_duration": 133293625
|
||||
}
|
||||
```
|
||||
|
||||
#### Chat request (Structured outputs)
|
||||
|
||||
##### Request
|
||||
@@ -712,6 +871,87 @@ Final response:
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
#### Chat request (With history, with tools)
|
||||
|
||||
##### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3.2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the weather in Toronto?"
|
||||
},
|
||||
// the message from the model appended to history
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_temperature",
|
||||
"arguments": {
|
||||
"city": "Toronto"
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
// the tool call result appended to history
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "11 degrees celsius",
|
||||
"tool_name": "get_temperature",
|
||||
}
|
||||
],
|
||||
"stream": false,
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to get the weather for"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama3.2",
|
||||
"created_at": "2025-07-07T20:43:37.688511Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "The current temperature in Toronto is 11°C."
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": true,
|
||||
"total_duration": 890771750,
|
||||
"load_duration": 707634750,
|
||||
"prompt_eval_count": 94,
|
||||
"prompt_eval_duration": 91703208,
|
||||
"eval_count": 11,
|
||||
"eval_duration": 90282125
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
|
||||
#### Chat request (with images)
|
||||
|
||||
##### Request
|
||||
@@ -1353,7 +1593,7 @@ Then there is a series of downloading responses. Until any of the download is co
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "downloading digestname",
|
||||
"status": "pulling digestname",
|
||||
"digest": "digestname",
|
||||
"total": 2142590208,
|
||||
"completed": 241970
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
# Benchmark
|
||||
|
||||
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
|
||||
|
||||
## When to use
|
||||
|
||||
Run these benchmarks when:
|
||||
- Making changes to the model inference engine
|
||||
- Modifying model loading/unloading logic
|
||||
- Changing prompt processing or token generation code
|
||||
- Implementing a new model architecture
|
||||
- Testing performance across different hardware setups
|
||||
|
||||
## Prerequisites
|
||||
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
|
||||
## Usage and Examples
|
||||
|
||||
>[!NOTE]
|
||||
>All commands must be run from the root directory of the Ollama project.
|
||||
|
||||
Basic syntax:
|
||||
```bash
|
||||
go test -bench=. ./benchmark/... -m $MODEL_NAME
|
||||
```
|
||||
|
||||
Required flags:
|
||||
- `-bench=.`: Run all benchmarks
|
||||
- `-m`: Model name to benchmark
|
||||
|
||||
Optional flags:
|
||||
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
|
||||
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
|
||||
|
||||
Common usage patterns:
|
||||
|
||||
Single benchmark run with a model specified:
|
||||
```bash
|
||||
go test -bench=. ./benchmark/... -m llama3.3
|
||||
```
|
||||
|
||||
## Output metrics
|
||||
|
||||
The benchmark reports several key metrics:
|
||||
|
||||
- `gen_tok/s`: Generated tokens per second
|
||||
- `prompt_tok/s`: Prompt processing tokens per second
|
||||
- `ttft_ms`: Time to first token in milliseconds
|
||||
- `load_ms`: Model load time in milliseconds
|
||||
- `gen_tokens`: Total tokens generated
|
||||
- `prompt_tokens`: Total prompt tokens processed
|
||||
|
||||
Each benchmark runs two scenarios:
|
||||
- Cold start: Model is loaded from disk for each test
|
||||
- Warm start: Model is pre-loaded in memory
|
||||
|
||||
Three prompt lengths are tested for each scenario:
|
||||
- Short prompt (100 tokens)
|
||||
- Medium prompt (500 tokens)
|
||||
- Long prompt (1000 tokens)
|
||||
@@ -118,7 +118,7 @@ To run tests, use `go test`:
|
||||
go test ./...
|
||||
```
|
||||
|
||||
> NOTE: In rare cirumstances, you may nedd to change a package using the new
|
||||
> NOTE: In rare circumstances, you may need to change a package using the new
|
||||
> "synctest" package in go1.24.
|
||||
>
|
||||
> If you do not have the "synctest" package enabled, you will not see build or
|
||||
|
||||
33
docs/faq.md
33
docs/faq.md
@@ -20,9 +20,9 @@ Please refer to the [GPU docs](./gpu.md).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 4096 tokens.
|
||||
By default, Ollama uses a context window size of 4096 tokens for most models. The `gpt-oss` model has a default context window size of 8192 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
This can be overridden in Settings in the Windows and macOS App, or with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
```shell
|
||||
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
|
||||
@@ -46,6 +46,8 @@ curl http://localhost:11434/api/generate -d '{
|
||||
}'
|
||||
```
|
||||
|
||||
Setting the context length higher may cause the model to not be able to fit onto the GPU which make the model run more slowly.
|
||||
|
||||
## How can I tell if my model was loaded onto the GPU?
|
||||
|
||||
Use the `ollama ps` command to see what models are currently loaded into memory.
|
||||
@@ -57,8 +59,8 @@ ollama ps
|
||||
> **Output**:
|
||||
>
|
||||
> ```
|
||||
> NAME ID SIZE PROCESSOR UNTIL
|
||||
> llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
||||
> NAME ID SIZE PROCESSOR CONTEXT UNTIL
|
||||
> gpt-oss:20b 05afbac4bad6 16 GB 100% GPU 8192 4 minutes from now
|
||||
> ```
|
||||
|
||||
The `Processor` column will show which memory the model was loaded in to:
|
||||
@@ -148,9 +150,11 @@ docker build -t ollama-with-ca .
|
||||
docker run -d -e HTTPS_PROXY=https://my.proxy.example.com -p 11434:11434 ollama-with-ca
|
||||
```
|
||||
|
||||
## Does Ollama send my prompts and answers back to ollama.com?
|
||||
## Does Ollama send my prompts and responses back to ollama.com?
|
||||
|
||||
No. Ollama runs locally, and conversation data does not leave your machine.
|
||||
If you're running a model locally, your prompts and responses will always stay on your machine. Ollama Turbo in the App allows you to run your queries on Ollama's servers if you don't have a powerful enough GPU. Web search lets a model query the web, giving you more accurate and up-to-date information. Both Turbo and web search require sending your prompts and responses to Ollama.com. This data is neither logged nor stored.
|
||||
|
||||
If you don't want to see the Turbo and web search options in the app, you can disable them in Settings by turning on Airplane mode. In Airplane mode, all models will run locally, and your prompts and responses will stay on your machine.
|
||||
|
||||
## How can I expose Ollama on my network?
|
||||
|
||||
@@ -292,7 +296,7 @@ If too many requests are sent to the server, it will respond with a 503 error in
|
||||
|
||||
## How does Ollama handle concurrent requests?
|
||||
|
||||
Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it is configured to allow parallel request processing.
|
||||
Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it can be configured to allow parallel request processing.
|
||||
|
||||
If there is insufficient available memory to load a new model request while one or more models are already loaded, all new requests will be queued until the new model can be loaded. As prior models become idle, one or more will be unloaded to make room for the new model. Queued requests will be processed in order. When using GPU inference new models must be able to completely fit in VRAM to allow concurrent model loads.
|
||||
|
||||
@@ -301,7 +305,7 @@ Parallel request processing for a given model results in increasing the context
|
||||
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
|
||||
|
||||
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference.
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default is 1, and will handle 1 request per model at a time.
|
||||
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
||||
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
||||
@@ -333,3 +337,16 @@ The currently available K/V cache quantization types are:
|
||||
How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count.
|
||||
|
||||
You may need to experiment with different quantization types to find the best balance between memory usage and quality.
|
||||
|
||||
## How can I stop Ollama from starting when I login to my computer
|
||||
|
||||
Ollama for Windows and macOS register as a login item during installation. You can disable this if you prefer not to have Ollama automatically start. Ollama will respect this setting across upgrades, unless you uninstall the application.
|
||||
|
||||
**Windows**
|
||||
- Remove `%APPDATA%\Microsoft\Windows\Start Menu\Programs\Startup\Ollama.lnk`
|
||||
|
||||
**MacOS Monterey (v12)**
|
||||
- Open `Settings` -> `Users & Groups` -> `Login Items` and find the `Ollama` entry, then click the `-` (minus) to remove
|
||||
|
||||
**MacOS Ventura (v13) and later**
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
# GPU
|
||||
## Nvidia
|
||||
Ollama supports Nvidia GPUs with compute capability 5.0+.
|
||||
Ollama supports Nvidia GPUs with compute capability 5.0+ and driver version 531 and newer.
|
||||
|
||||
Check your compute compatibility to see if your card is supported:
|
||||
[https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus)
|
||||
|
||||
| Compute Capability | Family | Cards |
|
||||
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
|
||||
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
|
||||
| | NVIDIA Professioal | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
|
||||
| 9.0 | NVIDIA | `H200` `H100` |
|
||||
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
|
||||
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
|
||||
|
||||
@@ -53,6 +53,8 @@ FROM /path/to/safetensors/directory
|
||||
|
||||
If you create the Modelfile in the same directory as the weights, you can use the command `FROM .`.
|
||||
|
||||
If you do not create the Modelfile, ollama will act as if there was a Modelfile with the command `FROM .`.
|
||||
|
||||
Now run the `ollama create` command from the directory where you created the `Modelfile`:
|
||||
|
||||
```shell
|
||||
|
||||
@@ -16,7 +16,7 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -L https://ollama.com/download/ollama-linux-amd64.tgz -o ollama-linux-amd64.tgz
|
||||
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz
|
||||
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
|
||||
```
|
||||
|
||||
@@ -34,7 +34,11 @@ ollama -v
|
||||
|
||||
### AMD GPU install
|
||||
|
||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
||||
If you have an AMD GPU, **also** download and extract the additional ROCm package:
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The ROCm tgz contains only AMD dependent libraries. You must extract **both** `ollama-linux-amd64.tgz` and `ollama-linux-amd64-rocm.tgz` into the same location.
|
||||
|
||||
|
||||
```shell
|
||||
curl -L https://ollama.com/download/ollama-linux-amd64-rocm.tgz -o ollama-linux-amd64-rocm.tgz
|
||||
@@ -112,8 +116,8 @@ sudo systemctl status ollama
|
||||
> While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||
> kernel source, the version is older and may not support all ROCm features. We
|
||||
> recommend you install the latest driver from
|
||||
> https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
||||
> GPU.
|
||||
> [AMD](https://www.amd.com/en/support/download/linux-drivers.html) for best support
|
||||
> of your Radeon GPU.
|
||||
|
||||
## Customizing
|
||||
|
||||
|
||||
42
docs/macos.md
Normal file
42
docs/macos.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# Ollama for macOS
|
||||
|
||||
## System Requirements
|
||||
|
||||
* MacOS Monterey (v12) or newer
|
||||
* Apple M series (CPU and GPU support) or x86 (CPU only)
|
||||
|
||||
|
||||
## Filesystem Requirements
|
||||
|
||||
The preferred method of installation is to mount the `ollama.dmg` and drag-and-drop the Ollama application to the system-wide `Applications` folder. Upon startup, the Ollama app will verify the `ollama` CLI is present in your PATH, and if not detected, will prompt for permission to create a link in `/usr/local/bin`
|
||||
|
||||
Once you've installed Ollama, you'll need additional space for storing the Large Language models, which can be tens to hundreds of GB in size. If your home directory doesn't have enough space, you can change where the binaries are installed, and where the models are stored.
|
||||
|
||||
### Changing Install Location
|
||||
|
||||
To install the Ollama application somewhere other than `Applications`, place the Ollama application in the desired location, and ensure the CLI `Ollama.app/Contents/Resources/ollama` or a sym-link to the CLI can be found in your path. Upon first start decline the "Move to Applications?" request.
|
||||
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
Ollama on MacOS stores files in a few different locations.
|
||||
- `~/.ollama` contains models and configuration
|
||||
- `~/.ollama/logs` contains logs
|
||||
- *app.log* contains most recent logs from the GUI application
|
||||
- *server.log* contains the most recent server logs
|
||||
- `<install location>/Ollama.app/Contents/Resources/ollama` the CLI binary
|
||||
|
||||
## Uninstall
|
||||
|
||||
To fully remove Ollama from your system, remove the following files and folders:
|
||||
|
||||
```
|
||||
sudo rm -rf /Applications/Ollama.app
|
||||
sudo rm /usr/local/bin/ollama
|
||||
rm -rf "~/Library/Application Support/Ollama"
|
||||
rm -rf "~/Library/Saved Application State/com.electron.ollama.savedState"
|
||||
rm -rf ~/Library/Caches/com.electron.ollama/
|
||||
rm -rf ~/Library/Caches/ollama
|
||||
rm -rf ~/Library/WebKit/com.electron.ollama
|
||||
rm -rf ~/.ollama
|
||||
```
|
||||
@@ -150,7 +150,7 @@ PARAMETER <parameter> <parametervalue>
|
||||
|
||||
| Parameter | Description | Value Type | Example Usage |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 4096) | int | num_ctx 4096 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||
|
||||
@@ -72,7 +72,7 @@ client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
|
||||
# Define the schema for the response
|
||||
class FriendInfo(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
age: int
|
||||
is_available: bool
|
||||
|
||||
class FriendList(BaseModel):
|
||||
|
||||
@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
|
||||
On **Linux** systems with systemd, the logs can be found with this command:
|
||||
|
||||
```shell
|
||||
journalctl -u ollama --no-pager --follow --pager-end
|
||||
journalctl -u ollama --no-pager --follow --pager-end
|
||||
```
|
||||
|
||||
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
||||
@@ -23,7 +23,7 @@ docker logs <container-name>
|
||||
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
|
||||
|
||||
When you run Ollama on **Windows**, there are a few different locations. You can view them in the explorer window by hitting `<cmd>+R` and type in:
|
||||
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
||||
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
|
||||
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
|
||||
|
||||
@@ -38,12 +38,12 @@ Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
|
||||
|
||||
## LLM libraries
|
||||
|
||||
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` an the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
|
||||
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` and the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
|
||||
|
||||
In the server log, you will see a message that looks something like this (varies from release to release):
|
||||
|
||||
```
|
||||
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
|
||||
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v12 rocm_v5]
|
||||
```
|
||||
|
||||
**Experimental LLM Library Override**
|
||||
@@ -97,7 +97,7 @@ If none of those resolve the problem, gather additional information and file an
|
||||
|
||||
On linux, AMD GPU access typically requires `video` and/or `render` group membership to access the `/dev/kfd` device. If permissions are not set up correctly, Ollama will detect this and report an error in the server log.
|
||||
|
||||
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
|
||||
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
|
||||
|
||||
If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure.
|
||||
- `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems
|
||||
|
||||
107
docs/turbo.md
Normal file
107
docs/turbo.md
Normal file
@@ -0,0 +1,107 @@
|
||||
# Turbo
|
||||
|
||||
> ⚠️ Turbo is preview
|
||||
|
||||
Ollama’s [Turbo](https://ollama.com/turbo) is a new way to run open-source models with acceleration from datacenter-grade hardware.
|
||||
|
||||
Currently, the following models are available in Turbo:
|
||||
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
|
||||
## Get started
|
||||
|
||||
### Ollama for macOS & Windows
|
||||
|
||||
Download Ollama
|
||||
|
||||
- Select a model such as `gpt-oss:20b` or `gpt-oss:120b`
|
||||
- Click on **Turbo**. You’ll be prompted to create an account or sign in
|
||||
|
||||
### Ollama’s CLI
|
||||
|
||||
- [Sign up](https://ollama.com/signup) for an Ollama account
|
||||
- Add your Ollama key [to ollama.com](https://ollama.com/settings/keys).
|
||||
|
||||
On macOS and Linux:
|
||||
|
||||
```shell
|
||||
cat ~/.ollama/id_ed25519.pub
|
||||
```
|
||||
|
||||
On Windows:
|
||||
|
||||
```
|
||||
type "%USERPROFILE%\.ollama\id_ed25519.pub"
|
||||
```
|
||||
|
||||
- Then run a model setting `OLLAMA_HOST` to `ollama.com`:
|
||||
```shell
|
||||
OLLAMA_HOST=ollama.com ollama run gpt-oss:120b
|
||||
```
|
||||
|
||||
### Ollama’s Python library
|
||||
|
||||
- Download Ollama's [Python library](https://github.com/ollama/ollama-python)
|
||||
- [Sign up](https://ollama.com/signup) for an Ollama account
|
||||
- Create an API key by visiting https://ollama.com/settings/keys
|
||||
|
||||
```python
|
||||
from ollama import Client
|
||||
|
||||
client = Client(
|
||||
host="https://ollama.com",
|
||||
headers={'Authorization': '<api key>'}
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
},
|
||||
]
|
||||
|
||||
for part in client.chat('gpt-oss:120b', messages=messages, stream=True):
|
||||
print(part['message']['content'], end='', flush=True)
|
||||
```
|
||||
|
||||
### Ollama’s JavaScript library
|
||||
|
||||
- Download Ollama's [JavaScript library](https://github.com/ollama/ollama-js)
|
||||
- [Sign up](https://ollama.com/signup) for an Ollama account
|
||||
- Create an API key by visiting https://ollama.com/settings/keys
|
||||
|
||||
```typescript
|
||||
import { Ollama } from 'ollama';
|
||||
|
||||
const ollama = new Ollama({
|
||||
host: 'https://ollama.com',
|
||||
headers: {
|
||||
Authorization: "Bearer <api key>"
|
||||
}
|
||||
});
|
||||
|
||||
const response = await ollama.chat({
|
||||
model: 'gpt-oss:120b',
|
||||
messages: [{ role: 'user', content: 'Explain quantum computing' }],
|
||||
stream: true
|
||||
});
|
||||
|
||||
for await (const part of response) {
|
||||
process.stdout.write(part.message.content)
|
||||
}
|
||||
```
|
||||
|
||||
### Community integrations
|
||||
|
||||
Turbo mode is also compatible with several community integrations.
|
||||
|
||||
#### Open WebUI
|
||||
|
||||
- Go to **settings** → **Admin settings** → **Connections**
|
||||
- Under **Ollama API,** click **+**
|
||||
- For the **URL** put `https://ollama.com`
|
||||
- For the **API key,** create an API key on https://ollama.com/settings/keys and add it.
|
||||
- Click **Save**
|
||||
|
||||
Now, if you navigate to the model selector, Turbo models should be available under **External**.
|
||||
@@ -30,20 +30,6 @@ To install the Ollama application in a location different than your home directo
|
||||
OllamaSetup.exe /DIR="d:\some\location"
|
||||
```
|
||||
|
||||
### Changing Model Location
|
||||
|
||||
To change where Ollama stores the downloaded models instead of using your home directory, set the environment variable `OLLAMA_MODELS` in your user account.
|
||||
|
||||
1. Start the Settings (Windows 11) or Control Panel (Windows 10) application and search for _environment variables_.
|
||||
|
||||
2. Click on _Edit environment variables for your account_.
|
||||
|
||||
3. Edit or create a new variable for your user account for `OLLAMA_MODELS` where you want the models stored
|
||||
|
||||
4. Click OK/Apply to save.
|
||||
|
||||
If Ollama is already running, Quit the tray application and relaunch it from the Start menu, or a new terminal started after you saved the environment variables.
|
||||
|
||||
## API Access
|
||||
|
||||
Here's a quick example showing API access from `powershell`
|
||||
@@ -82,9 +68,9 @@ If you'd like to install or integrate Ollama as a service, a standalone
|
||||
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
||||
and GPU library dependencies for Nvidia. If you have an AMD GPU, also download
|
||||
and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the
|
||||
same directory. This allows for embedding Ollama in existing applications, or
|
||||
running it as a system service via `ollama serve` with tools such as
|
||||
[NSSM](https://nssm.cc/).
|
||||
same directory. Both zip files are necessary for a complete AMD installation.
|
||||
This allows for embedding Ollama in existing applications, or running it as a
|
||||
system service via `ollama serve` with tools such as [NSSM](https://nssm.cc/).
|
||||
|
||||
> [!NOTE]
|
||||
> If you are upgrading from a prior version, you should remove the old directories first.
|
||||
|
||||
@@ -185,6 +185,8 @@ var (
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||
// Auth enables authentication between the Ollama client and server
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
// Enable the new memory estimation logic
|
||||
NewMemoryEstimates = Bool("OLLAMA_NEW_ESTIMATES")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
@@ -219,7 +221,7 @@ func Uint(key string, defaultValue uint) func() uint {
|
||||
|
||||
var (
|
||||
// NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable.
|
||||
NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0)
|
||||
NumParallel = Uint("OLLAMA_NUM_PARALLEL", 1)
|
||||
// MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable.
|
||||
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
|
||||
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
|
||||
@@ -270,6 +272,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_NEW_ESTIMATES": {"OLLAMA_NEW_ESTIMATES", NewMemoryEstimates(), "Enable the new memory estimation logic"},
|
||||
|
||||
// Informational
|
||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||
|
||||
@@ -10,4 +10,5 @@ type Config interface {
|
||||
Strings(string, ...[]string) []string
|
||||
Ints(string, ...[]int32) []int32
|
||||
Floats(string, ...[]float32) []float32
|
||||
Bools(string, ...[]bool) []bool
|
||||
}
|
||||
|
||||
191
fs/ggml/ggml.go
191
fs/ggml/ggml.go
@@ -1,14 +1,17 @@
|
||||
package ggml
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||
)
|
||||
|
||||
@@ -34,7 +37,8 @@ func (kv KV) Kind() string {
|
||||
}
|
||||
|
||||
func (kv KV) ParameterCount() uint64 {
|
||||
return keyValue(kv, "general.parameter_count", uint64(0))
|
||||
val, _ := keyValue(kv, "general.parameter_count", uint64(0))
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) FileType() FileType {
|
||||
@@ -53,16 +57,27 @@ func (kv KV) EmbeddingLength() uint64 {
|
||||
return uint64(kv.Uint("embedding_length"))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCount() uint64 {
|
||||
return uint64(kv.Uint("attention.head_count"))
|
||||
func (kv KV) HeadCountMax() uint64 {
|
||||
// TODO(drifkin): using the max value can cause an overestimation. In the
|
||||
// future if array values become more popular, we can adapt the more invasive
|
||||
// <https://github.com/ollama/ollama/pull/10225>
|
||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKV() uint64 {
|
||||
return uint64(kv.Uint("attention.head_count_kv", 1))
|
||||
func (kv KV) HeadCountMin() uint64 {
|
||||
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCount() uint64 {
|
||||
if heads := kv.HeadCount(); heads > 0 {
|
||||
func (kv KV) HeadCountKVMax() uint64 {
|
||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKVMin() uint64 {
|
||||
return uint64(kv.UintOrMinArrayValue("attention.head_count_kv", 1))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountMax() uint64 {
|
||||
if heads := kv.HeadCountMin(); heads > 0 {
|
||||
return kv.EmbeddingLength() / heads
|
||||
}
|
||||
|
||||
@@ -70,15 +85,11 @@ func (kv KV) EmbeddingHeadCount() uint64 {
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountK() uint64 {
|
||||
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
|
||||
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCountMax())))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountV() uint64 {
|
||||
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
|
||||
}
|
||||
|
||||
func (kv KV) GQA() uint64 {
|
||||
return kv.HeadCount() / kv.HeadCountKV()
|
||||
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCountMax())))
|
||||
}
|
||||
|
||||
func (kv KV) ContextLength() uint64 {
|
||||
@@ -90,44 +101,88 @@ func (kv KV) ChatTemplate() string {
|
||||
}
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
return keyValue(kv, key, append(defaultValue, "")...)
|
||||
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||
return keyValue(kv, key, append(defaultValue, false)...)
|
||||
val, _ := keyValue(kv, key, append(defaultValue, false)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) UintOrMaxArrayValue(key string, defaultValue uint32) uint32 {
|
||||
_, max := kv.UintOrArrayValue(key, defaultValue)
|
||||
return max
|
||||
}
|
||||
|
||||
func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
|
||||
min, _ := kv.UintOrArrayValue(key, defaultValue)
|
||||
return min
|
||||
}
|
||||
|
||||
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
|
||||
if u32, ok := keyValue(kv, key, uint32(0)); ok {
|
||||
return u32, u32
|
||||
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
|
||||
min := slices.Min(u32s.values)
|
||||
max := slices.Max(u32s.values)
|
||||
return min, max
|
||||
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
|
||||
min := slices.Min(i32s.values)
|
||||
max := slices.Max(i32s.values)
|
||||
if min < 0 || max < 0 {
|
||||
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
|
||||
}
|
||||
return uint32(min), uint32(max)
|
||||
}
|
||||
|
||||
return defaultValue, defaultValue
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values
|
||||
val, _ := keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||
return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values
|
||||
val, _ := keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values
|
||||
val, _ := keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||
return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values
|
||||
val, _ := keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
val, _ := keyValue(kv, key, &array[bool]{values: append(defaultValue, []bool(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return slices.Contains([]string{
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"mistral3",
|
||||
"llama4",
|
||||
"mllama",
|
||||
"qwen25vl",
|
||||
"gptoss", "gpt-oss",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
@@ -143,17 +198,17 @@ type arrayValueTypes interface {
|
||||
*array[string] | *array[float32] | *array[float64] | *array[bool]
|
||||
}
|
||||
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T {
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
|
||||
if val, ok := kv[key]; ok {
|
||||
return val.(T)
|
||||
if val, ok := kv[key].(T); ok {
|
||||
return val, true
|
||||
}
|
||||
|
||||
slog.Debug("key not found", "key", key, "default", defaultValue[0])
|
||||
return defaultValue[0]
|
||||
slog.Debug("key with type not found", "key", key, "default", defaultValue[0])
|
||||
return defaultValue[0], false
|
||||
}
|
||||
|
||||
type Tensors struct {
|
||||
@@ -222,36 +277,37 @@ type Tensor struct {
|
||||
|
||||
func (t Tensor) block() (n int) {
|
||||
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
|
||||
return -1
|
||||
return math.MaxInt
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (t Tensor) blockSize() uint64 {
|
||||
return (TensorType)(t.Kind).BlockSize()
|
||||
return TensorType(t.Kind).BlockSize()
|
||||
}
|
||||
|
||||
func (t TensorType) BlockSize() uint64 {
|
||||
switch t {
|
||||
case
|
||||
0, // F32
|
||||
1, // F16
|
||||
24, // I8
|
||||
25, // I16
|
||||
26, // I32
|
||||
27, // I64
|
||||
28, // F64
|
||||
30: // BF16
|
||||
TensorTypeF32,
|
||||
TensorTypeF16,
|
||||
TensorTypeI8,
|
||||
TensorTypeI16,
|
||||
TensorTypeI32,
|
||||
TensorTypeI64,
|
||||
TensorTypeF64,
|
||||
TensorTypeBF16:
|
||||
return 1
|
||||
case
|
||||
2, // Q4_0
|
||||
3, // Q4_1
|
||||
6, // Q5_0
|
||||
7, // Q5_1
|
||||
8, // Q8_0
|
||||
9, // Q8_1
|
||||
20: // IQ4_NL
|
||||
TensorTypeQ4_0,
|
||||
TensorTypeQ4_1,
|
||||
TensorTypeQ5_0,
|
||||
TensorTypeQ5_1,
|
||||
TensorTypeQ8_0,
|
||||
TensorTypeQ8_1,
|
||||
tensorTypeIQ4_NL,
|
||||
4, TensorTypeMXFP4:
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
@@ -324,6 +380,8 @@ func (t TensorType) TypeSize() uint64 {
|
||||
return blockSize/8 + blockSize/16 + blockSize/32
|
||||
case TensorTypeBF16:
|
||||
return 2
|
||||
case 4, TensorTypeMXFP4:
|
||||
return 1 + blockSize/2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
@@ -423,22 +481,26 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
context *= uint64(numParallel)
|
||||
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKV()
|
||||
heads := f.KV().HeadCountMax()
|
||||
headsKV := f.KV().HeadCountKVMax()
|
||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
|
||||
|
||||
embeddingHeads := f.KV().EmbeddingHeadCount()
|
||||
embeddingHeads := f.KV().EmbeddingHeadCountMax()
|
||||
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
||||
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
|
||||
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
var kvTotal uint64
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
for i := range kv {
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
kvTotal += kv[i]
|
||||
}
|
||||
|
||||
switch f.KV().Architecture() {
|
||||
@@ -504,7 +566,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
// vocab graph
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
case "gemma", "gemma2", "gemma3":
|
||||
case "gemma", "gemma2", "gemma3", "gemma3n":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
||||
@@ -517,6 +579,11 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
embedding*embeddingHeadsK*heads*9/16,
|
||||
)
|
||||
|
||||
if f.KV().Architecture() == "gemma3n" {
|
||||
fullOffload *= 4
|
||||
partialOffload *= 4
|
||||
}
|
||||
|
||||
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||
// engine. Gemma3 always uses the Ollama engine.
|
||||
if f.KV().Architecture() == "gemma3" {
|
||||
@@ -602,6 +669,22 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
4*qkvBias.Shape[0],
|
||||
)
|
||||
}
|
||||
case "gptoss", "gpt-oss":
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
for i := range kv {
|
||||
kv[i] = uint64(float64((embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
if i%2 == 0 {
|
||||
kv[i] *= (uint64(numParallel)*4096 + batch)
|
||||
} else {
|
||||
kv[i] *= context
|
||||
}
|
||||
}
|
||||
|
||||
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||
if useFlashAttention {
|
||||
// rough estimate of graph size with flash attention on
|
||||
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
@@ -676,6 +759,11 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||
|
||||
// SupportsKVCacheType checks if the requested cache type is supported
|
||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) {
|
||||
// gpt-oss uses attention with sinks which does not support quantized cache types
|
||||
slog.Warn("model only supports non-quantized cache types ", "mode", arch)
|
||||
return cacheType == "f16"
|
||||
}
|
||||
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
||||
}
|
||||
|
||||
@@ -692,6 +780,13 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||
}
|
||||
|
||||
// FlashAttention checks if the model should enable flash attention
|
||||
func (f GGML) FlashAttention() bool {
|
||||
return slices.Contains([]string{
|
||||
"gptoss", "gpt-oss",
|
||||
}, f.KV().String("general.architecture"))
|
||||
}
|
||||
|
||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||
func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
switch cacheType {
|
||||
|
||||
@@ -269,3 +269,33 @@ func TestKeyValue(t *testing.T) {
|
||||
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadCount(t *testing.T) {
|
||||
valuesArray := []int32{1, 5, 3, 4}
|
||||
cases := []struct {
|
||||
kv KV
|
||||
want uint64
|
||||
}{
|
||||
{
|
||||
kv: KV{
|
||||
"general.architecture": "abc",
|
||||
"abc.attention.head_count": &array[int32]{values: valuesArray, size: len(valuesArray)},
|
||||
},
|
||||
want: uint64(5),
|
||||
},
|
||||
{
|
||||
kv: KV{
|
||||
"general.architecture": "abc",
|
||||
"abc.attention.head_count": uint32(3),
|
||||
},
|
||||
want: uint64(3),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
got := tt.kv.HeadCountMax()
|
||||
if got != tt.want {
|
||||
t.Errorf("unexpected max value: got=%d want=%d", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -527,24 +527,21 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
keys := slices.Collect(maps.Keys(kv))
|
||||
slices.Sort(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
for _, key := range slices.Sorted(maps.Keys(kv)) {
|
||||
if err := ggufWriteKV(f, key, kv[key]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortStableFunc(ts, func(a, b *Tensor) int {
|
||||
if i, j := a.block(), b.block(); i < 0 && j > 0 {
|
||||
return 1
|
||||
} else if i > 0 && j < 0 {
|
||||
return -1
|
||||
} else {
|
||||
return cmp.Compare(i, j)
|
||||
}
|
||||
})
|
||||
slices.SortStableFunc(
|
||||
ts,
|
||||
func(a, b *Tensor) int {
|
||||
return cmp.Or(
|
||||
cmp.Compare(a.block(), b.block()),
|
||||
cmp.Compare(a.Name, b.Name),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
var s uint64
|
||||
for i := range ts {
|
||||
@@ -615,6 +612,10 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||
err = writeGGUFArray(ws, ggufTypeString, v)
|
||||
case *array[string]:
|
||||
err = writeGGUFArray(ws, ggufTypeString, v.values)
|
||||
case []bool:
|
||||
err = writeGGUFArray(ws, ggufTypeBool, v)
|
||||
case *array[bool]:
|
||||
err = writeGGUFArray(ws, ggufTypeBool, v.values)
|
||||
default:
|
||||
return fmt.Errorf("improper type for '%s'", k)
|
||||
}
|
||||
|
||||
@@ -2,62 +2,82 @@ package ggml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math/rand/v2"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWriteGGUF(t *testing.T) {
|
||||
w, err := os.CreateTemp(t.TempDir(), "*.bin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer w.Close()
|
||||
b := bytes.NewBuffer(make([]byte, 2*3))
|
||||
for range 8 {
|
||||
t.Run("shuffle", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if err := WriteGGUF(w, KV{
|
||||
"general.alignment": uint32(16),
|
||||
}, []*Tensor{
|
||||
{Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ts := []*Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||
}
|
||||
|
||||
r, err := os.Open(w.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
rand.Shuffle(len(ts), func(i, j int) {
|
||||
ts[i], ts[j] = ts[j], ts[i]
|
||||
})
|
||||
|
||||
ff, err := Decode(r, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
if diff := cmp.Diff(ff.KV(), KV{
|
||||
"general.alignment": uint32(16),
|
||||
"general.parameter_count": uint64(36),
|
||||
}); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
if err := WriteGGUF(w, KV{
|
||||
"general.alignment": uint32(16),
|
||||
}, ts); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(ff.Tensors(), Tensors{
|
||||
Offset: 336,
|
||||
items: []*Tensor{
|
||||
{Name: "test.0", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "test.1", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "test.2", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "test.3", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "test.4", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "test.5", Offset: 160, Shape: []uint64{2, 3}},
|
||||
},
|
||||
}, cmp.AllowUnexported(Tensors{})); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
r, err := os.Open(w.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
ff, err := Decode(r, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(KV{
|
||||
"general.alignment": uint32(16),
|
||||
"general.parameter_count": uint64(54),
|
||||
}, ff.KV()); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(Tensors{
|
||||
Offset: 592,
|
||||
items: []*Tensor{
|
||||
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.ffn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.ffn_down.weight", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.ffn_up.weight", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.2.ffn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
||||
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
|
||||
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
|
||||
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
|
||||
},
|
||||
}, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,9 +14,9 @@ const (
|
||||
FileTypeF16
|
||||
fileTypeQ4_0
|
||||
fileTypeQ4_1
|
||||
fileTypeQ4_1_F16 // unused by GGML
|
||||
fileTypeQ4_2 // unused by GGML
|
||||
fileTypeQ4_3 // unused by GGML
|
||||
fileTypeMXFP4 // originally fileTypeQ4_1_F16 // unused by GGML
|
||||
fileTypeQ4_2 // unused by GGML
|
||||
fileTypeQ4_3 // unused by GGML
|
||||
FileTypeQ8_0
|
||||
fileTypeQ5_0
|
||||
fileTypeQ5_1
|
||||
@@ -97,6 +97,8 @@ func (t FileType) String() string {
|
||||
return "Q4_0"
|
||||
case fileTypeQ4_1:
|
||||
return "Q4_1"
|
||||
case fileTypeMXFP4:
|
||||
return "MXFP4"
|
||||
case FileTypeQ8_0:
|
||||
return "Q8_0"
|
||||
case fileTypeQ5_0:
|
||||
@@ -172,6 +174,8 @@ func (ftype FileType) ToTensorType() TensorType {
|
||||
return TensorTypeQ2_K
|
||||
case FileTypeBF16:
|
||||
return TensorTypeBF16
|
||||
case fileTypeMXFP4:
|
||||
return TensorTypeMXFP4
|
||||
default:
|
||||
slog.Warn("unsupported file type", "type", ftype)
|
||||
return 0 // F32
|
||||
@@ -187,7 +191,7 @@ const (
|
||||
TensorTypeF16
|
||||
TensorTypeQ4_0
|
||||
TensorTypeQ4_1
|
||||
tensorTypeQ4_2 // unused by GGML
|
||||
tensorTypeQ4_2
|
||||
tensorTypeQ4_3 // unused by GGML
|
||||
TensorTypeQ5_0
|
||||
TensorTypeQ5_1
|
||||
@@ -222,6 +226,7 @@ const (
|
||||
tensorTypeIQ4_NL_4_4 // unused by GGML
|
||||
tensorTypeIQ4_NL_4_8 // unused by GGML
|
||||
tensorTypeIQ4_NL_8_8 // unused by GGML
|
||||
TensorTypeMXFP4
|
||||
)
|
||||
|
||||
// ParseFileType parses the provided GGUF file type
|
||||
@@ -260,6 +265,8 @@ func ParseTensorType(s string) (TensorType, error) {
|
||||
return TensorTypeF64, nil
|
||||
case "BF16":
|
||||
return TensorTypeBF16, nil
|
||||
case "MXFP4":
|
||||
return TensorTypeMXFP4, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported quantization type %s", s)
|
||||
}
|
||||
@@ -312,6 +319,8 @@ func (t TensorType) String() string {
|
||||
return "F64"
|
||||
case TensorTypeBF16:
|
||||
return "BF16"
|
||||
case 4, TensorTypeMXFP4:
|
||||
return "MXFP4"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
347
fs/gguf/gguf.go
Normal file
347
fs/gguf/gguf.go
Normal file
@@ -0,0 +1,347 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
typeUint8 uint32 = iota
|
||||
typeInt8
|
||||
typeUint16
|
||||
typeInt16
|
||||
typeUint32
|
||||
typeInt32
|
||||
typeFloat32
|
||||
typeBool
|
||||
typeString
|
||||
typeArray
|
||||
typeUint64
|
||||
typeInt64
|
||||
typeFloat64
|
||||
)
|
||||
|
||||
var ErrUnsupported = errors.New("unsupported")
|
||||
|
||||
type File struct {
|
||||
Magic [4]byte
|
||||
Version uint32
|
||||
|
||||
keyValues *lazy[KeyValue]
|
||||
tensors *lazy[TensorInfo]
|
||||
offset int64
|
||||
|
||||
file *os.File
|
||||
reader *bufferedReader
|
||||
bts []byte
|
||||
}
|
||||
|
||||
func Open(path string) (f *File, err error) {
|
||||
f = &File{bts: make([]byte, 4096)}
|
||||
f.file, err = os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.reader = newBufferedReader(f.file, 32<<10)
|
||||
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if bytes.Equal(f.Magic[:], []byte("gguf")) {
|
||||
return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic)
|
||||
}
|
||||
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.Version < 2 {
|
||||
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
|
||||
}
|
||||
|
||||
f.tensors, err = newLazy(f, f.readTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.tensors.successFunc = func() error {
|
||||
offset := f.reader.offset
|
||||
|
||||
alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32)
|
||||
f.offset = offset + (alignment-offset%alignment)%alignment
|
||||
return nil
|
||||
}
|
||||
|
||||
f.keyValues, err = newLazy(f, f.readKeyValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *File) readTensor() (TensorInfo, error) {
|
||||
name, err := readString(f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
dims, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
shape := make([]uint64, dims)
|
||||
for i := range dims {
|
||||
shape[i], err = read[uint64](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
}
|
||||
|
||||
type_, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
offset, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
return TensorInfo{
|
||||
Name: name,
|
||||
Offset: offset,
|
||||
Shape: shape,
|
||||
Type: TensorType(type_),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *File) readKeyValue() (KeyValue, error) {
|
||||
key, err := readString(f)
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
t, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
value, err := func() (any, error) {
|
||||
switch t {
|
||||
case typeUint8:
|
||||
return read[uint8](f)
|
||||
case typeInt8:
|
||||
return read[int8](f)
|
||||
case typeUint16:
|
||||
return read[uint16](f)
|
||||
case typeInt16:
|
||||
return read[int16](f)
|
||||
case typeUint32:
|
||||
return read[uint32](f)
|
||||
case typeInt32:
|
||||
return read[int32](f)
|
||||
case typeUint64:
|
||||
return read[uint64](f)
|
||||
case typeInt64:
|
||||
return read[int64](f)
|
||||
case typeFloat32:
|
||||
return read[float32](f)
|
||||
case typeFloat64:
|
||||
return read[float64](f)
|
||||
case typeBool:
|
||||
return read[bool](f)
|
||||
case typeString:
|
||||
return readString(f)
|
||||
case typeArray:
|
||||
return readArray(f)
|
||||
default:
|
||||
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
return KeyValue{
|
||||
Key: key,
|
||||
Value: Value{value},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func read[T any](f *File) (t T, err error) {
|
||||
err = binary.Read(f.reader, binary.LittleEndian, &t)
|
||||
return t, err
|
||||
}
|
||||
|
||||
func readString(f *File) (string, error) {
|
||||
n, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if int(n) > len(f.bts) {
|
||||
f.bts = make([]byte, n)
|
||||
}
|
||||
|
||||
bts := f.bts[:n]
|
||||
if _, err := io.ReadFull(f.reader, bts); err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer clear(bts)
|
||||
|
||||
return string(bts), nil
|
||||
}
|
||||
|
||||
func readArray(f *File) (any, error) {
|
||||
t, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case typeUint8:
|
||||
return readArrayData[uint8](f, n)
|
||||
case typeInt8:
|
||||
return readArrayData[int8](f, n)
|
||||
case typeUint16:
|
||||
return readArrayData[uint16](f, n)
|
||||
case typeInt16:
|
||||
return readArrayData[int16](f, n)
|
||||
case typeUint32:
|
||||
return readArrayData[uint32](f, n)
|
||||
case typeInt32:
|
||||
return readArrayData[int32](f, n)
|
||||
case typeUint64:
|
||||
return readArrayData[uint64](f, n)
|
||||
case typeInt64:
|
||||
return readArrayData[int64](f, n)
|
||||
case typeFloat32:
|
||||
return readArrayData[float32](f, n)
|
||||
case typeFloat64:
|
||||
return readArrayData[float64](f, n)
|
||||
case typeBool:
|
||||
return readArrayData[bool](f, n)
|
||||
case typeString:
|
||||
return readArrayString(f, n)
|
||||
default:
|
||||
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
|
||||
}
|
||||
}
|
||||
|
||||
func readArrayData[T any](f *File, n uint64) (s []T, err error) {
|
||||
s = make([]T, n)
|
||||
for i := range n {
|
||||
e, err := read[T](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s[i] = e
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func readArrayString(f *File, n uint64) (s []string, err error) {
|
||||
s = make([]string, n)
|
||||
for i := range n {
|
||||
e, err := readString(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s[i] = e
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (f *File) Close() error {
|
||||
f.keyValues.stop()
|
||||
f.tensors.stop()
|
||||
return f.file.Close()
|
||||
}
|
||||
|
||||
func (f *File) KeyValue(key string) KeyValue {
|
||||
if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") {
|
||||
key = f.KeyValue("general.architecture").String() + "." + key
|
||||
}
|
||||
|
||||
if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool {
|
||||
return kv.Key == key
|
||||
}); index >= 0 {
|
||||
return f.keyValues.values[index]
|
||||
}
|
||||
|
||||
for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() {
|
||||
if keyValue.Key == key {
|
||||
return keyValue
|
||||
}
|
||||
}
|
||||
|
||||
return KeyValue{}
|
||||
}
|
||||
|
||||
func (f *File) NumKeyValues() int {
|
||||
return int(f.keyValues.count)
|
||||
}
|
||||
|
||||
func (f *File) KeyValues() iter.Seq2[int, KeyValue] {
|
||||
return f.keyValues.All()
|
||||
}
|
||||
|
||||
func (f *File) TensorInfo(name string) TensorInfo {
|
||||
if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool {
|
||||
return t.Name == name
|
||||
}); index >= 0 {
|
||||
return f.tensors.values[index]
|
||||
}
|
||||
|
||||
// fast-forward through key values if we haven't already
|
||||
_ = f.keyValues.rest()
|
||||
for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() {
|
||||
if tensor.Name == name {
|
||||
return tensor
|
||||
}
|
||||
}
|
||||
|
||||
return TensorInfo{}
|
||||
}
|
||||
|
||||
func (f *File) NumTensors() int {
|
||||
return int(f.tensors.count)
|
||||
}
|
||||
|
||||
func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] {
|
||||
// fast forward through key values if we haven't already
|
||||
f.keyValues.rest()
|
||||
return f.tensors.All()
|
||||
}
|
||||
|
||||
func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) {
|
||||
t := f.TensorInfo(name)
|
||||
if t.NumBytes() == 0 {
|
||||
return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name)
|
||||
}
|
||||
|
||||
// fast forward through tensor info if we haven't already
|
||||
_ = f.tensors.rest()
|
||||
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil
|
||||
}
|
||||
249
fs/gguf/gguf_test.go
Normal file
249
fs/gguf/gguf_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package gguf_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
)
|
||||
|
||||
func createBinFile(tb testing.TB) string {
|
||||
tb.Helper()
|
||||
f, err := os.CreateTemp(tb.TempDir(), "")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
kv := ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(8),
|
||||
"llama.embedding_length": uint32(3),
|
||||
"llama.attention.head_count": uint32(2),
|
||||
"llama.attention.head_count_kv": uint32(2),
|
||||
"llama.attention.key_length": uint32(3),
|
||||
"llama.rope.dimension_count": uint32(4),
|
||||
"llama.rope.freq_base": float32(10000.0),
|
||||
"llama.rope.freq_scale": float32(1.0),
|
||||
"llama.attention.layer_norm_rms_epsilon": float32(1e-6),
|
||||
"tokenizer.ggml.eos_token_id": uint32(0),
|
||||
"tokenizer.ggml.eos_token_ids": []int32{1, 2, 3},
|
||||
"tokenizer.ggml.tokens": []string{"hello", "world"},
|
||||
"tokenizer.ggml.scores": []float32{0, 1},
|
||||
}
|
||||
|
||||
tensors := []*ggml.Tensor{
|
||||
{
|
||||
Name: "token_embd.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{2, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*2*3)),
|
||||
},
|
||||
{
|
||||
Name: "output.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 2},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*2)),
|
||||
},
|
||||
}
|
||||
|
||||
for i := range 8 {
|
||||
tensors = append(tensors, &ggml.Tensor{
|
||||
Name: "blk." + strconv.Itoa(i) + ".attn_q.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||
}, &ggml.Tensor{
|
||||
Name: "blk." + strconv.Itoa(i) + ".attn_k.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||
}, &ggml.Tensor{
|
||||
Name: "blk." + strconv.Itoa(i) + ".attn_v.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||
}, &ggml.Tensor{
|
||||
Name: "blk." + strconv.Itoa(i) + ".attn_output.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||
})
|
||||
}
|
||||
|
||||
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return f.Name()
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
f, err := gguf.Open(createBinFile(t))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if got := f.KeyValue("does.not.exist").Valid(); got {
|
||||
t.Errorf(`KeyValue("does.not.exist").Exists() = %v, want false`, got)
|
||||
}
|
||||
|
||||
if got := f.KeyValue("general.architecture").String(); got != "llama" {
|
||||
t.Errorf(`KeyValue("general.architecture").String() = %q, want %q`, got, "llama")
|
||||
}
|
||||
|
||||
if got := f.TensorInfo("token_embd.weight"); got.Name != "token_embd.weight" {
|
||||
t.Errorf(`TensorInfo("token_embd.weight").Name = %q, want %q`, got.Name, "token_embd.weight")
|
||||
} else if diff := cmp.Diff(got.Shape, []uint64{2, 3}); diff != "" {
|
||||
t.Errorf(`TensorInfo("token_embd.weight").Shape mismatch (-got +want):\n%s`, diff)
|
||||
} else if got.Type != gguf.TensorTypeF32 {
|
||||
t.Errorf(`TensorInfo("token_embd.weight").Type = %d, want %d`, got.Type, gguf.TensorTypeF32)
|
||||
}
|
||||
|
||||
if got := f.KeyValue("block_count").Uint(); got != 8 {
|
||||
t.Errorf(`KeyValue("block_count").Uint() = %d, want %d`, got, 8)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.tokens").Strings(), []string{"hello", "world"}); diff != "" {
|
||||
t.Errorf("KeyValue(\"tokenizer.ggml.tokens\").Strings() mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.scores").Floats(), []float64{0, 1}); diff != "" {
|
||||
t.Errorf("KeyValue(\"tokenizer.ggml.scores\").Ints() mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
var kvs []string
|
||||
for _, kv := range f.KeyValues() {
|
||||
if !kv.Valid() {
|
||||
t.Error("found invalid key-value pair:", kv)
|
||||
}
|
||||
|
||||
kvs = append(kvs, kv.Key)
|
||||
}
|
||||
|
||||
if len(kvs) != f.NumKeyValues() {
|
||||
t.Errorf("iterated key count = %d, want %d", len(kvs), f.NumKeyValues())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kvs, []string{
|
||||
"general.architecture",
|
||||
"llama.block_count",
|
||||
"llama.embedding_length",
|
||||
"llama.attention.head_count",
|
||||
"llama.attention.head_count_kv",
|
||||
"llama.attention.key_length",
|
||||
"llama.rope.dimension_count",
|
||||
"llama.rope.freq_base",
|
||||
"llama.rope.freq_scale",
|
||||
"llama.attention.layer_norm_rms_epsilon",
|
||||
"tokenizer.ggml.eos_token_id",
|
||||
"tokenizer.ggml.eos_token_ids",
|
||||
"tokenizer.ggml.tokens",
|
||||
"tokenizer.ggml.scores",
|
||||
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
|
||||
t.Errorf("KeyValues() mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
var tis []string
|
||||
for _, ti := range f.TensorInfos() {
|
||||
if !ti.Valid() {
|
||||
t.Error("found invalid tensor info:", ti)
|
||||
}
|
||||
|
||||
tis = append(tis, ti.Name)
|
||||
}
|
||||
|
||||
if len(tis) != f.NumTensors() {
|
||||
t.Errorf("iterated tensor count = %d, want %d", len(tis), f.NumTensors())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tis, []string{
|
||||
"token_embd.weight",
|
||||
"output.weight",
|
||||
"blk.0.attn_q.weight",
|
||||
"blk.0.attn_k.weight",
|
||||
"blk.0.attn_v.weight",
|
||||
"blk.0.attn_output.weight",
|
||||
"blk.1.attn_q.weight",
|
||||
"blk.1.attn_k.weight",
|
||||
"blk.1.attn_v.weight",
|
||||
"blk.1.attn_output.weight",
|
||||
"blk.2.attn_q.weight",
|
||||
"blk.2.attn_k.weight",
|
||||
"blk.2.attn_v.weight",
|
||||
"blk.2.attn_output.weight",
|
||||
"blk.3.attn_q.weight",
|
||||
"blk.3.attn_k.weight",
|
||||
"blk.3.attn_v.weight",
|
||||
"blk.3.attn_output.weight",
|
||||
"blk.4.attn_q.weight",
|
||||
"blk.4.attn_k.weight",
|
||||
"blk.4.attn_v.weight",
|
||||
"blk.4.attn_output.weight",
|
||||
"blk.5.attn_q.weight",
|
||||
"blk.5.attn_k.weight",
|
||||
"blk.5.attn_v.weight",
|
||||
"blk.5.attn_output.weight",
|
||||
"blk.6.attn_q.weight",
|
||||
"blk.6.attn_k.weight",
|
||||
"blk.6.attn_v.weight",
|
||||
"blk.6.attn_output.weight",
|
||||
"blk.7.attn_q.weight",
|
||||
"blk.7.attn_k.weight",
|
||||
"blk.7.attn_v.weight",
|
||||
"blk.7.attn_output.weight",
|
||||
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
|
||||
t.Errorf("TensorInfos() mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
ti, r, err := f.TensorReader("output.weight")
|
||||
if err != nil {
|
||||
t.Fatalf(`TensorReader("output.weight") error: %v`, err)
|
||||
}
|
||||
|
||||
if ti.Name != "output.weight" {
|
||||
t.Errorf(`TensorReader("output.weight").Name = %q, want %q`, ti.Name, "output.weight")
|
||||
} else if diff := cmp.Diff(ti.Shape, []uint64{3, 2}); diff != "" {
|
||||
t.Errorf(`TensorReader("output.weight").Shape mismatch (-got +want):\n%s`, diff)
|
||||
} else if ti.Type != gguf.TensorTypeF32 {
|
||||
t.Errorf(`TensorReader("output.weight").Type = %d, want %d`, ti.Type, gguf.TensorTypeF32)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := b.ReadFrom(r); err != nil {
|
||||
t.Fatalf(`ReadFrom TensorReader("output.weight") error: %v`, err)
|
||||
}
|
||||
|
||||
if b.Len() != int(ti.NumBytes()) {
|
||||
t.Errorf(`ReadFrom TensorReader("output.weight") length = %d, want %d`, b.Len(), ti.NumBytes())
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRead(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
p := createBinFile(b)
|
||||
for b.Loop() {
|
||||
f, err := gguf.Open(p)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
if got := f.KeyValue("general.architecture").String(); got != "llama" {
|
||||
b.Errorf("got = %q, want %q", got, "llama")
|
||||
}
|
||||
|
||||
// Iterate through some tensors
|
||||
for range f.TensorInfos() {
|
||||
}
|
||||
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
90
fs/gguf/keyvalue.go
Normal file
90
fs/gguf/keyvalue.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"slices"
|
||||
)
|
||||
|
||||
type KeyValue struct {
|
||||
Key string
|
||||
Value
|
||||
}
|
||||
|
||||
func (kv KeyValue) Valid() bool {
|
||||
return kv.Key != "" && kv.Value.value != nil
|
||||
}
|
||||
|
||||
type Value struct {
|
||||
value any
|
||||
}
|
||||
|
||||
func value[T any](v Value, kinds ...reflect.Kind) (t T) {
|
||||
vv := reflect.ValueOf(v.value)
|
||||
if slices.Contains(kinds, vv.Kind()) {
|
||||
t = vv.Convert(reflect.TypeOf(t)).Interface().(T)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) {
|
||||
switch vv := reflect.ValueOf(v.value); vv.Kind() {
|
||||
case reflect.Slice:
|
||||
if slices.Contains(kinds, vv.Type().Elem().Kind()) {
|
||||
ts = make([]T, vv.Len())
|
||||
for i := range vv.Len() {
|
||||
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Int returns Value as a signed integer. If it is not a signed integer, it returns 0.
|
||||
func (v Value) Int() int64 {
|
||||
return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
|
||||
}
|
||||
|
||||
// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil.
|
||||
func (v Value) Ints() (i64s []int64) {
|
||||
return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
|
||||
}
|
||||
|
||||
// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0.
|
||||
func (v Value) Uint() uint64 {
|
||||
return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
|
||||
}
|
||||
|
||||
// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil.
|
||||
func (v Value) Uints() (u64s []uint64) {
|
||||
return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
|
||||
}
|
||||
|
||||
// Float returns Value as a float. If it is not a float, it returns 0.
|
||||
func (v Value) Float() float64 {
|
||||
return value[float64](v, reflect.Float32, reflect.Float64)
|
||||
}
|
||||
|
||||
// Floats returns Value as a float slice. If it is not a float slice, it returns nil.
|
||||
func (v Value) Floats() (f64s []float64) {
|
||||
return values[float64](v, reflect.Float32, reflect.Float64)
|
||||
}
|
||||
|
||||
// Bool returns Value as a boolean. If it is not a boolean, it returns false.
|
||||
func (v Value) Bool() bool {
|
||||
return value[bool](v, reflect.Bool)
|
||||
}
|
||||
|
||||
// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil.
|
||||
func (v Value) Bools() (bools []bool) {
|
||||
return values[bool](v, reflect.Bool)
|
||||
}
|
||||
|
||||
// String returns Value as a string. If it is not a string, it returns an empty string.
|
||||
func (v Value) String() string {
|
||||
return value[string](v, reflect.String)
|
||||
}
|
||||
|
||||
// Strings returns Value as a string slice. If it is not a string slice, it returns nil.
|
||||
func (v Value) Strings() (strings []string) {
|
||||
return values[string](v, reflect.String)
|
||||
}
|
||||
208
fs/gguf/keyvalue_test.go
Normal file
208
fs/gguf/keyvalue_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func split(name string, values map[string][]any) (matched []any, unmatched []any) {
|
||||
for key, value := range values {
|
||||
if key == name {
|
||||
matched = value
|
||||
} else {
|
||||
unmatched = append(unmatched, value...)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestValue(t *testing.T) {
|
||||
values := map[string][]any{
|
||||
"int64": {int(42), int8(42), int16(42), int32(42), int64(42)},
|
||||
"uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)},
|
||||
"float64": {float32(42), float64(42)},
|
||||
"string": {"42", "hello"},
|
||||
"bool": {true, false},
|
||||
}
|
||||
|
||||
t.Run("int64", func(t *testing.T) {
|
||||
matched, unmatched := split("int64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64 := kv.Int(); i64 != 42 {
|
||||
t.Errorf("expected 42, got %d", i64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64 := kv.Int(); i64 != 0 {
|
||||
t.Errorf("expected 42, got %d", i64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uint64", func(t *testing.T) {
|
||||
matched, unmatched := split("uint64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64 := kv.Uint(); u64 != 42 {
|
||||
t.Errorf("expected 42, got %d", u64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64 := kv.Uint(); u64 != 0 {
|
||||
t.Errorf("expected 42, got %d", u64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("float64", func(t *testing.T) {
|
||||
matched, unmatched := split("float64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64 := kv.Float(); f64 != 42 {
|
||||
t.Errorf("expected 42, got %f", f64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64 := kv.Float(); f64 != 0 {
|
||||
t.Errorf("expected 42, got %f", f64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("string", func(t *testing.T) {
|
||||
matched, unmatched := split("string", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.String(); s != v {
|
||||
t.Errorf("expected 42, got %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.String(); s != "" {
|
||||
t.Errorf("expected 42, got %s", s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bool", func(t *testing.T) {
|
||||
matched, unmatched := split("bool", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bool(); b != v {
|
||||
t.Errorf("expected true, got %v", b)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bool(); b != false {
|
||||
t.Errorf("expected false, got %v", b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValues(t *testing.T) {
|
||||
values := map[string][]any{
|
||||
"int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}},
|
||||
"uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}},
|
||||
"float64s": {[]float32{42}, []float64{42}},
|
||||
"strings": {[]string{"42"}, []string{"hello"}},
|
||||
"bools": {[]bool{true}, []bool{false}},
|
||||
}
|
||||
|
||||
t.Run("int64s", func(t *testing.T) {
|
||||
matched, unmatched := split("int64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64s := kv.Ints(); i64s != nil {
|
||||
t.Errorf("expected nil, got %v", i64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uint64s", func(t *testing.T) {
|
||||
matched, unmatched := split("uint64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64s := kv.Uints(); u64s != nil {
|
||||
t.Errorf("expected nil, got %v", u64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("float64s", func(t *testing.T) {
|
||||
matched, unmatched := split("float64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64s := kv.Floats(); f64s != nil {
|
||||
t.Errorf("expected nil, got %v", f64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("strings", func(t *testing.T) {
|
||||
matched, unmatched := split("strings", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Strings(), v); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.Strings(); s != nil {
|
||||
t.Errorf("expected nil, got %v", s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bools", func(t *testing.T) {
|
||||
matched, unmatched := split("bools", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Bools(), v); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bools(); b != nil {
|
||||
t.Errorf("expected nil, got %v", b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
89
fs/gguf/lazy.go
Normal file
89
fs/gguf/lazy.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"iter"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type lazy[T any] struct {
|
||||
count uint64
|
||||
next func() (T, bool)
|
||||
stop func()
|
||||
values []T
|
||||
|
||||
// successFunc is called when all values have been successfully read.
|
||||
successFunc func() error
|
||||
}
|
||||
|
||||
func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) {
|
||||
it := lazy[T]{}
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
it.values = make([]T, 0)
|
||||
it.next, it.stop = iter.Pull(func(yield func(T) bool) {
|
||||
for i := range it.count {
|
||||
t, err := fn()
|
||||
if err != nil {
|
||||
slog.Error("error reading tensor", "index", i, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
it.values = append(it.values, t)
|
||||
if !yield(t) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if it.successFunc != nil {
|
||||
it.successFunc()
|
||||
}
|
||||
})
|
||||
|
||||
return &it, nil
|
||||
}
|
||||
|
||||
func (g *lazy[T]) Values() iter.Seq[T] {
|
||||
return func(yield func(T) bool) {
|
||||
for _, v := range g.All() {
|
||||
if !yield(v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *lazy[T]) All() iter.Seq2[int, T] {
|
||||
return func(yield func(int, T) bool) {
|
||||
for i := range int(g.count) {
|
||||
if i < len(g.values) {
|
||||
if !yield(i, g.values[i]) {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
t, ok := g.next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
if !yield(i, t) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *lazy[T]) rest() (collected bool) {
|
||||
for {
|
||||
_, ok := g.next()
|
||||
collected = collected || ok
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return collected
|
||||
}
|
||||
23
fs/gguf/reader.go
Normal file
23
fs/gguf/reader.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
)
|
||||
|
||||
type bufferedReader struct {
|
||||
offset int64
|
||||
*bufio.Reader
|
||||
}
|
||||
|
||||
func newBufferedReader(rs io.ReadSeeker, size int) *bufferedReader {
|
||||
return &bufferedReader{
|
||||
Reader: bufio.NewReaderSize(rs, size),
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *bufferedReader) Read(p []byte) (n int, err error) {
|
||||
n, err = rs.Reader.Read(p)
|
||||
rs.offset += int64(n)
|
||||
return n, err
|
||||
}
|
||||
288
fs/gguf/tensor.go
Normal file
288
fs/gguf/tensor.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TensorInfo struct {
|
||||
Name string
|
||||
Offset uint64
|
||||
Shape []uint64
|
||||
Type TensorType
|
||||
}
|
||||
|
||||
func (ti TensorInfo) Valid() bool {
|
||||
return ti.Name != "" && ti.NumBytes() > 0
|
||||
}
|
||||
|
||||
func (ti TensorInfo) NumValues() int64 {
|
||||
var numItems int64 = 1
|
||||
for _, dim := range ti.Shape {
|
||||
numItems *= int64(dim)
|
||||
}
|
||||
return numItems
|
||||
}
|
||||
|
||||
// NumBytes returns the number of bytes in the tensor.
|
||||
func (ti TensorInfo) NumBytes() int64 {
|
||||
return int64(float64(ti.NumValues()) * ti.Type.NumBytes())
|
||||
}
|
||||
|
||||
func (ti TensorInfo) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("name", ti.Name),
|
||||
slog.Int64("offset", int64(ti.Offset)),
|
||||
slog.Any("shape", ti.Shape),
|
||||
slog.Int64("num_values", ti.NumValues()),
|
||||
slog.Int64("num_bytes", ti.NumBytes()),
|
||||
slog.Any("type", ti.Type),
|
||||
)
|
||||
}
|
||||
|
||||
type TensorType uint32
|
||||
|
||||
const (
|
||||
TensorTypeF32 TensorType = iota
|
||||
TensorTypeF16
|
||||
TensorTypeQ4_0
|
||||
TensorTypeQ4_1
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeQ4_2
|
||||
tensorTypeQ4_3
|
||||
|
||||
TensorTypeQ5_0
|
||||
TensorTypeQ5_1
|
||||
TensorTypeQ8_0
|
||||
TensorTypeQ8_1
|
||||
TensorTypeQ2_K
|
||||
TensorTypeQ3_K
|
||||
TensorTypeQ4_K
|
||||
TensorTypeQ5_K
|
||||
TensorTypeQ6_K
|
||||
TensorTypeQ8_K
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeIQ2_XXS
|
||||
tensorTypeIQ2_XS
|
||||
tensorTypeIQ3_XXS
|
||||
tensorTypeIQ1_S
|
||||
tensorTypeIQ4_NL
|
||||
tensorTypeIQ3_S
|
||||
tensorTypeIQ2_S
|
||||
tensorTypeIQ4_XS
|
||||
|
||||
TensorTypeI8
|
||||
TensorTypeI16
|
||||
TensorTypeI32
|
||||
TensorTypeI64
|
||||
TensorTypeF64
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeIQ1_M
|
||||
|
||||
TensorTypeBF16
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeQ4_0_4_4
|
||||
tensorTypeQ4_0_4_8
|
||||
tensorTypeQ4_0_8_8
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeTQ1_0
|
||||
tensorTypeTQ2_0
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeIQ4_NL_4_4
|
||||
tensorTypeIQ4_NL_4_8
|
||||
tensorTypeIQ4_NL_8_8
|
||||
)
|
||||
|
||||
func (tt TensorType) NumBytes() float64 {
|
||||
return float64(tt.typeSize()) / float64(tt.blockSize())
|
||||
}
|
||||
|
||||
func (tt TensorType) typeSize() int64 {
|
||||
switch tt {
|
||||
case TensorTypeF32:
|
||||
return 4
|
||||
case TensorTypeF16:
|
||||
return 2
|
||||
case TensorTypeQ4_0:
|
||||
return 2 + tt.blockSize()/2
|
||||
case TensorTypeQ4_1:
|
||||
return 2 + 2 + tt.blockSize()/2
|
||||
case TensorTypeQ5_0:
|
||||
return 2 + 4 + tt.blockSize()/2
|
||||
case TensorTypeQ5_1:
|
||||
return 2 + 2 + 4 + tt.blockSize()/2
|
||||
case TensorTypeQ8_0:
|
||||
return 2 + tt.blockSize()
|
||||
case TensorTypeQ8_1:
|
||||
return 2 + 2 + tt.blockSize()
|
||||
case TensorTypeQ2_K:
|
||||
return tt.blockSize()/16 + tt.blockSize()/4 + 2 + 2
|
||||
case TensorTypeQ3_K:
|
||||
return tt.blockSize()/8 + tt.blockSize()/4 + 12 + 2
|
||||
case TensorTypeQ4_K:
|
||||
return 2 + 2 + 12 + tt.blockSize()/2
|
||||
case TensorTypeQ5_K:
|
||||
return 2 + 2 + 12 + tt.blockSize()/8 + tt.blockSize()/2
|
||||
case TensorTypeQ6_K:
|
||||
return tt.blockSize()/2 + tt.blockSize()/4 + tt.blockSize()/16 + 2
|
||||
case TensorTypeQ8_K:
|
||||
return 4 + tt.blockSize() + 2*tt.blockSize()/16
|
||||
case tensorTypeIQ2_XXS:
|
||||
return 2 + 2*tt.blockSize()/8
|
||||
case tensorTypeIQ2_XS:
|
||||
return 2 + 2*tt.blockSize()/8 + tt.blockSize()/32
|
||||
case tensorTypeIQ3_XXS:
|
||||
return 2 + tt.blockSize()/4 + tt.blockSize()/8
|
||||
case tensorTypeIQ1_S:
|
||||
return 2 + tt.blockSize()/8 + tt.blockSize()/16
|
||||
case tensorTypeIQ4_NL:
|
||||
return 2 + tt.blockSize()/2
|
||||
case tensorTypeIQ3_S:
|
||||
return 2 + tt.blockSize()/4 + tt.blockSize()/8 + tt.blockSize()/32 + 4
|
||||
case tensorTypeIQ2_S:
|
||||
return 2 + tt.blockSize()/4 + tt.blockSize()/16
|
||||
case tensorTypeIQ4_XS:
|
||||
return 2 + 2 + tt.blockSize()/2 + tt.blockSize()/64
|
||||
case TensorTypeI8:
|
||||
return 1
|
||||
case TensorTypeI16:
|
||||
return 2
|
||||
case TensorTypeI32:
|
||||
return 4
|
||||
case TensorTypeI64:
|
||||
return 8
|
||||
case TensorTypeF64:
|
||||
return 8
|
||||
case tensorTypeIQ1_M:
|
||||
return tt.blockSize()/8 + tt.blockSize()/16 + tt.blockSize()/32
|
||||
case TensorTypeBF16:
|
||||
return 2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (tt TensorType) blockSize() int64 {
|
||||
switch tt {
|
||||
case TensorTypeF32,
|
||||
TensorTypeF16,
|
||||
TensorTypeI8,
|
||||
TensorTypeI16,
|
||||
TensorTypeI32,
|
||||
TensorTypeI64,
|
||||
TensorTypeF64,
|
||||
TensorTypeBF16:
|
||||
return 1
|
||||
case TensorTypeQ4_0,
|
||||
TensorTypeQ4_1,
|
||||
TensorTypeQ5_0,
|
||||
TensorTypeQ5_1,
|
||||
TensorTypeQ8_0,
|
||||
TensorTypeQ8_1,
|
||||
tensorTypeIQ4_NL:
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
}
|
||||
}
|
||||
|
||||
func (tt TensorType) String() string {
|
||||
switch tt {
|
||||
case TensorTypeF32:
|
||||
return "f32"
|
||||
case TensorTypeF16:
|
||||
return "f16"
|
||||
case TensorTypeQ4_0:
|
||||
return "q4_0"
|
||||
case TensorTypeQ4_1:
|
||||
return "q4_1"
|
||||
case tensorTypeQ4_2:
|
||||
return "q4_2"
|
||||
case tensorTypeQ4_3:
|
||||
return "q4_3"
|
||||
case TensorTypeQ5_0:
|
||||
return "q5_0"
|
||||
case TensorTypeQ5_1:
|
||||
return "q5_1"
|
||||
case TensorTypeQ8_0:
|
||||
return "q8_0"
|
||||
case TensorTypeQ8_1:
|
||||
return "q8_1"
|
||||
case TensorTypeQ2_K:
|
||||
return "q2_k"
|
||||
case TensorTypeQ3_K:
|
||||
return "q3_k"
|
||||
case TensorTypeQ4_K:
|
||||
return "q4_k"
|
||||
case TensorTypeQ5_K:
|
||||
return "q5_k"
|
||||
case TensorTypeQ6_K:
|
||||
return "q6_k"
|
||||
case TensorTypeQ8_K:
|
||||
return "q8_k"
|
||||
case tensorTypeIQ2_XXS:
|
||||
return "iq2_xxs"
|
||||
case tensorTypeIQ2_XS:
|
||||
return "iq2_xs"
|
||||
case tensorTypeIQ3_XXS:
|
||||
return "iq3_xxs"
|
||||
case tensorTypeIQ1_S:
|
||||
return "iq1_s"
|
||||
case tensorTypeIQ4_NL:
|
||||
return "iq4_nl"
|
||||
case tensorTypeIQ3_S:
|
||||
return "iq3_s"
|
||||
case tensorTypeIQ2_S:
|
||||
return "iq2_s"
|
||||
case tensorTypeIQ4_XS:
|
||||
return "iq4_xs"
|
||||
case TensorTypeI8:
|
||||
return "i8"
|
||||
case TensorTypeI16:
|
||||
return "i16"
|
||||
case TensorTypeI32:
|
||||
return "i32"
|
||||
case TensorTypeI64:
|
||||
return "i64"
|
||||
case TensorTypeF64:
|
||||
return "f64"
|
||||
case tensorTypeIQ1_M:
|
||||
return "iq1_m"
|
||||
case TensorTypeBF16:
|
||||
return "bf16"
|
||||
case tensorTypeQ4_0_4_4:
|
||||
return "q4_0_4_4"
|
||||
case tensorTypeQ4_0_4_8:
|
||||
return "q4_0_4_8"
|
||||
case tensorTypeQ4_0_8_8:
|
||||
return "q4_0_8_8"
|
||||
case tensorTypeTQ1_0:
|
||||
return "tq1_0"
|
||||
case tensorTypeTQ2_0:
|
||||
return "tq2_0"
|
||||
case tensorTypeIQ4_NL_4_4:
|
||||
return "iq4_nl_4_4"
|
||||
case tensorTypeIQ4_NL_4_8:
|
||||
return "iq4_nl_4_8"
|
||||
case tensorTypeIQ4_NL_8_8:
|
||||
return "iq4_nl_8_8"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (tt TensorType) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Uint64("value", uint64(tt)),
|
||||
slog.String("name", strings.ToUpper(tt.String())),
|
||||
slog.Int64("size", tt.typeSize()),
|
||||
slog.Int64("block_size", tt.blockSize()),
|
||||
slog.Float64("num_bytes", tt.NumBytes()),
|
||||
)
|
||||
}
|
||||
6
go.mod
6
go.mod
@@ -19,12 +19,13 @@ require (
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/google/go-cmp v0.7.0
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/tools v0.30.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -44,7 +45,6 @@ require (
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
gonum.org/v1/gonum v0.15.0 // indirect
|
||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||
)
|
||||
@@ -71,7 +71,7 @@ require (
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sys v0.31.0
|
||||
golang.org/x/term v0.30.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -112,8 +112,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
|
||||
463
harmony/harmonyparser.go
Normal file
463
harmony/harmonyparser.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package harmony
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type harmonyParserState int
|
||||
|
||||
const (
|
||||
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
||||
harmonyParserState_ParsingHeader
|
||||
harmonyParserState_ParsingContent
|
||||
)
|
||||
|
||||
func (s harmonyParserState) String() string {
|
||||
switch s {
|
||||
// we're looking for the message start tag
|
||||
case harmonyParserState_LookingForMessageStart:
|
||||
return "LookingForMessageStart"
|
||||
case harmonyParserState_ParsingHeader:
|
||||
return "ParsingHeader"
|
||||
case harmonyParserState_ParsingContent:
|
||||
return "ParsingContent"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type HarmonyParser struct {
|
||||
state harmonyParserState
|
||||
MessageStartTag string
|
||||
MessageEndTag string
|
||||
HeaderEndTag string
|
||||
acc strings.Builder
|
||||
lifetimeAcc strings.Builder
|
||||
}
|
||||
|
||||
type HarmonyEvent interface {
|
||||
isHarmonyEvent()
|
||||
}
|
||||
|
||||
type HarmonyEventMessageStart struct{}
|
||||
|
||||
func (HarmonyEventMessageStart) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyEventHeaderComplete struct {
|
||||
Header HarmonyHeader
|
||||
}
|
||||
|
||||
func (HarmonyEventHeaderComplete) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyEventContentEmitted struct {
|
||||
Content string
|
||||
}
|
||||
|
||||
func (HarmonyEventContentEmitted) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyEventMessageEnd struct{}
|
||||
|
||||
func (HarmonyEventMessageEnd) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyHeader struct {
|
||||
Role string
|
||||
Channel string
|
||||
Recipient string
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddImplicitStart() {
|
||||
s.acc.WriteString("<|start|>assistant")
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
|
||||
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||
// handle prefilling conditions
|
||||
if lastMessage.Content != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
||||
return
|
||||
} else if lastMessage.Thinking != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
||||
return
|
||||
}
|
||||
}
|
||||
s.AddImplicitStart()
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
|
||||
s.lifetimeAcc.WriteString(content)
|
||||
s.acc.WriteString(content)
|
||||
|
||||
var events []HarmonyEvent
|
||||
|
||||
keepLooping := true
|
||||
// we loop because we might pass through multiple parsing states in a single
|
||||
// call to addContent, and we want to make sure callers don't have to wait for
|
||||
// data that's already unambiguous
|
||||
for keepLooping {
|
||||
var newEvents []HarmonyEvent
|
||||
newEvents, keepLooping = eat(s)
|
||||
events = append(events, newEvents...)
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// the additional bool return is true iff we should continue eating
|
||||
func eat(s *HarmonyParser) ([]HarmonyEvent, bool) {
|
||||
switch s.state {
|
||||
case harmonyParserState_LookingForMessageStart:
|
||||
// does the acc contain the message start tag?
|
||||
if strings.Contains(s.acc.String(), s.MessageStartTag) {
|
||||
// split the acc into the message start tag and the rest
|
||||
split := strings.SplitN(s.acc.String(), s.MessageStartTag, 2)
|
||||
before := split[0]
|
||||
if before != "" {
|
||||
slog.Warn("harmony parser: found message start tag in the middle of the content", "content", s.acc.String())
|
||||
}
|
||||
after := split[1]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
s.state = harmonyParserState_ParsingHeader
|
||||
return []HarmonyEvent{HarmonyEventMessageStart{}}, true
|
||||
}
|
||||
|
||||
// no match, so we keep accumulating
|
||||
return nil, false
|
||||
case harmonyParserState_ParsingHeader:
|
||||
if strings.Contains(s.acc.String(), s.HeaderEndTag) {
|
||||
split := strings.SplitN(s.acc.String(), s.HeaderEndTag, 2)
|
||||
header := split[0]
|
||||
after := split[1]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
s.state = harmonyParserState_ParsingContent
|
||||
return []HarmonyEvent{HarmonyEventHeaderComplete{Header: s.parseHeader(header)}}, true
|
||||
}
|
||||
return nil, false
|
||||
case harmonyParserState_ParsingContent:
|
||||
if strings.Contains(s.acc.String(), s.MessageEndTag) {
|
||||
// if we already have the message end tag, we can emit the content up to it
|
||||
split := strings.SplitN(s.acc.String(), s.MessageEndTag, 2)
|
||||
content := split[0]
|
||||
after := split[1]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
s.state = harmonyParserState_LookingForMessageStart
|
||||
events := []HarmonyEvent{}
|
||||
if content != "" {
|
||||
events = append(events, HarmonyEventContentEmitted{Content: content})
|
||||
}
|
||||
events = append(events, HarmonyEventMessageEnd{})
|
||||
return events, true
|
||||
} else if overlapLen := overlap(s.acc.String(), s.MessageEndTag); overlapLen > 0 {
|
||||
// if our suffix contains the start of the message end tag, we can emit
|
||||
// the content up to the start of the message end tag
|
||||
content := s.acc.String()[:len(s.acc.String())-overlapLen]
|
||||
remaining := s.acc.String()[len(s.acc.String())-overlapLen:]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(remaining)
|
||||
// emit the content we know isn't part of the message end tag, and keep
|
||||
// accumulating to disambiguate the rest
|
||||
if content == "" {
|
||||
return nil, false
|
||||
}
|
||||
return []HarmonyEvent{HarmonyEventContentEmitted{Content: content}}, false
|
||||
} else {
|
||||
// no end tag, so it's still normal content that we can immediately emit
|
||||
content := s.acc.String()
|
||||
if content == "" {
|
||||
return nil, false
|
||||
}
|
||||
s.acc.Reset()
|
||||
return []HarmonyEvent{HarmonyEventContentEmitted{Content: content}}, false
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) parseHeader(raw string) HarmonyHeader {
|
||||
harmonyHeader := HarmonyHeader{}
|
||||
|
||||
// if `<|constrain|>` is present, ensure it has a space before it so it gets
|
||||
// parsed as a separate token, even if the model didn't include the space
|
||||
if strings.Contains(raw, "<|constrain|>") {
|
||||
raw = strings.Replace(raw, "<|constrain|>", " <|constrain|>", 1)
|
||||
raw = strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
// look for the optional channel tag, which is `<|channel|>` followed by the
|
||||
// channel name, all without any whitespace
|
||||
channelIndex := strings.Index(raw, "<|channel|>")
|
||||
if channelIndex != -1 {
|
||||
before := raw[:channelIndex]
|
||||
after := raw[channelIndex+len("<|channel|>"):]
|
||||
// the channel name is `after` all the way up to the first (if any) whitespace character
|
||||
idx := strings.IndexFunc(after, func(r rune) bool {
|
||||
return unicode.IsSpace(r)
|
||||
})
|
||||
if idx == -1 {
|
||||
idx = len(after)
|
||||
}
|
||||
harmonyHeader.Channel = after[:idx]
|
||||
after = after[idx:]
|
||||
// now we remove the channel tag from the raw string to further process
|
||||
raw = before + after
|
||||
raw = strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
// split the header into whitespace-separated tokens
|
||||
tokens := strings.Fields(raw)
|
||||
|
||||
// the first token is treated as the role
|
||||
if len(tokens) == 0 {
|
||||
slog.Error("harmony parser: missing role in header", "header", raw)
|
||||
return harmonyHeader
|
||||
}
|
||||
role := tokens[0]
|
||||
tokens = tokens[1:]
|
||||
// special case: if role starts with to= then it's a tool call
|
||||
if strings.HasPrefix(role, "to=") {
|
||||
harmonyHeader.Recipient = role[3:]
|
||||
harmonyHeader.Role = "tool"
|
||||
} else {
|
||||
harmonyHeader.Role = role
|
||||
}
|
||||
|
||||
// the recipient (if any) can be specified before or after the channel tag, so
|
||||
// we check it at the end once we've already parsed the channel and role
|
||||
if harmonyHeader.Recipient == "" && len(tokens) > 0 && strings.HasPrefix(tokens[0], "to=") {
|
||||
harmonyHeader.Recipient = tokens[0][3:]
|
||||
}
|
||||
|
||||
return harmonyHeader
|
||||
}
|
||||
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// harmonyMessageState represents the current state of message processing
|
||||
type harmonyMessageState int
|
||||
|
||||
const (
|
||||
harmonyMessageState_Normal harmonyMessageState = iota
|
||||
harmonyMessageState_Thinking
|
||||
harmonyMessageState_ToolCalling
|
||||
)
|
||||
|
||||
// HarmonyMessageHandler processes harmony events and accumulates content appropriately.
|
||||
// This is a higher level interface that maps harmony concepts into ollama concepts
|
||||
type HarmonyMessageHandler struct {
|
||||
state harmonyMessageState
|
||||
HarmonyParser *HarmonyParser
|
||||
FunctionNameMap *FunctionNameMap
|
||||
}
|
||||
|
||||
// NewHarmonyMessageHandler creates a new message handler
|
||||
func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
||||
return &HarmonyMessageHandler{
|
||||
state: harmonyMessageState_Normal,
|
||||
HarmonyParser: &HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
},
|
||||
FunctionNameMap: NewFunctionNameMap(),
|
||||
}
|
||||
}
|
||||
|
||||
// AddContent processes the content and returns the content, thinking, and tool content.
|
||||
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
|
||||
func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) {
|
||||
contentSb := strings.Builder{}
|
||||
thinkingSb := strings.Builder{}
|
||||
toolContentSb := strings.Builder{}
|
||||
|
||||
events := h.HarmonyParser.AddContent(content)
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case HarmonyEventHeaderComplete:
|
||||
logutil.Trace("harmony event header complete", "header", event.Header)
|
||||
switch event.Header.Channel {
|
||||
case "analysis":
|
||||
if event.Header.Recipient != "" {
|
||||
h.state = harmonyMessageState_ToolCalling
|
||||
// event.Header.Recipient is the tool name, something like
|
||||
// "browser.search" for a built-in, or "functions.calc" for a
|
||||
// custom one
|
||||
toolParser.SetToolName(event.Header.Recipient)
|
||||
} else {
|
||||
h.state = harmonyMessageState_Thinking
|
||||
}
|
||||
case "commentary":
|
||||
if event.Header.Recipient != "" {
|
||||
h.state = harmonyMessageState_ToolCalling
|
||||
toolParser.SetToolName(event.Header.Recipient)
|
||||
} else {
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
case "final":
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
case HarmonyEventContentEmitted:
|
||||
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
||||
if h.state == harmonyMessageState_Normal {
|
||||
contentSb.WriteString(event.Content)
|
||||
} else if h.state == harmonyMessageState_Thinking {
|
||||
thinkingSb.WriteString(event.Content)
|
||||
} else if h.state == harmonyMessageState_ToolCalling {
|
||||
toolContentSb.WriteString(event.Content)
|
||||
}
|
||||
case HarmonyEventMessageEnd:
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
}
|
||||
return contentSb.String(), thinkingSb.String(), toolContentSb.String()
|
||||
}
|
||||
|
||||
func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator {
|
||||
return &HarmonyToolCallAccumulator{
|
||||
state: harmonyToolCallState_Normal,
|
||||
currentToolName: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type harmonyToolCallState int
|
||||
|
||||
const (
|
||||
harmonyToolCallState_Normal harmonyToolCallState = iota
|
||||
harmonyToolCallState_ToolCalling
|
||||
)
|
||||
|
||||
type HarmonyToolCallAccumulator struct {
|
||||
state harmonyToolCallState
|
||||
acc strings.Builder
|
||||
currentToolName *string
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) SetToolName(toolName string) {
|
||||
a.currentToolName = &toolName
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) Add(content string) {
|
||||
a.acc.WriteString(content)
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) Drain() (*string, string) {
|
||||
str := a.acc.String()
|
||||
a.state = harmonyToolCallState_Normal
|
||||
a.acc.Reset()
|
||||
return a.currentToolName, str
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) Content() string {
|
||||
return a.acc.String()
|
||||
}
|
||||
|
||||
// FunctionNameMap maps a user-specified function name to a valid function
|
||||
// name for harmony (which look like TypeScript identifiers). This is needed to
|
||||
// transform user-specified function names, which might contain characters that
|
||||
// are not allowed in TypeScript identifiers
|
||||
type FunctionNameMap struct {
|
||||
userToHarmony map[string]string
|
||||
harmonyToUser map[string]string
|
||||
}
|
||||
|
||||
func NewFunctionNameMap() *FunctionNameMap {
|
||||
return &FunctionNameMap{
|
||||
userToHarmony: make(map[string]string),
|
||||
harmonyToUser: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
|
||||
harmonyFunctionName := m.deriveName(userFunctionName)
|
||||
m.userToHarmony[userFunctionName] = harmonyFunctionName
|
||||
m.harmonyToUser[harmonyFunctionName] = userFunctionName
|
||||
return harmonyFunctionName
|
||||
}
|
||||
|
||||
// OriginalFromConverted looks up the reverse-mapping of a previously-converted
|
||||
// user->harmony function name. To unmap reliably, the mapping must exist, as
|
||||
// the conversion process is not reversible without the appropriate state
|
||||
func (m *FunctionNameMap) OriginalFromConverted(harmonyFunctionName string) string {
|
||||
if userFunctionName, ok := m.harmonyToUser[harmonyFunctionName]; ok {
|
||||
return userFunctionName
|
||||
}
|
||||
slog.Warn("harmony parser: no reverse mapping found for function name", "harmonyFunctionName", harmonyFunctionName)
|
||||
// fallback to the original function name if we can't find a mapping
|
||||
return harmonyFunctionName
|
||||
}
|
||||
|
||||
// convertToValidChars converts a user-specified function name to a valid
|
||||
// TypeScript identifier.
|
||||
//
|
||||
// Limitations:
|
||||
//
|
||||
// - This doesn't restrict reserved TypeScript keywords.
|
||||
// - We don't perform a real ID_Start/ID_Continue check, and instead use the more
|
||||
// restrictive unicode.IsLetter/unicode.IsDigit check. Unclear what kind of
|
||||
// identifiers these models were trained on, so in the end we might want to
|
||||
// convert unicode-heavy identifiers to their closest ASCII equivalents.
|
||||
func (m *FunctionNameMap) convertToValidChars(userFunctionName string) string {
|
||||
mapper := func(r rune) rune {
|
||||
// first, replace certain characters with underscores
|
||||
if r == ' ' || r == '-' || r == '.' {
|
||||
return '_'
|
||||
}
|
||||
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
|
||||
return r
|
||||
}
|
||||
|
||||
// finally, remove any other characters
|
||||
return -1
|
||||
}
|
||||
candidate := strings.Map(mapper, userFunctionName)
|
||||
|
||||
// set a default name if we end up with nothing left
|
||||
if candidate == "" {
|
||||
return "unnamed"
|
||||
}
|
||||
|
||||
// if the candidate starts with a number, prepend an underscore to make it a
|
||||
// valid identifier
|
||||
if unicode.IsDigit(rune(candidate[0])) {
|
||||
candidate = "_" + candidate
|
||||
}
|
||||
|
||||
return candidate
|
||||
}
|
||||
|
||||
func (m *FunctionNameMap) deriveName(userFunctionName string) string {
|
||||
originalCandidate := m.convertToValidChars(userFunctionName)
|
||||
candidate := originalCandidate
|
||||
|
||||
// Check for dupes, and if so, add a number to the end.
|
||||
// We start at 2 because if we have dupes and the first is never renamed, it
|
||||
// makes sense for them to be named, say, `f`, `f_2`, `f_3`
|
||||
count := 2
|
||||
for {
|
||||
if _, exists := m.harmonyToUser[candidate]; !exists {
|
||||
break
|
||||
}
|
||||
candidate = fmt.Sprintf("%s_%d", originalCandidate, count)
|
||||
count++
|
||||
}
|
||||
|
||||
return candidate
|
||||
}
|
||||
537
harmony/harmonyparser_test.go
Normal file
537
harmony/harmonyparser_test.go
Normal file
@@ -0,0 +1,537 @@
|
||||
package harmony
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHeaderParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, wantRole, wantChannel, wantRecipient string
|
||||
}{
|
||||
{
|
||||
in: "assistant<|channel|>analysis",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "",
|
||||
},
|
||||
{
|
||||
in: "assistant<|channel|>analysis to=functions.get_weather",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
{
|
||||
in: "assistant to=functions.get_weather<|channel|>analysis",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// special case where the role is replaced by the recipient (matches reference code)
|
||||
{
|
||||
in: "to=functions.get_weather<|channel|>analysis",
|
||||
wantRole: "tool",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// extra token after the recipient is ignored
|
||||
{
|
||||
in: "assistant to=functions.get_weather abc<|channel|>analysis",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// with constrain tag, recipient after channel tag
|
||||
{
|
||||
in: "assistant<|channel|>commentary to=functions.get_weather <|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// with constrain tag, recipient before channel tag
|
||||
{
|
||||
in: "assistant to=functions.get_weather<|channel|>commentary <|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// constrain tag without space
|
||||
{
|
||||
in: "assistant<|channel|>commentary to=functions.get_weather<|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// constrain tag without space, different order
|
||||
{
|
||||
in: "assistant to=functions.get_weather<|channel|>commentary<|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
header := parser.parseHeader(tt.in)
|
||||
|
||||
if header.Role != tt.wantRole {
|
||||
t.Errorf("case %d: got role \"%s\", want \"%s\"", i, header.Role, tt.wantRole)
|
||||
}
|
||||
if header.Channel != tt.wantChannel {
|
||||
t.Errorf("case %d: got channel \"%s\", want \"%s\"", i, header.Channel, tt.wantChannel)
|
||||
}
|
||||
if header.Recipient != tt.wantRecipient {
|
||||
t.Errorf("case %d: got recipient \"%s\", want \"%s\"", i, header.Recipient, tt.wantRecipient)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyParserHeaderEvent(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, wantRole, wantChannel, wantRecipient string
|
||||
implicitStart bool
|
||||
}{
|
||||
{
|
||||
in: "<|start|>user<|message|>What is 2 + 2?<|end|>",
|
||||
wantRole: "user",
|
||||
wantChannel: "",
|
||||
wantRecipient: "",
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>analysis<|message|>What is 2 + 2?<|end|>",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "",
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{\"location\":\"San Francisco\"}<|call|><|start|>functions.get_weather to=assistant<|message|>{\"sunny\": true, \"temperature\": 20}<|end|>",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
{
|
||||
in: "<|channel|>analysis<|message|>User asks weather in SF. We need location. Use get_current_weather with location \"San Francisco, CA\".<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{\"location\":\"San Francisco, CA\"}<|call|>",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "",
|
||||
implicitStart: true,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
if tt.implicitStart {
|
||||
parser.AddImplicitStart()
|
||||
}
|
||||
gotEvents := parser.AddContent(tt.in)
|
||||
if len(gotEvents) == 0 {
|
||||
t.Errorf("case %d: got no events, want at least one", i)
|
||||
}
|
||||
|
||||
var firstHeaderEvent *HarmonyEventHeaderComplete
|
||||
// print events
|
||||
for _, event := range gotEvents {
|
||||
fmt.Printf("event: %+v\n", event)
|
||||
}
|
||||
for _, event := range gotEvents {
|
||||
if event, ok := event.(HarmonyEventHeaderComplete); ok {
|
||||
firstHeaderEvent = &event
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if firstHeaderEvent == nil {
|
||||
t.Errorf("case %d: got no header complete event, want one", i)
|
||||
continue
|
||||
}
|
||||
gotHeader := firstHeaderEvent.Header
|
||||
if gotHeader.Role != tt.wantRole || gotHeader.Channel != tt.wantChannel || gotHeader.Recipient != tt.wantRecipient {
|
||||
t.Errorf("case %d: got header %+v, want role=%s channel=%s recipient=%s", i, gotHeader, tt.wantRole, tt.wantChannel, tt.wantRecipient)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyParserNonStreaming(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
implicitStart bool
|
||||
wantEvents []HarmonyEvent
|
||||
}{
|
||||
{
|
||||
in: "<|start|>user<|message|>What is 2 + 2?<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "What is 2 + 2?"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>analysis<|message|>The answer is 4<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "The answer is 4"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>commentary to=functions.calc<|message|>Computing...<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||
HarmonyEventContentEmitted{Content: "Computing..."},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>user<|message|><|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>user<|message|>Hello<|end|><|start|>assistant<|message|>Hi!<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hello"},
|
||||
HarmonyEventMessageEnd{},
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hi!"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|channel|>analysis<|message|>Thinking about the request<|end|>",
|
||||
implicitStart: true,
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}, HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}}, HarmonyEventContentEmitted{Content: "Thinking about the request"}, HarmonyEventMessageEnd{}},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
if tt.implicitStart {
|
||||
parser.AddImplicitStart()
|
||||
}
|
||||
gotEvents := parser.AddContent(tt.in)
|
||||
if !reflect.DeepEqual(gotEvents, tt.wantEvents) {
|
||||
t.Errorf("case %d: got events %#v, want %#v", i, gotEvents, tt.wantEvents)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyParserStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []HarmonyEvent
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
implicitStart bool
|
||||
steps []step
|
||||
}{
|
||||
{
|
||||
desc: "simple message streamed character by character",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "|",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "start|>u",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "ser<|mess",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "age|>Hi",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hi"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: " there",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: " there"}},
|
||||
},
|
||||
{
|
||||
input: "<|e",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "nd|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "message with channel streamed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>assistant",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "<|chan",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "nel|>analysis",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "<|message|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}}},
|
||||
},
|
||||
{
|
||||
input: "Thinking",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Thinking"}},
|
||||
},
|
||||
{
|
||||
input: "...",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "..."}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "message with channel and recipient",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>assistant<|channel|>commentary to=functions.calc<|message|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "{\"x\": 5}",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "{\"x\": 5}"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "message with channel and recipient (receipient before channel)",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>assistant to=functions.calc<|channel|>commentary<|message|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "{\"x\": 5}",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "{\"x\": 5}"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "implicit start with channel",
|
||||
implicitStart: true,
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|channel|>thinking",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "<|message|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "thinking", Recipient: ""}}},
|
||||
},
|
||||
{
|
||||
input: "Processing request",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Processing request"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple messages streamed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>user<|message|>Hello<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hello"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "<|start|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "assistant<|message|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "", Recipient: ""}}},
|
||||
},
|
||||
{
|
||||
input: "Hi!",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Hi!"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty message",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>system<|message|><|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "system", Channel: "", Recipient: ""}},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial tag that looks like end but isn't",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>user<|message|>test<|e",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "test"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "xample|>more",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "<|example|>more"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
if tc.implicitStart {
|
||||
parser.AddImplicitStart()
|
||||
}
|
||||
|
||||
for i, step := range tc.steps {
|
||||
gotEvents := parser.AddContent(step.input)
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFunctionConvertToValidChars tests only FunctionNameMap.convert(), which doesn't
|
||||
// handle any saving (and therefore no dupe handling)
|
||||
func TestFunctionConvertToValidChars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "replace spaces with underscores", in: "get weather", want: "get_weather"},
|
||||
{name: "replace hyphens with underscores", in: "get-weather", want: "get_weather"},
|
||||
{name: "replace periods with underscores", in: "get.weather", want: "get_weather"},
|
||||
{name: "disallow non-word characters", in: "get weather!", want: "get_weather"},
|
||||
{name: "strip out invalid non-alphanumeric unicode characters", in: "a🫠bc", want: "abc"},
|
||||
{name: "names that only contain invalid characters", in: "🫠", want: "unnamed"},
|
||||
{name: "leading number", in: "123", want: "_123"},
|
||||
{name: "$ allowed", in: "$", want: "$"},
|
||||
// show that we allow weird unicode letter characters, though we might want
|
||||
// to convert them to their closest ASCII equivalents in the future
|
||||
{name: "allow weird unicode letter characters", in: "𝓸𝓵𝓵𝓪𝓶𝓪", want: "𝓸𝓵𝓵𝓪𝓶𝓪"},
|
||||
// names that look like words but are invalid (i.e., not ID_Start/ID_Continue)
|
||||
{name: "disallow non-word characters that look like words", in: "ⓞⓛⓛⓐⓜⓐ123", want: "_123"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := NewFunctionNameMap()
|
||||
got := parser.convertToValidChars(tt.in)
|
||||
if got != tt.want {
|
||||
t.Errorf("case %d: got %q, want %q", i, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionConvertAndAdd(t *testing.T) {
|
||||
// make a fresh map for each test, but within a test use the same map so we can test for dupe handling
|
||||
tests := []struct {
|
||||
name string
|
||||
in []string
|
||||
want []string
|
||||
}{
|
||||
{name: "basic dupe handling", in: []string{"get weather", "get weather"}, want: []string{"get_weather", "get_weather_2"}},
|
||||
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
|
||||
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
|
||||
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
parser := NewFunctionNameMap()
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for j, in := range tt.in {
|
||||
got := parser.ConvertAndAdd(in)
|
||||
want := tt.want[j]
|
||||
if got != want {
|
||||
t.Errorf("case %d: got %q, want %q", i, got, want)
|
||||
}
|
||||
// check that the maps are correct
|
||||
if parser.userToHarmony[in] != want {
|
||||
t.Errorf("case %d: userToHarmony[%q] = %q, want %q", i, in, parser.userToHarmony[in], want)
|
||||
}
|
||||
if parser.harmonyToUser[want] != in {
|
||||
t.Errorf("case %d: harmonyToUser[%q] = %q, want %q", i, want, parser.harmonyToUser[want], in)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,10 +2,13 @@
|
||||
|
||||
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
||||
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
|
||||
|
||||
|
||||
The integration tests have 2 modes of operating.
|
||||
|
||||
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||
|
||||
@@ -390,7 +390,7 @@ func TestAPIEmbeddings(t *testing.T) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "orca-mini",
|
||||
Model: libraryEmbedModels[0],
|
||||
Prompt: "why is the sky blue?",
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBlueSky(t *testing.T) {
|
||||
@@ -37,8 +36,8 @@ func TestUnicode(t *testing.T) {
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
|
||||
Prompt: "天空为什么是蓝色的?",
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
|
||||
Prompt: "天空为什么是蓝色的?", // Why is the sky blue?
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -50,8 +49,20 @@ func TestUnicode(t *testing.T) {
|
||||
}
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
DoGenerate(ctx, t, client, req, []string{"散射", "频率"}, 120*time.Second, 120*time.Second)
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
slog.Info("loading", "model", req.Model)
|
||||
err := client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||
}
|
||||
skipIfNotGPULoaded(ctx, t, client, req.Model, 100)
|
||||
|
||||
DoGenerate(ctx, t, client, req, []string{
|
||||
"散射", // scattering
|
||||
"频率", // frequency
|
||||
}, 120*time.Second, 120*time.Second)
|
||||
}
|
||||
|
||||
func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
@@ -69,7 +80,9 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
}
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
||||
}
|
||||
|
||||
@@ -84,7 +97,9 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
}
|
||||
|
||||
modelDir, err := os.MkdirTemp("", "ollama_埃")
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(modelDir)
|
||||
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
|
||||
|
||||
|
||||
@@ -4,257 +4,176 @@ package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
func TestMultiModelConcurrency(t *testing.T) {
|
||||
var (
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: "llama3.2:1b",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "tinydolphin",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
resp = [2][]string{
|
||||
{"sunlight"},
|
||||
{"england", "english", "massachusetts", "pilgrims", "british", "festival"},
|
||||
}
|
||||
)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(req))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for i := 0; i < len(req); i++ {
|
||||
require.NoError(t, PullIfMissing(ctx, client, req[i].Model))
|
||||
}
|
||||
|
||||
for i := 0; i < len(req); i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
// Note: CPU based inference can crawl so don't give up too quickly
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 30*time.Second)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestIntegrationConcurrentPredict(t *testing.T) {
|
||||
// Send multiple requests in parallel (concurrently) to a single model and ensure responses are expected
|
||||
func TestConcurrentGenerate(t *testing.T) {
|
||||
// Assumes all requests have the same model
|
||||
req, resp := GenerateRequests()
|
||||
reqLimit := len(req)
|
||||
iterLimit := 5
|
||||
numParallel := int(envconfig.NumParallel() + 1)
|
||||
iterLimit := 3
|
||||
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||
require.NoError(t, err)
|
||||
// Don't hammer on small VRAM cards...
|
||||
if maxVram < 4*format.GibiByte {
|
||||
reqLimit = min(reqLimit, 2)
|
||||
iterLimit = 2
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 9*time.Minute)
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial request
|
||||
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 10*time.Second)
|
||||
slog.Info("loading", "model", req[0].Model)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: req[0].Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req[0].Model, err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(reqLimit)
|
||||
for i := 0; i < reqLimit; i++ {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
wg.Add(numParallel)
|
||||
for i := range numParallel {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
slog.Info("Starting", "req", i, "iter", j)
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
k := r.Int() % len(req)
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 20*time.Second)
|
||||
DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
|
||||
// Stress the scheduler and attempt to load more models than will fit to cause thrashing
|
||||
// This test will always load at least 2 models even on CPU based systems
|
||||
func TestMultiModelStress(t *testing.T) {
|
||||
s := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
|
||||
s := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if s == "" {
|
||||
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
|
||||
s = "0"
|
||||
}
|
||||
|
||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if maxVram < 2*format.GibiByte {
|
||||
t.Skip("VRAM less than 2G, skipping model stress tests")
|
||||
|
||||
// All models compatible with ollama-engine
|
||||
smallModels := []string{
|
||||
"llama3.2:1b",
|
||||
"qwen3:0.6b",
|
||||
"gemma2:2b",
|
||||
"deepseek-r1:1.5b", // qwen2 arch
|
||||
"gemma3:270m",
|
||||
}
|
||||
mediumModels := []string{
|
||||
"llama3.2:3b", // ~3.4G
|
||||
"qwen3:8b", // ~6.6G
|
||||
"gpt-oss:20b", // ~15G
|
||||
"deepseek-r1:7b", // ~5.6G
|
||||
"gemma3:4b", // ~5.8G
|
||||
"gemma2:9b", // ~8.1G
|
||||
}
|
||||
|
||||
type model struct {
|
||||
name string
|
||||
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
|
||||
}
|
||||
|
||||
smallModels := []model{
|
||||
{
|
||||
name: "llama3.2:1b",
|
||||
size: 2876 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "phi",
|
||||
size: 2616 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "gemma:2b",
|
||||
size: 2364 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "stable-code:3b",
|
||||
size: 2608 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "starcoder2:3b",
|
||||
size: 2166 * format.MebiByte,
|
||||
},
|
||||
}
|
||||
mediumModels := []model{
|
||||
{
|
||||
name: "llama2",
|
||||
size: 5118 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "mistral",
|
||||
size: 4620 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "orca-mini:7b",
|
||||
size: 5118 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "dolphin-mistral",
|
||||
size: 4620 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "gemma:7b",
|
||||
size: 5000 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "codellama:7b",
|
||||
size: 5118 * format.MebiByte,
|
||||
},
|
||||
}
|
||||
|
||||
// These seem to be too slow to be useful...
|
||||
// largeModels := []model{
|
||||
// {
|
||||
// name: "llama2:13b",
|
||||
// size: 7400 * format.MebiByte,
|
||||
// },
|
||||
// {
|
||||
// name: "codellama:13b",
|
||||
// size: 7400 * format.MebiByte,
|
||||
// },
|
||||
// {
|
||||
// name: "orca-mini:13b",
|
||||
// size: 7400 * format.MebiByte,
|
||||
// },
|
||||
// {
|
||||
// name: "gemma:7b",
|
||||
// size: 5000 * format.MebiByte,
|
||||
// },
|
||||
// {
|
||||
// name: "starcoder2:15b",
|
||||
// size: 9100 * format.MebiByte,
|
||||
// },
|
||||
// }
|
||||
|
||||
var chosenModels []model
|
||||
var chosenModels []string
|
||||
switch {
|
||||
case maxVram < 10000*format.MebiByte:
|
||||
slog.Info("selecting small models")
|
||||
chosenModels = smallModels
|
||||
// case maxVram < 30000*format.MebiByte:
|
||||
default:
|
||||
slog.Info("selecting medium models")
|
||||
chosenModels = mediumModels
|
||||
// default:
|
||||
// slog.Info("selecting large models")
|
||||
// chosenModels = largeModels
|
||||
}
|
||||
|
||||
req, resp := GenerateRequests()
|
||||
|
||||
for i := range req {
|
||||
if i > len(chosenModels) {
|
||||
break
|
||||
}
|
||||
req[i].Model = chosenModels[i].name
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Make sure all the models are pulled before we get started
|
||||
for _, r := range req {
|
||||
require.NoError(t, PullIfMissing(ctx, client, r.Model))
|
||||
for _, model := range chosenModels {
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
consumed := uint64(256 * format.MebiByte) // Assume some baseline usage
|
||||
for i := 0; i < len(req); i++ {
|
||||
// Always get at least 2 models, but don't overshoot VRAM too much or we'll take too long
|
||||
if i > 1 && consumed > maxVram {
|
||||
slog.Info("achieved target vram exhaustion", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed))
|
||||
break
|
||||
// Determine how many models we can load in parallel before we exceed VRAM
|
||||
// The intent is to go 1 over what can fit so we force the scheduler to thrash
|
||||
targetLoadCount := 0
|
||||
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
|
||||
for i, model := range chosenModels {
|
||||
req := &api.GenerateRequest{Model: model}
|
||||
slog.Info("loading", "model", model)
|
||||
err = client.Generate(ctx, req, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", model, err)
|
||||
}
|
||||
consumed += chosenModels[i].size
|
||||
slog.Info("target vram", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed))
|
||||
targetLoadCount++
|
||||
if i > 0 {
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to list running models: %s", err)
|
||||
}
|
||||
if len(models.Models) < targetLoadCount {
|
||||
loaded := []string{}
|
||||
for _, m := range models.Models {
|
||||
loaded = append(loaded, m.Name)
|
||||
}
|
||||
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetLoadCount == len(chosenModels) {
|
||||
// TODO consider retrying the medium models
|
||||
slog.Warn("all models being used without exceeding VRAM, set OLLAMA_MAX_VRAM so test can pick larger models")
|
||||
}
|
||||
|
||||
r := rand.New(rand.NewSource(0))
|
||||
var wg sync.WaitGroup
|
||||
for i := range targetLoadCount {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
reqs, resps := GenerateRequests()
|
||||
for j := 0; j < 3; j++ {
|
||||
slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 5*time.Second)
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
k := r.Int() % len(reqs)
|
||||
reqs[k].Model = chosenModels[i]
|
||||
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Prompt)
|
||||
DoGenerate(ctx, t, client, reqs[k], resps[k],
|
||||
120*time.Second, // Be extra patient for the model to load initially
|
||||
10*time.Second, // Once results start streaming, fail if they stall
|
||||
)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(2 * time.Second)
|
||||
time.Sleep(10 * time.Second)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@@ -265,7 +184,21 @@ func TestMultiModelStress(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
for _, m := range models.Models {
|
||||
slog.Info("loaded model snapshot", "model", m)
|
||||
var procStr string
|
||||
switch {
|
||||
case m.SizeVRAM == 0:
|
||||
procStr = "100% CPU"
|
||||
case m.SizeVRAM == m.Size:
|
||||
procStr = "100% GPU"
|
||||
case m.SizeVRAM > m.Size || m.Size == 0:
|
||||
procStr = "Unknown"
|
||||
default:
|
||||
sizeCPU := m.Size - m.SizeVRAM
|
||||
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
|
||||
procStr = fmt.Sprintf("%d%%/%d%%", int(cpuPercent), int(100-cpuPercent))
|
||||
}
|
||||
|
||||
slog.Info("loaded model snapshot", "model", m.Name, "CPU/GPU", procStr, "expires", format.HumanTime(m.ExpiresAt, "Never"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -20,7 +22,7 @@ func TestLongInputContext(t *testing.T) {
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "llama2",
|
||||
Model: smol,
|
||||
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
@@ -34,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia"}, 120*time.Second, 10*time.Second)
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
@@ -47,7 +49,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "llama2",
|
||||
Model: smol,
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
@@ -61,5 +63,104 @@ func TestContextExhaustion(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||
func TestGenerateWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := GenerateRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial request
|
||||
slog.Info("loading", "model", modelOverride)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", modelOverride, err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numParallel)
|
||||
for i := range numParallel {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
k := i % len(req)
|
||||
req[k].Model = modelOverride
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
c := DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
req[k].Context = c
|
||||
req[k].Prompt = "tell me more!"
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||
func TestChatWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := ChatRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial empty request
|
||||
slog.Info("loading", "model", modelOverride)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", modelOverride, err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numParallel)
|
||||
for i := range numParallel {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
k := i % len(req)
|
||||
req[k].Model = modelOverride
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
if assistant == nil {
|
||||
t.Fatalf("didn't get an assistant response for context")
|
||||
}
|
||||
req[k].Messages = append(req[k].Messages,
|
||||
*assistant,
|
||||
api.Message{Role: "user", Content: "tell me more!"},
|
||||
)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
57
integration/library_models_test.go
Normal file
57
integration/library_models_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
//go:build integration && library
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// First run of this scenario on a target system will take a long time to download
|
||||
// ~1.5TB of models. Set a sufficiently large -timeout for your network speed
|
||||
func TestLibraryModelsGenerate(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
chatModels := libraryChatModels
|
||||
for _, model := range chatModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||
}
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0.1,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength"}
|
||||
// Special cases
|
||||
if model == "duckdb-nsql" {
|
||||
anyResp = []string{"select", "from"}
|
||||
} else if model == "granite3-guardian" || model == "shieldgemma" || model == "llama-guard3" || model == "bespoke-minicheck" {
|
||||
anyResp = []string{"yes", "no", "safe", "unsafe"}
|
||||
} else if model == "openthinker" || model == "nexusraven" {
|
||||
anyResp = []string{"plugin", "im_sep", "components", "function call"}
|
||||
} else if model == "starcoder" || model == "starcoder2" || model == "magicoder" || model == "deepseek-coder" {
|
||||
req.Prompt = "def fibonacci():"
|
||||
anyResp = []string{"f(n)", "sequence", "n-1", "main()", "__main__", "while"}
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVisionModels(t *testing.T) {
|
||||
@@ -32,7 +31,9 @@ func TestVisionModels(t *testing.T) {
|
||||
for _, v := range testCases {
|
||||
t.Run(v.model, func(t *testing.T) {
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: v.model,
|
||||
Prompt: "what does the text in this image say?",
|
||||
@@ -52,7 +53,9 @@ func TestVisionModels(t *testing.T) {
|
||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||
resp := "the ollam"
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// llava models on CPU can be quite slow to start
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
})
|
||||
@@ -62,7 +65,9 @@ func TestVisionModels(t *testing.T) {
|
||||
func TestIntegrationSplitBatch(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: "gemma3:4b",
|
||||
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
||||
@@ -84,7 +89,9 @@ func TestIntegrationSplitBatch(t *testing.T) {
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// llava models on CPU can be quite slow to start,
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||
}
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
||||
// package to avoid circular dependencies
|
||||
|
||||
var (
|
||||
stream = false
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: smol,
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
resp = [2][]string{
|
||||
{"sunlight", "scattering", "interact"},
|
||||
{"england", "english", "massachusetts", "pilgrims"},
|
||||
}
|
||||
)
|
||||
|
||||
func TestIntegrationSimple(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||
defer cancel()
|
||||
GenerateTestHelper(ctx, t, req[0], resp[0])
|
||||
}
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestMaxQueue(t *testing.T) {
|
||||
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
|
||||
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
||||
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
|
||||
return
|
||||
@@ -45,7 +45,9 @@ func TestMaxQueue(t *testing.T) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Context for the worker threads so we can shut them down
|
||||
// embedCtx, embedCancel := context.WithCancel(ctx)
|
||||
@@ -89,7 +91,9 @@ func TestMaxQueue(t *testing.T) {
|
||||
switch {
|
||||
case genErr == nil:
|
||||
successCount++
|
||||
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
|
||||
if len(resp.Embedding) < 5 { // somewhat arbitrary, but sufficient to be reasonable
|
||||
t.Fatalf("embeddings shorter than expected: %d", len(resp.Embedding))
|
||||
}
|
||||
case errors.Is(genErr, context.Canceled):
|
||||
canceledCount++
|
||||
case strings.Contains(genErr.Error(), "busy"):
|
||||
@@ -97,7 +101,9 @@ func TestMaxQueue(t *testing.T) {
|
||||
case strings.Contains(genErr.Error(), "connection reset by peer"):
|
||||
resetByPeerCount++
|
||||
default:
|
||||
require.NoError(t, genErr, "%d request failed", i)
|
||||
if genErr != nil {
|
||||
t.Fatalf("%d request failed", i)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("embed finished", "id", i)
|
||||
@@ -108,8 +114,13 @@ func TestMaxQueue(t *testing.T) {
|
||||
embedwg.Wait()
|
||||
|
||||
slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
|
||||
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
|
||||
require.True(t, busyCount > 0, "no requests hit busy error but some should have")
|
||||
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
|
||||
|
||||
if resetByPeerCount != 0 {
|
||||
t.Fatalf("Connections reset by peer, have you updated your fd and socket limits? %d", resetByPeerCount)
|
||||
}
|
||||
if busyCount == 0 {
|
||||
t.Fatalf("no requests hit busy error but some should have")
|
||||
}
|
||||
if canceledCount > 0 {
|
||||
t.Fatalf("no requests should have been canceled due to timeout %d", canceledCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,35 +19,6 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
var (
|
||||
started = time.Now()
|
||||
chatModels = []string{
|
||||
"granite3-moe:latest",
|
||||
"granite-code:latest",
|
||||
"nemotron-mini:latest",
|
||||
"command-r:latest",
|
||||
"gemma2:latest",
|
||||
"gemma:latest",
|
||||
"internlm2:latest",
|
||||
"phi3.5:latest",
|
||||
"phi3:latest",
|
||||
// "phi:latest", // flaky, sometimes generates no response on first query
|
||||
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
|
||||
"falcon:latest",
|
||||
"falcon2:latest",
|
||||
"minicpm-v:latest",
|
||||
"mistral:latest",
|
||||
"orca-mini:latest",
|
||||
"llama2:latest",
|
||||
"llama3.1:latest",
|
||||
"llama3.2:latest",
|
||||
"llama3.2-vision:latest",
|
||||
"qwen2.5-coder:latest",
|
||||
"qwen:latest",
|
||||
"solar-pro:latest",
|
||||
}
|
||||
)
|
||||
|
||||
func TestModelsGenerate(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
@@ -68,6 +39,13 @@ func TestModelsGenerate(t *testing.T) {
|
||||
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
||||
}
|
||||
|
||||
var chatModels []string
|
||||
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
|
||||
chatModels = ollamaEngineChatModels
|
||||
} else {
|
||||
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
|
||||
}
|
||||
|
||||
for _, model := range chatModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
|
||||
266
integration/model_perf_test.go
Normal file
266
integration/model_perf_test.go
Normal file
@@ -0,0 +1,266 @@
|
||||
//go:build integration && perf
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
var (
|
||||
// Models that don't work reliably with the large context prompt in this test case
|
||||
longContextFlakes = []string{
|
||||
"granite-code:latest",
|
||||
"nemotron-mini:latest",
|
||||
"falcon:latest", // 2k model
|
||||
"falcon2:latest", // 2k model
|
||||
"minicpm-v:latest",
|
||||
"qwen:latest",
|
||||
"solar-pro:latest",
|
||||
}
|
||||
)
|
||||
|
||||
// Note: this test case can take a long time to run, particularly on models with
|
||||
// large contexts. Run with -timeout set to a large value to get reasonable coverage
|
||||
// Example usage:
|
||||
//
|
||||
// go test --tags=integration,perf -count 1 ./integration -v -timeout 90m -run TestModelsPerf 2>&1 | tee int.log
|
||||
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
|
||||
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
|
||||
func TestModelsPerf(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// TODO use info API eventually
|
||||
var maxVram uint64
|
||||
var err error
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err = strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
|
||||
}
|
||||
} else {
|
||||
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadFile(filepath.Join("testdata", "shakespeare.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open test data file: %s", err)
|
||||
}
|
||||
longPrompt := "summarize the following: " + string(data)
|
||||
|
||||
var chatModels []string
|
||||
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
|
||||
chatModels = ollamaEngineChatModels
|
||||
} else {
|
||||
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
|
||||
}
|
||||
|
||||
for _, model := range chatModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||
}
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
var maxContext int
|
||||
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
||||
if err != nil {
|
||||
t.Fatalf("show failed: %s", err)
|
||||
}
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
|
||||
|
||||
if maxVram > 0 {
|
||||
resp, err := client.List(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("list models failed %v", err)
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
// For these tests we want to exercise a some amount of overflow on the CPU
|
||||
if m.Name == model && float32(m.Size)*0.75 > float32(maxVram) {
|
||||
t.Skipf("model %s is too large %s for available VRAM %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
|
||||
}
|
||||
}
|
||||
}
|
||||
slog.Info("scneario", "model", model, "max_context", maxContext)
|
||||
loaded := false
|
||||
defer func() {
|
||||
// best effort unload once we're done with the model
|
||||
if loaded {
|
||||
client.Generate(ctx, &api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
||||
}
|
||||
}()
|
||||
|
||||
// Some models don't handle the long context data well so skip them to avoid flaky test results
|
||||
longContextFlake := false
|
||||
for _, flake := range longContextFlakes {
|
||||
if model == flake {
|
||||
longContextFlake = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// iterate through a few context sizes for coverage without excessive runtime
|
||||
var contexts []int
|
||||
keepGoing := true
|
||||
if maxContext > 16384 {
|
||||
contexts = []int{4096, 8192, 16384, maxContext}
|
||||
} else if maxContext > 8192 {
|
||||
contexts = []int{4096, 8192, maxContext}
|
||||
} else if maxContext > 4096 {
|
||||
contexts = []int{4096, maxContext}
|
||||
} else if maxContext > 0 {
|
||||
contexts = []int{maxContext}
|
||||
} else {
|
||||
t.Fatal("unknown max context size")
|
||||
}
|
||||
for _, numCtx := range contexts {
|
||||
if !keepGoing && numCtx > 8192 { // Always try up to 8k before bailing out
|
||||
break
|
||||
}
|
||||
skipLongPrompt := false
|
||||
|
||||
// Workaround bug 11172 temporarily...
|
||||
maxPrompt := longPrompt
|
||||
// If we fill the context too full with the prompt, many models
|
||||
// quickly hit context shifting and go bad.
|
||||
if len(maxPrompt) > numCtx*2 { // typically yields ~1/2 full context
|
||||
maxPrompt = maxPrompt[:numCtx*2]
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
prompt string
|
||||
anyResp []string
|
||||
}{
|
||||
{"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}},
|
||||
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}},
|
||||
}
|
||||
var gpuPercent int
|
||||
for _, tc := range testCases {
|
||||
if len(tc.prompt) > 100 && (longContextFlake || skipLongPrompt) {
|
||||
slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent)
|
||||
continue
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: tc.prompt,
|
||||
KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_ctx": numCtx,
|
||||
},
|
||||
}
|
||||
atLeastOne := false
|
||||
var resp api.GenerateResponse
|
||||
|
||||
stream := false
|
||||
req.Stream = &stream
|
||||
|
||||
// Avoid potentially getting stuck indefinitely
|
||||
limit := 5 * time.Minute
|
||||
genCtx, cancel := context.WithDeadlineCause(
|
||||
ctx,
|
||||
time.Now().Add(limit),
|
||||
fmt.Errorf("generate on model %s with ctx %d took longer than %v", model, numCtx, limit),
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error {
|
||||
resp = rsp
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// Avoid excessive test runs, but don't consider a failure with massive context
|
||||
if numCtx > 16384 && strings.Contains(err.Error(), "took longer") {
|
||||
slog.Warn("max context was taking too long, skipping", "error", err)
|
||||
keepGoing = false
|
||||
skipLongPrompt = true
|
||||
continue
|
||||
}
|
||||
t.Fatalf("generate error: ctx:%d err:%s", numCtx, err)
|
||||
}
|
||||
loaded = true
|
||||
for _, expResp := range tc.anyResp {
|
||||
if strings.Contains(strings.ToLower(resp.Response), expResp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Response)
|
||||
}
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("failed to list running models", "error", err)
|
||||
continue
|
||||
}
|
||||
if len(models.Models) > 1 {
|
||||
slog.Warn("multiple models loaded, may impact performance results", "loaded", models.Models)
|
||||
}
|
||||
for _, m := range models.Models {
|
||||
if m.Name == model {
|
||||
if m.SizeVRAM == 0 {
|
||||
slog.Info("Model fully loaded into CPU")
|
||||
gpuPercent = 0
|
||||
keepGoing = false
|
||||
skipLongPrompt = true
|
||||
} else if m.SizeVRAM == m.Size {
|
||||
slog.Info("Model fully loaded into GPU")
|
||||
gpuPercent = 100
|
||||
} else {
|
||||
sizeCPU := m.Size - m.SizeVRAM
|
||||
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
|
||||
gpuPercent = int(100 - cpuPercent)
|
||||
slog.Info("Model split between CPU/GPU", "CPU", cpuPercent, "GPU", gpuPercent)
|
||||
keepGoing = false
|
||||
|
||||
// Heuristic to avoid excessive test run time
|
||||
if gpuPercent < 90 {
|
||||
skipLongPrompt = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
|
||||
"MODEL",
|
||||
"CONTEXT",
|
||||
"GPU PERCENT",
|
||||
"PROMPT COUNT",
|
||||
"LOAD TIME",
|
||||
"PROMPT EVAL TPS",
|
||||
"EVAL TPS",
|
||||
)
|
||||
fmt.Fprintf(os.Stderr, "MODEL_PERF_DATA:%s,%d,%d,%d,%0.2f,%0.2f,%0.2f\n",
|
||||
model,
|
||||
numCtx,
|
||||
gpuPercent,
|
||||
resp.PromptEvalCount,
|
||||
float64(resp.LoadDuration)/1000000000.0,
|
||||
float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
|
||||
float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
23
integration/testdata/embed.json
vendored
23
integration/testdata/embed.json
vendored
File diff suppressed because one or more lines are too long
124456
integration/testdata/shakespeare.txt
vendored
Normal file
124456
integration/testdata/shakespeare.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -25,15 +26,245 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
smol = "llama3.2:1b"
|
||||
var (
|
||||
smol = "llama3.2:1b"
|
||||
stream = false
|
||||
)
|
||||
|
||||
func Init() {
|
||||
var (
|
||||
started = time.Now()
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"gpt-oss:20b",
|
||||
"gemma3n:e2b",
|
||||
"mistral-small3.2:latest",
|
||||
"deepseek-r1:1.5b",
|
||||
"llama3.2-vision:latest",
|
||||
"qwen2.5-coder:latest",
|
||||
"qwen2.5vl:3b",
|
||||
"qwen3:0.6b", // dense
|
||||
"qwen3:30b", // MOE
|
||||
"gemma3:1b",
|
||||
"llama3.1:latest",
|
||||
"llama3.2:latest",
|
||||
"gemma2:latest",
|
||||
"minicpm-v:latest", // arch=qwen2
|
||||
"granite-code:latest", // arch=llama
|
||||
}
|
||||
llamaRunnerChatModels = []string{
|
||||
"mistral:latest",
|
||||
"falcon3:latest",
|
||||
"granite3-moe:latest",
|
||||
"command-r:latest",
|
||||
"nemotron-mini:latest",
|
||||
"phi3.5:latest",
|
||||
"solar-pro:latest",
|
||||
"internlm2:latest",
|
||||
"codellama:latest", // arch=llama
|
||||
"phi3:latest",
|
||||
"falcon2:latest",
|
||||
"gemma:latest",
|
||||
"llama2:latest",
|
||||
"nous-hermes:latest",
|
||||
"orca-mini:latest",
|
||||
"qwen:latest",
|
||||
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
|
||||
"falcon:latest",
|
||||
}
|
||||
|
||||
// Some library models are quite large - ensure large VRAM and sufficient disk space
|
||||
// before running scenarios based on this set
|
||||
libraryChatModels = []string{
|
||||
"alfred",
|
||||
"athene-v2",
|
||||
"aya-expanse",
|
||||
"aya",
|
||||
"bakllava",
|
||||
"bespoke-minicheck",
|
||||
"codebooga",
|
||||
"codegeex4",
|
||||
"codegemma",
|
||||
"codellama",
|
||||
"codeqwen",
|
||||
"codestral",
|
||||
"codeup",
|
||||
"cogito",
|
||||
"command-a",
|
||||
"command-r-plus",
|
||||
"command-r",
|
||||
"command-r7b-arabic",
|
||||
"command-r7b",
|
||||
"dbrx",
|
||||
"deepcoder",
|
||||
"deepscaler",
|
||||
"deepseek-coder-v2",
|
||||
"deepseek-coder",
|
||||
"deepseek-llm",
|
||||
"deepseek-r1",
|
||||
// "deepseek-v2.5", // requires 155 GB VRAM
|
||||
"deepseek-v2",
|
||||
// "deepseek-v3", // requires 482 GB VRAM
|
||||
"devstral",
|
||||
"dolphin-llama3",
|
||||
"dolphin-mistral",
|
||||
"dolphin-mixtral",
|
||||
"dolphin-phi",
|
||||
"dolphin3",
|
||||
"dolphincoder",
|
||||
"duckdb-nsql",
|
||||
"everythinglm",
|
||||
"exaone-deep",
|
||||
"exaone3.5",
|
||||
"falcon",
|
||||
"falcon2",
|
||||
"falcon3",
|
||||
"firefunction-v2",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"glm4",
|
||||
"goliath",
|
||||
"gpt-oss:20b",
|
||||
"granite-code",
|
||||
"granite3-dense",
|
||||
"granite3-guardian",
|
||||
"granite3-moe",
|
||||
"granite3.1-dense",
|
||||
"granite3.1-moe",
|
||||
"granite3.2-vision",
|
||||
"granite3.2",
|
||||
"granite3.3",
|
||||
"hermes3",
|
||||
"internlm2",
|
||||
"llama-guard3",
|
||||
"llama-pro",
|
||||
"llama2-chinese",
|
||||
"llama2-uncensored",
|
||||
"llama2",
|
||||
"llama3-chatqa",
|
||||
"llama3-gradient",
|
||||
"llama3-groq-tool-use",
|
||||
"llama3.1",
|
||||
"llama3.2-vision",
|
||||
"llama3.2",
|
||||
"llama3.3",
|
||||
"llama3",
|
||||
"llama4",
|
||||
"llava-llama3",
|
||||
"llava-phi3",
|
||||
"llava",
|
||||
"magicoder",
|
||||
"magistral",
|
||||
"marco-o1",
|
||||
"mathstral",
|
||||
"meditron",
|
||||
"medllama2",
|
||||
"megadolphin",
|
||||
"minicpm-v",
|
||||
"mistral-large",
|
||||
"mistral-nemo",
|
||||
"mistral-openorca",
|
||||
"mistral-small",
|
||||
"mistral-small3.1",
|
||||
"mistral-small3.2",
|
||||
"mistral",
|
||||
"mistrallite",
|
||||
"mixtral",
|
||||
"moondream",
|
||||
"nemotron-mini",
|
||||
"nemotron",
|
||||
"neural-chat",
|
||||
"nexusraven",
|
||||
"notus",
|
||||
"nous-hermes",
|
||||
"nous-hermes2-mixtral",
|
||||
"nous-hermes2",
|
||||
"nuextract",
|
||||
"olmo2",
|
||||
"open-orca-platypus2",
|
||||
"openchat",
|
||||
"opencoder",
|
||||
"openhermes",
|
||||
"openthinker",
|
||||
"orca-mini",
|
||||
"orca2",
|
||||
// "phi", // unreliable
|
||||
"phi3.5",
|
||||
"phi3",
|
||||
"phi4-mini-reasoning",
|
||||
"phi4-mini",
|
||||
"phi4-reasoning",
|
||||
"phi4",
|
||||
"phind-codellama",
|
||||
"qwen",
|
||||
"qwen2-math",
|
||||
"qwen2.5-coder",
|
||||
"qwen2.5",
|
||||
"qwen2.5vl",
|
||||
"qwen2",
|
||||
"qwen3:0.6b", // dense
|
||||
"qwen3:30b", // MOE
|
||||
"qwq",
|
||||
"r1-1776",
|
||||
"reader-lm",
|
||||
"reflection",
|
||||
"sailor2",
|
||||
"samantha-mistral",
|
||||
"shieldgemma",
|
||||
"smallthinker",
|
||||
"smollm",
|
||||
"smollm2",
|
||||
"solar-pro",
|
||||
"solar",
|
||||
"sqlcoder",
|
||||
"stable-beluga",
|
||||
"stable-code",
|
||||
"stablelm-zephyr",
|
||||
"stablelm2",
|
||||
"starcoder",
|
||||
"starcoder2",
|
||||
"starling-lm",
|
||||
"tinydolphin",
|
||||
"tinyllama",
|
||||
"tulu3",
|
||||
"vicuna",
|
||||
"wizard-math",
|
||||
"wizard-vicuna-uncensored",
|
||||
"wizard-vicuna",
|
||||
"wizardcoder",
|
||||
"wizardlm-uncensored",
|
||||
"wizardlm2",
|
||||
"xwinlm",
|
||||
"yarn-llama2",
|
||||
"yarn-mistral",
|
||||
"yi-coder",
|
||||
"yi",
|
||||
"zephyr",
|
||||
}
|
||||
libraryEmbedModels = []string{
|
||||
"all-minilm",
|
||||
"bge-large",
|
||||
"bge-m3",
|
||||
"granite-embedding",
|
||||
"mxbai-embed-large",
|
||||
"nomic-embed-text",
|
||||
"paraphrase-multilingual",
|
||||
"snowflake-arctic-embed",
|
||||
"snowflake-arctic-embed2",
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
lifecycle.InitLogging()
|
||||
custom := os.Getenv("OLLAMA_TEST_SMOL_MODEL")
|
||||
if custom != "" {
|
||||
slog.Info("setting smol test model to " + custom)
|
||||
smol = custom
|
||||
}
|
||||
}
|
||||
|
||||
func FindPort() string {
|
||||
@@ -205,7 +436,9 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||
}
|
||||
lifecycle.ServerLogFile = fp.Name()
|
||||
fp.Close()
|
||||
require.NoError(t, startServer(t, ctx, testEndpoint))
|
||||
if err := startServer(t, ctx, testEndpoint); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
return client, testEndpoint, func() {
|
||||
@@ -238,19 +471,25 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
|
||||
if err := PullIfMissing(ctx, client, genReq.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
|
||||
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) []int {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
var context []int
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
// fmt.Print(".")
|
||||
buf.Write([]byte(response.Response))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return errors.New("stall was detected while streaming response, aborting")
|
||||
}
|
||||
if len(response.Context) > 0 {
|
||||
context = response.Context
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -271,7 +510,13 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
|
||||
return context
|
||||
}
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
|
||||
}
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
@@ -281,11 +526,14 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, atLeastOne, "%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
}
|
||||
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
return context
|
||||
}
|
||||
|
||||
// Generate a set of requests
|
||||
@@ -294,65 +542,125 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
return []api.GenerateRequest{
|
||||
{
|
||||
Model: smol,
|
||||
Prompt: "why is the ocean blue?",
|
||||
Prompt: "why is the ocean blue? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "why is the color of dirt brown?",
|
||||
Prompt: "why is the color of dirt brown? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of independence day?",
|
||||
Prompt: "what is the origin of independence day? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the composition of air?",
|
||||
Prompt: "what is the composition of air? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
[][]string{
|
||||
{"sunlight"},
|
||||
{"soil", "organic", "earth", "black", "tan"},
|
||||
{"england", "english", "massachusetts", "pilgrims", "british"},
|
||||
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
|
||||
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
|
||||
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"},
|
||||
{"fourth", "july", "declaration", "independence"},
|
||||
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
}
|
||||
}
|
||||
|
||||
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
role := "assistant"
|
||||
fn := func(response api.ChatResponse) error {
|
||||
// fmt.Print(".")
|
||||
role = response.Message.Role
|
||||
buf.Write([]byte(response.Message.Content))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return errors.New("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||
} else {
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
|
||||
return nil
|
||||
}
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
|
||||
}
|
||||
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||
}
|
||||
|
||||
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
return &api.Message{Role: role, Content: buf.String()}
|
||||
}
|
||||
|
||||
func ChatRequests() ([]api.ChatRequest, [][]string) {
|
||||
genReqs, results := GenerateRequests()
|
||||
reqs := make([]api.ChatRequest, len(genReqs))
|
||||
// think := api.ThinkValue{Value: "low"}
|
||||
for i := range reqs {
|
||||
reqs[i].Model = genReqs[i].Model
|
||||
reqs[i].Stream = genReqs[i].Stream
|
||||
reqs[i].KeepAlive = genReqs[i].KeepAlive
|
||||
// reqs[i].Think = &think
|
||||
reqs[i].Messages = []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: genReqs[i].Prompt,
|
||||
},
|
||||
}
|
||||
}
|
||||
return reqs, results
|
||||
}
|
||||
|
||||
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
// TODO use info API in the future
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Don't hammer on small VRAM cards...
|
||||
if maxVram < gb*format.GibiByte {
|
||||
t.Skip("skipping with small VRAM to avoid timeouts")
|
||||
@@ -360,6 +668,39 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
}
|
||||
}
|
||||
|
||||
// Skip if the target model isn't X% GPU loaded to avoid excessive runtime
|
||||
func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) {
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to list running models: %s", err)
|
||||
}
|
||||
loaded := []string{}
|
||||
for _, m := range models.Models {
|
||||
loaded = append(loaded, m.Name)
|
||||
if m.Name != model {
|
||||
continue
|
||||
}
|
||||
gpuPercent := 0
|
||||
switch {
|
||||
case m.SizeVRAM == 0:
|
||||
gpuPercent = 0
|
||||
case m.SizeVRAM == m.Size:
|
||||
gpuPercent = 100
|
||||
case m.SizeVRAM > m.Size || m.Size == 0:
|
||||
t.Logf("unexpected size detected: %d", m.SizeVRAM)
|
||||
default:
|
||||
sizeCPU := m.Size - m.SizeVRAM
|
||||
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110)
|
||||
gpuPercent = int(100 - cpuPercent)
|
||||
}
|
||||
if gpuPercent < minPercent {
|
||||
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
|
||||
}
|
||||
return
|
||||
}
|
||||
t.Skip(fmt.Sprintf("model %s not loaded - actually loaded: %v", model, loaded))
|
||||
}
|
||||
|
||||
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
||||
deadline, hasDeadline := t.Deadline()
|
||||
if !hasDeadline {
|
||||
|
||||
@@ -19,12 +19,22 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
||||
// The tensors are of shape embed dim, kv heads, batch size
|
||||
// The mask is of shape history size, batch size
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
windowSize int32
|
||||
chunkSize int32
|
||||
DType ml.DType
|
||||
|
||||
// swaWindowSize is the number of tokens that will be included in the mask
|
||||
// during attention operations. swaMemorySize is the number of tokens that
|
||||
// will be retained in memory for partial prefix caching. Set to math.MaxInt32
|
||||
// for unlimited or if sliding window attention is not being used.
|
||||
swaWindowSize int32
|
||||
swaMemorySize int32
|
||||
|
||||
chunkSize int32
|
||||
|
||||
opts CausalOptions
|
||||
|
||||
// maxBatch is the largest batch that we might receive
|
||||
maxBatch int
|
||||
|
||||
// config controls mostly backend-specific optimizations
|
||||
config *ml.CacheConfig
|
||||
|
||||
@@ -85,32 +95,41 @@ type cellRange struct {
|
||||
|
||||
func NewCausalCache(shift shiftFn) *Causal {
|
||||
return &Causal{
|
||||
windowSize: math.MaxInt32,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
return &Causal{
|
||||
windowSize: windowSize,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
swaWindowSize: windowSize,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
|
||||
return &Causal{
|
||||
swaWindowSize: windowSize,
|
||||
swaMemorySize: memorySize,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
|
||||
return &Causal{
|
||||
windowSize: math.MaxInt32,
|
||||
chunkSize: chunkSize,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
chunkSize: chunkSize,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,11 +154,25 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
||||
c.config.MaskDType = ml.DTypeF32
|
||||
}
|
||||
|
||||
if c.swaWindowSize == 0 {
|
||||
c.swaWindowSize = math.MaxInt32
|
||||
}
|
||||
if c.swaMemorySize == 0 {
|
||||
c.swaMemorySize = c.swaWindowSize
|
||||
}
|
||||
if int(c.swaMemorySize) > capacity {
|
||||
c.swaMemorySize = math.MaxInt32
|
||||
}
|
||||
|
||||
if c.swaMemorySize < c.swaWindowSize {
|
||||
panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
|
||||
}
|
||||
|
||||
var cacheSize int
|
||||
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
|
||||
if c.swaMemorySize == math.MaxInt32 {
|
||||
cacheSize = maxSequences * capacity
|
||||
} else {
|
||||
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
|
||||
cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
|
||||
}
|
||||
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||
c.cells = make([]cacheCell, cacheSize)
|
||||
@@ -147,6 +180,7 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
||||
c.DType = dtype
|
||||
c.cellRanges = make(map[int]cellRange)
|
||||
c.backend = backend
|
||||
c.maxBatch = maxBatch
|
||||
}
|
||||
|
||||
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||
@@ -180,10 +214,10 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
if err != nil {
|
||||
slog.Warn("unable to find a kv cache slot", "cache", c)
|
||||
return err
|
||||
}
|
||||
|
||||
c.curCellRange = newRange()
|
||||
for i, pos := range batch.Positions {
|
||||
seq := batch.Sequences[i]
|
||||
|
||||
@@ -194,19 +228,12 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
||||
seqRange = newRange()
|
||||
}
|
||||
|
||||
if c.curLoc+i > seqRange.max {
|
||||
seqRange.max = c.curLoc + i
|
||||
}
|
||||
if seqRange.max > c.curCellRange.max {
|
||||
c.curCellRange.max = seqRange.max
|
||||
}
|
||||
seqRange.min = min(seqRange.min, c.curLoc+i)
|
||||
c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i)
|
||||
|
||||
seqRange.max = max(seqRange.max, c.curLoc+i)
|
||||
c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i)
|
||||
|
||||
if c.curLoc+i < seqRange.min {
|
||||
seqRange.min = c.curLoc + i
|
||||
}
|
||||
if seqRange.min < c.curCellRange.min {
|
||||
c.curCellRange.min = seqRange.min
|
||||
}
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
} else {
|
||||
@@ -248,7 +275,16 @@ func (c *Causal) findStartLoc() (int, error) {
|
||||
}
|
||||
|
||||
func (c *Causal) updateSlidingWindow() {
|
||||
if c.windowSize == math.MaxInt32 {
|
||||
c.curCellRange = newRange()
|
||||
|
||||
if c.swaMemorySize == math.MaxInt32 {
|
||||
for _, seq := range c.curSequences {
|
||||
if seqRange, ok := c.cellRanges[seq]; ok {
|
||||
c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
|
||||
c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -278,12 +314,16 @@ func (c *Causal) updateSlidingWindow() {
|
||||
|
||||
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
if c.cells[i].pos < pos-c.windowSize {
|
||||
if c.cells[i].pos < pos-c.swaMemorySize {
|
||||
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
} else {
|
||||
newRange.min = min(newRange.min, i)
|
||||
newRange.max = max(newRange.max, i)
|
||||
}
|
||||
if c.cells[i].pos >= pos-c.swaWindowSize {
|
||||
c.curCellRange.min = min(c.curCellRange.min, i)
|
||||
c.curCellRange.max = max(c.curCellRange.max, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,7 +363,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||
c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
||||
c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
||||
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
}
|
||||
}
|
||||
@@ -338,9 +378,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
||||
|
||||
if c.config.MaskDType != ml.DTypeF32 {
|
||||
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||
ctx.Forward(maskTensor.Copy(ctx, out))
|
||||
maskTensor = out
|
||||
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||
}
|
||||
|
||||
return maskTensor
|
||||
@@ -481,6 +519,8 @@ func (c *Causal) defrag() {
|
||||
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
|
||||
c.updateSlidingWindow()
|
||||
}
|
||||
|
||||
func (c *Causal) SetLayer(layer int) {
|
||||
@@ -606,7 +646,7 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
}
|
||||
|
||||
func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
if c.windowSize == math.MaxInt32 {
|
||||
if c.swaMemorySize == math.MaxInt32 {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -628,8 +668,8 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
lastWindowStart := max(0, last-c.windowSize)
|
||||
posWindowStart := max(0, pos-c.windowSize)
|
||||
lastWindowStart := max(0, last-c.swaMemorySize)
|
||||
posWindowStart := max(0, pos-c.swaWindowSize)
|
||||
|
||||
return posWindowStart >= lastWindowStart
|
||||
}
|
||||
@@ -639,48 +679,64 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
return ErrNotSupported
|
||||
}
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
seqRange := c.cellRanges[seq]
|
||||
size := seqRange.max - seqRange.min + 1
|
||||
|
||||
offsets := make([]int32, size)
|
||||
for i := range offsets {
|
||||
cell := c.cells[seqRange.min+i]
|
||||
for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
|
||||
size := min(seqRange.max-start+1, c.maxBatch)
|
||||
offsets := make([]int32, size)
|
||||
|
||||
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||
offsets[i] = offset
|
||||
var batchFirst, batchLast int
|
||||
|
||||
batchFirst = -1
|
||||
for i := range offsets {
|
||||
cell := c.cells[start+i]
|
||||
|
||||
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||
offsets[i] = offset
|
||||
if batchFirst < 0 {
|
||||
batchFirst = i
|
||||
}
|
||||
batchLast = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
if batchFirst < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
offsets = offsets[batchFirst : batchLast+1]
|
||||
|
||||
key = key.View(ctx, rowSize*seqRange.min,
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
size,
|
||||
)
|
||||
ctx := c.backend.NewContext()
|
||||
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||
|
||||
roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||
if err != nil {
|
||||
return err
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
key = key.View(ctx, rowSize*(start+batchFirst),
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
len(offsets),
|
||||
)
|
||||
|
||||
roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||
if err != nil {
|
||||
ctx.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Forward(roped.Copy(ctx, key))
|
||||
}
|
||||
|
||||
ctx.Forward(roped.Copy(ctx, key))
|
||||
ctx.Compute()
|
||||
ctx.Close()
|
||||
}
|
||||
|
||||
ctx.Compute()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -60,6 +60,8 @@ func TestSWA(t *testing.T) {
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
x := float32(math.Inf(-1))
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "FirstBatch",
|
||||
@@ -69,7 +71,12 @@ func TestSWA(t *testing.T) {
|
||||
pos: []int32{0, 1, 2, 3},
|
||||
expected: []float32{1, 2, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
expectedMask: []float32{
|
||||
0, x, x, x,
|
||||
0, 0, x, x,
|
||||
x, 0, 0, x,
|
||||
x, x, 0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SecondBatch",
|
||||
@@ -79,7 +86,53 @@ func TestSWA(t *testing.T) {
|
||||
pos: []int32{4, 5},
|
||||
expected: []float32{5, 6, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
|
||||
expectedMask: []float32{
|
||||
0, x, x, 0,
|
||||
0, 0, x, x,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestSWAMem(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewSWAMemCache(1, 3, nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
x := float32(math.Inf(-1))
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4},
|
||||
inShape: []int{1, 1, 4},
|
||||
seqs: []int{0, 0, 0, 0},
|
||||
pos: []int32{0, 1, 2, 3},
|
||||
expected: []float32{1, 2, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{
|
||||
0, x, x, x,
|
||||
0, 0, x, x,
|
||||
x, 0, 0, x,
|
||||
x, x, 0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SecondBatch",
|
||||
in: []float32{5, 6},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{4, 5},
|
||||
expected: []float32{4, 5, 6},
|
||||
expectedShape: []int{1, 1, 3},
|
||||
expectedMask: []float32{
|
||||
0, 0, x,
|
||||
x, 0, 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -437,6 +490,70 @@ func TestCanResume(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanResumeSWAMem(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
windowSize := int32(4)
|
||||
memSize := int32(5)
|
||||
cache := NewSWAMemCache(windowSize, memSize, nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
context := backend.NewContext()
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{0, 1, 2, 3, 4, 5},
|
||||
Sequences: []int{0, 0, 0, 0, 0, 0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// shift window by adding position 6
|
||||
err = cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{6, 7},
|
||||
Sequences: []int{0, 0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// only the latest position has overlapping windows
|
||||
if cache.CanResume(0, 0) {
|
||||
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 1) {
|
||||
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 2) {
|
||||
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 3) {
|
||||
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 4) {
|
||||
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 5) {
|
||||
t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
|
||||
}
|
||||
if !cache.CanResume(0, 6) {
|
||||
t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
|
||||
}
|
||||
if !cache.CanResume(0, 7) {
|
||||
t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
|
||||
}
|
||||
}
|
||||
|
||||
type testBackend struct {
|
||||
ml.Backend
|
||||
}
|
||||
|
||||
2
llama/build-info.cpp
generated
vendored
2
llama/build-info.cpp
generated
vendored
@@ -1,4 +1,4 @@
|
||||
int LLAMA_BUILD_NUMBER = 0;
|
||||
char const *LLAMA_COMMIT = "de4c07f93783a1a96456a44dc16b9db538ee1618";
|
||||
char const *LLAMA_COMMIT = "e54d41befcc1575f4c898c5ff4ef43970cead75f";
|
||||
char const *LLAMA_COMPILER = "";
|
||||
char const *LLAMA_BUILD_TARGET = "";
|
||||
|
||||
@@ -1,23 +1,32 @@
|
||||
protect **/*.go
|
||||
include common/
|
||||
include common/base64.*
|
||||
include common/common.*
|
||||
include common/json-schema-to-grammar.*
|
||||
include common/json.*
|
||||
include common/log.*
|
||||
include common/sampling.*
|
||||
include common/stb_image.*
|
||||
include include/
|
||||
include include/llama.*
|
||||
include include/llama-*.*
|
||||
include tools/
|
||||
include tools/mtmd/
|
||||
include tools/mtmd/clip.*
|
||||
include tools/mtmd/clip-impl.*
|
||||
include tools/mtmd/llava.*
|
||||
include src/
|
||||
include src/llama.*
|
||||
include src/llama-*.*
|
||||
include src/unicode-data.*
|
||||
include src/unicode.*
|
||||
exclude *
|
||||
protect .rsync-filter
|
||||
protect *.go
|
||||
include /common/
|
||||
include /common/base64.*
|
||||
include /common/common.*
|
||||
include /common/json-schema-to-grammar.*
|
||||
include /common/json.*
|
||||
include /common/log.*
|
||||
include /common/sampling.*
|
||||
include /include/
|
||||
include /include/llama.*
|
||||
include /include/llama-*.*
|
||||
include /tools/
|
||||
include /tools/mtmd/
|
||||
include /tools/mtmd/*.h
|
||||
include /tools/mtmd/clip.cpp
|
||||
include /tools/mtmd/mtmd.cpp
|
||||
include /tools/mtmd/mtmd-audio.cpp
|
||||
include /tools/mtmd/mtmd-helper.cpp
|
||||
include /src/
|
||||
include /src/llama.*
|
||||
include /src/llama-*.*
|
||||
include /src/unicode-data.*
|
||||
include /src/unicode.*
|
||||
include /vendor/
|
||||
include /vendor/miniaudio/
|
||||
include /vendor/miniaudio/*.h
|
||||
include /vendor/nlohmann/
|
||||
include /vendor/nlohmann/*.hpp
|
||||
include /vendor/stb/
|
||||
include /vendor/stb/*.h
|
||||
hide *
|
||||
|
||||
221
llama/llama.cpp/common/common.cpp
vendored
221
llama/llama.cpp/common/common.cpp
vendored
@@ -203,6 +203,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||
|
||||
DWORD p = NORMAL_PRIORITY_CLASS;
|
||||
switch (prio) {
|
||||
case GGML_SCHED_PRIO_LOW: p = BELOW_NORMAL_PRIORITY_CLASS; break;
|
||||
case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break;
|
||||
case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break;
|
||||
case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break;
|
||||
@@ -228,6 +229,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||
|
||||
int p = 0;
|
||||
switch (prio) {
|
||||
case GGML_SCHED_PRIO_LOW: p = 5; break;
|
||||
case GGML_SCHED_PRIO_NORMAL: p = 0; break;
|
||||
case GGML_SCHED_PRIO_MEDIUM: p = -5; break;
|
||||
case GGML_SCHED_PRIO_HIGH: p = -10; break;
|
||||
@@ -443,9 +445,37 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||
s = std::move(builder);
|
||||
}
|
||||
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
|
||||
bool string_remove_suffix(std::string & str, const std::string_view & suffix) {
|
||||
bool has_suffix = string_ends_with(str, suffix);
|
||||
if (has_suffix) {
|
||||
str = str.substr(0, str.size() - suffix.size());
|
||||
}
|
||||
return has_suffix;
|
||||
}
|
||||
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
|
||||
if (!str.empty() && !stop.empty()) {
|
||||
const char text_last_char = str.back();
|
||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||
if (stop[char_index] == text_last_char) {
|
||||
const auto current_partial = stop.substr(0, char_index + 1);
|
||||
if (string_ends_with(str, current_partial)) {
|
||||
return str.size() - char_index - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
std::string regex_escape(const std::string & s) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
return std::regex_replace(s, special_chars, "\\$0");
|
||||
return std::regex_replace(s, special_chars, "\\$&");
|
||||
}
|
||||
|
||||
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
|
||||
@@ -685,11 +715,17 @@ bool fs_validate_filename(const std::string & filename) {
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
# pragma clang diagnostic push
|
||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
|
||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
||||
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic pop
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
filename_utf32 = converter.from_bytes(filename);
|
||||
@@ -746,6 +782,9 @@ bool fs_validate_filename(const std::string & filename) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#include <iostream>
|
||||
|
||||
|
||||
// returns true if successful, false otherwise
|
||||
bool fs_create_directory_with_parents(const std::string & path) {
|
||||
#ifdef _WIN32
|
||||
@@ -763,9 +802,16 @@ bool fs_create_directory_with_parents(const std::string & path) {
|
||||
// process path from front to back, procedurally creating directories
|
||||
while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
|
||||
const std::wstring subpath = wpath.substr(0, pos_slash);
|
||||
const wchar_t * test = subpath.c_str();
|
||||
|
||||
const bool success = CreateDirectoryW(test, NULL);
|
||||
pos_slash += 1;
|
||||
|
||||
// skip the drive letter, in some systems it can return an access denied error
|
||||
if (subpath.length() == 2 && subpath[1] == ':') {
|
||||
continue;
|
||||
}
|
||||
|
||||
const bool success = CreateDirectoryW(subpath.c_str(), NULL);
|
||||
|
||||
if (!success) {
|
||||
const DWORD error = GetLastError();
|
||||
|
||||
@@ -779,8 +825,6 @@ bool fs_create_directory_with_parents(const std::string & path) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
pos_slash += 1;
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -830,7 +874,7 @@ std::string fs_get_cache_directory() {
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else {
|
||||
@@ -876,31 +920,6 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
if (params.reranking) {
|
||||
bool ok = true;
|
||||
|
||||
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
}
|
||||
|
||||
if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
}
|
||||
|
||||
if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
}
|
||||
}
|
||||
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
@@ -910,7 +929,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
return iparams;
|
||||
}
|
||||
|
||||
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
|
||||
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
|
||||
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
|
||||
params.ctx_shift = false;
|
||||
}
|
||||
@@ -942,6 +961,35 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
}
|
||||
}
|
||||
|
||||
if (llama_pooling_type(lctx) == LLAMA_POOLING_TYPE_RANK) {
|
||||
bool ok = true;
|
||||
|
||||
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
}
|
||||
|
||||
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
|
||||
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
|
||||
|
||||
if (!has_eos && !has_sep) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
} else if (!has_eos) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
|
||||
} else if (!has_sep) {
|
||||
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
}
|
||||
}
|
||||
|
||||
// load and optionally apply lora adapters
|
||||
for (auto & la : params.lora_adapters) {
|
||||
llama_adapter_lora_ptr lora;
|
||||
@@ -966,15 +1014,21 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias.push_back({i, -INFINITY});
|
||||
}
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
// add EOG biases to the active set of logit biases
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
||||
}
|
||||
|
||||
if (params.sampling.penalty_last_n == -1) {
|
||||
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||
@@ -1017,7 +1071,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
if (llama_model_has_decoder(model)) {
|
||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
|
||||
}
|
||||
llama_kv_self_clear(lctx);
|
||||
llama_memory_clear(llama_get_memory(lctx), true);
|
||||
llama_synchronize(lctx);
|
||||
llama_perf_context_reset(lctx);
|
||||
llama_set_warmup(lctx, false);
|
||||
@@ -1068,6 +1122,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||
mparams.use_mmap = params.use_mmap;
|
||||
mparams.use_mlock = params.use_mlock;
|
||||
mparams.check_tensors = params.check_tensors;
|
||||
mparams.use_extra_bufts = !params.no_extra_bufts;
|
||||
|
||||
if (params.kv_overrides.empty()) {
|
||||
mparams.kv_overrides = NULL;
|
||||
@@ -1083,6 +1138,9 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
|
||||
}
|
||||
|
||||
mparams.progress_callback = params.load_progress_callback;
|
||||
mparams.progress_callback_user_data = params.load_progress_callback_user_data;
|
||||
|
||||
return mparams;
|
||||
}
|
||||
|
||||
@@ -1114,11 +1172,8 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.op_offload = !params.no_op_offload;
|
||||
|
||||
if (params.reranking) {
|
||||
cparams.embeddings = true;
|
||||
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
||||
}
|
||||
cparams.swa_full = params.swa_full;
|
||||
cparams.kv_unified = params.kv_unified;
|
||||
|
||||
cparams.type_k = params.cache_type_k;
|
||||
cparams.type_v = params.cache_type_v;
|
||||
@@ -1252,6 +1307,9 @@ std::vector<llama_token> common_tokenize(
|
||||
int n_tokens = text.length() + 2 * add_special;
|
||||
std::vector<llama_token> result(n_tokens);
|
||||
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
||||
if (n_tokens == std::numeric_limits<int32_t>::min()) {
|
||||
throw std::runtime_error("Tokenization failed: input text too large, tokenization result exceeds int32_t limit");
|
||||
}
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
||||
@@ -1306,81 +1364,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
|
||||
return text;
|
||||
}
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
//
|
||||
|
||||
void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) {
|
||||
static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
|
||||
|
||||
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
|
||||
view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
|
||||
|
||||
llama_kv_cache_view_cell * c_curr = view.cells;
|
||||
llama_seq_id * cs_curr = view.cells_sequences;
|
||||
|
||||
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
|
||||
if (i % row_size == 0) {
|
||||
printf("\n%5d: ", i);
|
||||
}
|
||||
int seq_count = 0;
|
||||
for (int j = 0; j < view.n_seq_max; j++) {
|
||||
if (cs_curr[j] >= 0) { seq_count++; }
|
||||
}
|
||||
putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
|
||||
}
|
||||
|
||||
printf("\n=== Done dumping\n");
|
||||
}
|
||||
|
||||
void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) {
|
||||
static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
|
||||
|
||||
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
|
||||
view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
|
||||
|
||||
std::unordered_map<llama_seq_id, size_t> seqs;
|
||||
llama_kv_cache_view_cell * c_curr = view.cells;
|
||||
llama_seq_id * cs_curr = view.cells_sequences;
|
||||
|
||||
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
|
||||
for (int j = 0; j < view.n_seq_max; j++) {
|
||||
if (cs_curr[j] < 0) { continue; }
|
||||
if (seqs.find(cs_curr[j]) == seqs.end()) {
|
||||
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
|
||||
const size_t sz = seqs.size();
|
||||
seqs[cs_curr[j]] = sz;
|
||||
}
|
||||
}
|
||||
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
|
||||
}
|
||||
|
||||
printf("=== Sequence legend: ");
|
||||
for (const auto & it : seqs) {
|
||||
printf("%zu=%d, ", it.second, it.first);
|
||||
}
|
||||
printf("'+'=other sequence ids");
|
||||
|
||||
c_curr = view.cells;
|
||||
cs_curr = view.cells_sequences;
|
||||
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
|
||||
if (i % row_size == 0) {
|
||||
printf("\n%5d: ", i);
|
||||
}
|
||||
for (int j = 0; j < view.n_seq_max; j++) {
|
||||
if (cs_curr[j] >= 0) {
|
||||
const auto & it = seqs.find(cs_curr[j]);
|
||||
putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
|
||||
} else {
|
||||
putchar('.');
|
||||
}
|
||||
}
|
||||
putchar(' ');
|
||||
}
|
||||
|
||||
printf("\n=== Done dumping\n");
|
||||
}
|
||||
|
||||
//
|
||||
// Embedding utils
|
||||
//
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package common
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++11
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/../include
|
||||
// #cgo CXXFLAGS: -std=c++17
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/../include -I${SRCDIR}/../vendor
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/../../../ml/backend/ggml/ggml/include
|
||||
import "C"
|
||||
|
||||
79
llama/llama.cpp/common/common.h
vendored
79
llama/llama.cpp/common/common.h
vendored
@@ -6,7 +6,9 @@
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
|
||||
#ifdef _WIN32
|
||||
@@ -75,10 +77,11 @@ enum llama_example {
|
||||
LLAMA_EXAMPLE_SERVER,
|
||||
LLAMA_EXAMPLE_CVECTOR_GENERATOR,
|
||||
LLAMA_EXAMPLE_EXPORT_LORA,
|
||||
LLAMA_EXAMPLE_LLAVA,
|
||||
LLAMA_EXAMPLE_MTMD,
|
||||
LLAMA_EXAMPLE_LOOKUP,
|
||||
LLAMA_EXAMPLE_PARALLEL,
|
||||
LLAMA_EXAMPLE_TTS,
|
||||
LLAMA_EXAMPLE_DIFFUSION,
|
||||
|
||||
LLAMA_EXAMPLE_COUNT,
|
||||
};
|
||||
@@ -114,7 +117,7 @@ enum common_grammar_trigger_type {
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
|
||||
};
|
||||
|
||||
struct common_grammar_trigger {
|
||||
@@ -175,7 +178,8 @@ struct common_params_sampling {
|
||||
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
|
||||
std::set<llama_token> preserved_tokens;
|
||||
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
// print the parameters into a string
|
||||
std::string print() const;
|
||||
@@ -197,6 +201,10 @@ struct common_params_speculative {
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||
|
||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||
|
||||
struct cpu_params cpuparams;
|
||||
struct cpu_params cpuparams_batch;
|
||||
@@ -212,9 +220,26 @@ struct common_params_vocoder {
|
||||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||
};
|
||||
|
||||
struct common_params_diffusion {
|
||||
int32_t steps = 128;
|
||||
bool visual_mode = false;
|
||||
|
||||
float eps = 0; // epsilon for timesteps
|
||||
int32_t block_length = 0; // block length for generation
|
||||
|
||||
int32_t algorithm = 4; // default algorithm: low-confidence
|
||||
float alg_temp = 0.0f; // algorithm temperature
|
||||
|
||||
float cfg_scale = 0; // classifier-free guidance scale
|
||||
bool add_gumbel_noise = false; // add gumbel noise to the logits if temp > 0.0
|
||||
};
|
||||
|
||||
enum common_reasoning_format {
|
||||
COMMON_REASONING_FORMAT_NONE,
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`
|
||||
COMMON_REASONING_FORMAT_AUTO,
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
|
||||
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
|
||||
};
|
||||
|
||||
struct common_params {
|
||||
@@ -262,6 +287,7 @@ struct common_params {
|
||||
struct common_params_sampling sampling;
|
||||
struct common_params_speculative speculative;
|
||||
struct common_params_vocoder vocoder;
|
||||
struct common_params_diffusion diffusion;
|
||||
|
||||
struct common_params_model model;
|
||||
|
||||
@@ -290,6 +316,7 @@ struct common_params {
|
||||
int32_t verbosity = 0;
|
||||
int32_t control_vector_layer_start = -1; // layer range for control vector
|
||||
int32_t control_vector_layer_end = -1; // layer range for control vector
|
||||
bool offline = false;
|
||||
|
||||
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
|
||||
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
|
||||
@@ -322,17 +349,19 @@ struct common_params {
|
||||
bool flash_attn = false; // flash attention
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool ctx_shift = true; // context shift on inifinite text generation
|
||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
bool kv_unified = false; // enable unified KV cache
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
bool verbose_prompt = false; // print prompt tokens before generation
|
||||
bool display_prompt = true; // print prompt before generation
|
||||
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
|
||||
bool no_kv_offload = false; // disable KV offloading
|
||||
bool warmup = true; // warmup run
|
||||
bool check_tensors = false; // validate tensor data
|
||||
bool no_op_offload = false; // globally disable offload host tensor operations to device
|
||||
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
|
||||
|
||||
bool single_turn = false; // single turn chat conversation
|
||||
|
||||
@@ -352,7 +381,7 @@ struct common_params {
|
||||
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
||||
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
||||
std::string embd_sep = "\n"; // separator of embeddings
|
||||
bool reranking = false; // enable reranking support on server
|
||||
std::string cls_sep = "\t"; // separator of classification sequences
|
||||
|
||||
// server params
|
||||
int32_t port = 8080; // server listens on this network port
|
||||
@@ -363,16 +392,21 @@ struct common_params {
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = ""; // NOLINT
|
||||
std::string api_prefix = ""; // NOLINT
|
||||
std::string chat_template = ""; // NOLINT
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
|
||||
int reasoning_budget = -1;
|
||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||
|
||||
std::vector<std::string> api_keys;
|
||||
|
||||
std::string ssl_file_key = ""; // NOLINT
|
||||
std::string ssl_file_cert = ""; // NOLINT
|
||||
|
||||
std::map<std::string, std::string> default_template_kwargs;
|
||||
|
||||
// "advanced" endpoints are disabled by default for better security
|
||||
bool webui = true;
|
||||
bool endpoint_slots = false;
|
||||
@@ -407,10 +441,12 @@ struct common_params {
|
||||
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
|
||||
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
|
||||
int32_t i_chunk = 0; // start processing from this chunk
|
||||
int8_t imat_dat = 0; // whether the legacy imatrix.dat format should be output (gguf <= 0 < dat)
|
||||
|
||||
bool process_output = false; // collect data for the output tensor
|
||||
bool compute_ppl = true; // whether to compute perplexity
|
||||
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
|
||||
bool process_output = false; // collect data for the output tensor
|
||||
bool compute_ppl = true; // whether to compute perplexity
|
||||
bool show_statistics = false; // show imatrix statistics per tensor
|
||||
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
|
||||
|
||||
// cvector-generator params
|
||||
int n_pca_batch = 100;
|
||||
@@ -426,6 +462,11 @@ struct common_params {
|
||||
|
||||
// common params
|
||||
std::string out_file; // output filename for all example programs
|
||||
// optional callback for model loading progress and cancellation:
|
||||
// called with a progress value between 0.0 and 1.0.
|
||||
// return false from callback to abort model loading or true to continue
|
||||
llama_progress_callback load_progress_callback = NULL;
|
||||
void * load_progress_callback_user_data = NULL;
|
||||
};
|
||||
|
||||
// call once at the start of a program if it uses libcommon
|
||||
@@ -503,10 +544,10 @@ static bool string_starts_with(const std::string & str,
|
||||
return str.rfind(prefix, 0) == 0;
|
||||
}
|
||||
|
||||
static bool string_ends_with(const std::string & str,
|
||||
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
// While we wait for C++20's std::string::ends_with...
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
||||
bool string_remove_suffix(std::string & str, const std::string_view & suffix);
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
||||
|
||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||
void string_process_escapes(std::string & input);
|
||||
@@ -615,16 +656,6 @@ std::string common_detokenize(
|
||||
const std::vector<llama_token> & tokens,
|
||||
bool special = true);
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
//
|
||||
|
||||
// Dump the KV cache view with the number of sequences per cell.
|
||||
void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80);
|
||||
|
||||
// Dump the KV cache view showing individual sequences in each cell (long output).
|
||||
void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
|
||||
|
||||
//
|
||||
// Embedding utils
|
||||
//
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
@@ -40,49 +41,6 @@ static std::string build_repetition(const std::string & item_rule, int min_items
|
||||
return result;
|
||||
}
|
||||
|
||||
/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
|
||||
class string_view {
|
||||
const std::string & _str;
|
||||
const size_t _start;
|
||||
const size_t _end;
|
||||
public:
|
||||
string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
|
||||
|
||||
size_t size() const {
|
||||
return _end - _start;
|
||||
}
|
||||
|
||||
size_t length() const {
|
||||
return size();
|
||||
}
|
||||
|
||||
operator std::string() const {
|
||||
return str();
|
||||
}
|
||||
|
||||
std::string str() const {
|
||||
return _str.substr(_start, _end - _start);
|
||||
}
|
||||
|
||||
string_view substr(size_t pos, size_t len = std::string::npos) const {
|
||||
return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
|
||||
}
|
||||
|
||||
char operator[](size_t pos) const {
|
||||
auto index = _start + pos;
|
||||
if (index >= _end) {
|
||||
throw std::out_of_range("string_view index out of range");
|
||||
}
|
||||
return _str[_start + pos];
|
||||
}
|
||||
|
||||
bool operator==(const string_view & other) const {
|
||||
std::string this_str = *this;
|
||||
std::string other_str = other;
|
||||
return this_str == other_str;
|
||||
}
|
||||
};
|
||||
|
||||
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
||||
auto has_min = min_value != std::numeric_limits<int>::min();
|
||||
auto has_max = max_value != std::numeric_limits<int>::max();
|
||||
@@ -111,14 +69,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
||||
}
|
||||
out << "}";
|
||||
};
|
||||
std::function<void(const string_view &, const string_view &)> uniform_range =
|
||||
[&](const string_view & from, const string_view & to) {
|
||||
std::function<void(const std::string_view &, const std::string_view &)> uniform_range =
|
||||
[&](const std::string_view & from, const std::string_view & to) {
|
||||
size_t i = 0;
|
||||
while (i < from.length() && i < to.length() && from[i] == to[i]) {
|
||||
i++;
|
||||
}
|
||||
if (i > 0) {
|
||||
out << "\"" << from.substr(0, i).str() << "\"";
|
||||
out << "\"" << from.substr(0, i) << "\"";
|
||||
}
|
||||
if (i < from.length() && i < to.length()) {
|
||||
if (i > 0) {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
||||
bool force_gbnf = false);
|
||||
|
||||
15
llama/llama.cpp/common/sampling.cpp
vendored
15
llama/llama.cpp/common/sampling.cpp
vendored
@@ -161,7 +161,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
} else {
|
||||
std::vector<std::string> patterns_at_start;
|
||||
std::vector<std::string> trigger_patterns;
|
||||
std::vector<std::string> patterns_anywhere;
|
||||
std::vector<llama_token> trigger_tokens;
|
||||
for (const auto & trigger : params.grammar_triggers) {
|
||||
@@ -173,10 +173,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
|
||||
{
|
||||
const auto & pattern = trigger.value;
|
||||
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
|
||||
patterns_anywhere.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||
{
|
||||
trigger_patterns.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||
@@ -190,10 +193,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> trigger_patterns;
|
||||
if (!patterns_at_start.empty()) {
|
||||
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
if (!patterns_anywhere.empty()) {
|
||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
|
||||
425
llama/llama.cpp/include/llama.h
vendored
425
llama/llama.cpp/include/llama.h
vendored
@@ -61,59 +61,23 @@ extern "C" {
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
struct llama_sampler;
|
||||
struct llama_kv_cache;
|
||||
|
||||
typedef struct llama_memory_i * llama_memory_t;
|
||||
|
||||
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
|
||||
|
||||
typedef int32_t llama_pos;
|
||||
typedef int32_t llama_token;
|
||||
typedef int32_t llama_seq_id;
|
||||
|
||||
enum llama_vocab_type {
|
||||
LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
|
||||
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
|
||||
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
||||
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
||||
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
||||
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
|
||||
};
|
||||
|
||||
// pre-tokenization types
|
||||
enum llama_vocab_pre_type {
|
||||
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
|
||||
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
|
||||
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
|
||||
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
|
||||
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
||||
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
||||
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
||||
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
|
||||
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
|
||||
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
|
||||
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
|
||||
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
|
||||
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
|
||||
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
|
||||
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
|
||||
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
|
||||
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
|
||||
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
|
||||
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
|
||||
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
||||
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
||||
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
||||
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
||||
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
||||
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
||||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
||||
LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
|
||||
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
|
||||
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
||||
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
||||
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
||||
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
|
||||
LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
@@ -188,6 +152,7 @@ extern "C" {
|
||||
//LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack
|
||||
LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors
|
||||
|
||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||
};
|
||||
@@ -240,18 +205,21 @@ extern "C" {
|
||||
|
||||
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
||||
|
||||
// Input data for llama_decode
|
||||
// Input data for llama_encode/llama_decode
|
||||
// A llama_batch object can contain input about one or many sequences
|
||||
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
|
||||
//
|
||||
// - token : the token ids of the input (used when embd is NULL)
|
||||
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||
// - pos : the positions of the respective token in the sequence
|
||||
// (if set to NULL, the token position will be tracked automatically by llama_decode)
|
||||
// (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode)
|
||||
// - seq_id : the sequence to which the respective token belongs
|
||||
// (if set to NULL, the sequence ID will be assumed to be 0)
|
||||
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||
// (if set to NULL, only the logits for last token will be returned)
|
||||
// (if set to NULL:
|
||||
// - if embeddings: all tokens are output
|
||||
// - if not: only the last token is output
|
||||
// )
|
||||
//
|
||||
typedef struct llama_batch {
|
||||
int32_t n_tokens;
|
||||
@@ -261,7 +229,7 @@ extern "C" {
|
||||
llama_pos * pos;
|
||||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
} llama_batch;
|
||||
|
||||
enum llama_model_kv_override_type {
|
||||
@@ -317,10 +285,11 @@ extern "C" {
|
||||
const struct llama_model_kv_override * kv_overrides;
|
||||
|
||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool use_mmap; // use mmap if possible
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool check_tensors; // validate model tensor data
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool use_mmap; // use mmap if possible
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool check_tensors; // validate model tensor data
|
||||
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
|
||||
};
|
||||
|
||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||
@@ -345,7 +314,7 @@ extern "C" {
|
||||
float yarn_beta_fast; // YaRN low correction dim
|
||||
float yarn_beta_slow; // YaRN high correction dim
|
||||
uint32_t yarn_orig_ctx; // YaRN original context size
|
||||
float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
|
||||
float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
@@ -361,10 +330,16 @@ extern "C" {
|
||||
|
||||
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||
bool embeddings; // if true, extract embeddings (together with logits)
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
||||
bool no_perf; // whether to measure performance timings
|
||||
bool op_offload; // whether to offload host tensor operations to device
|
||||
bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
|
||||
bool flash_attn; // use flash attention [EXPERIMENTAL]
|
||||
bool no_perf; // measure performance timings
|
||||
bool op_offload; // offload host tensor operations to device
|
||||
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
|
||||
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
|
||||
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
@@ -381,6 +356,7 @@ extern "C" {
|
||||
void * imatrix; // pointer to importance matrix data
|
||||
void * kv_overrides; // pointer to vector containing overrides
|
||||
void * tensor_types; // pointer to vector containing tensor types
|
||||
void * prune_layers; // pointer to vector containing layer indices to prune
|
||||
} llama_model_quantize_params;
|
||||
|
||||
typedef struct llama_logit_bias {
|
||||
@@ -470,6 +446,7 @@ extern "C" {
|
||||
LLAMA_API int64_t llama_time_us(void);
|
||||
|
||||
LLAMA_API size_t llama_max_devices(void);
|
||||
LLAMA_API size_t llama_max_parallel_sequences(void);
|
||||
|
||||
LLAMA_API bool llama_supports_mmap (void);
|
||||
LLAMA_API bool llama_supports_mlock (void);
|
||||
@@ -489,9 +466,11 @@ extern "C" {
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
||||
|
||||
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
||||
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
|
||||
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
|
||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
|
||||
|
||||
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
|
||||
|
||||
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
||||
|
||||
@@ -500,10 +479,18 @@ extern "C" {
|
||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
|
||||
|
||||
// Get the model's RoPE frequency scaling factor
|
||||
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
||||
|
||||
// Returns the number of classifier outputs (only valid for classifier models)
|
||||
// Undefined behavior for non-classifier models
|
||||
LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model);
|
||||
|
||||
// Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
|
||||
LLAMA_API const char * llama_model_cls_label(const struct llama_model * model, uint32_t i);
|
||||
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
|
||||
|
||||
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
|
||||
@@ -552,6 +539,9 @@ extern "C" {
|
||||
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
|
||||
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
|
||||
|
||||
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
|
||||
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
|
||||
|
||||
// Returns 0 on success
|
||||
LLAMA_API uint32_t llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
@@ -604,210 +594,190 @@ extern "C" {
|
||||
int32_t il_end);
|
||||
|
||||
//
|
||||
// KV cache
|
||||
// Memory
|
||||
//
|
||||
|
||||
// TODO: start using struct llama_kv_cache
|
||||
|
||||
// Information associated with an individual cell in the KV cache view.
|
||||
struct llama_kv_cache_view_cell {
|
||||
// The position for this cell. Takes KV cache shifts into account.
|
||||
// May be negative if the cell is not populated.
|
||||
llama_pos pos;
|
||||
};
|
||||
|
||||
// An updateable view of the KV cache.
|
||||
struct llama_kv_cache_view {
|
||||
// Number of KV cache cells. This will be the same as the context size.
|
||||
int32_t n_cells;
|
||||
|
||||
// Maximum number of sequences that can exist in a cell. It's not an error
|
||||
// if there are more sequences in a cell than this value, however they will
|
||||
// not be visible in the view cells_sequences.
|
||||
int32_t n_seq_max;
|
||||
|
||||
// Number of tokens in the cache. For example, if there are two populated
|
||||
// cells, the first with 1 sequence id in it and the second with 2 sequence
|
||||
// ids then you'll have 3 tokens.
|
||||
int32_t token_count;
|
||||
|
||||
// Number of populated cache cells.
|
||||
int32_t used_cells;
|
||||
|
||||
// Maximum contiguous empty slots in the cache.
|
||||
int32_t max_contiguous;
|
||||
|
||||
// Index to the start of the max_contiguous slot range. Can be negative
|
||||
// when cache is full.
|
||||
int32_t max_contiguous_idx;
|
||||
|
||||
// Information for an individual cell.
|
||||
struct llama_kv_cache_view_cell * cells;
|
||||
|
||||
// The sequences for each cell. There will be n_seq_max items per cell.
|
||||
llama_seq_id * cells_sequences;
|
||||
};
|
||||
|
||||
// Create an empty KV cache view. (use only for debugging purposes)
|
||||
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
|
||||
|
||||
// Free a KV cache view. (use only for debugging purposes)
|
||||
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
|
||||
|
||||
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
||||
// TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx)
|
||||
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
||||
|
||||
///
|
||||
|
||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
|
||||
|
||||
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
||||
"use llama_kv_self_n_tokens instead");
|
||||
|
||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
|
||||
|
||||
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
|
||||
"use llama_kv_self_used_cells instead");
|
||||
|
||||
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
||||
LLAMA_API void llama_kv_self_clear(
|
||||
struct llama_context * ctx);
|
||||
// Clear the memory contents
|
||||
// If data == true, the data buffers will also be cleared together with the metadata
|
||||
LLAMA_API void llama_memory_clear(
|
||||
llama_memory_t mem,
|
||||
bool data);
|
||||
|
||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||
// seq_id < 0 : match any sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API bool llama_kv_self_seq_rm(
|
||||
LLAMA_API bool llama_memory_seq_rm(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
|
||||
// Copy all tokens that belong to the specified sequence to another sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_memory_seq_cp(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
|
||||
// Removes all tokens that do not belong to the specified sequence
|
||||
LLAMA_API void llama_memory_seq_keep(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_memory_seq_add(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta);
|
||||
|
||||
// Integer division of the positions by factor of `d > 1`
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_memory_seq_div(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d);
|
||||
|
||||
// Returns the smallest position present in the memory for the specified sequence
|
||||
// This is typically non-zero only for SWA caches
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
|
||||
// Return -1 if the sequence is empty
|
||||
LLAMA_API llama_pos llama_memory_seq_pos_min(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Returns the largest position present in the memory for the specified sequence
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
|
||||
// Return -1 if the sequence is empty
|
||||
LLAMA_API llama_pos llama_memory_seq_pos_max(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Check if the memory supports shifting
|
||||
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
|
||||
|
||||
//
|
||||
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
|
||||
//
|
||||
|
||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||
DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
|
||||
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
|
||||
|
||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||
DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
|
||||
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
|
||||
|
||||
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_clear(
|
||||
struct llama_context * ctx),
|
||||
"Use llama_memory_clear() instead");
|
||||
|
||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||
// seq_id < 0 : match any sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
llama_pos p1),
|
||||
"Use llama_memory_seq_rm() instead");
|
||||
|
||||
// Copy all tokens that belong to the specified sequence to another sequence
|
||||
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_self_seq_cp(
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_cp(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
llama_pos p1),
|
||||
"Use llama_memory_seq_cp() instead");
|
||||
|
||||
// Removes all tokens that do not belong to the specified sequence
|
||||
LLAMA_API void llama_kv_self_seq_keep(
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_keep(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
llama_seq_id seq_id),
|
||||
"Use llama_memory_seq_keep() instead");
|
||||
|
||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_self_update()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_self_seq_add(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta);
|
||||
|
||||
// Integer division of the positions by factor of `d > 1`
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_self_update()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_self_seq_div(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d);
|
||||
|
||||
// Returns the largest position present in the KV cache for the specified sequence
|
||||
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_self_update()
|
||||
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
|
||||
|
||||
// Check if the context supports KV cache shifting
|
||||
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
||||
|
||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_clear(
|
||||
struct llama_context * ctx),
|
||||
"use llama_kv_self_clear instead");
|
||||
|
||||
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1),
|
||||
"use llama_kv_self_seq_rm instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1),
|
||||
"use llama_kv_self_seq_cp instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"use llama_kv_self_seq_keep instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_add(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta),
|
||||
"use llama_kv_self_seq_add instead");
|
||||
"Use llama_memory_seq_add() instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
|
||||
// Integer division of the positions by factor of `d > 1`
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_div(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d),
|
||||
"use llama_kv_self_seq_div instead");
|
||||
"Use llama_memory_seq_div() instead");
|
||||
|
||||
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
||||
// Returns the smallest position present in the KV cache for the specified sequence
|
||||
// This is typically non-zero only for SWA caches
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
||||
// Return -1 if the sequence is empty
|
||||
DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"use llama_kv_self_seq_pos_max instead");
|
||||
"Use llama_memory_seq_pos_min() instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
|
||||
"use llama_kv_self_defrag instead");
|
||||
// Returns the largest position present in the KV cache for the specified sequence
|
||||
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
|
||||
// Return -1 if the sequence is empty
|
||||
DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"Use llama_memory_seq_pos_max() instead");
|
||||
|
||||
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
|
||||
"use llama_kv_self_can_shift instead");
|
||||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
|
||||
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
|
||||
"use llama_kv_self_update instead");
|
||||
// Check if the context supports KV cache shifting
|
||||
DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx),
|
||||
"use llama_memory_can_shift() instead");
|
||||
|
||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
|
||||
"simply remove this call, updates are applied lazily on the next llama_decode()");
|
||||
|
||||
//
|
||||
// State / sessions
|
||||
//
|
||||
|
||||
// Returns the *actual* size in bytes of the state
|
||||
// (logits, embedding and kv_cache)
|
||||
// (logits, embedding and memory)
|
||||
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
||||
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
||||
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
||||
@@ -863,12 +833,12 @@ extern "C" {
|
||||
size_t n_token_count),
|
||||
"use llama_state_save_file instead");
|
||||
|
||||
// Get the exact size needed to copy the KV cache of a single sequence
|
||||
// Get the exact size needed to copy the state of a single sequence
|
||||
LLAMA_API size_t llama_state_seq_get_size(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Copy the KV cache of a single sequence into the specified buffer
|
||||
// Copy the state of a single sequence into the specified buffer
|
||||
LLAMA_API size_t llama_state_seq_get_data(
|
||||
struct llama_context * ctx,
|
||||
uint8_t * dst,
|
||||
@@ -934,18 +904,23 @@ extern "C" {
|
||||
// For encode-decoder contexts, processes the batch using the encoder.
|
||||
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
|
||||
// 0 - success
|
||||
// < 0 - error. the KV cache state is restored to the state before this call
|
||||
// < 0 - error. the memory state is restored to the state before this call
|
||||
LLAMA_API int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
// Process a batch of tokens.
|
||||
// Requires KV cache.
|
||||
// Requires the context to have a memory.
|
||||
// For encode-decoder contexts, processes the batch using the decoder.
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
// < 0 - error. the KV cache state is restored to the state before this call
|
||||
// Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
|
||||
// To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
|
||||
// Upon other return values, the memory state is restored to the state before this call
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
// 2 - aborted (processed ubatches will remain in the context's memory)
|
||||
// -1 - invalid input batch
|
||||
// < -1 - fatal error (processed ubatches will remain in the context's memory)
|
||||
LLAMA_API int32_t llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
@@ -961,8 +936,8 @@ extern "C" {
|
||||
// Get the number of threads used for prompt and batch processing (multiple token).
|
||||
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
|
||||
|
||||
// Set whether the model is in embeddings mode or not
|
||||
// If true, embeddings will be returned but logits will not
|
||||
// Set whether the context outputs embeddings or not
|
||||
// TODO: rename to avoid confusion with llama_get_embeddings()
|
||||
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
|
||||
|
||||
// Set whether to use causal attention or not
|
||||
@@ -986,6 +961,7 @@ extern "C" {
|
||||
// in the order they have appeared in the batch.
|
||||
// Rows: number of tokens for which llama_batch.logits[i] != 0
|
||||
// Cols: n_vocab
|
||||
// TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
|
||||
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||
|
||||
// Logits for the ith token. For positive indices, Equivalent to:
|
||||
@@ -1000,6 +976,7 @@ extern "C" {
|
||||
// in the order they have appeared in the batch.
|
||||
// shape: [n_outputs*n_embd]
|
||||
// Otherwise, returns NULL.
|
||||
// TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
|
||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
|
||||
// Get the embeddings for the ith token. For positive indices, Equivalent to:
|
||||
@@ -1011,7 +988,7 @@ extern "C" {
|
||||
|
||||
// Get the embeddings for a sequence id
|
||||
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
||||
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
|
||||
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence
|
||||
// otherwise: float[n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
|
||||
@@ -1038,9 +1015,11 @@ extern "C" {
|
||||
LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
|
||||
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
|
||||
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
|
||||
LLAMA_API llama_token llama_vocab_mask(const struct llama_vocab * vocab); // mask
|
||||
|
||||
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
|
||||
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
|
||||
LLAMA_API bool llama_vocab_get_add_sep(const struct llama_vocab * vocab);
|
||||
|
||||
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
|
||||
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);
|
||||
@@ -1084,6 +1063,7 @@ extern "C" {
|
||||
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
|
||||
/// @return Returns the number of tokens on success, no more than n_tokens_max
|
||||
/// @return Returns a negative number on failure - the number of tokens that would have been returned
|
||||
/// @return Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit)
|
||||
/// @param add_special Allow to add BOS and EOS tokens if model is configured to do so.
|
||||
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
|
||||
/// as plaintext. Does not insert a leading space.
|
||||
@@ -1421,6 +1401,7 @@ extern "C" {
|
||||
|
||||
int32_t n_p_eval;
|
||||
int32_t n_eval;
|
||||
int32_t n_reused; // number of times a ggml compute graph had been reused
|
||||
};
|
||||
|
||||
struct llama_perf_sampler_data {
|
||||
|
||||
599
llama/llama.cpp/src/llama-arch.cpp
vendored
599
llama/llama.cpp/src/llama-arch.cpp
vendored
@@ -20,6 +20,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_BERT, "bert" },
|
||||
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
|
||||
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
|
||||
{ LLM_ARCH_NEO_BERT, "neo-bert" },
|
||||
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
|
||||
{ LLM_ARCH_BLOOM, "bloom" },
|
||||
{ LLM_ARCH_STABLELM, "stablelm" },
|
||||
@@ -33,6 +34,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_PHI3, "phi3" },
|
||||
{ LLM_ARCH_PHIMOE, "phimoe" },
|
||||
{ LLM_ARCH_PLAMO, "plamo" },
|
||||
{ LLM_ARCH_PLAMO2, "plamo2" },
|
||||
{ LLM_ARCH_CODESHELL, "codeshell" },
|
||||
{ LLM_ARCH_ORION, "orion" },
|
||||
{ LLM_ARCH_INTERNLM2, "internlm2" },
|
||||
@@ -41,8 +43,12 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_GEMMA, "gemma" },
|
||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||
{ LLM_ARCH_GEMMA3, "gemma3" },
|
||||
{ LLM_ARCH_GEMMA3N, "gemma3n" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
{ LLM_ARCH_MAMBA2, "mamba2" },
|
||||
{ LLM_ARCH_JAMBA, "jamba" },
|
||||
{ LLM_ARCH_FALCON_H1, "falcon-h1" },
|
||||
{ LLM_ARCH_XVERSE, "xverse" },
|
||||
{ LLM_ARCH_COMMAND_R, "command-r" },
|
||||
{ LLM_ARCH_COHERE2, "cohere2" },
|
||||
@@ -56,23 +62,38 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
||||
{ LLM_ARCH_CHATGLM, "chatglm" },
|
||||
{ LLM_ARCH_GLM4, "glm4" },
|
||||
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
|
||||
{ LLM_ARCH_BITNET, "bitnet" },
|
||||
{ LLM_ARCH_T5, "t5" },
|
||||
{ LLM_ARCH_T5ENCODER, "t5encoder" },
|
||||
{ LLM_ARCH_JAIS, "jais" },
|
||||
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
||||
{ LLM_ARCH_EXAONE, "exaone" },
|
||||
{ LLM_ARCH_EXAONE4, "exaone4" },
|
||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
||||
{ LLM_ARCH_RWKV7, "rwkv7" },
|
||||
{ LLM_ARCH_ARWKV7, "arwkv7" },
|
||||
{ LLM_ARCH_GRANITE, "granite" },
|
||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||
{ LLM_ARCH_GRANITE_HYBRID, "granitehybrid" },
|
||||
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
||||
{ LLM_ARCH_SOLAR, "solar" },
|
||||
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
||||
{ LLM_ARCH_PLM, "plm" },
|
||||
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
||||
{ LLM_ARCH_DOTS1, "dots1" },
|
||||
{ LLM_ARCH_ARCEE, "arcee" },
|
||||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
||||
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
|
||||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||
{ LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" },
|
||||
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
||||
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
|
||||
{ LLM_ARCH_LFM2, "lfm2" },
|
||||
{ LLM_ARCH_DREAM, "dream" },
|
||||
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
|
||||
{ LLM_ARCH_LLADA, "llada" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -109,6 +130,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
|
||||
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
|
||||
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
|
||||
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
|
||||
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
|
||||
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
||||
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
||||
@@ -166,6 +188,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
|
||||
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
|
||||
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
|
||||
{ LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
|
||||
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
|
||||
|
||||
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
|
||||
@@ -176,6 +199,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" },
|
||||
{ LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" },
|
||||
|
||||
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
|
||||
|
||||
{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
|
||||
|
||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
||||
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
||||
@@ -194,13 +221,13 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
|
||||
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
|
||||
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
|
||||
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
|
||||
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
|
||||
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
|
||||
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
|
||||
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
|
||||
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
|
||||
@@ -244,6 +271,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_ARCEE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LLAMA4,
|
||||
{
|
||||
@@ -450,6 +495,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
||||
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
||||
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
@@ -494,6 +540,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_NEO_BERT,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
|
||||
{ LLM_TENSOR_CLS, "cls" },
|
||||
{ LLM_TENSOR_CLS_OUT, "cls.output" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_JINA_BERT_V2,
|
||||
{
|
||||
@@ -735,6 +796,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_PLAMO2,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||
{ LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
|
||||
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
{ LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
|
||||
{ LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
|
||||
{ LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_CODESHELL,
|
||||
{
|
||||
@@ -894,6 +985,42 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GEMMA3N,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
{ LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
|
||||
{ LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
|
||||
{ LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" },
|
||||
{ LLM_TENSOR_ALTUP_UNEMBD_PROJ, "altup_unembd_proj" },
|
||||
{ LLM_TENSOR_ALTUP_PROJ, "altup_proj" },
|
||||
{ LLM_TENSOR_PER_LAYER_INP_GATE, "blk.%d.inp_gate" },
|
||||
{ LLM_TENSOR_PER_LAYER_PROJ, "blk.%d.proj" },
|
||||
{ LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" },
|
||||
{ LLM_TENSOR_ALTUP_CORRECT_COEF, "blk.%d.altup_correct_coef" },
|
||||
{ LLM_TENSOR_ALTUP_CORRECT_SCALE, "blk.%d.altup_correct_scale" },
|
||||
{ LLM_TENSOR_ALTUP_PREDICT_COEF, "blk.%d.altup_predict_coef" },
|
||||
{ LLM_TENSOR_ALTUP_ROUTER, "blk.%d.altup_router" },
|
||||
{ LLM_TENSOR_ALTUP_ROUTER_NORM, "blk.%d.altup_router_norm" },
|
||||
{ LLM_TENSOR_LAUREL_L, "blk.%d.laurel_l" },
|
||||
{ LLM_TENSOR_LAUREL_R, "blk.%d.laurel_r" },
|
||||
{ LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_STARCODER2,
|
||||
{
|
||||
@@ -928,6 +1055,77 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_MAMBA2,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
||||
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_JAMBA,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||
{ LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
|
||||
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||
{ LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
|
||||
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||
{ LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
|
||||
{ LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
|
||||
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_FALCON_H1,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
||||
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_XVERSE,
|
||||
{
|
||||
@@ -1198,6 +1396,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GLM4_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
// NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number)
|
||||
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
|
||||
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
|
||||
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
|
||||
{ LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" },
|
||||
{ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
|
||||
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_BITNET,
|
||||
{
|
||||
@@ -1321,6 +1553,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_EXAONE4,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
}
|
||||
},
|
||||
{
|
||||
LLM_ARCH_RWKV6,
|
||||
{
|
||||
@@ -1483,6 +1735,46 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GRANITE_HYBRID,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
// mamba(2) ssm layers
|
||||
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
||||
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
// attention layers
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
// dense FFN
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
// moe FFN
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
// shared expert
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -1570,6 +1862,231 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_DOTS1,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
}
|
||||
},
|
||||
{
|
||||
LLM_ARCH_ERNIE4_5,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_ERNIE4_5_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_HUNYUAN_DENSE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_SMOLLM3,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_OPENAI_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_SINKS, "blk.%d.attn_sinks" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LFM2,
|
||||
{
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
|
||||
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
|
||||
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
}
|
||||
},
|
||||
{
|
||||
LLM_ARCH_SMALLTHINKER,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_DREAM,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LLADA,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
@@ -1609,6 +2126,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_SINKS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SCALE}},
|
||||
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
@@ -1654,7 +2172,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
|
||||
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
|
||||
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
|
||||
{LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
@@ -1698,6 +2220,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
// altup / laurel (gemma 3n)
|
||||
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ALTUP_CORRECT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ALTUP_ROUTER, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ALTUP_ROUTER_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_LAUREL_L, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_LAUREL_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
// this tensor is loaded for T5, but never used
|
||||
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
|
||||
{LLM_TENSOR_BSKCN_TV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
@@ -1717,13 +2256,30 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
|
||||
{LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
|
||||
// These tensors only exist in the last layer(s) and are treated as output tensors
|
||||
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
};
|
||||
|
||||
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
||||
|
||||
std::string LLM_KV::operator()(llm_kv kv) const {
|
||||
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
|
||||
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
||||
std::string name = ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
||||
|
||||
if (suffix != nullptr) {
|
||||
name += ".";
|
||||
name += suffix;
|
||||
}
|
||||
|
||||
return name;
|
||||
}
|
||||
|
||||
std::string LLM_TN_IMPL::str() const {
|
||||
@@ -1762,3 +2318,40 @@ llm_arch llm_arch_from_string(const std::string & name) {
|
||||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
|
||||
return LLM_TENSOR_INFOS.at(tensor);
|
||||
}
|
||||
|
||||
bool llm_arch_is_recurrent(const llm_arch & arch) {
|
||||
switch (arch) {
|
||||
case LLM_ARCH_MAMBA:
|
||||
case LLM_ARCH_MAMBA2:
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_arch_is_hybrid(const llm_arch & arch) {
|
||||
switch (arch) {
|
||||
case LLM_ARCH_JAMBA:
|
||||
case LLM_ARCH_FALCON_H1:
|
||||
case LLM_ARCH_PLAMO2:
|
||||
case LLM_ARCH_GRANITE_HYBRID:
|
||||
case LLM_ARCH_LFM2:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_arch_is_diffusion(const llm_arch & arch) {
|
||||
switch (arch) {
|
||||
case LLM_ARCH_DREAM:
|
||||
case LLM_ARCH_LLADA:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
63
llama/llama.cpp/src/llama-arch.h
vendored
63
llama/llama.cpp/src/llama-arch.h
vendored
@@ -24,6 +24,7 @@ enum llm_arch {
|
||||
LLM_ARCH_BERT,
|
||||
LLM_ARCH_NOMIC_BERT,
|
||||
LLM_ARCH_NOMIC_BERT_MOE,
|
||||
LLM_ARCH_NEO_BERT,
|
||||
LLM_ARCH_JINA_BERT_V2,
|
||||
LLM_ARCH_BLOOM,
|
||||
LLM_ARCH_STABLELM,
|
||||
@@ -37,6 +38,7 @@ enum llm_arch {
|
||||
LLM_ARCH_PHI3,
|
||||
LLM_ARCH_PHIMOE,
|
||||
LLM_ARCH_PLAMO,
|
||||
LLM_ARCH_PLAMO2,
|
||||
LLM_ARCH_CODESHELL,
|
||||
LLM_ARCH_ORION,
|
||||
LLM_ARCH_INTERNLM2,
|
||||
@@ -45,8 +47,12 @@ enum llm_arch {
|
||||
LLM_ARCH_GEMMA,
|
||||
LLM_ARCH_GEMMA2,
|
||||
LLM_ARCH_GEMMA3,
|
||||
LLM_ARCH_GEMMA3N,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
LLM_ARCH_MAMBA2,
|
||||
LLM_ARCH_JAMBA,
|
||||
LLM_ARCH_FALCON_H1,
|
||||
LLM_ARCH_XVERSE,
|
||||
LLM_ARCH_COMMAND_R,
|
||||
LLM_ARCH_COHERE2,
|
||||
@@ -60,23 +66,38 @@ enum llm_arch {
|
||||
LLM_ARCH_DEEPSEEK2,
|
||||
LLM_ARCH_CHATGLM,
|
||||
LLM_ARCH_GLM4,
|
||||
LLM_ARCH_GLM4_MOE,
|
||||
LLM_ARCH_BITNET,
|
||||
LLM_ARCH_T5,
|
||||
LLM_ARCH_T5ENCODER,
|
||||
LLM_ARCH_JAIS,
|
||||
LLM_ARCH_NEMOTRON,
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_EXAONE4,
|
||||
LLM_ARCH_RWKV6,
|
||||
LLM_ARCH_RWKV6QWEN2,
|
||||
LLM_ARCH_RWKV7,
|
||||
LLM_ARCH_ARWKV7,
|
||||
LLM_ARCH_GRANITE,
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
LLM_ARCH_GRANITE_HYBRID,
|
||||
LLM_ARCH_CHAMELEON,
|
||||
LLM_ARCH_SOLAR,
|
||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||
LLM_ARCH_PLM,
|
||||
LLM_ARCH_BAILINGMOE,
|
||||
LLM_ARCH_DOTS1,
|
||||
LLM_ARCH_ARCEE,
|
||||
LLM_ARCH_ERNIE4_5,
|
||||
LLM_ARCH_ERNIE4_5_MOE,
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
LLM_ARCH_HUNYUAN_DENSE,
|
||||
LLM_ARCH_SMOLLM3,
|
||||
LLM_ARCH_OPENAI_MOE,
|
||||
LLM_ARCH_LFM2,
|
||||
LLM_ARCH_DREAM,
|
||||
LLM_ARCH_SMALLTHINKER,
|
||||
LLM_ARCH_LLADA,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -113,6 +134,7 @@ enum llm_kv {
|
||||
LLM_KV_EXPERT_WEIGHTS_NORM,
|
||||
LLM_KV_EXPERT_GATING_FUNC,
|
||||
LLM_KV_MOE_EVERY_N_LAYERS,
|
||||
LLM_KV_NEXTN_PREDICT_LAYERS,
|
||||
LLM_KV_POOLING_TYPE,
|
||||
LLM_KV_LOGIT_SCALE,
|
||||
LLM_KV_DECODER_START_TOKEN_ID,
|
||||
@@ -170,6 +192,7 @@ enum llm_kv {
|
||||
LLM_KV_SSM_CONV_KERNEL,
|
||||
LLM_KV_SSM_STATE_SIZE,
|
||||
LLM_KV_SSM_TIME_STEP_RANK,
|
||||
LLM_KV_SSM_GROUP_COUNT,
|
||||
LLM_KV_SSM_DT_B_C_RMS,
|
||||
|
||||
LLM_KV_WKV_HEAD_SIZE,
|
||||
@@ -192,13 +215,13 @@ enum llm_kv {
|
||||
LLM_KV_TOKENIZER_MASK_ID,
|
||||
LLM_KV_TOKENIZER_ADD_BOS,
|
||||
LLM_KV_TOKENIZER_ADD_EOS,
|
||||
LLM_KV_TOKENIZER_ADD_SEP,
|
||||
LLM_KV_TOKENIZER_ADD_PREFIX,
|
||||
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
|
||||
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
|
||||
LLM_KV_TOKENIZER_HF_JSON,
|
||||
LLM_KV_TOKENIZER_RWKV,
|
||||
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
|
||||
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
|
||||
LLM_KV_TOKENIZER_FIM_PRE_ID,
|
||||
LLM_KV_TOKENIZER_FIM_SUF_ID,
|
||||
LLM_KV_TOKENIZER_FIM_MID_ID,
|
||||
@@ -215,6 +238,10 @@ enum llm_kv {
|
||||
LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
|
||||
LLM_KV_CONVNEXT_BLOCK_COUNT,
|
||||
|
||||
LLM_KV_CLASSIFIER_OUTPUT_LABELS,
|
||||
|
||||
LLM_KV_SHORTCONV_L_CACHE,
|
||||
|
||||
// deprecated:
|
||||
LLM_KV_TOKENIZER_PREFIX_ID,
|
||||
LLM_KV_TOKENIZER_SUFFIX_ID,
|
||||
@@ -241,6 +268,7 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_OUT_NORM,
|
||||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||
LLM_TENSOR_ATTN_SINKS,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
@@ -265,12 +293,32 @@ enum llm_tensor {
|
||||
LLM_TENSOR_LAYER_OUT_NORM,
|
||||
LLM_TENSOR_POST_ATTN_NORM,
|
||||
LLM_TENSOR_POST_MLP_NORM,
|
||||
LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n
|
||||
LLM_TENSOR_PER_LAYER_MODEL_PROJ, // gemma3n
|
||||
LLM_TENSOR_PER_LAYER_INP_GATE, // gemma3n
|
||||
LLM_TENSOR_PER_LAYER_PROJ, // gemma3n
|
||||
LLM_TENSOR_PER_LAYER_PROJ_NORM, // gemma3n
|
||||
LLM_TENSOR_PER_LAYER_POST_NORM, // gemma3n
|
||||
LLM_TENSOR_ALTUP_PROJ, // gemma3n
|
||||
LLM_TENSOR_ALTUP_UNEMBD_PROJ, // gemma3n
|
||||
LLM_TENSOR_ALTUP_CORRECT_COEF, // gemma3n
|
||||
LLM_TENSOR_ALTUP_CORRECT_SCALE, // gemma3n
|
||||
LLM_TENSOR_ALTUP_PREDICT_COEF, // gemma3n
|
||||
LLM_TENSOR_ALTUP_ROUTER, // gemma3n
|
||||
LLM_TENSOR_ALTUP_ROUTER_NORM, // gemma3n
|
||||
LLM_TENSOR_LAUREL_L, // gemma3n
|
||||
LLM_TENSOR_LAUREL_R, // gemma3n
|
||||
LLM_TENSOR_LAUREL_POST_NORM, // gemma3n
|
||||
LLM_TENSOR_SSM_IN,
|
||||
LLM_TENSOR_SSM_CONV1D,
|
||||
LLM_TENSOR_SSM_X,
|
||||
LLM_TENSOR_SSM_DT,
|
||||
LLM_TENSOR_SSM_DT_NORM,
|
||||
LLM_TENSOR_SSM_A,
|
||||
LLM_TENSOR_SSM_B_NORM,
|
||||
LLM_TENSOR_SSM_C_NORM,
|
||||
LLM_TENSOR_SSM_D,
|
||||
LLM_TENSOR_SSM_NORM,
|
||||
LLM_TENSOR_SSM_OUT,
|
||||
LLM_TENSOR_TIME_MIX_W0,
|
||||
LLM_TENSOR_TIME_MIX_W1,
|
||||
@@ -365,6 +413,15 @@ enum llm_tensor {
|
||||
LLM_TENSOR_POS_NET_ATTN_K,
|
||||
LLM_TENSOR_POS_NET_ATTN_V,
|
||||
LLM_TENSOR_POS_NET_ATTN_OUT,
|
||||
LLM_TENSOR_SHORTCONV_CONV,
|
||||
LLM_TENSOR_SHORTCONV_INPROJ,
|
||||
LLM_TENSOR_SHORTCONV_OUTPROJ,
|
||||
LLM_TENSOR_NEXTN_EH_PROJ,
|
||||
LLM_TENSOR_NEXTN_EMBED_TOKENS,
|
||||
LLM_TENSOR_NEXTN_ENORM,
|
||||
LLM_TENSOR_NEXTN_HNORM,
|
||||
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
|
||||
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
|
||||
};
|
||||
|
||||
enum llm_tensor_layer {
|
||||
@@ -438,3 +495,7 @@ const char * llm_arch_name(llm_arch arch);
|
||||
llm_arch llm_arch_from_string(const std::string & name);
|
||||
|
||||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
||||
|
||||
bool llm_arch_is_recurrent(const llm_arch & arch);
|
||||
bool llm_arch_is_hybrid (const llm_arch & arch);
|
||||
bool llm_arch_is_diffusion(const llm_arch & arch);
|
||||
|
||||
1089
llama/llama.cpp/src/llama-batch.cpp
vendored
1089
llama/llama.cpp/src/llama-batch.cpp
vendored
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user