Merge branch 'ollama:main' into main

This commit is contained in:
likelovewant
2025-02-26 13:46:00 +08:00
committed by GitHub
72 changed files with 6573 additions and 411 deletions

View File

@@ -111,13 +111,13 @@ jobs:
- os: windows
arch: amd64
preset: 'CUDA 12'
install: https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_551.61_windows.exe
cuda-version: '12.4'
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
cuda-version: '12.8'
- os: windows
arch: amd64
preset: 'ROCm 6'
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe
rocm-version: '6.1'
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2'
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
env:
@@ -160,6 +160,10 @@ jobs:
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'CPU'
run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:

View File

@@ -81,7 +81,7 @@ jobs:
install: https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
flags: '-DAMDGPU_TARGETS=gfx1010'
runs-on: windows
steps:
@@ -140,6 +140,13 @@ jobs:
env:
CMAKE_GENERATOR: Ninja
go_mod_tidy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: check that 'go mod tidy' is clean
run: go mod tidy --diff || (echo "Please run 'go mod tidy'." && exit 1)
test:
strategy:
matrix:
@@ -147,15 +154,83 @@ jobs:
runs-on: ${{ matrix.os }}
env:
CGO_ENABLED: '1'
GOEXPERIMENT: 'synctest'
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
- name: checkout
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
- name: cache restore
uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
# Note: unlike the other setups, this is only grabbing the mod download
# cache, rather than the whole mod directory, as the download cache
# contains zips that can be unpacked in parallel faster than they can be
# fetched and extracted by tar
path: |
~/.cache/go-build
~/go/pkg/mod/cache
~\AppData\Local\go-build
# NOTE: The -3- here should be incremented when the scheme of data to be
# cached changes (e.g. path above changes).
key: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}-${{ github.run_id }}
restore-keys: |
${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}
${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-
- name: Setup Go
uses: actions/setup-go@v5
with:
# The caching strategy of setup-go is less than ideal, and wastes
# time by not saving artifacts due to small failures like the linter
# complaining, etc. This means subsequent have to rebuild their world
# again until all checks pass. For instance, if you mispell a word,
# you're punished until you fix it. This is more hostile than
# helpful.
cache: false
go-version-file: go.mod
# TODO(bmizerany): replace this heavy tool with just the
# tools/checks/binaries we want and then make them all run in parallel
# across jobs, not on a single tiny vm on Github Actions.
- uses: golangci/golangci-lint-action@v6
with:
args: --timeout 10m0s -v
- run: go test ./...
- name: go test
# Do not skip tests in the face of linter errors, or 'go mod tidy'
# checks, which are secondary to the tests. Tests trump linters.
if: always()
run: go test -count=1 -benchtime=1x ./...
# It is tempting to run this in a platform independent way, but the past
# shows this codebase will see introductions of platform specific code
# generation, and so we need to check this per platform to ensure we
# don't abuse go generate on specific platforms.
- name: check that 'go generate' is clean
run: |
go generate ./...
git diff --name-only --exit-code || (echo "Please run 'go generate ./...'." && exit 1)
- name: cache save
# Always save the cache, even if the job fails. The artifacts produced
# during the building of test binaries are not all for naught. They can
# be used to speed up subsequent runs.
if: always()
uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
# Note: unlike the other setups, this is only grabbing the mod download
# cache, rather than the whole mod directory, as the download cache
# contains zips that can be unpacked in parallel faster than they can be
# fetched and extracted by tar
path: |
~/.cache/go-build
~/go/pkg/mod/cache
~\AppData\Local\go-build
# NOTE: The -3- here should be incremented when the scheme of data to be
# cached changes (e.g. path above changes).
key: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}-${{ github.run_id }}
patches:
runs-on: ubuntu-latest

2
.gitignore vendored
View File

@@ -7,7 +7,6 @@
0
dist
build
ollama
.cache
*.exe
.idea
@@ -16,3 +15,4 @@ test_data
__debug_bin*
llama/build
llama/vendor
/ollama

View File

@@ -6,8 +6,6 @@ linters:
- bidichk
- bodyclose
- containedctx
- contextcheck
- errcheck
- gocheckcompilerdirectives
- gofmt
- gofumpt
@@ -23,10 +21,11 @@ linters:
- staticcheck
- tenv
- unconvert
- unused
- usestdlibvars
- wastedassign
- whitespace
disable:
- usestdlibvars
- errcheck
linters-settings:
staticcheck:
checks:
@@ -39,5 +38,4 @@ severity:
- gofmt
- goimports
- intrange
- usestdlibvars
severity: info

View File

@@ -28,7 +28,7 @@
"name": "CUDA 12",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "60;61;62;70;72;75;80;86;87;89;90;90a"
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;62;70;72;75;80;86;87;89;90;90a"
}
},
{

View File

@@ -2,22 +2,24 @@
ARG FLAVOR=${TARGETARCH}
ARG ROCMVERSION=6.1.2
ARG ROCMVERSION=6.3.3
ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.2.0
ARG JETPACK6VERSION=r36.4.0
ARG CMAKEVERSION=3.31.2
FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCMVERSION}-complete AS base-amd64
RUN sed -i -e 's/mirror.centos.org/vault.centos.org/g' -e 's/^#.*baseurl=http/baseurl=http/g' -e 's/^mirrorlist=http/#mirrorlist=http/g' /etc/yum.repos.d/*.repo \
&& yum install -y yum-utils devtoolset-10-gcc devtoolset-10-gcc-c++ \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo \
&& curl -s -L https://github.com/ccache/ccache/releases/download/v4.10.2/ccache-4.10.2-linux-x86_64.tar.xz | tar -Jx -C /usr/local/bin --strip-components 1
ENV PATH=/opt/rh/devtoolset-10/root/usr/bin:/opt/rh/devtoolset-11/root/usr/bin:$PATH
# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
RUN yum install -y yum-utils \
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
FROM --platform=linux/arm64 rockylinux:8 AS base-arm64
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
# install epel-release for ccache
RUN yum install -y yum-utils epel-release \
&& yum install -y clang ccache \
&& dnf install -y clang ccache \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
ENV CC=clang CXX=clang++
@@ -29,9 +31,8 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ENV LDFLAGS=-s
FROM base AS cpu
# amd64 uses gcc which requires devtoolset-11 for AVX extensions while arm64 uses clang
RUN if [ "$(uname -m)" = "x86_64" ]; then yum install -y devtoolset-11-gcc devtoolset-11-gcc-c++; fi
ENV PATH=/opt/rh/devtoolset-11/root/usr/bin:$PATH
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \
&& cmake --build --parallel --preset 'CPU' \
@@ -39,7 +40,7 @@ RUN --mount=type=cache,target=/root/.ccache \
FROM base AS cuda-11
ARG CUDA11VERSION=11.3
RUN yum install -y cuda-toolkit-${CUDA11VERSION//./-}
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \
@@ -47,8 +48,8 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS cuda-12
ARG CUDA12VERSION=12.4
RUN yum install -y cuda-toolkit-${CUDA12VERSION//./-}
ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \
@@ -56,6 +57,7 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \
&& cmake --build --parallel --preset 'ROCm 6' \
@@ -104,7 +106,7 @@ COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 lib/ollama/cuda_jetpack6
FROM --platform=linux/arm64 scratch AS rocm
FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
FROM ${FLAVOR} AS archive

View File

@@ -404,6 +404,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
### Cloud

View File

@@ -132,7 +132,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf *bytes.Buffer
var buf io.Reader
if data != nil {
bts, err := json.Marshal(data)
if err != nil {

View File

@@ -1,6 +1,13 @@
package api
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
@@ -43,3 +50,206 @@ func TestClientFromEnvironment(t *testing.T) {
})
}
}
// testError represents an internal error type with status code and message
// this is used since the error response from the server is not a standard error struct
type testError struct {
message string
statusCode int
}
func (e testError) Error() string {
return e.message
}
func TestClientStream(t *testing.T) {
testCases := []struct {
name string
responses []any
wantErr string
}{
{
name: "immediate error response",
responses: []any{
testError{
message: "test error message",
statusCode: http.StatusBadRequest,
},
},
wantErr: "test error message",
},
{
name: "error after successful chunks, ok response",
responses: []any{
ChatResponse{Message: Message{Content: "partial response 1"}},
ChatResponse{Message: Message{Content: "partial response 2"}},
testError{
message: "mid-stream error",
statusCode: http.StatusOK,
},
},
wantErr: "mid-stream error",
},
{
name: "successful stream completion",
responses: []any{
ChatResponse{Message: Message{Content: "chunk 1"}},
ChatResponse{Message: Message{Content: "chunk 2"}},
ChatResponse{
Message: Message{Content: "final chunk"},
Done: true,
DoneReason: "stop",
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("expected http.Flusher")
}
w.Header().Set("Content-Type", "application/x-ndjson")
for _, resp := range tc.responses {
if errResp, ok := resp.(testError); ok {
w.WriteHeader(errResp.statusCode)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
flusher.Flush()
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var receivedChunks []ChatResponse
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
var resp ChatResponse
if err := json.Unmarshal(chunk, &resp); err != nil {
return fmt.Errorf("failed to unmarshal chunk: %w", err)
}
receivedChunks = append(receivedChunks, resp)
return nil
})
if tc.wantErr != "" {
if err == nil {
t.Fatal("expected error but got nil")
}
if !strings.Contains(err.Error(), tc.wantErr) {
t.Errorf("expected error containing %q, got %v", tc.wantErr, err)
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
}{
{
name: "immediate error response",
response: testError{
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
},
{
name: "server error response",
response: testError{
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
},
{
name: "successful response",
response: struct {
ID string `json:"id"`
Success bool `json:"success"`
}{
ID: "msg_123",
Success: true,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(tc.response); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var resp struct {
ID string `json:"id"`
Success bool `json:"success"`
}
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
if tc.wantErr != "" {
if err == nil {
t.Fatalf("got nil, want error %q", tc.wantErr)
}
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
return
}
if err != nil {
t.Fatalf("got error %q, want nil", err)
}
if expectedResp, ok := tc.response.(struct {
ID string `json:"id"`
Success bool `json:"success"`
}); ok {
if resp.ID != expectedResp.ID {
t.Errorf("response ID mismatch: got %q, want %q", resp.ID, expectedResp.ID)
}
if resp.Success != expectedResp.Success {
t.Errorf("response Success mismatch: got %v, want %v", resp.Success, expectedResp.Success)
}
}
})
}
}

View File

@@ -10,6 +10,8 @@ import (
"strconv"
"strings"
"time"
"github.com/ollama/ollama/envconfig"
)
// StatusError is an error with an HTTP status code and message.
@@ -609,7 +611,7 @@ func DefaultOptions() Options {
Runner: Runner{
// options set when the model is loaded
NumCtx: 2048,
NumCtx: int(envconfig.ContextLength()),
NumBatch: 512,
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
NumThread: 0, // let the runtime decide

View File

@@ -10,6 +10,7 @@ import (
"os"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/spf13/cobra"
@@ -490,6 +491,96 @@ func TestPushHandler(t *testing.T) {
}
}
func TestListHandler(t *testing.T) {
tests := []struct {
name string
args []string
serverResponse []api.ListModelResponse
expectedError string
expectedOutput string
}{
{
name: "list all models",
args: []string{},
serverResponse: []api.ListModelResponse{
{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-48 * time.Hour)},
},
expectedOutput: "NAME ID SIZE MODIFIED \n" +
"model1 sha256:abc12 1.0 KB 24 hours ago \n" +
"model2 sha256:def45 2.0 KB 2 days ago \n",
},
{
name: "filter models by prefix",
args: []string{"model1"},
serverResponse: []api.ListModelResponse{
{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-24 * time.Hour)},
},
expectedOutput: "NAME ID SIZE MODIFIED \n" +
"model1 sha256:abc12 1.0 KB 24 hours ago \n",
},
{
name: "server error",
args: []string{},
expectedError: "server error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/tags" || r.Method != http.MethodGet {
t.Errorf("unexpected request to %s %s", r.Method, r.URL.Path)
http.Error(w, "not found", http.StatusNotFound)
return
}
if tt.expectedError != "" {
http.Error(w, tt.expectedError, http.StatusInternalServerError)
return
}
response := api.ListResponse{Models: tt.serverResponse}
if err := json.NewEncoder(w).Encode(response); err != nil {
t.Fatal(err)
}
}))
defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{}
cmd.SetContext(context.TODO())
// Capture stdout
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
err := ListHandler(cmd, tt.args)
// Restore stdout and get output
w.Close()
os.Stdout = oldStdout
output, _ := io.ReadAll(r)
if tt.expectedError == "" {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if got := string(output); got != tt.expectedOutput {
t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got)
}
} else {
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
}
})
}
}
func TestCreateHandler(t *testing.T) {
tests := []struct {
name string

View File

@@ -57,7 +57,8 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
}
}
if gpuInfo.computeMajor < 6 || gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
return "v11"
}
return "v12"

View File

@@ -41,20 +41,11 @@ Install prerequisites:
- [CMake](https://cmake.org/download/)
- [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) including the Native Desktop Workload
- (Optional) AMD GPU support
- [ROCm](https://rocm.github.io/install.html)
- [ROCm](https://rocm.docs.amd.com/en/latest/)
- [Ninja](https://github.com/ninja-build/ninja/releases)
- (Optional) NVIDIA GPU support
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
> [!IMPORTANT]
> Ensure prerequisites are in `PATH` before running CMake.
> [!IMPORTANT]
> ROCm is not compatible with Visual Studio CMake generators. Use `-GNinja` when configuring the project.
> [!IMPORTANT]
> CUDA is only compatible with Visual Studio CMake generators.
Then, configure and build the project:
```shell
@@ -62,6 +53,14 @@ cmake -B build
cmake --build build --config Release
```
> [!IMPORTANT]
> Building for ROCm requires additional flags:
> ```
> cmake -B build -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++
> cmake --build build --config Release
> ```
Lastly, run Ollama:
```shell

View File

@@ -73,6 +73,10 @@ curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
## Linux docker
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
## NVIDIA GPU Discovery
When Ollama starts up, it takes inventory of the GPUs present in the system to determine compatibility and how much VRAM is available. Sometimes this discovery can fail to find your GPUs. In general, running the latest driver will yield the best results.
@@ -100,8 +104,6 @@ On linux, AMD GPU access typically requires `video` and/or `render` group member
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure.
- `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems
- `OLLAMA_DEBUG=1` During GPU discovery additional information will be reported

View File

@@ -53,8 +53,8 @@ func Host() *url.URL {
}
}
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
func Origins() (origins []string) {
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable.
func AllowedOrigins() (origins []string) {
if s := Var("OLLAMA_ORIGINS"); s != "" {
origins = strings.Split(s, ",")
}
@@ -167,6 +167,8 @@ var (
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
// Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 2048)
)
func String(s string) func() string {
@@ -249,9 +251,10 @@ func AsMap() map[string]EnvVar {
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 2048)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
// Informational

View File

@@ -134,7 +134,7 @@ func TestOrigins(t *testing.T) {
t.Run(tt.value, func(t *testing.T) {
t.Setenv("OLLAMA_ORIGINS", tt.value)
if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
if diff := cmp.Diff(AllowedOrigins(), tt.expect); diff != "" {
t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
}
})
@@ -272,3 +272,19 @@ func TestVar(t *testing.T) {
})
}
}
func TestContextLength(t *testing.T) {
cases := map[string]uint{
"": 2048,
"4096": 4096,
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_CONTEXT_LENGTH", k)
if i := ContextLength(); i != v {
t.Errorf("%s: expected %d, got %d", k, v, i)
}
})
}
}

View File

@@ -207,11 +207,26 @@ func (t Tensor) block() (n int) {
func (t Tensor) blockSize() uint64 {
switch t.Kind {
case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
case
0, // F32
1, // F16
24, // I8
25, // I16
26, // I32
27, // I64
28, // F64
30: // BF16
return 1
case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL
case
2, // Q4_0
3, // Q4_1
6, // Q5_0
7, // Q5_1
8, // Q8_0
9, // Q8_1
20: // IQ4_NL
return 32
default: // All others
default:
return 256
}
}
@@ -235,7 +250,7 @@ func (t Tensor) typeSize() uint64 {
case 8: // Q8_0
return 2 + blockSize
case 9: // Q8_1
return 4 + 4 + blockSize
return 2 + 2 + blockSize
case 10: // Q2_K
return blockSize/16 + blockSize/4 + 2 + 2
case 11: // Q3_K
@@ -247,7 +262,7 @@ func (t Tensor) typeSize() uint64 {
case 14: // Q6_K
return blockSize/2 + blockSize/4 + blockSize/16 + 2
case 15: // Q8_K
return 2 + blockSize + 2*blockSize/16
return 4 + blockSize + 2*blockSize/16
case 16: // IQ2_XXS
return 2 + 2*blockSize/8
case 17: // IQ2_XS
@@ -276,6 +291,8 @@ func (t Tensor) typeSize() uint64 {
return 8
case 29: // IQ1_M
return blockSize/8 + blockSize/16 + blockSize/32
case 30: // BF16
return 2
default:
return 0
}

View File

@@ -3,6 +3,7 @@ package ggml
import (
"maps"
"slices"
"strconv"
"strings"
"testing"
@@ -157,3 +158,55 @@ func TestTensorLayers(t *testing.T) {
})
}
}
// ref: https://github.com/ggml-org/llama.cpp/blob/a82c9e7c23ef6db48cebfa194dc9cebbc4ac3552/ggml/src/ggml.c#L572
func TestTensorTypes(t *testing.T) {
cases := []struct {
kind uint32
blockSize uint64
typeSize uint64
}{
{0, 1, 4},
{1, 1, 2},
{2, 32, 18},
{3, 32, 20},
{6, 32, 22},
{7, 32, 24},
{8, 32, 34},
{9, 32, 36},
{10, 256, 84},
{11, 256, 110},
{12, 256, 144},
{13, 256, 176},
{14, 256, 210},
{15, 256, 292},
{16, 256, 66},
{17, 256, 74},
{18, 256, 98},
{19, 256, 50},
{20, 32, 18},
{21, 256, 110},
{22, 256, 82},
{23, 256, 136},
{24, 1, 1},
{25, 1, 2},
{26, 1, 4},
{27, 1, 8},
{28, 1, 8},
{29, 256, 56},
{30, 1, 2},
}
for _, tt := range cases {
t.Run(strconv.Itoa(int(tt.kind)), func(t *testing.T) {
tensor := Tensor{Kind: tt.kind}
if tensor.blockSize() != tt.blockSize {
t.Errorf("unexpected block size: got=%d want=%d", tensor.blockSize(), tt.blockSize)
}
if tensor.typeSize() != tt.typeSize {
t.Errorf("unexpected type size: got=%d want=%d", tensor.typeSize(), tt.typeSize)
}
})
}
}

2
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/ollama/ollama
go 1.23.4
go 1.24
require (
github.com/containerd/console v1.0.3

View File

@@ -4,17 +4,23 @@ Date: Sun, 16 Feb 2025 20:00:22 -0500
Subject: [PATCH] use std::filesystem::path instead of wstring
---
ggml/src/ggml-backend-reg.cpp | 116 ++++++++++++----------------------
1 file changed, 40 insertions(+), 76 deletions(-)
ggml/src/ggml-backend-reg.cpp | 144 ++++++++++++++--------------------
1 file changed, 58 insertions(+), 86 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 84b21dd8..de78feae 100644
index 84b21dd8..e35a6936 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -72,16 +72,6 @@
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
@@ -66,26 +66,6 @@
#include "ggml-kompute.h"
#endif
-// disable C++17 deprecation warning for std::codecvt_utf8
-#if defined(__clang__)
-# pragma clang diagnostic push
-# pragma clang diagnostic ignored "-Wdeprecated-declarations"
-#endif
-
-static std::wstring utf8_to_utf16(const std::string & str) {
- std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
- return converter.from_bytes(str);
@@ -25,10 +31,14 @@ index 84b21dd8..de78feae 100644
- return converter.to_bytes(str);
-}
-
#if defined(__clang__)
# pragma clang diagnostic pop
#endif
@@ -96,12 +86,12 @@ struct dl_handle_deleter {
-#if defined(__clang__)
-# pragma clang diagnostic pop
-#endif
-
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
@@ -96,7 +76,7 @@ struct dl_handle_deleter {
}
};
@@ -37,24 +47,44 @@ index 84b21dd8..de78feae 100644
// suppress error dialogs for missing DLLs
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
- HMODULE handle = LoadLibraryW(path.c_str());
+ HMODULE handle = LoadLibraryW(path.wstring().c_str());
SetErrorMode(old_mode);
@@ -129,8 +119,8 @@ struct dl_handle_deleter {
@@ -129,8 +109,8 @@ struct dl_handle_deleter {
}
};
-static void * dl_load_library(const std::wstring & path) {
- dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL);
+static void * dl_load_library(const std::filesystem::path & path) {
+ dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
+ dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
return handle;
}
@@ -222,11 +212,11 @@ struct ggml_backend_registry {
@@ -141,6 +121,25 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
#endif
+static std::string path_to_string(const std::filesystem::path & path)
+{
+#ifdef _WIN32
+ const std::wstring wstr = path.wstring();
+ const int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, nullptr, 0, nullptr, nullptr);
+ if (size_needed <= 0) {
+ return std::string();
+ }
+
+ // size_needed includes the null terminator
+ std::string str(size_needed - 1, '\0');
+ WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, str.data(), size_needed, nullptr, nullptr);
+ return str;
+#else
+ return path.string();
+#endif
+}
+
+
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
struct ggml_backend_reg_entry {
@@ -222,11 +221,11 @@ struct ggml_backend_registry {
);
}
@@ -64,49 +94,49 @@ index 84b21dd8..de78feae 100644
if (!handle) {
if (!silent) {
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str());
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path.string().c_str());
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(path).c_str());
}
return nullptr;
}
@@ -234,7 +224,7 @@ struct ggml_backend_registry {
@@ -234,7 +233,7 @@ struct ggml_backend_registry {
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (score_fn && score_fn() == 0) {
if (!silent) {
- GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str());
+ GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path.string().c_str());
+ GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_to_string(path).c_str());
}
return nullptr;
}
@@ -242,7 +232,7 @@ struct ggml_backend_registry {
@@ -242,7 +241,7 @@ struct ggml_backend_registry {
auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
if (!backend_init_fn) {
if (!silent) {
- GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str());
+ GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path.string().c_str());
+ GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_to_string(path).c_str());
}
return nullptr;
}
@@ -251,16 +241,16 @@ struct ggml_backend_registry {
@@ -251,16 +250,16 @@ struct ggml_backend_registry {
if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
if (!silent) {
if (!reg) {
- GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str());
+ GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path.string().c_str());
+ GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path_to_string(path).c_str());
} else {
GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
- __func__, utf16_to_utf8(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
+ __func__, path.string().c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
+ __func__, path_to_string(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
}
}
return nullptr;
}
- GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
+ GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path.string().c_str());
+ GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_to_string(path).c_str());
register_backend(reg, score_fn ? score_fn() : -1, std::move(handle));
@@ -396,14 +386,14 @@ ggml_backend_t ggml_backend_init_best(void) {
@@ -396,14 +395,14 @@ ggml_backend_t ggml_backend_init_best(void) {
// Dynamic loading
ggml_backend_reg_t ggml_backend_load(const char * path) {
@@ -123,7 +153,7 @@ index 84b21dd8..de78feae 100644
#if defined(__APPLE__)
// get executable path
std::vector<char> path;
@@ -415,15 +405,9 @@ static std::wstring get_executable_path() {
@@ -415,15 +414,9 @@ static std::wstring get_executable_path() {
}
path.resize(size);
}
@@ -141,7 +171,7 @@ index 84b21dd8..de78feae 100644
std::vector<char> path(1024);
while (true) {
// get executable path
@@ -436,76 +420,56 @@ static std::wstring get_executable_path() {
@@ -436,76 +429,55 @@ static std::wstring get_executable_path() {
break;
}
if (len < (ssize_t) path.size()) {
@@ -179,11 +209,11 @@ index 84b21dd8..de78feae 100644
-static std::wstring backend_filename_prefix() {
-#ifdef _WIN32
- return L"ggml-";
+ return std::filesystem::path(path.data()).parent_path();
#else
-#else
- return L"libggml-";
+ return {};
+ return std::filesystem::path(path.data()).parent_path();
#endif
+ return {};
}
-static std::wstring backend_filename_suffix() {
@@ -234,7 +264,7 @@ index 84b21dd8..de78feae 100644
for (const auto & search_path : search_paths) {
if (!fs::exists(search_path)) {
continue;
@@ -514,31 +478,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
@@ -514,31 +486,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
for (const auto & entry : dir_it) {
try {
if (entry.is_regular_file()) {
@@ -247,20 +277,20 @@ index 84b21dd8..de78feae 100644
+ dl_handle_ptr handle { dl_load_library(entry.path()) };
if (!handle) {
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, entry.path().string().c_str());
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
continue;
}
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (!score_fn) {
- GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, entry.path().string().c_str());
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
continue;
}
int s = score_fn();
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, entry.path().string().c_str(), s);
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
if (s > best_score) {
best_score = s;
- best_path = entry.path().wstring();
@@ -270,11 +300,11 @@ index 84b21dd8..de78feae 100644
}
} catch (const std::exception & e) {
- GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, entry.path().string().c_str(), e.what());
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
}
}
}
@@ -546,7 +510,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
@@ -546,7 +518,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
if (best_score == 0) {
// try to load the base backend
for (const auto & search_path : search_paths) {

View File

@@ -0,0 +1,24 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Tue, 18 Feb 2025 14:47:21 -0800
Subject: [PATCH] remove amx
---
ggml/src/CMakeLists.txt | 4 ----
1 file changed, 4 deletions(-)
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 72b488dd..50828717 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -293,10 +293,6 @@ if (GGML_CPU_ALL_VARIANTS)
ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512)
ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI)
- if (NOT MSVC)
- # MSVC doesn't support AMX
- ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
- endif()
else ()
ggml_add_cpu_backend_variant_impl("")
endif()

View File

@@ -1 +0,0 @@
LLAMACPP_BASE_COMMIT=46e3556e01b824e52395fb050b29804b6cff2a7c

View File

@@ -17,6 +17,7 @@ import (
func TestEstimateGPULayers(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", "1")
t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16
t.Setenv("OLLAMA_CONTEXT_LENGTH", "2048")
modelName := "dummy"
f, err := os.CreateTemp(t.TempDir(), modelName)

View File

@@ -26,9 +26,24 @@ type Backend interface {
SystemInfo() string
}
var backends = make(map[string]func(*os.File) (Backend, error))
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
// MainGPU is the index of the primary GPU to use
MainGPU int
// NumGPULayers is the number of layers to offload to GPUs
NumGPULayers int
// TensorSplit is the fraction of the model to offload to each GPU
TensorSplit []float32
}
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
@@ -36,9 +51,9 @@ func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
backends[name] = f
}
func NewBackend(f *os.File) (Backend, error) {
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
if backend, ok := backends["ggml"]; ok {
return backend(f)
return backend(f, params)
}
return nil, fmt.Errorf("unsupported backend")
@@ -96,6 +111,26 @@ type Tensor interface {
Copy(ctx Context, t2 Tensor) Tensor
}
// ScaledDotProductAttention implements a fused attention
// operation equivalent to following code on a tensor named
// query:
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
//
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
//
// kq = kq.Softmax(ctx)
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
type ScaledDotProductAttention interface {
ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
}
type number interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 |
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |

View File

@@ -82,9 +82,11 @@ type Backend struct {
meta *fs.GGML
cpus, gpus []Context
tensors map[string]*Context
sched *C.struct_ggml_backend_sched
}
func New(r *os.File) (ml.Backend, error) {
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fs.Decode(r, -1)
if err != nil {
return nil, err
@@ -182,10 +184,24 @@ func New(r *os.File) (ml.Backend, error) {
return nil, err
}
backends := make([]*C.struct_ggml_backend, len(gpus)+len(cpus))
bufts := make([]*C.struct_ggml_backend_buffer_type, len(gpus)+len(cpus))
for i, c := range append(gpus, cpus...) {
backends[i] = c.backend
bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend)
}
return &Backend{
meta: meta,
cpus: cpus,
gpus: gpus,
sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
C.int(len(backends)),
C.size_t(max(8192, len(meta.Tensors().Items())*5)),
true,
),
}, nil
}
@@ -219,31 +235,23 @@ func (b *Backend) NewContext() ml.Context {
})
backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus))
bufts := make([]*C.struct_ggml_backend_buffer_type, len(b.gpus)+len(b.cpus))
for i, c := range append(b.gpus, b.cpus...) {
backends[i] = c.backend
bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend)
}
return &Context{
b: b,
ctx: c,
backend: backends[0],
nodes: nodes,
sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
C.int(len(backends)),
C.size_t(nodes),
true,
),
}
}
type Context struct {
b *Backend
ctx *C.struct_ggml_context
backend *C.struct_ggml_backend
sched *C.struct_ggml_backend_sched
graph *C.struct_ggml_cgraph
nodes int
}
@@ -257,12 +265,13 @@ func (c *Context) Forward(t ml.Tensor) {
}
func (c *Context) Compute(tensors ...ml.Tensor) {
C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
C.ggml_backend_sched_reset(c.b.sched)
needSync := true
sync := func() {
if needSync {
C.ggml_backend_sched_synchronize(c.sched)
C.ggml_backend_sched_synchronize(c.b.sched)
needSync = false
}
}
@@ -350,7 +359,6 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
func (c *Context) Close() {
if c != nil {
C.ggml_backend_sched_free(c.sched)
C.ggml_free(c.ctx)
}
}
@@ -477,7 +485,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
}
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
return (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
}
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
@@ -643,6 +651,21 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor
if mask != nil {
kqMask = mask.(*Tensor).t
}
kq := key.MulmatFullPrec(ctx, t)
kq = &Tensor{
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
func (b *Backend) SystemInfo() string {
var compiler string
switch C.get_compiler() {

View File

@@ -293,10 +293,6 @@ if (GGML_CPU_ALL_VARIANTS)
ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512)
ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI)
if (NOT MSVC)
# MSVC doesn't support AMX
ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
endif()
else ()
ggml_add_cpu_backend_variant_impl("")
endif()

View File

@@ -66,16 +66,6 @@
#include "ggml-kompute.h"
#endif
// disable C++17 deprecation warning for std::codecvt_utf8
#if defined(__clang__)
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#endif
#if defined(__clang__)
# pragma clang diagnostic pop
#endif
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
@@ -91,7 +81,7 @@ static dl_handle * dl_load_library(const std::filesystem::path & path) {
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
HMODULE handle = LoadLibraryW(path.wstring().c_str());
HMODULE handle = LoadLibraryW(path.c_str());
SetErrorMode(old_mode);
@@ -120,7 +110,7 @@ struct dl_handle_deleter {
};
static void * dl_load_library(const std::filesystem::path & path) {
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
return handle;
}
@@ -131,6 +121,25 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
#endif
static std::string path_to_string(const std::filesystem::path & path)
{
#ifdef _WIN32
const std::wstring wstr = path.wstring();
const int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, nullptr, 0, nullptr, nullptr);
if (size_needed <= 0) {
return std::string();
}
// size_needed includes the null terminator
std::string str(size_needed - 1, '\0');
WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, str.data(), size_needed, nullptr, nullptr);
return str;
#else
return path.string();
#endif
}
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
struct ggml_backend_reg_entry {
@@ -216,7 +225,7 @@ struct ggml_backend_registry {
dl_handle_ptr handle { dl_load_library(path) };
if (!handle) {
if (!silent) {
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path.string().c_str());
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(path).c_str());
}
return nullptr;
}
@@ -224,7 +233,7 @@ struct ggml_backend_registry {
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (score_fn && score_fn() == 0) {
if (!silent) {
GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path.string().c_str());
GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_to_string(path).c_str());
}
return nullptr;
}
@@ -232,7 +241,7 @@ struct ggml_backend_registry {
auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
if (!backend_init_fn) {
if (!silent) {
GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path.string().c_str());
GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_to_string(path).c_str());
}
return nullptr;
}
@@ -241,16 +250,16 @@ struct ggml_backend_registry {
if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
if (!silent) {
if (!reg) {
GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path.string().c_str());
GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path_to_string(path).c_str());
} else {
GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
__func__, path.string().c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
__func__, path_to_string(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
}
}
return nullptr;
}
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path.string().c_str());
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_to_string(path).c_str());
register_backend(reg, score_fn ? score_fn() : -1, std::move(handle));
@@ -432,9 +441,8 @@ static std::filesystem::path get_executable_path() {
}
return std::filesystem::path(path.data()).parent_path();
#else
return {};
#endif
return {};
}
static std::string backend_filename_prefix() {
@@ -483,18 +491,18 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
dl_handle_ptr handle { dl_load_library(entry.path()) };
if (!handle) {
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, entry.path().string().c_str());
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
continue;
}
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (!score_fn) {
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, entry.path().string().c_str());
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
continue;
}
int s = score_fn();
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, entry.path().string().c_str(), s);
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
if (s > best_score) {
best_score = s;
best_path = entry.path();
@@ -502,7 +510,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
}
}
} catch (const std::exception & e) {
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, entry.path().string().c_str(), e.what());
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
}
}
}

View File

@@ -1,4 +1,4 @@
// Code generated Fri Jan 10 13:05:45 PST 2025. DO NOT EDIT.
// Code generated by go generate. DO NOT EDIT.
#define GGML_COMMON_DECL_METAL
#define GGML_COMMON_IMPL_METAL
#if defined(GGML_METAL_EMBED_LIBRARY)

View File

@@ -2,7 +2,7 @@
package metal
//go:generate sh -c "{ echo // Code generated $(date). DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal"
//go:generate sh -c "{ echo // Code generated by 'go generate'. DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal"
// #cgo CPPFLAGS: -DGGML_METAL_EMBED_LIBRARY -I.. -I../../include
// #cgo LDFLAGS: -framework Metal -framework MetalKit

59
ml/nn/attention.go Normal file
View File

@@ -0,0 +1,59 @@
package nn
import (
"fmt"
"github.com/ollama/ollama/ml"
)
// Attention implements scaled dot-product attention for transformer models:
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
//
// Parameters:
// - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
// - mask: Optional attention mask that is added to the attention score. If
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
//
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
}
if mask != nil && query.Dim(1) != mask.Dim(1) {
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
}
if key.Dim(1) != value.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
}
if mask != nil && key.Dim(1) != mask.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
} else {
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)
if mask != nil {
kq = kq.Add(ctx, mask)
}
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
}

View File

@@ -70,14 +70,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
}
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(modelPath string) (Model, error) {
func New(modelPath string, params ml.BackendParams) (Model, error) {
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
b, err := ml.NewBackend(r)
b, err := ml.NewBackend(r, params)
if err != nil {
return nil, err
}

View File

@@ -86,13 +86,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.MulmatFullPrec(ctx, q)
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)
@@ -120,11 +115,19 @@ type Layer struct {
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
@@ -144,22 +147,26 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState)
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return hiddenState.Rows(ctx, outputs), nil
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
}
func init() {

View File

@@ -93,15 +93,13 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
return nil, err
}
// TODO: attention mask, cross attention mask
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache))
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return hiddenState.Rows(ctx, outputs), nil
// TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
}
func init() {

View File

@@ -38,13 +38,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
scores = scores.Add(ctx, mask)
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
@@ -74,11 +69,19 @@ type TextSelfAttentionDecoderLayer struct {
MLP *TextMLP
}
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
residual := hiddenState
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
@@ -104,7 +107,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
var key, value ml.Tensor
var key, value, mask ml.Tensor
if crossAttentionStates != nil {
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
@@ -117,19 +120,15 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
cache.Put(ctx, key, value)
} else {
key, value, _ = cache.Get(ctx)
key, value, mask = cache.Get(ctx)
}
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scores := key.Mulmat(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return ca.Output.Forward(ctx, attention)
@@ -145,7 +144,7 @@ type TextCrossAttentionDecoderLayer struct {
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
}
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
residual := hiddenState
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@@ -161,14 +160,14 @@ func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _,
}
type TextDecoderLayer interface {
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
}
type TextDecoder struct {
Layers []TextDecoderLayer
}
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
for i, layer := range d.Layers {
layerType := selfAttentionLayer
if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
@@ -179,7 +178,12 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
cache.SetLayerType(layerType)
if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
var lastLayerOutputs ml.Tensor
if i == len(d.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
}
}
@@ -205,9 +209,9 @@ type TextModel struct {
*TextModelOptions
}
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState)
}

View File

@@ -25,6 +25,7 @@ import (
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
@@ -64,8 +65,8 @@ type Sequence struct {
// number of tokens to predict
numPredict int
// set of samplers to run on generated logits
samplers []sample.Sampler
// sampler with transforms to run on generated logits
sampler sample.Sampler
// channel to send back the embedding if embedding only
embedding chan []float32
@@ -92,7 +93,7 @@ type NewSequenceParams struct {
numPredict int
stop []string
numKeep int32
samplers []sample.Sampler
sampler sample.Sampler
embedding bool
}
@@ -135,7 +136,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
responses: make(chan string, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplers: params.samplers,
sampler: params.sampler,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
@@ -392,13 +393,7 @@ func (s *Server) processBatch() error {
return fmt.Errorf("failed to decode batch: %w", err)
}
f32s := modelOutput.Floats()
// TODO(jessegross): This will no longer be necessary once the sampling interface takes f32s
logits := make([]float64, len(f32s))
for i, f32 := range f32s {
logits[i] = float64(f32)
}
logits := modelOutput.Floats()
for i, seq := range s.seqs {
if seq == nil {
@@ -432,14 +427,12 @@ func (s *Server) processBatch() error {
}
// sample a token
vocabSize := len(f32s) / len(options.Outputs)
tokens, err := sample.Sample(logits[seq.iBatch*vocabSize:(seq.iBatch+1)*vocabSize], seq.samplers...)
if err != nil {
return err
}
vocabSize := len(logits) / len(options.Outputs)
// TODO(jessegross): Sampler will output a single int32 in the future
token := int32(tokens[0])
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
if err != nil {
return fmt.Errorf("failed to sample token: %w", err)
}
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
@@ -564,27 +557,6 @@ type CompletionResponse struct {
Timings Timings `json:"timings"`
}
func getSamplers(_ CompletionRequest) []sample.Sampler {
// TODO(jessegross): Waiting for sampling code
/*samplingParams.TopK = req.TopK
samplingParams.TopP = req.TopP
samplingParams.MinP = req.MinP
samplingParams.TypicalP = req.TypicalP
samplingParams.Temp = req.Temperature
samplingParams.RepeatLastN = req.RepeatLastN
samplingParams.PenaltyRepeat = req.RepeatPenalty
samplingParams.PenaltyFreq = req.FrequencyPenalty
samplingParams.PenaltyPresent = req.PresencePenalty
samplingParams.Mirostat = req.Mirostat
samplingParams.MirostatTau = req.MirostatTau
samplingParams.MirostatEta = req.MirostatEta
samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar*/
return []sample.Sampler{sample.Greedy()}
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req CompletionRequest
req.Options = Options(api.DefaultOptions())
@@ -603,11 +575,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
sampler, err := sample.NewSampler(
req.Temperature,
req.TopK,
req.TopP,
req.MinP,
req.Seed,
)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create sampler: %v", err), http.StatusInternalServerError)
return
}
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: int32(req.NumKeep),
samplers: getSamplers(req),
sampler: sampler,
embedding: false,
})
if err != nil {
@@ -801,6 +785,7 @@ func (m *multiLPath) String() string {
func (s *Server) loadModel(
mpath string,
params ml.BackendParams,
lpath multiLPath,
parallel int,
kvCacheType string,
@@ -808,12 +793,12 @@ func (s *Server) loadModel(
multiUserCache bool,
) {
var err error
s.model, err = model.New(mpath)
s.model, err = model.New(mpath, params)
if err != nil {
panic(err)
}
slog.Info("system", "info", s.model.Backend().SystemInfo() /* "threads", *threads */)
slog.Info("system", "info", s.model.Backend().SystemInfo(), "threads", params.NumThreads)
// TODO(jessegross): LoRA loading
if lpath.String() != "" {
@@ -843,17 +828,17 @@ func Execute(args []string) error {
mpath := fs.String("model", "", "Path to model binary file")
parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
batchSize := fs.Int("batch-size", 512, "Batch size")
_ = fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
_ = fs.Int("main-gpu", 0, "Main GPU")
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
_ = fs.Bool("flash-attn", false, "Enable flash attention")
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
port := fs.Int("port", 8080, "Port to expose the server on")
_ = fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
verbose := fs.Bool("verbose", false, "verbose output (default: disabled)")
_ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
_ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing")
_ = fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
var lpaths multiLPath
@@ -890,15 +875,11 @@ func Execute(args []string) error {
}
// TODO(jessegross): Parameters that need to be implemented:
// n-gpu-layers
// main-gpu
// flash-attn
// threads
// no-mmap
// mlock
// tensor-split
/*var tensorSplitFloats []float32
var tensorSplitFloats []float32
if *tensorSplit != "" {
stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
@@ -907,10 +888,17 @@ func Execute(args []string) error {
f, _ := strconv.ParseFloat(s, 32)
tensorSplitFloats = append(tensorSplitFloats, float32(f))
}
}*/
}
params := ml.BackendParams{
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,
TensorSplit: tensorSplitFloats,
}
server.ready.Add(1)
go server.loadModel(*mpath, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)

View File

@@ -1,13 +0,0 @@
package sample
import "gonum.org/v1/gonum/floats"
type greedy struct{}
func Greedy() Sampler {
return greedy{}
}
func (s greedy) Sample(t []float64) ([]float64, error) {
return []float64{float64(floats.MaxIdx(t))}, nil
}

View File

@@ -1,74 +0,0 @@
package sample
import (
"slices"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/stat/sampleuv"
)
type Sampler interface {
Sample([]float64) ([]float64, error)
}
type Temperature float64
func (s Temperature) Sample(t []float64) ([]float64, error) {
floats.Div(t, slices.Repeat([]float64{float64(s)}, len(t)))
return t, nil
}
type softmax struct{}
func Softmax() Sampler {
return softmax{}
}
func (softmax) Sample(t []float64) ([]float64, error) {
return t, nil
}
type TopK int
func (s TopK) Sample(t []float64) ([]float64, error) {
return t, nil
}
type TopP float32
func (s TopP) Sample(t []float64) ([]float64, error) {
return t, nil
}
type MinP float32
func (s MinP) Sample(t []float64) ([]float64, error) {
return t, nil
}
type weighed struct{}
func Weighed() Sampler {
return weighed{}
}
func (s weighed) Sample(t []float64) ([]float64, error) {
w := sampleuv.NewWeighted(t, nil)
if v, ok := w.Take(); ok {
return []float64{float64(v)}, nil
}
return t, nil
}
func Sample(floats []float64, samplers ...Sampler) ([]float64, error) {
var err error
for _, sampler := range samplers {
floats, err = sampler.Sample(floats)
if err != nil {
return nil, err
}
}
return floats, nil
}

139
sample/samplers.go Normal file
View File

@@ -0,0 +1,139 @@
package sample
import (
"errors"
"math"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/stat/sampleuv"
)
type Sampler interface {
Sample([]float32) (int32, error)
}
type weighted struct {
src rand.Source
transforms []Transform
}
// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279
func Weighted(seed *uint64, transforms ...Transform) Sampler {
var src rand.Source
if seed != nil {
src = rand.NewSource(*seed)
}
return weighted{src: src, transforms: transforms}
}
func (s weighted) Sample(logits []float32) (int32, error) {
logits64 := make([]float64, len(logits))
for i, v := range logits {
logits64[i] = float64(v)
}
for _, t := range s.transforms {
logits64 = t.Apply(logits64)
}
logitsCopy := make([]float64, 0, len(logits))
indices := make([]int, 0, len(logits))
for i, logit := range logits64 {
if !math.IsInf(logit, -1) {
logitsCopy = append(logitsCopy, logit)
indices = append(indices, i)
}
}
if len(logitsCopy) == 0 {
return -1, errors.New("no valid logits found for weighed sampling")
}
probs := softmax(logitsCopy)
w := sampleuv.NewWeighted(probs, s.src)
if idx, ok := w.Take(); ok {
return int32(indices[idx]), nil
}
return -1, errors.New("weighed sampler failed, no valid token found")
}
type greedy struct {
transforms []Transform
}
func Greedy(transforms ...Transform) Sampler {
return greedy{transforms: transforms}
}
func (s greedy) Sample(logits []float32) (int32, error) {
logits64 := make([]float64, len(logits))
for i, v := range logits {
logits64[i] = float64(v)
}
for _, t := range s.transforms {
logits64 = t.Apply(logits64)
}
var maxIdx int
var maxLogit float64
for i, logit := range logits64 {
if logit > maxLogit {
maxLogit = logit
maxIdx = i
}
}
if maxLogit == math.Inf(-1) {
return -1, errors.New("no valid logits found for greedy sampling")
}
return int32(maxIdx), nil
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
transforms := []Transform{}
if temperature < 0 || temperature > 2 {
return nil, errors.New("temperature must be between 0 and 2")
}
if temperature != 0 {
transforms = append(transforms, Temperature(temperature))
}
if topK != 0 {
if topK <= 0 {
return nil, errors.New("topK must be greater than 0")
}
transforms = append(transforms, TopK(topK))
}
if topP != 0 {
if topP < 0 || topP >= 1 {
return nil, errors.New("topP must be between 0 and 1")
}
transforms = append(transforms, TopP(topP))
}
if minP != 0 {
if minP < 0 || minP >= 1 {
return nil, errors.New("minP must be between 0 and 1")
}
transforms = append(transforms, MinP(minP))
}
if len(transforms) == 0 {
return nil, errors.New("at least one transform is required")
}
if temperature == 0 {
return Greedy(transforms...), nil
}
if seed != 0 {
seed64 := uint64(seed)
return Weighted(&seed64, transforms...), nil
}
return Weighted(nil, transforms...), nil
}

238
sample/samplers_test.go Normal file
View File

@@ -0,0 +1,238 @@
package sample
import (
"math"
"math/rand/v2"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWeighted(t *testing.T) {
got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
if err != nil {
t.Error(err)
return
}
want := int32(1)
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
if err == nil {
t.Error("expected error for no valid tokens, got index", got)
}
seed := uint64(42)
got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
if err != nil {
t.Error(err)
return
}
// With seed 42, we expect a consistent sample
want = int32(3) // This will be deterministic due to the seed
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
}
type testTransform struct {
id int
callOrder *[]int
}
func (ts *testTransform) Apply(logits []float64) []float64 {
if ts.callOrder != nil {
*ts.callOrder = append(*ts.callOrder, ts.id)
}
return logits
}
func TestSample(t *testing.T) {
input := []float32{1, 2, 3, 4}
var callOrder []int
mock1 := &testTransform{
id: 1,
callOrder: &callOrder,
}
mock2 := &testTransform{
id: 2,
callOrder: &callOrder,
}
mock3 := &testTransform{
id: 3,
callOrder: &callOrder,
}
got, err := Greedy(mock1, mock2, mock3).Sample(input)
if err != nil {
t.Error(err)
return
}
want := int32(3) // Greedy sampler should pick highest logit
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
wantOrder := []int{1, 2, 3}
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
t.Errorf("call order mismatch (-want +got):\n%s", diff)
}
callOrder = nil
_, err = Weighted(nil, mock1, mock2, mock3).Sample(input)
if err != nil {
t.Error(err)
return
}
wantOrder = []int{1, 2, 3}
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
t.Errorf("call order mismatch (-want +got):\n%s", diff)
}
}
func TestNewSampler(t *testing.T) {
tests := []struct {
name string
temperature float32
topK int
topP float32
minP float32
seed int
wantErr bool
}{
{
name: "no transforms",
wantErr: true,
},
{
name: "temperature",
temperature: 0.5,
wantErr: false,
},
{
name: "invalid temperature negative",
temperature: -1,
wantErr: true,
},
{
name: "invalid temperature too high",
temperature: 2.1,
wantErr: true,
},
{
name: "top k",
topK: 10,
wantErr: false,
},
{
name: "invalid top k negative",
topK: -1,
wantErr: true,
},
{
name: "top p",
topP: 0.9,
wantErr: false,
},
{
name: "invalid top p negative",
topP: -0.1,
wantErr: true,
},
{
name: "invalid top p one",
topP: 1.0,
wantErr: true,
},
{
name: "min p",
minP: 0.2,
wantErr: false,
},
{
name: "invalid min p negative",
minP: -0.1,
wantErr: true,
},
{
name: "invalid min p one",
minP: 1.0,
wantErr: true,
},
{
name: "seed",
seed: 42,
wantErr: true, // seed alone is not valid without other transforms
},
{
name: "default values",
temperature: 0.8,
topK: 40,
topP: 0.9,
minP: 0.0,
seed: 0,
wantErr: false,
},
{
name: "all zeroes",
temperature: 0.0,
topK: 0,
topP: 0.0,
minP: 0.0,
seed: 0,
wantErr: true, // all zeroes means no transforms
},
{
name: "all transforms",
temperature: 0.8,
topK: 50,
topP: 0.95,
minP: 0.1,
seed: 42,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
if (err != nil) != tt.wantErr {
t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func BenchmarkSample(b *testing.B) {
transforms := []Transform{
Temperature(0.5),
TopK(10),
TopP(0.9),
MinP(0.2),
}
samplers := map[string]Sampler{
"Greedy": Greedy(transforms...),
"Weighted": Weighted(nil, transforms...),
}
logits := make([]float32, 1<<16)
for i := range logits {
logits[i] = rand.Float32()
}
for name, s := range samplers {
b.Run(name, func(b *testing.B) {
b.ResetTimer()
for range b.N {
if _, err := s.Sample(logits); err != nil {
b.Error(err)
}
}
})
}
}

120
sample/transforms.go Normal file
View File

@@ -0,0 +1,120 @@
package sample
import (
"cmp"
"math"
"slices"
pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
)
type Transform interface {
Apply([]float64) []float64
}
// TODO(parthsareen): potentially cache softmax values
func softmax(logits []float64) []float64 {
var sum float64
probs := make([]float64, len(logits))
for i, v := range logits {
probs[i] = math.Exp(v)
sum += probs[i]
}
for i := range probs {
probs[i] /= sum
}
return probs
}
type Temperature float64
func (t Temperature) Apply(logits []float64) []float64 {
temp := math.Max(float64(t), 1e-7)
// subtracting max logit to avoid under/overflow
maxLogit := slices.Max(logits)
for i := range logits {
logits[i] = (logits[i] - maxLogit) / temp
}
return logits
}
type logitMap struct {
index int
logit float64
}
type TopK int
// TODO(parthsareen): avoid having to check all logits after this transform
func (k TopK) Apply(logits []float64) []float64 {
if int(k) >= len(logits) {
return logits
}
q := pq.NewWith(func(a, b logitMap) int {
return -cmp.Compare(a.logit, b.logit)
})
for i, logit := range logits {
q.Enqueue(logitMap{index: i, logit: logit})
}
validLogits := make(map[int]float64)
for range k {
logitMap, _ := q.Dequeue()
validLogits[logitMap.index] = logitMap.logit
}
for i := range logits {
if _, ok := validLogits[i]; !ok {
logits[i] = math.Inf(-1)
}
}
return logits
}
type TopP float64
func (p TopP) Apply(logits []float64) []float64 {
probs := softmax(logits)
indices := make([]int, len(probs))
for i := range indices {
indices[i] = i
}
// sort in descending order
slices.SortFunc(indices, func(i, j int) int {
return cmp.Compare(probs[j], probs[i])
})
var sum float64
for i, idx := range indices {
sum += probs[idx]
if sum > float64(p) {
for _, idx := range indices[i+1:] {
logits[idx] = math.Inf(-1)
}
break
}
}
return logits
}
type MinP float64
func (p MinP) Apply(logits []float64) []float64 {
probs := softmax(logits)
threshold := slices.Max(probs) * float64(p)
for i, prob := range probs {
if prob < threshold {
logits[i] = math.Inf(-1)
}
}
return logits
}

80
sample/transforms_test.go Normal file
View File

@@ -0,0 +1,80 @@
package sample
import (
"math"
"math/rand/v2"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestTemperature(t *testing.T) {
got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
want := []float64{-4, -10, 0, -14, -6, -12, -8}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
}
func TestSoftmax(t *testing.T) {
got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("probs mismatch (-want +got):\n%s", diff)
}
}
func TestTopK(t *testing.T) {
got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
want = []float64{-3, -2, -1, 0, 1, 2, 4}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
}
func TestTopP(t *testing.T) {
got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
}
func TestMinP(t *testing.T) {
got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
}
func BenchmarkTransform(b *testing.B) {
transforms := map[string]Transform{
"Temperature": Temperature(0.5),
"TopK": TopK(10),
"TopP": TopP(0.9),
"MinP": MinP(0.2),
}
logits := make([]float64, 1<<16)
for i := range logits {
logits[i] = rand.Float64()
}
for name, transform := range transforms {
b.Run(name, func(b *testing.B) {
b.ResetTimer()
for range b.N {
transform.Apply(logits)
}
})
}
}

View File

@@ -28,7 +28,7 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then
${LOAD_OR_PUSH} \
--platform=linux/amd64 \
${OLLAMA_COMMON_BUILD_ARGS} \
--target runtime-rocm \
--build-arg FLAVOR=rocm \
-f Dockerfile \
-t ${FINAL_IMAGE_REPO}:$VERSION-rocm \
.

View File

@@ -22,8 +22,34 @@ docker buildx build \
-f Dockerfile \
.
# buildx behavior changes for single vs. multiplatform
if echo $PLATFORM | grep "," > /dev/null ; then
mv -f ./dist/linux_*64/ollama* ./dist/
rmdir ./dist/linux_*64
if echo $PLATFORM | grep "amd64" > /dev/null; then
outDir="./dist"
if echo $PLATFORM | grep "," > /dev/null ; then
outDir="./dist/linux_amd64"
fi
docker buildx build \
--output type=local,dest=${outDir} \
--platform=linux/amd64 \
${OLLAMA_COMMON_BUILD_ARGS} \
--build-arg FLAVOR=rocm \
--target archive \
-f Dockerfile \
.
fi
# buildx behavior changes for single vs. multiplatform
echo "Compressing linux tar bundles..."
if echo $PLATFORM | grep "," > /dev/null ; then
tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | pigz -9vc >./dist/ollama-linux-arm64.tgz
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
tar c -C ./dist/linux_amd64 --exclude rocm . | pigz -9vc >./dist/ollama-linux-amd64.tgz
tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
elif echo $PLATFORM | grep "arm64" > /dev/null ; then
tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | pigz -9vc >./dist/ollama-linux-arm64.tgz
tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz
tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz
elif echo $PLATFORM | grep "amd64" > /dev/null ; then
tar c -C ./dist/ --exclude rocm bin lib | pigz -9vc >./dist/ollama-linux-amd64.tgz
tar c -C ./dist/ ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz
fi

View File

@@ -26,6 +26,9 @@ function checkEnv() {
$MSVC_INSTALL=(Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation
$env:VCToolsRedistDir=(get-item "${MSVC_INSTALL}\VC\Redist\MSVC\*")[0]
}
if (-Not (get-command -ErrorAction silent ninja)) {
$script:NINJA_DIR=(gci -path (Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation -r -fi ninja.exe) | split-path -parent
}
# Locate CUDA versions
# Note: this assumes every version found will be built
$cudaList=(get-item "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*\bin\" -ea 'silentlycontinue')
@@ -75,6 +78,7 @@ function checkEnv() {
} else {
write-host "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree"
}
$script:JOBS=((Get-CimInstance Win32_ComputerSystem).NumberOfLogicalProcessors)
}
@@ -83,51 +87,57 @@ function buildOllama() {
Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
New-Item "${script:SRC_DIR}\dist\windows-${script:ARCH}\lib\ollama\" -ItemType Directory -ea 0
# Default first, then conditionall ROCm and cuda v11
write-host "Building Default native backend libraries"
$env:CMAKE_GENERATOR="ninja"
& cmake --preset Default
& cmake --fresh --preset CPU --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset Default -j 12
& cmake --build --preset CPU --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component CPU --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build -j 12
# TODO - add steps for v11 and ROCm
#
# if ("$script:CUDA_DIRS".Contains("v11") -and "$script:CUDA_DIRS".Contains("v12")) {
# # We assume the default is v12, so override for v11
# $origCUDA_PATH=$env:CUDA_PATH
# $hashEnv = @{}
# Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
# $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $v11="$_" }}
# write-host "$v11"
# # $env:CUDA_PATH=$hashEnv[$v11]
# # $env:CUDACXX=$hashEnv[$v11]+"\bin\nvcc.exe"
# $env:CUDAToolkit_ROOT=$hashEnv[$v11]
# # ls env:
# write-host "Building CUDA v11 backend libraries"
# & cmake --preset "CUDA 11"
# $env:CUDA_PATH=$origCUDA_PATH
# exit(1)
# if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
# # & cmake --build --preset "CUDA 11" -j 12
# # if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
# }
# if ($env:HIP_PATH) {
# write-host "Building ROCm backend libraries"
# $env:HIPCXX="${env:HIP_PATH}\bin\clang++.exe"
# $env:HIP_PLATFORM="amd"
# $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
# & cmake --preset "ROCm"
# $env:HIPCXX=""
# $env:HIP_PLATFORM=""
# $env:CMAKE_PREFIX_PATH=""
# if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
# & cmake --build --preset "ROCm" -j 12
# if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
# }
$hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v11")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $v11="$_" }}
$env:CUDAToolkit_ROOT=$hashEnv[$v11]
write-host "Building CUDA v11 backend libraries"
# Note: cuda v11 requires msvc 2019 so force the older generator
# to avoid 2022 (or newer) from being used as the default
& cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 11" --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
if ("$script:CUDA_DIRS".Contains("v12")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }}
$env:CUDAToolkit_ROOT=$hashEnv[$v12]
write-host "Building CUDA v12 backend libraries"
& cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 12" --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
if ($env:HIP_PATH) {
write-host "Building ROCm backend libraries"
if ($null -ne $script:NINJA_DIR) {
$env:PATH="$script:NINJA_DIR;$env:PATH"
}
$env:HIPCXX="${env:HIP_PATH}\bin\clang++.exe"
$env:HIP_PLATFORM="amd"
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
& cmake --fresh --preset "ROCm 6" -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
$env:HIPCXX=""
$env:HIP_PLATFORM=""
$env:CMAKE_PREFIX_PATH=""
& cmake --build --preset "ROCm" --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "HIP" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
} else {
write-host "Skipping generate step with OLLAMA_SKIP_GENERATE set"
}

544
server/internal/cache/blob/cache.go vendored Normal file
View File

@@ -0,0 +1,544 @@
// Package blob implements a content-addressable disk cache for blobs and
// manifests.
package blob
import (
"bytes"
"crypto/sha256"
"errors"
"fmt"
"hash"
"io"
"io/fs"
"iter"
"os"
"path/filepath"
"strings"
"time"
"github.com/ollama/ollama/server/internal/internal/names"
)
// Entry contains metadata about a blob in the cache.
type Entry struct {
Digest Digest
Size int64
Time time.Time // when added to the cache
}
// DiskCache caches blobs and manifests on disk.
//
// The cache is rooted at a directory, which is created if it does not exist.
//
// Blobs are stored in the "blobs" subdirectory, and manifests are stored in the
// "manifests" subdirectory. A example directory structure might look like:
//
// <dir>/
// blobs/
// sha256-<digest> - <blob data>
// manifests/
// <host>/
// <namespace>/
// <name>/
// <tag> - <manifest data>
//
// The cache is safe for concurrent use.
//
// Name casing is preserved in the cache, but is not significant when resolving
// names. For example, "Foo" and "foo" are considered the same name.
//
// The cache is not safe for concurrent use. It guards concurrent writes, but
// does not prevent duplicated effort. Because blobs are immutable, duplicate
// writes should result in the same file being written to disk.
type DiskCache struct {
// Dir specifies the top-level directory where blobs and manifest
// pointers are stored.
dir string
now func() time.Time
testHookBeforeFinalWrite func(f *os.File)
}
// PutString is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error {
return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data)))
}
// Open opens a cache rooted at the given directory. If the directory does not
// exist, it is created. If the directory is not a directory, an error is
// returned.
func Open(dir string) (*DiskCache, error) {
if dir == "" {
return nil, errors.New("blob: empty directory name")
}
info, err := os.Stat(dir)
if err == nil && !info.IsDir() {
return nil, fmt.Errorf("%q is not a directory", dir)
}
if err := os.MkdirAll(dir, 0o777); err != nil {
return nil, err
}
subdirs := []string{"blobs", "manifests"}
for _, subdir := range subdirs {
if err := os.MkdirAll(filepath.Join(dir, subdir), 0o777); err != nil {
return nil, err
}
}
// TODO(bmizerany): support shards
c := &DiskCache{
dir: dir,
now: time.Now,
}
return c, nil
}
func readAndSum(filename string, limit int64) (data []byte, _ Digest, err error) {
f, err := os.Open(filename)
if err != nil {
return nil, Digest{}, err
}
defer f.Close()
h := sha256.New()
r := io.TeeReader(f, h)
data, err = io.ReadAll(io.LimitReader(r, limit))
if err != nil {
return nil, Digest{}, err
}
var d Digest
h.Sum(d.sum[:0])
return data, d, nil
}
//lint:ignore U1000 used for debugging purposes as needed in tests
var debug = false
// debugger returns a function that can be used to add a step to the error message.
// The error message will be a list of steps that were taken before the error occurred.
// The steps are added in the order they are called.
//
// To set the error message, call the returned function with an empty string.
//
//lint:ignore U1000 used for debugging purposes as needed in tests
func debugger(err *error) func(step string) {
if !debug {
return func(string) {}
}
var steps []string
return func(step string) {
if step == "" && *err != nil {
*err = fmt.Errorf("%q: %w", steps, *err)
return
}
steps = append(steps, step)
if len(steps) > 100 {
// shift hints in case of a bug that causes a lot of hints
copy(steps, steps[1:])
steps = steps[:100]
}
}
}
// Resolve resolves a name to a digest. The name is expected to
// be in either of the following forms:
//
// @<digest>
// <name>
// <name>
//
// If a digest is provided, it is returned as is and nothing else happens.
//
// If a name is provided for a manifest that exists in the cache, the digest
// of the manifest is returned. If there is no manifest in the cache, it
// returns [fs.ErrNotExist].
//
// To cover the case where a manifest may change without the cache knowing
// (e.g. it was reformatted or modified by hand), the manifest data read and
// hashed is passed to a PutBytes call to ensure that the manifest is in the
// blob store. This is done to ensure that future calls to [Get] succeed in
// these cases.
//
// TODO(bmizerany): Move Links/Resolve/etc. out of this package.
func (c *DiskCache) Resolve(name string) (Digest, error) {
name, digest := splitNameDigest(name)
if digest != "" {
return ParseDigest(digest)
}
// We want to address manifests files by digest using Get. That requires
// them to be blobs. This cannot be directly accomplished by looking in
// the blob store because manifests can change without Ollama knowing
// (e.g. a user modifies a manifests by hand then pushes it to update
// their model). We also need to support the blob caches inherited from
// older versions of Ollama, which do not store manifests in the blob
// store, so for these cases, we need to handle adding the manifests to
// the blob store, just in time.
//
// So now we read the manifests file, hash it, and copy it to the blob
// store if it's not already there.
//
// This should be cheap because manifests are small, and accessed
// infrequently.
file, err := c.manifestPath(name)
if err != nil {
return Digest{}, err
}
data, d, err := readAndSum(file, 1<<20)
if err != nil {
return Digest{}, err
}
// Ideally we'd read the "manifest" file as a manifest to the blob file,
// but we are not changing this yet, so copy the manifest to the blob
// store so it can be addressed by digest subsequent calls to Get.
if err := PutBytes(c, d, data); err != nil {
return Digest{}, err
}
return d, nil
}
// Put writes a new blob to the cache, identified by its digest. The operation
// reads content from r, which must precisely match both the specified size and
// digest.
//
// Concurrent write safety is achieved through file locking. The implementation
// guarantees write integrity by enforcing size limits and content validation
// before allowing the file to reach its final state.
func (c *DiskCache) Put(d Digest, r io.Reader, size int64) error {
return c.copyNamedFile(c.GetFile(d), r, d, size)
}
// Import imports a blob from the provided reader into the cache. It reads the
// entire content of the reader, calculates its digest, and stores it in the
// cache.
//
// Import should be considered unsafe for use with untrusted data, such as data
// read from a network. The caller is responsible for ensuring the integrity of
// the data being imported.
func (c *DiskCache) Import(r io.Reader, size int64) (Digest, error) {
// users that want to change the temp dir can set TEMPDIR.
f, err := os.CreateTemp("", "blob-")
if err != nil {
return Digest{}, err
}
defer os.Remove(f.Name())
// Copy the blob to a temporary file.
h := sha256.New()
r = io.TeeReader(r, h)
n, err := io.Copy(f, r)
if err != nil {
return Digest{}, err
}
if n != size {
return Digest{}, fmt.Errorf("blob: expected %d bytes, got %d", size, n)
}
// Check the digest.
var d Digest
h.Sum(d.sum[:0])
if err := f.Close(); err != nil {
return Digest{}, err
}
name := c.GetFile(d)
// Rename the temporary file to the final file.
if err := os.Rename(f.Name(), name); err != nil {
return Digest{}, err
}
os.Chtimes(name, c.now(), c.now()) // mainly for tests
return d, nil
}
// Get retrieves a blob from the cache using the provided digest. The operation
// fails if the digest is malformed or if any errors occur during blob
// retrieval.
func (c *DiskCache) Get(d Digest) (Entry, error) {
name := c.GetFile(d)
info, err := os.Stat(name)
if err != nil {
return Entry{}, err
}
if info.Size() == 0 {
return Entry{}, fs.ErrNotExist
}
return Entry{
Digest: d,
Size: info.Size(),
Time: info.ModTime(),
}, nil
}
// Link creates a symbolic reference in the cache that maps the provided name
// to a blob identified by its digest, making it retrievable by name using
// [Resolve].
//
// It returns an error if either the name or digest is invalid, or if link
// creation encounters any issues.
func (c *DiskCache) Link(name string, d Digest) error {
manifest, err := c.manifestPath(name)
if err != nil {
return err
}
f, err := os.OpenFile(c.GetFile(d), os.O_RDONLY, 0)
if err != nil {
return err
}
defer f.Close()
// TODO(bmizerany): test this happens only if the blob was found to
// avoid leaving debris
if err := os.MkdirAll(filepath.Dir(manifest), 0o777); err != nil {
return err
}
info, err := f.Stat()
if err != nil {
return err
}
// Copy manifest to cache directory.
return c.copyNamedFile(manifest, f, d, info.Size())
}
// Unlink removes the any link for name. If the link does not exist, nothing
// happens, and no error is returned.
//
// It returns an error if the name is invalid or if the link removal encounters
// any issues.
func (c *DiskCache) Unlink(name string) error {
manifest, err := c.manifestPath(name)
if err != nil {
return err
}
err = os.Remove(manifest)
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return err
}
// GetFile returns the absolute path to the file, in the cache, for the given
// digest. It does not check if the file exists.
//
// The returned path should not be stored, used outside the lifetime of the
// cache, or interpreted in any way.
func (c *DiskCache) GetFile(d Digest) string {
filename := fmt.Sprintf("sha256-%x", d.sum)
return absJoin(c.dir, "blobs", filename)
}
// Links returns a sequence of links in the cache in lexical order.
func (c *DiskCache) Links() iter.Seq2[string, error] {
return func(yield func(string, error) bool) {
for path, err := range c.links() {
if err != nil {
yield("", err)
return
}
if !yield(pathToName(path), nil) {
return
}
}
}
}
// pathToName converts a path to a name. It is the inverse of nameToPath. The
// path is assumed to be in filepath.ToSlash format.
func pathToName(s string) string {
s = strings.TrimPrefix(s, "manifests/")
rr := []rune(s)
for i := len(rr) - 1; i > 0; i-- {
if rr[i] == '/' {
rr[i] = ':'
return string(rr)
}
}
return s
}
// manifestPath finds the first manifest file on disk that matches the given
// name using a case-insensitive comparison. If no manifest file is found, it
// returns the path where the manifest file would be if it existed.
//
// If two manifest files exists on disk that match the given name using a
// case-insensitive comparison, the one that sorts first, lexically, is
// returned.
func (c *DiskCache) manifestPath(name string) (string, error) {
np, err := nameToPath(name)
if err != nil {
return "", err
}
maybe := filepath.Join("manifests", np)
for l, err := range c.links() {
if err != nil {
return "", err
}
if strings.EqualFold(maybe, l) {
return filepath.Join(c.dir, l), nil
}
}
return filepath.Join(c.dir, maybe), nil
}
// links returns a sequence of links in the cache in lexical order.
func (c *DiskCache) links() iter.Seq2[string, error] {
// TODO(bmizerany): reuse empty dirnames if exist
return func(yield func(string, error) bool) {
fsys := os.DirFS(c.dir)
manifests, err := fs.Glob(fsys, "manifests/*/*/*/*")
if err != nil {
yield("", err)
return
}
for _, manifest := range manifests {
if !yield(manifest, nil) {
return
}
}
}
}
type checkWriter struct {
d Digest
size int64
n int64
h hash.Hash
f *os.File
err error
testHookBeforeFinalWrite func(*os.File)
}
func (w *checkWriter) seterr(err error) error {
if w.err == nil {
w.err = err
}
return err
}
// Write writes p to the underlying hash and writer. The last write to the
// underlying writer is guaranteed to be the last byte of p as verified by the
// hash.
func (w *checkWriter) Write(p []byte) (int, error) {
_, err := w.h.Write(p)
if err != nil {
return 0, w.seterr(err)
}
nextSize := w.n + int64(len(p))
if nextSize == w.size {
// last write. check hash.
sum := w.h.Sum(nil)
if !bytes.Equal(sum, w.d.sum[:]) {
return 0, w.seterr(fmt.Errorf("file content changed underfoot"))
}
if w.testHookBeforeFinalWrite != nil {
w.testHookBeforeFinalWrite(w.f)
}
}
if nextSize > w.size {
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
}
n, err := w.f.Write(p)
w.n += int64(n)
return n, w.seterr(err)
}
// copyNamedFile copies file into name, expecting it to have the given Digest
// and size, if that file is not present already.
func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size int64) error {
info, err := os.Stat(name)
if err == nil && info.Size() == size {
// File already exists with correct size. This is good enough.
// We can skip expensive hash checks.
//
// TODO: Do the hash check, but give caller a way to skip it.
return nil
}
// Copy file to cache directory.
mode := os.O_RDWR | os.O_CREATE
if err == nil && info.Size() > size { // shouldn't happen but fix in case
mode |= os.O_TRUNC
}
f, err := os.OpenFile(name, mode, 0o666)
if err != nil {
return err
}
defer f.Close()
if size == 0 {
// File now exists with correct size.
// Only one possible zero-length file, so contents are OK too.
// Early return here makes sure there's a "last byte" for code below.
return nil
}
// From here on, if any of the I/O writing the file fails,
// we make a best-effort attempt to truncate the file f
// before returning, to avoid leaving bad bytes in the file.
// Copy file to f, but also into h to double-check hash.
cw := &checkWriter{
d: out,
size: size,
h: sha256.New(),
f: f,
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
}
n, err := io.Copy(cw, file)
if err != nil {
f.Truncate(0)
return err
}
if n < size {
f.Truncate(0)
return io.ErrUnexpectedEOF
}
if err := f.Close(); err != nil {
// Data might not have been written,
// but file may look like it is the right size.
// To be extra careful, remove cached file.
os.Remove(name)
return err
}
os.Chtimes(name, c.now(), c.now()) // mainly for tests
return nil
}
func splitNameDigest(s string) (name, digest string) {
i := strings.LastIndexByte(s, '@')
if i < 0 {
return s, ""
}
return s[:i], s[i+1:]
}
var errInvalidName = errors.New("invalid name")
func nameToPath(name string) (_ string, err error) {
if strings.Contains(name, "@") {
// TODO(bmizerany): HACK: Fix names.Parse to validate.
// TODO(bmizerany): merge with default parts (maybe names.Merge(a, b))
return "", errInvalidName
}
n := names.Parse(name)
if !n.IsFullyQualified() {
return "", errInvalidName
}
return filepath.Join(n.Host(), n.Namespace(), n.Model(), n.Tag()), nil
}
func absJoin(pp ...string) string {
abs, err := filepath.Abs(filepath.Join(pp...))
if err != nil {
// Likely a bug bug or a bad OS problem. Just panic.
panic(err)
}
return abs
}

685
server/internal/cache/blob/cache_test.go vendored Normal file
View File

@@ -0,0 +1,685 @@
package blob
import (
"crypto/sha256"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"time"
"github.com/ollama/ollama/server/internal/internal/testutil"
)
func init() {
debug = true
}
var epoch = func() time.Time {
d := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
if d.IsZero() {
panic("time zero")
}
return d
}()
func TestOpenErrors(t *testing.T) {
exe, err := os.Executable()
if err != nil {
panic(err)
}
cases := []struct {
dir string
err string
}{
{t.TempDir(), ""},
{"", "empty directory name"},
{exe, "not a directory"},
}
for _, tt := range cases {
t.Run(tt.dir, func(t *testing.T) {
_, err := Open(tt.dir)
if tt.err == "" {
if err != nil {
t.Fatal(err)
}
return
}
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), tt.err) {
t.Fatalf("err = %v, want %q", err, tt.err)
}
})
}
}
func TestGetFile(t *testing.T) {
t.Chdir(t.TempDir())
c, err := Open(".")
if err != nil {
t.Fatal(err)
}
d := mkdigest("1")
got := c.GetFile(d)
cleaned := filepath.Clean(got)
if cleaned != got {
t.Fatalf("got is unclean: %q", got)
}
if !filepath.IsAbs(got) {
t.Fatal("got is not absolute")
}
abs, _ := filepath.Abs(c.dir)
if !strings.HasPrefix(got, abs) {
t.Fatalf("got is not local to %q", c.dir)
}
}
func TestBasic(t *testing.T) {
c, err := Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
now := epoch
c.now = func() time.Time { return now }
checkEntry := entryChecker(t, c)
checkFailed := func(err error) {
if err == nil {
t.Helper()
t.Fatal("expected error")
}
}
_, err = c.Resolve("invalid")
checkFailed(err)
_, err = c.Resolve("h/n/m:t")
checkFailed(err)
dx := mkdigest("x")
d, err := c.Resolve(fmt.Sprintf("h/n/m:t@%s", dx))
if err != nil {
t.Fatal(err)
}
if d != dx {
t.Fatalf("d = %v, want %v", d, dx)
}
_, err = c.Get(Digest{})
checkFailed(err)
// not committed yet
_, err = c.Get(dx)
checkFailed(err)
err = PutBytes(c, dx, "!")
checkFailed(err)
err = PutBytes(c, dx, "x")
if err != nil {
t.Fatal(err)
}
checkEntry(dx, 1, now)
t0 := now
now = now.Add(1*time.Hour + 1*time.Minute)
err = PutBytes(c, dx, "x")
if err != nil {
t.Fatal(err)
}
// check not updated
checkEntry(dx, 1, t0)
}
type sleepFunc func(d time.Duration) time.Time
func openTester(t *testing.T) (*DiskCache, sleepFunc) {
t.Helper()
c, err := Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
now := epoch
c.now = func() time.Time { return now }
return c, func(d time.Duration) time.Time {
now = now.Add(d)
return now
}
}
func TestManifestPath(t *testing.T) {
check := testutil.Checker(t)
c, sleep := openTester(t)
d1 := mkdigest("1")
err := PutBytes(c, d1, "1")
check(err)
err = c.Link("h/n/m:t", d1)
check(err)
t0 := sleep(0)
sleep(1 * time.Hour)
err = c.Link("h/n/m:t", d1) // nop expected
check(err)
file := must(c.manifestPath("h/n/m:t"))
info, err := os.Stat(file)
check(err)
testutil.CheckTime(t, info.ModTime(), t0)
}
func TestManifestExistsWithoutBlob(t *testing.T) {
t.Chdir(t.TempDir())
check := testutil.Checker(t)
c, err := Open(".")
check(err)
checkEntry := entryChecker(t, c)
man := must(c.manifestPath("h/n/m:t"))
os.MkdirAll(filepath.Dir(man), 0o777)
testutil.WriteFile(t, man, "1")
got, err := c.Resolve("h/n/m:t")
check(err)
want := mkdigest("1")
if got != want {
t.Fatalf("got = %v, want %v", got, want)
}
e, err := c.Get(got)
check(err)
checkEntry(got, 1, e.Time)
}
func TestPut(t *testing.T) {
c, sleep := openTester(t)
check := testutil.Checker(t)
checkEntry := entryChecker(t, c)
d := mkdigest("hello, world")
err := PutBytes(c, d, "hello")
if err == nil {
t.Fatal("expected error")
}
got, err := c.Get(d)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("expected error, got %v", got)
}
// Put a valid blob
err = PutBytes(c, d, "hello, world")
check(err)
checkEntry(d, 12, sleep(0))
// Put a blob with content that does not hash to the digest
err = PutBytes(c, d, "hello")
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
// Put the valid blob back and check it
err = PutBytes(c, d, "hello, world")
check(err)
checkEntry(d, 12, sleep(0))
// Put a blob that errors during Read
err = c.Put(d, &errOnBangReader{s: "!"}, 1)
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
// Put valid blob back and check it
err = PutBytes(c, d, "hello, world")
check(err)
checkEntry(d, 12, sleep(0))
// Put a blob with mismatched size
err = c.Put(d, strings.NewReader("hello, world"), 11)
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
// Final byte does not match the digest (testing commit phase)
err = PutBytes(c, d, "hello, world$")
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
reset := c.setTestHookBeforeFinalWrite(func(f *os.File) {
// change mode to read-only
f.Truncate(0)
f.Chmod(0o400)
f.Close()
f1, err := os.OpenFile(f.Name(), os.O_RDONLY, 0)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { f1.Close() })
*f = *f1
})
defer reset()
err = PutBytes(c, d, "hello, world")
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
reset()
}
func TestImport(t *testing.T) {
c, _ := openTester(t)
checkEntry := entryChecker(t, c)
want := mkdigest("x")
got, err := c.Import(strings.NewReader("x"), 1)
if err != nil {
t.Fatal(err)
}
if want != got {
t.Fatalf("digest = %v, want %v", got, want)
}
checkEntry(want, 1, epoch)
got, err = c.Import(strings.NewReader("x"), 1)
if err != nil {
t.Fatal(err)
}
if want != got {
t.Fatalf("digest = %v, want %v", got, want)
}
checkEntry(want, 1, epoch)
}
func (c *DiskCache) setTestHookBeforeFinalWrite(h func(*os.File)) (reset func()) {
old := c.testHookBeforeFinalWrite
c.testHookBeforeFinalWrite = h
return func() { c.testHookBeforeFinalWrite = old }
}
func TestPutGetZero(t *testing.T) {
c, sleep := openTester(t)
check := testutil.Checker(t)
checkEntry := entryChecker(t, c)
d := mkdigest("x")
err := PutBytes(c, d, "x")
check(err)
checkEntry(d, 1, sleep(0))
err = os.Truncate(c.GetFile(d), 0)
check(err)
_, err = c.Get(d)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want fs.ErrNotExist", err)
}
}
func TestPutZero(t *testing.T) {
c, _ := openTester(t)
d := mkdigest("x")
err := c.Put(d, strings.NewReader("x"), 0) // size == 0 (not size of content)
testutil.Check(t, err)
checkNotExists(t, c, d)
}
func TestCommit(t *testing.T) {
check := testutil.Checker(t)
c, err := Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
checkEntry := entryChecker(t, c)
now := epoch
c.now = func() time.Time { return now }
d1 := mkdigest("1")
err = c.Link("h/n/m:t", d1)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want fs.ErrNotExist", err)
}
err = PutBytes(c, d1, "1")
check(err)
err = c.Link("h/n/m:t", d1)
check(err)
got, err := c.Resolve("h/n/m:t")
check(err)
if got != d1 {
t.Fatalf("d = %v, want %v", got, d1)
}
// commit again, more than 1 byte
d2 := mkdigest("22")
err = PutBytes(c, d2, "22")
check(err)
err = c.Link("h/n/m:t", d2)
check(err)
checkEntry(d2, 2, now)
filename := must(c.manifestPath("h/n/m:t"))
data, err := os.ReadFile(filename)
check(err)
if string(data) != "22" {
t.Fatalf("data = %q, want %q", data, "22")
}
t0 := now
now = now.Add(1 * time.Hour)
err = c.Link("h/n/m:t", d2) // same contents; nop
check(err)
info, err := os.Stat(filename)
check(err)
testutil.CheckTime(t, info.ModTime(), t0)
}
func TestManifestInvalidBlob(t *testing.T) {
c, _ := openTester(t)
d := mkdigest("1")
err := c.Link("h/n/m:t", d)
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
err = PutBytes(c, d, "1")
testutil.Check(t, err)
err = os.WriteFile(c.GetFile(d), []byte("invalid"), 0o666)
if err != nil {
t.Fatal(err)
}
err = c.Link("h/n/m:t", d)
if !strings.Contains(err.Error(), "underfoot") {
t.Fatalf("err = %v, want error to contain %q", err, "underfoot")
}
}
func TestManifestNameReuse(t *testing.T) {
t.Run("case-insensitive", func(t *testing.T) {
// This should run on all file system types.
testManifestNameReuse(t)
})
t.Run("case-sensitive", func(t *testing.T) {
useCaseInsensitiveTempDir(t)
testManifestNameReuse(t)
})
}
func testManifestNameReuse(t *testing.T) {
check := testutil.Checker(t)
c, _ := openTester(t)
d1 := mkdigest("1")
err := PutBytes(c, d1, "1")
check(err)
err = c.Link("h/n/m:t", d1)
check(err)
d2 := mkdigest("22")
err = PutBytes(c, d2, "22")
check(err)
err = c.Link("H/N/M:T", d2)
check(err)
var g [2]Digest
g[0], err = c.Resolve("h/n/m:t")
check(err)
g[1], err = c.Resolve("H/N/M:T")
check(err)
w := [2]Digest{d2, d2}
if g != w {
t.Fatalf("g = %v, want %v", g, w)
}
var got []string
for l, err := range c.links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
}
want := []string{"manifests/h/n/m/t"}
if !slices.Equal(got, want) {
t.Fatalf("got = %v, want %v", got, want)
}
// relink with different case
err = c.Unlink("h/n/m:t")
check(err)
err = c.Link("h/n/m:T", d1)
check(err)
got = got[:0]
for l, err := range c.links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
}
// we should have only one link that is same case as the last link
want = []string{"manifests/h/n/m/T"}
if !slices.Equal(got, want) {
t.Fatalf("got = %v, want %v", got, want)
}
}
func TestManifestFile(t *testing.T) {
cases := []struct {
in string
want string
}{
{"", ""},
// valid names
{"h/n/m:t", "/manifests/h/n/m/t"},
{"hh/nn/mm:tt", "/manifests/hh/nn/mm/tt"},
{"%/%/%/%", ""},
// already a path
{"h/n/m/t", ""},
// refs are not names
{"h/n/m:t@sha256-1", ""},
{"m@sha256-1", ""},
{"n/m:t@sha256-1", ""},
}
c, _ := openTester(t)
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
got, err := c.manifestPath(tt.in)
if err != nil && tt.want != "" {
t.Fatalf("unexpected error: %v", err)
}
if err == nil && tt.want == "" {
t.Fatalf("expected error")
}
dir := filepath.ToSlash(c.dir)
got = filepath.ToSlash(got)
got = strings.TrimPrefix(got, dir)
if got != tt.want {
t.Fatalf("got = %q, want %q", got, tt.want)
}
})
}
}
func TestNames(t *testing.T) {
c, _ := openTester(t)
check := testutil.Checker(t)
check(PutBytes(c, mkdigest("1"), "1"))
check(PutBytes(c, mkdigest("2"), "2"))
check(c.Link("h/n/m:t", mkdigest("1")))
check(c.Link("h/n/m:u", mkdigest("2")))
var got []string
for l, err := range c.Links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
}
want := []string{"h/n/m:t", "h/n/m:u"}
if !slices.Equal(got, want) {
t.Fatalf("got = %v, want %v", got, want)
}
}
func mkdigest(s string) Digest {
return Digest{sha256.Sum256([]byte(s))}
}
func checkNotExists(t *testing.T, c *DiskCache, d Digest) {
t.Helper()
_, err := c.Get(d)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want fs.ErrNotExist", err)
}
}
func entryChecker(t *testing.T, c *DiskCache) func(Digest, int64, time.Time) {
t.Helper()
return func(d Digest, size int64, mod time.Time) {
t.Helper()
t.Run("checkEntry:"+d.String(), func(t *testing.T) {
t.Helper()
defer func() {
if t.Failed() {
dumpCacheContents(t, c)
}
}()
e, err := c.Get(d)
if size == 0 && errors.Is(err, fs.ErrNotExist) {
err = nil
}
if err != nil {
t.Fatal(err)
}
if e.Digest != d {
t.Errorf("e.Digest = %v, want %v", e.Digest, d)
}
if e.Size != size {
t.Fatalf("e.Size = %v, want %v", e.Size, size)
}
testutil.CheckTime(t, e.Time, mod)
info, err := os.Stat(c.GetFile(d))
if err != nil {
t.Fatal(err)
}
if info.Size() != size {
t.Fatalf("info.Size = %v, want %v", info.Size(), size)
}
testutil.CheckTime(t, info.ModTime(), mod)
})
}
}
func must[T any](v T, err error) T {
if err != nil {
panic(err)
}
return v
}
func TestNameToPath(t *testing.T) {
_, err := nameToPath("h/n/m:t")
if err != nil {
t.Fatal(err)
}
}
type errOnBangReader struct {
s string
n int
}
func (e *errOnBangReader) Read(p []byte) (int, error) {
if len(p) < 1 {
return 0, io.ErrShortBuffer
}
if e.n >= len(p) {
return 0, io.EOF
}
if e.s[e.n] == '!' {
return 0, errors.New("bang")
}
p[0] = e.s[e.n]
e.n++
return 1, nil
}
func dumpCacheContents(t *testing.T, c *DiskCache) {
t.Helper()
var b strings.Builder
fsys := os.DirFS(c.dir)
fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
t.Helper()
if err != nil {
return err
}
info, err := d.Info()
if err != nil {
return err
}
// Format like ls:
//
// ; ls -la
// drwxr-xr-x 224 Jan 13 14:22 blob/sha256-123
// drwxr-xr-x 224 Jan 13 14:22 manifest/h/n/m
fmt.Fprintf(&b, " %s % 4d %s %s\n",
info.Mode(),
info.Size(),
info.ModTime().Format("Jan 2 15:04"),
path,
)
return nil
})
t.Log()
t.Logf("cache contents:\n%s", b.String())
}

View File

@@ -0,0 +1,93 @@
package blob
import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
func isCaseSensitive(dir string) bool {
defer func() {
os.Remove(filepath.Join(dir, "_casecheck"))
}()
exists := func(file string) bool {
_, err := os.Stat(file)
return err == nil
}
file := filepath.Join(dir, "_casecheck")
FILE := filepath.Join(dir, "_CASECHECK")
if exists(file) || exists(FILE) {
panic(fmt.Sprintf("_casecheck already exists in %q; remove and try again.", dir))
}
err := os.WriteFile(file, nil, 0o666)
if err != nil {
panic(err)
}
return !exists(FILE)
}
func isCI() bool {
return os.Getenv("CI") != ""
}
const volumeHint = `
Unable to locate case-insensitive TMPDIR on darwin.
To run tests, create the case-insensitive volume /Volumes/data:
$ sudo diskutil apfs addVolume disk1 APFSX data -mountpoint /Volumes/data
or run with:
CI=1 go test ./...
`
// useCaseInsensitiveTempDir sets TMPDIR to a case-insensitive directory
// can find one, otherwise it skips the test if the CI environment variable is
// set, or GOOS is not darwin.
func useCaseInsensitiveTempDir(t *testing.T) bool {
if isCaseSensitive(os.TempDir()) {
// Use the default temp dir if it is already case-sensitive.
return true
}
if runtime.GOOS == "darwin" {
// If darwin, check for the special case-sensitive volume and
// use it if available.
const volume = "/Volumes/data"
_, err := os.Stat(volume)
if err == nil {
tmpdir := filepath.Join(volume, "tmp")
os.MkdirAll(tmpdir, 0o700)
t.Setenv("TMPDIR", tmpdir)
return true
}
if isCI() {
// Special case darwin in CI; it is not case-sensitive
// by default, and we will be testing other platforms
// that are case-sensitive, so we'll have the test
// being skipped covered there.
t.Skip("Skipping test in CI for darwin; TMPDIR is not case-insensitive.")
}
}
if !isCI() {
// Require devs to always tests with a case-insensitive TMPDIR.
// TODO(bmizerany): Print platform-specific instructions or
// link to docs on that topic.
lines := strings.Split(volumeHint, "\n")
for _, line := range lines {
t.Log(line)
}
}
return false
}

95
server/internal/cache/blob/digest.go vendored Normal file
View File

@@ -0,0 +1,95 @@
package blob
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"slices"
"strings"
)
var ErrInvalidDigest = errors.New("invalid digest")
// Digest is a blob identifier that is the SHA-256 hash of a blob's content.
//
// It is comparable and can be used as a map key.
type Digest struct {
sum [32]byte
}
// ParseDigest parses a digest from a string. If the string is not a valid
// digest, a call to the returned digest's IsValid method will return false.
//
// The input string may be in one of two forms:
//
// - ("sha256-<hex>"), where <hex> is a 64-character hexadecimal string.
// - ("sha256:<hex>"), where <hex> is a 64-character hexadecimal string.
//
// The [Digest.String] method will return the canonical form of the
// digest, "sha256:<hex>".
func ParseDigest[S ~[]byte | ~string](v S) (Digest, error) {
s := string(v)
i := strings.IndexAny(s, ":-")
var zero Digest
if i < 0 {
return zero, ErrInvalidDigest
}
prefix, sum := s[:i], s[i+1:]
if prefix != "sha256" || len(sum) != 64 {
return zero, ErrInvalidDigest
}
var d Digest
_, err := hex.Decode(d.sum[:], []byte(sum))
if err != nil {
return zero, ErrInvalidDigest
}
return d, nil
}
func DigestFromBytes[S ~[]byte | ~string](v S) Digest {
return Digest{sha256.Sum256([]byte(v))}
}
// String returns the string representation of the digest in the conventional
// form "sha256:<hex>".
func (d Digest) String() string {
return fmt.Sprintf("sha256:%x", d.sum[:])
}
func (d Digest) Short() string {
return fmt.Sprintf("%x", d.sum[:4])
}
func (d Digest) Compare(other Digest) int {
return slices.Compare(d.sum[:], other.sum[:])
}
// IsValid returns true if the digest is valid, i.e. if it is the SHA-256 hash
// of some content.
func (d Digest) IsValid() bool {
return d != (Digest{})
}
// MarshalText implements the encoding.TextMarshaler interface. It returns an
// error if [Digest.IsValid] returns false.
func (d Digest) MarshalText() ([]byte, error) {
return []byte(d.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface, and only
// works for a zero digest. If [Digest.IsValid] returns true, it returns an
// error.
func (d *Digest) UnmarshalText(text []byte) error {
if *d != (Digest{}) {
return errors.New("digest: illegal UnmarshalText on valid digest")
}
v, err := ParseDigest(string(text))
if err != nil {
return err
}
*d = v
return nil
}

View File

@@ -0,0 +1,63 @@
package blob
import (
"encoding/json"
"testing"
)
func TestParseDigest(t *testing.T) {
cases := []struct {
in string
valid bool
}{
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true},
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true},
// too short
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false},
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false},
// too long
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false},
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false},
// invalid prefix
{"sha255-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
{"sha255:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
{"sha256!0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
// invalid hex
{"sha256-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false},
{"sha256:XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false},
}
for _, tt := range cases {
got, err := ParseDigest(tt.in)
if tt.valid && err != nil {
t.Errorf("ParseDigest(%q) = %v, %v; want valid", tt.in, got, err)
}
want := "sha256:" + tt.in[7:]
if tt.valid && got.String() != want {
t.Errorf("ParseDigest(%q).String() = %q, want %q", tt.in, got.String(), want)
}
}
}
func TestDigestMarshalText(t *testing.T) {
const s = `"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"`
var d Digest
if err := json.Unmarshal([]byte(s), &d); err != nil {
t.Errorf("json.Unmarshal: %v", err)
}
out, err := json.Marshal(d)
if err != nil {
t.Errorf("json.Marshal: %v", err)
}
want := `"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"`
if string(out) != want {
t.Errorf("json.Marshal: got %s, want %s", out, want)
}
if err := json.Unmarshal([]byte(`"invalid"`), &Digest{}); err == nil {
t.Errorf("json.Unmarshal: expected error")
}
}

View File

@@ -0,0 +1,78 @@
package chunks
import (
"fmt"
"iter"
"strconv"
"strings"
)
type Chunk struct {
Start, End int64
}
func New(start, end int64) Chunk {
return Chunk{start, end}
}
// ParseRange parses a string in the form "unit=range" where unit is a string
// and range is a string in the form "start-end". It returns the unit and the
// range as a Chunk.
func ParseRange(s string) (unit string, _ Chunk, _ error) {
unit, r, _ := strings.Cut(s, "=")
if r == "" {
return unit, Chunk{}, nil
}
c, err := Parse(r)
if err != nil {
return "", Chunk{}, err
}
return unit, c, err
}
// Parse parses a string in the form "start-end" and returns the Chunk.
func Parse(s string) (Chunk, error) {
startStr, endStr, _ := strings.Cut(s, "-")
start, err := strconv.ParseInt(startStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid start: %v", err)
}
end, err := strconv.ParseInt(endStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid end: %v", err)
}
if start > end {
return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end)
}
return Chunk{start, end}, nil
}
// Of returns a sequence of contiguous Chunks of size chunkSize that cover
// the range [0, size), in order.
func Of(size, chunkSize int64) iter.Seq[Chunk] {
return func(yield func(Chunk) bool) {
for start := int64(0); start < size; start += chunkSize {
end := min(start+chunkSize-1, size-1)
if !yield(Chunk{start, end}) {
break
}
}
}
}
// Count returns the number of Chunks of size chunkSize needed to cover the
// range [0, size).
func Count(size, chunkSize int64) int64 {
return (size + chunkSize - 1) / chunkSize
}
// Size returns end minus start plus one.
func (c Chunk) Size() int64 {
return c.End - c.Start + 1
}
// String returns the string representation of the Chunk in the form
// "{start}-{end}".
func (c Chunk) String() string {
return fmt.Sprintf("%d-%d", c.Start, c.End)
}

View File

@@ -0,0 +1,65 @@
package chunks
import (
"slices"
"testing"
)
func TestOf(t *testing.T) {
cases := []struct {
total int64
chunkSize int64
want []Chunk
}{
{0, 1, nil},
{1, 1, []Chunk{{0, 0}}},
{1, 2, []Chunk{{0, 0}}},
{2, 1, []Chunk{{0, 0}, {1, 1}}},
{10, 9, []Chunk{{0, 8}, {9, 9}}},
}
for _, tt := range cases {
got := slices.Collect(Of(tt.total, tt.chunkSize))
if !slices.Equal(got, tt.want) {
t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want)
}
}
}
func TestSize(t *testing.T) {
cases := []struct {
c Chunk
want int64
}{
{Chunk{0, 0}, 1},
{Chunk{0, 1}, 2},
{Chunk{3, 4}, 2},
}
for _, tt := range cases {
got := tt.c.Size()
if got != tt.want {
t.Errorf("%v: got %d; want %d", tt.c, got, tt.want)
}
}
}
func TestCount(t *testing.T) {
cases := []struct {
total int64
chunkSize int64
want int64
}{
{0, 1, 0},
{1, 1, 1},
{1, 2, 1},
{2, 1, 2},
{10, 9, 2},
}
for _, tt := range cases {
got := Count(tt.total, tt.chunkSize)
if got != tt.want {
t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want)
}
}
}

View File

@@ -0,0 +1,802 @@
// Package ollama provides a client for interacting with an Ollama registry
// which pushes and pulls model manifests and layers as defined by the
// [ollama.com/manifest].
package ollama
import (
"bufio"
"bytes"
"cmp"
"context"
"crypto"
"crypto/ed25519"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync/atomic"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/internal/backoff"
"github.com/ollama/ollama/server/internal/internal/names"
"github.com/ollama/ollama/server/internal/internal/syncs"
_ "embed"
)
// Errors
var (
// ErrManifestNotFound is returned when a manifest is not found in the
// cache or registry.
ErrManifestNotFound = errors.New("manifest not found")
// ErrManifestInvalid is returned when a manifest found in a local or
// remote cache is invalid.
ErrManifestInvalid = errors.New("invalid manifest")
// ErrMissingModel is returned when the model part of a name is missing
// or invalid.
ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}")
// ErrCached is passed to [Trace.PushUpdate] when a layer already
// exists. It is a non-fatal error and is never returned by [Registry.Push].
ErrCached = errors.New("cached")
)
// Defaults
const (
// DefaultChunkingThreshold is the threshold at which a layer should be
// split up into chunks when downloading.
DefaultChunkingThreshold = 128 << 20
// DefaultMaxChunkSize is the default maximum size of a chunk to
// download. It is configured based on benchmarks and aims to strike a
// balance between download speed and memory usage.
DefaultMaxChunkSize = 8 << 20
)
// DefaultCache returns a new disk cache for storing models. If the
// OLLAMA_MODELS environment variable is set, it uses that directory;
// otherwise, it uses $HOME/.ollama/models.
func DefaultCache() (*blob.DiskCache, error) {
dir := os.Getenv("OLLAMA_MODELS")
if dir == "" {
home, err := os.UserHomeDir()
if err != nil {
return nil, err
}
dir = filepath.Join(home, ".ollama", "models")
}
return blob.Open(dir)
}
// Error is the standard error returned by Ollama APIs.
type Error struct {
Status int `json:"-"`
Code string `json:"code"`
Message string `json:"message"`
}
func (e *Error) Error() string {
return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
}
// UnmarshalJSON implements json.Unmarshaler.
func (e *Error) UnmarshalJSON(b []byte) error {
type E Error
var v struct{ Errors []E }
if err := json.Unmarshal(b, &v); err != nil {
return err
}
if len(v.Errors) == 0 {
return fmt.Errorf("no messages in error response: %s", string(b))
}
*e = Error(v.Errors[0]) // our registry only returns one error.
return nil
}
// TODO(bmizerany): make configurable on [Registry]
var defaultName = func() names.Name {
n := names.Parse("ollama.com/library/_:latest")
if !n.IsFullyQualified() {
panic("default name is not fully qualified")
}
return n
}()
// Registry is a client for performing push and pull operations against an
// Ollama registry.
type Registry struct {
// UserAgent is the User-Agent header to send with requests to the
// registry. If empty, the User-Agent is determined by HTTPClient.
UserAgent string
// Key is the key used to authenticate with the registry.
//
// Currently, only Ed25519 keys are supported.
Key crypto.PrivateKey
// HTTPClient is the HTTP client used to make requests to the registry.
//
// If nil, [http.DefaultClient] is used.
//
// As a quick note: If a Registry function that makes a call to a URL
// with the "https+insecure" scheme, the client will be cloned and the
// transport will be set to skip TLS verification, unless the client's
// Transport done not have a Clone method with the same signature as
// [http.Transport.Clone], which case, the call will fail.
HTTPClient *http.Client
// MaxStreams is the maximum number of concurrent streams to use when
// pushing or pulling models. If zero, the number of streams is
// determined by [runtime.GOMAXPROCS].
//
// Clients that want "unlimited" streams should set this to a large
// number.
MaxStreams int
// ChunkingThreshold is the maximum size of a layer to download in a single
// request. If zero, [DefaultChunkingThreshold] is used.
ChunkingThreshold int64
// MaxChunkSize is the maximum size of a chunk to download. If zero,
// the default is [DefaultMaxChunkSize].
//
// It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64
}
// RegistryFromEnv returns a new Registry configured from the environment. The
// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
// system's temporary directory.
//
// It returns an error if any configuration in the environment is invalid.
func RegistryFromEnv() (*Registry, error) {
home, err := os.UserHomeDir()
if err != nil {
return nil, err
}
keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519"))
if err != nil {
return nil, err
}
var rc Registry
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
if err != nil {
return nil, err
}
maxStreams := os.Getenv("OLLAMA_REGISTRY_MAXSTREAMS")
if maxStreams != "" {
var err error
rc.MaxStreams, err = strconv.Atoi(maxStreams)
if err != nil {
return nil, fmt.Errorf("invalid OLLAMA_REGISTRY_MAXSTREAMS: %w", err)
}
}
return &rc, nil
}
type PushParams struct {
// From is an optional destination name for the model. If empty, the
// destination name is the same as the source name.
From string
}
// parseName parses name using [names.ParseExtended] and then merges the name with the
// default name, and checks that the name is fully qualified. If a digest is
// present, it parse and returns it with the other fields as their zero values.
//
// It returns an error if the name is not fully qualified, or if the digest, if
// any, is invalid.
//
// The scheme is returned as provided by [names.ParseExtended].
func parseName(s string) (scheme string, n names.Name, d blob.Digest, err error) {
scheme, n, ds := names.ParseExtended(s)
n = names.Merge(n, defaultName)
if ds != "" {
// Digest is present. Validate it.
d, err = blob.ParseDigest(ds)
if err != nil {
return "", names.Name{}, blob.Digest{}, err
}
}
// The name check is deferred until after the digest check because we
// say that digests take precedence over names, and so should there
// errors when being parsed.
if !n.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, ErrNameInvalid
}
scheme = cmp.Or(scheme, "https")
return scheme, n, d, nil
}
func (r *Registry) maxStreams() int {
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
// Large downloads require a writter stream, so ensure we have at least
// two streams to avoid a deadlock.
return max(n, 2)
}
func (r *Registry) maxChunkingThreshold() int64 {
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
}
// chunkSizeFor returns the chunk size for a layer of the given size. If the
// size is less than or equal to the max chunking threshold, the size is
// returned; otherwise, the max chunk size is returned.
func (r *Registry) maxChunkSize() int64 {
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
}
// Push pushes the model with the name in the cache to the remote registry.
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
if p == nil {
p = &PushParams{}
}
m, err := ResolveLocal(c, cmp.Or(p.From, name))
if err != nil {
return err
}
// Before much else happens, check layers at not null, the blobs exist,
// and the sizes match. This prevents long uploads followed by
// disappointment.
for _, l := range m.Layers {
if l == nil {
return fmt.Errorf("%w: null layer", ErrManifestInvalid)
}
info, err := c.Get(l.Digest)
if err != nil {
return fmt.Errorf("error getting %s: %w", l.Digest.Short(), err)
}
if info.Size != l.Size {
return fmt.Errorf("size mismatch for %s: %d != %d", l.Digest.Short(), info.Size, l.Size)
}
}
t := traceFromContext(ctx)
scheme, n, _, err := parseName(name)
if err != nil {
// This should never happen since ResolveLocal should have
// already validated the name.
panic(err)
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var g errgroup.Group
g.SetLimit(r.maxStreams())
for _, l := range m.Layers {
var progress atomic.Int64
g.Go(func() (err error) {
defer func() { t.update(l, progress.Load(), err) }()
t.update(l, 0, nil)
startURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/uploads/?digest=%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
)
res, err := r.doOK(ctx, "POST", startURL, nil)
if err != nil {
return err
}
res.Body.Close()
f, err := os.Open(c.GetFile(l.Digest))
if err != nil {
return err
}
defer f.Close()
uploadURL := res.Header.Get("Location")
if uploadURL == "" {
t.update(l, l.Size, ErrCached)
return nil
}
req, err := r.newRequest(ctx, "PUT", uploadURL, f)
if err != nil {
return fmt.Errorf("invalid upload URL returned from registry: %q: %w", uploadURL, err)
}
req.ContentLength = l.Size
res, err = doOK(r.client(), req)
if err == nil {
res.Body.Close()
}
return err
})
}
if err := g.Wait(); err != nil {
return err
}
// Commit
path := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
n.Tag(),
)
res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data))
if err == nil {
res.Body.Close()
}
// TODO(bmizerany): add a "commit" trace event
return err
}
func canRetry(err error) bool {
var re *Error
if !errors.As(err, &re) {
return false
}
return re.Status >= 500
}
// Pull pulls the model with the given name from the remote registry into the
// cache.
//
// For layers larger then [Registry.MaxChunkSize], the layer is downloaded in
// chunks of the specified size, and then reassembled and verified. This is
// typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
scheme, n, _, err := parseName(name)
if err != nil {
return err
}
m, err := r.Resolve(ctx, name)
if err != nil {
return err
}
if len(m.Layers) == 0 {
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
}
exists := func(l *Layer) bool {
info, err := c.Get(l.Digest)
return err == nil && info.Size == l.Size
}
t := traceFromContext(ctx)
var g errgroup.Group
g.SetLimit(r.maxStreams())
for _, l := range m.Layers {
if exists(l) {
t.update(l, l.Size, ErrCached)
continue
}
blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest)
req, err := r.newRequest(ctx, "GET", blobURL, nil)
if err != nil {
t.update(l, 0, err)
continue
}
t.update(l, 0, nil)
if l.Size <= r.maxChunkingThreshold() {
g.Go(func() error {
res, err := doOK(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
err = c.Put(l.Digest, res.Body, l.Size)
if err == nil {
t.update(l, l.Size, nil)
}
return err
})
} else {
q := syncs.NewRelayReader()
g.Go(func() (err error) {
defer func() { q.CloseWithError(err) }()
return c.Put(l.Digest, q, l.Size)
})
var progress atomic.Int64
// We want to avoid extra round trips per chunk due to
// redirects from the registry to the blob store, so
// fire an initial request to get the final URL and
// then use that URL for the chunk requests.
req.Header.Set("Range", "bytes=0-0")
res, err := doOK(r.client(), req)
if err != nil {
return err
}
res.Body.Close()
req = res.Request.WithContext(req.Context())
streamNo := 0
tws := make([]*bufio.Writer, r.maxStreams()-1)
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
ticket := q.Take()
bufIdx := streamNo % len(tws)
streamNo++
g.Go(func() (err error) {
defer func() {
if err != nil {
q.CloseWithError(err)
}
ticket.Close()
t.update(l, progress.Load(), err)
}()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := func() error {
req := req.Clone(req.Context())
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := doOK(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
tw := tws[bufIdx]
if tw == nil {
tw = bufio.NewWriterSize(nil, int(r.maxChunkSize()))
tws[bufIdx] = tw
}
tw.Reset(ticket)
defer tw.Reset(nil) // release ticket
_, err = io.CopyN(tw, res.Body, chunk.Size())
if err != nil {
return maybeUnexpectedEOF(err)
}
if err := tw.Flush(); err != nil {
return err
}
total := progress.Add(chunk.Size())
if total >= l.Size {
q.Close()
}
return nil
}()
if !canRetry(err) {
return err
}
}
return nil
})
}
}
}
if err := g.Wait(); err != nil {
return err
}
// store the manifest blob
md := blob.DigestFromBytes(m.Data)
if err := blob.PutBytes(c, md, m.Data); err != nil {
return err
}
// commit the manifest with a link
return c.Link(m.Name, md)
}
// Manifest represents a [ollama.com/manifest].
type Manifest struct {
Name string `json:"-"` // the canonical name of the model
Data []byte `json:"-"` // the raw data of the manifest
Layers []*Layer `json:"layers"`
}
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
// Layer returns the layer with the given
// digest, or nil if not found.
func (m *Manifest) Layer(d blob.Digest) *Layer {
for _, l := range m.Layers {
if l.Digest == d {
return l
}
}
return nil
}
// MarshalJSON implements json.Marshaler.
//
// NOTE: It adds an empty config object to the manifest, which is required by
// the registry, but not used by the client. In the future, the config object
// will not be required by the registry and this will should be removed.
func (m Manifest) MarshalJSON() ([]byte, error) {
type M Manifest
v := struct {
M
// This is ignored, mostly, by the registry But, if not
// present, it will cause an error to be returned during the
// last phase of the commit which expects it, but does nothing
// with it. This will be fixed in a future release of
// ollama.com.
Config *Layer `json:"config"`
}{
M: M(m),
Config: &Layer{Digest: emptyDigest},
}
return json.Marshal(v)
}
// unmarshalManifest unmarshals the data into a manifest, and sets the name
// field to the string representation of the name.
//
// It panics if the name is not fully qualified. Callers should ensure the name
// is fully qualified before calling this function.
func unmarshalManifest(n names.Name, data []byte) (*Manifest, error) {
if !n.IsFullyQualified() {
panic(fmt.Sprintf("unmarshalManifest: name is not fully qualified: %s", n.String()))
}
var m Manifest
if err := json.Unmarshal(data, &m); err != nil {
return nil, err
}
m.Name = n.String()
m.Data = data
return &m, nil
}
// Layer is a layer in a model.
type Layer struct {
Digest blob.Digest `json:"digest"`
MediaType string `json:"mediaType"`
Size int64 `json:"size"`
}
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
// parsed using [names.ParseExtended] but the scheme is ignored.
func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
_, n, d, err := parseName(name)
if err != nil {
return nil, err
}
if !d.IsValid() {
d, err = c.Resolve(n.String())
if err != nil {
return nil, err
}
}
data, err := os.ReadFile(c.GetFile(d))
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name)
}
return nil, err
}
m, err := unmarshalManifest(n, data)
if err != nil {
return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err))
}
return m, nil
}
// Resolve resolves a name to a Manifest in the remote registry.
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
scheme, n, d, err := parseName(name)
if err != nil {
return nil, err
}
manifestURL := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s", scheme, n.Host(), n.Namespace(), n.Model(), n.Tag())
if d.IsValid() {
manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
}
res, err := r.doOK(ctx, "GET", manifestURL, nil)
if err != nil {
return nil, err
}
defer res.Body.Close()
data, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
// TODO(bmizerany): return digest here
m, err := unmarshalManifest(n, data)
if err != nil {
return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err))
}
return m, nil
}
func (r *Registry) client() *http.Client {
if r.HTTPClient != nil {
return r.HTTPClient
}
return http.DefaultClient
}
// newRequest constructs a new request, ready to use, with the given method,
// url, and body, presigned with client Key and UserAgent.
func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, err
}
if r.UserAgent != "" {
req.Header.Set("User-Agent", r.UserAgent)
}
if r.Key != nil {
token, err := makeAuthToken(r.Key)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token)
}
return req, nil
}
// doOK makes a request with the given client and request, and returns the
// response if the status code is 200. If the status code is not 200, an Error
// is parsed from the response body and returned. If any other error occurs, it
// is returned.
func doOK(c *http.Client, r *http.Request) (*http.Response, error) {
if r.URL.Scheme == "https+insecure" {
// TODO(bmizerany): clone client.Transport, set
// InsecureSkipVerify, etc.
type cloner interface {
Clone() *http.Transport
}
// Attempt to configure the transport to skip TLS verification
// if we can clone it, otherwise fall through and let the http
// client complain and the scheme being invalid.
x, ok := cmp.Or(c.Transport, http.DefaultTransport).(cloner)
if ok {
tr := x.Clone()
tr.TLSClientConfig = cmp.Or(tr.TLSClientConfig, &tls.Config{})
tr.TLSClientConfig.InsecureSkipVerify = true
cc := *c // shallow copy
cc.Transport = tr
c = &cc
r = r.Clone(r.Context())
r.URL.Scheme = "https"
// fall through
}
}
res, err := c.Do(r)
if err != nil {
return nil, err
}
if res.StatusCode/100 != 2 {
out, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
var re Error
if err := json.Unmarshal(out, &re); err != nil {
// Use the raw body if we can't parse it as an error object.
re.Message = string(out)
}
re.Status = res.StatusCode
return nil, &re
}
return res, nil
}
// doOK is a convenience method for making a request with newRequest and
// passing it to doOK with r.client().
func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
req, err := r.newRequest(ctx, method, path, body)
if err != nil {
return nil, err
}
return doOK(r.client(), req)
}
// makeAuthToken creates an Ollama auth token for the given private key.
//
// NOTE: This format is OLD, overly complex, and should be replaced. We're
// inheriting it from the original Ollama client and ollama.com
// implementations, so we need to support it for now.
func makeAuthToken(key crypto.PrivateKey) (string, error) {
privKey, _ := key.(*ed25519.PrivateKey)
if privKey == nil {
return "", fmt.Errorf("unsupported private key type: %T", key)
}
url := fmt.Sprintf("https://ollama.com?ts=%d", time.Now().Unix())
// Part 1: the checkData (e.g. the URL with a timestamp)
// Part 2: the public key
pubKeyShort, err := func() ([]byte, error) {
sshPubKey, err := ssh.NewPublicKey(privKey.Public())
if err != nil {
return nil, err
}
pubKeyParts := bytes.Fields(ssh.MarshalAuthorizedKey(sshPubKey))
if len(pubKeyParts) < 2 {
return nil, fmt.Errorf("malformed public key: %q", pubKeyParts)
}
pubKeyShort := pubKeyParts[1]
return pubKeyShort, nil
}()
if err != nil {
return "", err
}
// Part 3: the signature
sig := ed25519.Sign(*privKey, []byte(checkData(url)))
// Assemble the token: <checkData>:<pubKey>:<signature>
var b strings.Builder
io.WriteString(&b, base64.StdEncoding.EncodeToString([]byte(url)))
b.WriteByte(':')
b.Write(pubKeyShort)
b.WriteByte(':')
io.WriteString(&b, base64.StdEncoding.EncodeToString(sig))
return b.String(), nil
}
// The original spec for Ollama tokens was to use the SHA256 of the zero
// string as part of the signature. I'm not sure why that was, but we still
// need it to verify the signature.
var zeroSum = func() string {
sha256sum := sha256.Sum256(nil)
x := base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))
return x
}()
// checkData takes a URL and creates the original string format of the
// data signature that is used by the ollama client to sign requests
func checkData(url string) string {
return fmt.Sprintf("GET,%s,%s", url, zeroSum)
}
func maybeUnexpectedEOF(err error) error {
if errors.Is(err, io.EOF) {
return io.ErrUnexpectedEOF
}
return err
}

View File

@@ -0,0 +1,656 @@
package ollama
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"math/rand/v2"
"net/http"
"net/http/httptest"
"os"
"path"
"reflect"
"slices"
"strings"
"testing"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/internal/testutil"
)
func TestManifestMarshalJSON(t *testing.T) {
// All manifests should contain an "empty" config object.
var m Manifest
data, err := json.Marshal(m)
if err != nil {
t.Fatal(err)
}
if !bytes.Contains(data, []byte(`"config":{"digest":"sha256:`)) {
t.Error("expected manifest to contain empty config")
t.Fatalf("got:\n%s", string(data))
}
}
func link(c *blob.DiskCache, name string, manifest string) {
_, n, _, err := parseName(name)
if err != nil {
panic(err)
}
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
if err != nil {
panic(err)
}
if err := c.Link(n.String(), d); err != nil {
panic(err)
}
}
var errRoundTrip = errors.New("forced roundtrip error")
type recordRoundTripper http.HandlerFunc
func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
w := httptest.NewRecorder()
rr(w, req)
if w.Code == 499 {
return nil, errRoundTrip
}
resp := w.Result()
// For some reason, Response.Request is not set by httptest.NewRecorder, so we
// set it manually.
resp.Request = req
return w.Result(), nil
}
// newClient constructs a cache with predefined manifests for testing. The manifests are:
//
// empty: no data
// zero: no layers
// single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data
//
// Tests that want to ensure the client does not communicate with the upstream
// registry should pass a nil handler, which will cause a panic if
// communication is attempted.
//
// To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper()
c, err := blob.Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
mklayer := func(data string) *Layer {
return &Layer{
Digest: importBytes(t, c, data),
Size: int64(len(data)),
}
}
commit := func(name string, layers ...*Layer) {
t.Helper()
data, err := json.Marshal(&Manifest{Layers: layers})
if err != nil {
t.Fatal(err)
}
link(c, name, string(data))
}
link(c, "empty", "")
commit("zero")
commit("single", mklayer("exists"))
commit("multiple", mklayer("exists"), mklayer("present"))
commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
commit("null", nil)
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
link(c, "invalid", "!!!!!")
rc := &Registry{
HTTPClient: &http.Client{
Transport: recordRoundTripper(h),
},
}
return rc, c
}
func okHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func checkErrCode(t *testing.T, err error, status int, code string) {
t.Helper()
var e *Error
if !errors.As(err, &e) || e.Status != status || e.Code != code {
t.Errorf("err = %v; want %v %v", err, status, code)
}
}
func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
d, err := c.Import(strings.NewReader(data), int64(len(data)))
if err != nil {
t.Fatal(err)
}
return d
}
func TestRegistryPushInvalidNames(t *testing.T) {
rc, c := newClient(t, nil)
cases := []struct {
name string
err error
}{
{"", ErrNameInvalid},
{"@", ErrNameInvalid},
{"@x", blob.ErrInvalidDigest},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
// Create a new registry and push a new image.
err := rc.Push(t.Context(), c, tt.name, nil)
if !errors.Is(err, tt.err) {
t.Errorf("err = %v; want %v", err, tt.err)
}
})
}
}
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
return WithTrace(ctx, t), t
}
func TestPushZero(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "empty", nil)
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
}
}
func TestPushSingle(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "single", nil)
testutil.Check(t, err)
}
func TestPushMultiple(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "multiple", nil)
testutil.Check(t, err)
}
func TestPushNotFound(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Errorf("unexpected request: %v", r)
})
err := rc.Push(t.Context(), c, "notfound", nil)
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
}
}
func TestPushNullLayer(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Push(t.Context(), c, "null", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
}
func TestPushSizeMismatch(t *testing.T) {
rc, c := newClient(t, nil)
ctx, _ := withTraceUnexpected(t.Context())
got := rc.Push(ctx, c, "sizemismatch", nil)
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
t.Errorf("err = %v; want size mismatch", got)
}
}
func TestPushInvalid(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Push(t.Context(), c, "invalid", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
}
func TestPushExistsAtRemote(t *testing.T) {
var pushed bool
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/uploads/") {
if !pushed {
// First push. Return an uploadURL.
pushed = true
w.Header().Set("Location", "http://blob.store/blobs/123")
return
}
w.WriteHeader(http.StatusAccepted)
return
}
io.Copy(io.Discard, r.Body)
w.WriteHeader(http.StatusOK)
})
rc.MaxStreams = 1 // prevent concurrent uploads
var errs []error
ctx := WithTrace(t.Context(), &Trace{
Update: func(_ *Layer, n int64, err error) {
// uploading one at a time so no need to lock
errs = append(errs, err)
},
})
check := testutil.Checker(t)
err := rc.Push(ctx, c, "single", nil)
check(err)
if !errors.Is(errors.Join(errs...), nil) {
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
}
err = rc.Push(ctx, c, "single", nil)
check(err)
}
func TestPushRemoteError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(500)
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
return
}
})
got := rc.Push(t.Context(), c, "single", nil)
checkErrCode(t, got, 500, "blob_error")
}
func TestPushLocationError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", ":///x")
w.WriteHeader(http.StatusAccepted)
})
got := rc.Push(t.Context(), c, "single", nil)
wantContains := "invalid upload URL"
if got == nil || !strings.Contains(got.Error(), wantContains) {
t.Errorf("err = %v; want to contain %v", got, wantContains)
}
}
func TestPushUploadRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Host == "blob.store" {
w.WriteHeader(499) // force RoundTrip error on upload
return
}
w.Header().Set("Location", "http://blob.store/blobs/123")
})
got := rc.Push(t.Context(), c, "single", nil)
if !errors.Is(got, errRoundTrip) {
t.Errorf("got = %v; want %v", got, errRoundTrip)
}
}
func TestPushUploadFileOpenError(t *testing.T) {
rc, c := newClient(t, okHandler)
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, _ int64, err error) {
// Remove the file just before it is opened for upload,
// but after the initial Stat that happens before the
// upload starts
os.Remove(c.GetFile(l.Digest))
},
})
got := rc.Push(ctx, c, "single", nil)
if !errors.Is(got, fs.ErrNotExist) {
t.Errorf("got = %v; want fs.ErrNotExist", got)
}
}
func TestPushCommitRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
panic("unexpected")
}
w.WriteHeader(499) // force RoundTrip error
})
err := rc.Push(t.Context(), c, "zero", nil)
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
}
func checkNotExist(t *testing.T, err error) {
t.Helper()
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v; want fs.ErrNotExist", err)
}
}
func TestRegistryPullInvalidName(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Pull(t.Context(), c, "://")
if !errors.Is(err, ErrNameInvalid) {
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
}
}
func TestRegistryPullInvalidManifest(t *testing.T) {
cases := []string{
"",
"null",
"!!!",
`{"layers":[]}`,
}
for _, resp := range cases {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, resp)
})
err := rc.Pull(t.Context(), c, "x")
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want invalid manifest", err)
}
}
}
func TestRegistryPullNotCached(t *testing.T) {
check := testutil.Checker(t)
var c *blob.DiskCache
var rc *Registry
d := blob.DigestFromBytes("some data")
rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
io.WriteString(w, "some data")
return
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d)
})
// Confirm that the layer does not exist locally
_, err := ResolveLocal(c, "model")
checkNotExist(t, err)
_, err = c.Get(d)
checkNotExist(t, err)
err = rc.Pull(t.Context(), c, "model")
check(err)
mw, err := rc.Resolve(t.Context(), "model")
check(err)
mg, err := ResolveLocal(c, "model")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
// Confirm successful download
info, err := c.Get(d)
check(err)
if info.Digest != d {
t.Errorf("info.Digest = %v; want %v", info.Digest, d)
}
if info.Size != 9 {
t.Errorf("info.Size = %v; want %v", info.Size, 9)
}
data, err := os.ReadFile(c.GetFile(d))
check(err)
if string(data) != "some data" {
t.Errorf("data = %q; want %q", data, "exists")
}
}
func TestRegistryPullCached(t *testing.T) {
cached := blob.DigestFromBytes("exists")
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(499) // should not be called
return
}
if strings.Contains(r.URL.Path, "/manifests/") {
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached)
}
})
var errs []error
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(d *Layer, n int64, err error) {
t.Logf("update %v %d %v", d, n, err)
reads = append(reads, n)
errs = append(errs, err)
},
})
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
err := rc.Pull(ctx, c, "single")
testutil.Check(t, err)
want := []int64{6}
if !errors.Is(errors.Join(errs...), ErrCached) {
t.Errorf("errs = %v; want %v", errs, ErrCached)
}
if !slices.Equal(reads, want) {
t.Errorf("pairs = %v; want %v", reads, want)
}
}
func TestRegistryPullManifestNotFound(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
err := rc.Pull(t.Context(), c, "notfound")
checkErrCode(t, err, 404, "")
}
func TestRegistryPullResolveRemoteError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
})
err := rc.Pull(t.Context(), c, "single")
checkErrCode(t, err, 500, "an_error")
}
func TestRegistryPullResolveRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/manifests/") {
w.WriteHeader(499) // force RoundTrip error
return
}
})
err := rc.Pull(t.Context(), c, "single")
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
}
// TestRegistryPullMixedCachedNotCached tests that cached layers do not
// interfere with pulling layers that are not cached
func TestRegistryPullMixedCachedNotCached(t *testing.T) {
x := blob.DigestFromBytes("xxxxxx")
e := blob.DigestFromBytes("exists")
y := blob.DigestFromBytes("yyyyyy")
for i := range 10 {
t.Logf("iteration %d", i)
digests := []blob.Digest{x, e, y}
rand.Shuffle(len(digests), func(i, j int) {
digests[i], digests[j] = digests[j], digests[i]
})
manifest := fmt.Sprintf(`{
"layers": [
{"digest":"%s","size":6},
{"digest":"%s","size":6},
{"digest":"%s","size":6}
]
}`, digests[0], digests[1], digests[2])
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch path.Base(r.URL.Path) {
case "latest":
io.WriteString(w, manifest)
case x.String():
io.WriteString(w, "xxxxxx")
case e.String():
io.WriteString(w, "exists")
case y.String():
io.WriteString(w, "yyyyyy")
default:
panic(fmt.Sprintf("unexpected request: %v", r))
}
})
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("update %v %d %v", l, n, err)
},
})
// Check that we pull all layers that we can.
err := rc.Pull(ctx, c, "mixed")
if err != nil {
t.Fatal(err)
}
for _, d := range digests {
info, err := c.Get(d)
if err != nil {
t.Fatalf("Get(%v): %v", d, err)
}
if info.Size != 6 {
t.Errorf("info.Size = %v; want %v", info.Size, 6)
}
}
}
}
func TestRegistryPullChunking(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
if r.URL.Host != "blob.store" {
// The production registry redirects to the blob store.
http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound)
return
}
if strings.Contains(r.URL.Path, "/blobs/") {
rng := r.Header.Get("Range")
if rng == "" {
http.Error(w, "missing range", http.StatusBadRequest)
return
}
_, c, err := chunks.ParseRange(r.Header.Get("Range"))
if err != nil {
panic(err)
}
io.WriteString(w, "remote"[c.Start:c.End+1])
return
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote"))
})
// Force chunking by setting the threshold to less than the size of the
// layer.
rc.ChunkingThreshold = 3
rc.MaxChunkSize = 3
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(d *Layer, n int64, err error) {
if err != nil {
t.Errorf("update %v %d %v", d, n, err)
}
reads = append(reads, n)
},
})
err := rc.Pull(ctx, c, "remote")
testutil.Check(t, err)
want := []int64{0, 3, 6}
if !slices.Equal(reads, want) {
t.Errorf("reads = %v; want %v", reads, want)
}
}
func TestRegistryResolveByDigest(t *testing.T) {
check := testutil.Checker(t)
exists := blob.DigestFromBytes("exists")
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v2/alice/palace/blobs/"+exists.String() {
w.WriteHeader(499) // should not hit manifest endpoint
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
})
_, err := rc.Resolve(t.Context(), "alice/palace@"+exists.String())
check(err)
}
func TestInsecureSkipVerify(t *testing.T) {
exists := blob.DigestFromBytes("exists")
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
}))
defer s.Close()
const name = "ollama.com/library/insecure"
var rc Registry
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
_, err := rc.Resolve(t.Context(), url)
if err == nil || !strings.Contains(err.Error(), "failed to verify") {
t.Errorf("err = %v; want cert verifiction failure", err)
}
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
_, err = rc.Resolve(t.Context(), url)
testutil.Check(t, err)
}
func TestCanRetry(t *testing.T) {
cases := []struct {
err error
want bool
}{
{nil, false},
{errors.New("x"), false},
{ErrCached, false},
{ErrManifestInvalid, false},
{ErrNameInvalid, false},
{&Error{Status: 100}, false},
{&Error{Status: 500}, true},
}
for _, tt := range cases {
if got := canRetry(tt.err); got != tt.want {
t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want)
}
}
}

View File

@@ -0,0 +1,48 @@
package ollama
import (
"context"
)
// Trace is a set of functions that are called to report progress during blob
// downloads and uploads.
type Trace struct {
// Update is called during [Registry.Push] and [Registry.Pull] to
// report the progress of blob uploads and downloads.
//
// It is called once at the beginning of the download with a zero n and
// then once per read operation with the number of bytes read so far,
// and an error if any.
//
// A function assigned must be safe for concurrent use. The function is
// called synchronously and so should not block or take long to run.
Update func(_ *Layer, n int64, _ error)
}
func (t *Trace) update(l *Layer, n int64, err error) {
if t.Update != nil {
t.Update(l, n, err)
}
}
type traceKey struct{}
// WithTrace returns a context derived from ctx that uses t to report trace
// events.
func WithTrace(ctx context.Context, t *Trace) context.Context {
return context.WithValue(ctx, traceKey{}, t)
}
var emptyTrace = &Trace{}
// traceFromContext returns the Trace associated with ctx, or an empty Trace if
// none is found.
//
// It never returns nil.
func traceFromContext(ctx context.Context) *Trace {
t, _ := ctx.Value(traceKey{}).(*Trace)
if t == nil {
return emptyTrace
}
return t
}

View File

@@ -0,0 +1,220 @@
// safetensors provides a reader for the safetensor directories and files.
package safetensors
import (
"encoding/json"
"fmt"
"io"
"io/fs"
"iter"
"slices"
"strconv"
"strings"
)
// Tensor represents a single tensor in a safetensors file.
//
// It's zero value is not valid. Use [Model.Tensors] to get valid tensors.
//
// It is not safe for use across multiple goroutines.
type Tensor struct {
name string
dataType string
shape []int64
fsys fs.FS
fname string // entry name in fsys
offset int64
size int64
}
type Model struct {
fsys fs.FS
}
func Read(fsys fs.FS) (*Model, error) {
return &Model{fsys: fsys}, nil
}
func (m *Model) Tensors() iter.Seq2[*Tensor, error] {
return func(yield func(*Tensor, error) bool) {
entries, err := fs.Glob(m.fsys, "*.safetensors")
if err != nil {
yield(nil, err)
return
}
for _, e := range entries {
tt, err := m.readTensors(e)
if err != nil {
yield(nil, err)
return
}
for _, t := range tt {
if !yield(t, nil) {
return
}
}
}
}
}
func (m *Model) readTensors(fname string) ([]*Tensor, error) {
f, err := m.fsys.Open(fname)
if err != nil {
return nil, err
}
defer f.Close()
finfo, err := f.Stat()
if err != nil {
return nil, err
}
headerSize, err := readInt64(f)
if err != nil {
return nil, err
}
data := make([]byte, headerSize)
_, err = io.ReadFull(f, data)
if err != nil {
return nil, err
}
var raws map[string]json.RawMessage
if err := json.Unmarshal(data, &raws); err != nil {
return nil, err
}
// TODO(bmizerany): do something with metadata? This could be another
// header read if needed. We also need to figure out if the metadata is
// present in only one .safetensors file or if each file may have their
// own and if it needs to follow each tensor. Currently, I (bmizerany)
// am only seeing them show up with one entry for file type which is
// always "pt".
tt := make([]*Tensor, 0, len(raws))
for name, raw := range raws {
if !strings.HasPrefix(name, "model.layer") {
continue
}
var v struct {
DataType string `json:"dtype"`
Shape []int64 `json:"shape"`
Offsets []int64 `json:"data_offsets"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return nil, fmt.Errorf("error unmarshalling layer %q: %w", name, err)
}
if len(v.Offsets) != 2 {
return nil, fmt.Errorf("invalid offsets for %q: %v", name, v.Offsets)
}
// TODO(bmizerany): after collecting, validate all offests make
// tensors contiguous?
begin, end := v.Offsets[0], v.Offsets[1]
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
return nil, err
}
// TODO(bmizerany): just yield.. don't be silly and make a slice :)
tt = append(tt, &Tensor{
name: name,
dataType: v.DataType,
shape: v.Shape,
fsys: m.fsys,
fname: fname,
offset: begin,
size: end - begin,
})
}
return tt, nil
}
func checkBeginEnd(size, begin, end int64) error {
if begin < 0 {
return fmt.Errorf("begin must not be negative: %d", begin)
}
if end < 0 {
return fmt.Errorf("end must not be negative: %d", end)
}
if end < begin {
return fmt.Errorf("end must be >= begin: %d < %d", end, begin)
}
if end > size {
return fmt.Errorf("end must be <= size: %d > %d", end, size)
}
return nil
}
func readInt64(r io.Reader) (int64, error) {
var v uint64
var buf [8]byte
if _, err := io.ReadFull(r, buf[:]); err != nil {
return 0, err
}
for i := range buf {
v |= uint64(buf[i]) << (8 * i)
}
return int64(v), nil
}
type Shape []int64
func (s Shape) String() string {
var b strings.Builder
b.WriteByte('[')
for i, v := range s {
if i > 0 {
b.WriteByte(',')
}
b.WriteString(strconv.FormatInt(v, 10))
}
b.WriteByte(']')
return b.String()
}
func (t *Tensor) Name() string { return t.name }
func (t *Tensor) DataType() string { return t.dataType }
func (t *Tensor) Size() int64 { return t.size }
func (t *Tensor) Shape() Shape { return slices.Clone(t.shape) }
func (t *Tensor) Reader() (io.ReadCloser, error) {
f, err := t.fsys.Open(t.fname)
if err != nil {
return nil, err
}
r := newSectionReader(f, t.offset, t.size)
rc := struct {
io.Reader
io.Closer
}{r, f}
return rc, nil
}
// newSectionReader returns a new io.Reader that reads from r starting at
// offset. It is a convenience function for creating a io.SectionReader when r
// may not be an io.ReaderAt.
//
// If r is already a ReaderAt, it is returned directly, otherwise if r is an
// io.Seeker, a new io.ReaderAt is returned that wraps r after seeking to the
// beginning of the file.
//
// If r is an io.Seeker,
// or slow path. The slow path is used when r does not implement io.ReaderAt,
// in which case it must discard the data it reads.
func newSectionReader(r io.Reader, offset, n int64) io.Reader {
if r, ok := r.(io.ReaderAt); ok {
return io.NewSectionReader(r, offset, n)
}
if r, ok := r.(io.ReadSeeker); ok {
r.Seek(offset, io.SeekStart)
return io.LimitReader(r, n)
}
// Discard to offset and return a limited reader.
_, err := io.CopyN(io.Discard, r, offset)
if err != nil {
return nil
}
return io.LimitReader(r, n)
}

View File

@@ -0,0 +1,366 @@
package main
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
"mime"
"net/http"
"os"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors"
"golang.org/x/sync/errgroup"
)
var stdout io.Writer = os.Stdout
const usage = `Opp is a tool for pushing and pulling Ollama models.
Usage:
opp [flags] <push|pull|import>
Commands:
push Upload a model to the Ollama server.
pull Download a model from the Ollama server.
import Import a model from a local safetensor directory.
Examples:
# Pull a model from the Ollama server.
opp pull library/llama3.2:latest
# Push a model to the Ollama server.
opp push username/my_model:8b
# Import a model from a local safetensor directory.
opp import /path/to/safetensor
Envionment Variables:
OLLAMA_MODELS
The directory where models are pushed and pulled from
(default ~/.ollama/models).
`
func main() {
flag.Usage = func() {
fmt.Fprint(os.Stderr, usage)
}
flag.Parse()
c, err := ollama.DefaultCache()
if err != nil {
log.Fatal(err)
}
rc, err := ollama.RegistryFromEnv()
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
err = func() error {
switch cmd := flag.Arg(0); cmd {
case "pull":
return cmdPull(ctx, rc, c)
case "push":
return cmdPush(ctx, rc, c)
case "import":
return cmdImport(ctx, c)
default:
if cmd == "" {
flag.Usage()
} else {
fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd)
}
os.Exit(2)
return errors.New("unreachable")
}
}()
if err != nil {
fmt.Fprintf(os.Stderr, "opp: %v\n", err)
os.Exit(1)
}
}
func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
model := flag.Arg(1)
if model == "" {
flag.Usage()
os.Exit(1)
}
tr := http.DefaultTransport.(*http.Transport).Clone()
// TODO(bmizerany): configure transport?
rc.HTTPClient = &http.Client{Transport: tr}
var mu sync.Mutex
p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded]
var pb bytes.Buffer
printProgress := func() {
pb.Reset()
mu.Lock()
for d, s := range p {
// Write progress to a buffer first to avoid blocking
// on stdout while holding the lock.
stamp := time.Now().Format("2006/01/02 15:04:05")
fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0]))
if s[0] == s[1] {
delete(p, d)
}
}
mu.Unlock()
io.Copy(stdout, &pb)
}
ctx = ollama.WithTrace(ctx, &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
if err != nil && !errors.Is(err, ollama.ErrCached) {
fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err)
return
}
mu.Lock()
p[l.Digest] = [2]int64{l.Size, n}
mu.Unlock()
},
})
errc := make(chan error)
go func() {
errc <- rc.Pull(ctx, c, model)
}()
t := time.NewTicker(time.Second)
defer t.Stop()
for {
select {
case <-t.C:
printProgress()
case err := <-errc:
printProgress()
return err
}
}
}
func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
args := flag.Args()[1:]
flag := flag.NewFlagSet("push", flag.ExitOnError)
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: opp push <model>\n")
flag.PrintDefaults()
}
flag.Parse(args)
model := flag.Arg(0)
if model == "" {
return fmt.Errorf("missing model argument")
}
from := cmp.Or(*flagFrom, model)
m, err := ollama.ResolveLocal(c, from)
if err != nil {
return err
}
ctx = ollama.WithTrace(ctx, &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
switch {
case errors.Is(err, ollama.ErrCached):
fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n)
case err != nil:
fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err)
case n == 0:
l := m.Layer(l.Digest)
mt, p, _ := mime.ParseMediaType(l.MediaType)
mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.")
switch mt {
case "tensor":
fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"])
default:
fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType)
}
}
},
})
return rc.Push(ctx, c, model, &ollama.PushParams{
From: from,
})
}
type trackingReader struct {
io.Reader
n *atomic.Int64
}
func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.Reader.Read(p)
r.n.Add(int64(n))
return n, err
}
func cmdImport(ctx context.Context, c *blob.DiskCache) error {
args := flag.Args()[1:]
flag := flag.NewFlagSet("import", flag.ExitOnError)
flagAs := flag.String("as", "", "Import using the provided name.")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: opp import <SafetensorDir>\n")
flag.PrintDefaults()
}
flag.Parse(args)
dir := cmp.Or(flag.Arg(0), ".")
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
m, err := safetensors.Read(os.DirFS(dir))
if err != nil {
return err
}
var total int64
var tt []*safetensors.Tensor
for t, err := range m.Tensors() {
if err != nil {
return err
}
tt = append(tt, t)
total += t.Size()
}
var n atomic.Int64
done := make(chan error)
go func() {
layers := make([]*ollama.Layer, len(tt))
var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
var ctxErr error
for i, t := range tt {
if ctx.Err() != nil {
// The context may cancel AFTER we exit the
// loop, and so if we use ctx.Err() after the
// loop we may report it as the error that
// broke the loop, when it was not. This can
// manifest as a false-negative, leading the
// user to think their import failed when it
// did not, so capture it if and only if we
// exit the loop because of a ctx.Err() and
// report it.
ctxErr = ctx.Err()
break
}
g.Go(func() (err error) {
rc, err := t.Reader()
if err != nil {
return err
}
defer rc.Close()
tr := &trackingReader{rc, &n}
d, err := c.Import(tr, t.Size())
if err != nil {
return err
}
if err := rc.Close(); err != nil {
return err
}
layers[i] = &ollama.Layer{
Digest: d,
Size: t.Size(),
MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{
"name": t.Name(),
"dtype": t.DataType(),
"shape": t.Shape().String(),
}),
}
return nil
})
}
done <- func() error {
if err := errors.Join(g.Wait(), ctxErr); err != nil {
return err
}
m := &ollama.Manifest{Layers: layers}
data, err := json.MarshalIndent(m, "", " ")
if err != nil {
return err
}
d := blob.DigestFromBytes(data)
err = blob.PutBytes(c, d, data)
if err != nil {
return err
}
return c.Link(*flagAs, d)
}()
}()
fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir)
csiHideCursor(stdout)
defer csiShowCursor(stdout)
csiSavePos(stdout)
writeProgress := func() {
csiRestorePos(stdout)
nn := n.Load()
fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n",
formatNatural(nn),
formatNatural(total),
nn*100/total,
ansiClearToEndOfLine,
)
}
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
writeProgress()
case err := <-done:
writeProgress()
return err
}
}
}
func formatNatural(n int64) string {
switch {
case n < 1024:
return fmt.Sprintf("%d B", n)
case n < 1024*1024:
return fmt.Sprintf("%.1f KB", float64(n)/1024)
case n < 1024*1024*1024:
return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024))
default:
return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024))
}
}
const ansiClearToEndOfLine = "\033[K"
func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") }
func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") }
func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") }
func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") }

View File

@@ -0,0 +1,11 @@
package main
import (
"fmt"
"os"
)
func main() {
fmt.Println("Run as 'go test -bench=.' to run the benchmarks")
os.Exit(1)
}

View File

@@ -0,0 +1,107 @@
package main
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"sync/atomic"
"testing"
"time"
"github.com/ollama/ollama/server/internal/chunks"
"golang.org/x/sync/errgroup"
)
func BenchmarkDownload(b *testing.B) {
run := func(fileSize, chunkSize int64) {
name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize)
b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) })
}
run(100<<20, 8<<20)
run(100<<20, 16<<20)
run(100<<20, 32<<20)
run(100<<20, 64<<20)
run(100<<20, 128<<20) // 1 chunk
}
func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error {
const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d"
req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := c.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
_, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read
return err
}
var sleepTime atomic.Int64
func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) {
client := &http.Client{
Transport: func() http.RoundTripper {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DisableKeepAlives = true
return tr
}(),
}
defer client.CloseIdleConnections()
// warm up the client
run(context.Background(), client, chunks.New(0, 1<<20))
b.SetBytes(fileSize)
b.ReportAllocs()
// Give our CDN a min to breathe between benchmarks.
time.Sleep(time.Duration(sleepTime.Swap(3)))
for b.Loop() {
g, ctx := errgroup.WithContext(b.Context())
g.SetLimit(runtime.GOMAXPROCS(0))
for chunk := range chunks.Of(fileSize, chunkSize) {
g.Go(func() error { return run(ctx, client, chunk) })
}
if err := g.Wait(); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkWrite(b *testing.B) {
b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) })
}
func benchmarkWrite(b *testing.B, chunkSize int) {
b.ReportAllocs()
dir := b.TempDir()
f, err := os.Create(filepath.Join(dir, "write-single"))
if err != nil {
b.Fatal(err)
}
defer f.Close()
data := make([]byte, chunkSize)
b.SetBytes(int64(chunkSize))
r := bytes.NewReader(data)
for b.Loop() {
r.Reset(data)
_, err := io.Copy(f, r)
if err != nil {
b.Fatal(err)
}
}
}

View File

@@ -0,0 +1,48 @@
package backoff
import (
"context"
"iter"
"math/rand/v2"
"time"
)
func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
var n int
return func(yield func(int, error) bool) {
var t *time.Timer
for {
if ctx.Err() != nil {
yield(n, ctx.Err())
return
}
if !yield(n, nil) {
return
}
n++
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := time.Duration(n*n) * 10 * time.Millisecond
if d > maxBackoff {
d = maxBackoff
}
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
if t == nil {
t = time.NewTimer(d)
} else {
t.Reset(d)
}
select {
case <-ctx.Done():
t.Stop()
case <-t.C:
}
}
}
}

View File

@@ -0,0 +1,40 @@
//go:build goexperiment.synctest
package backoff
import (
"context"
"errors"
"testing"
"testing/synctest"
"time"
)
func TestLoop(t *testing.T) {
synctest.Run(func() {
last := -1
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
for n, err := range Loop(ctx, 100*time.Millisecond) {
if !errors.Is(err, ctx.Err()) {
t.Errorf("err = %v, want nil", err)
}
if err != nil {
break
}
if n != last+1 {
t.Errorf("n = %d, want %d", n, last+1)
}
last = n
if n > 5 {
cancel()
}
}
if last != 6 {
t.Errorf("last = %d, want 6", last)
}
})
}

View File

@@ -0,0 +1,38 @@
package backoff
import (
"context"
"testing"
"testing/synctest"
"time"
)
func TestLoopAllocs(t *testing.T) {
for i := range 3 {
got := testing.AllocsPerRun(1000, func() {
for tick := range Loop(t.Context(), 1) {
if tick >= i {
break
}
}
})
want := float64(0)
if i > 0 {
want = 3 // due to time.NewTimer
}
if got > want {
t.Errorf("[%d ticks]: allocs = %v, want 0", i, want)
}
}
}
func BenchmarkLoop(b *testing.B) {
ctx := context.Background()
synctest.Run(func() {
for n := range Loop(ctx, 100*time.Millisecond) {
if n == b.N {
break
}
}
})
}

View File

@@ -0,0 +1,229 @@
package names
import (
"cmp"
"fmt"
"strings"
"github.com/ollama/ollama/server/internal/internal/stringsx"
)
const MaxNameLength = 50 + 1 + 50 + 1 + 50 // <namespace>/<model>:<tag>
type Name struct {
// Make incomparable to enfoce use of Compare / Equal for
// case-insensitive comparisons.
_ [0]func()
h string
n string
m string
t string
}
// Parse parses and assembles a Name from a name string. The
// format of a valid name string is:
//
// s:
// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
// { host } "/" { namespace } "/" { model } ":" { tag }
// { host } "/" { namespace } "/" { model } "@" { digest }
// { host } "/" { namespace } "/" { model }
// { namespace } "/" { model } ":" { tag } "@" { digest }
// { namespace } "/" { model } ":" { tag }
// { namespace } "/" { model } "@" { digest }
// { namespace } "/" { model }
// { model } ":" { tag } "@" { digest }
// { model } ":" { tag }
// { model } "@" { digest }
// { model }
// "@" { digest }
// host:
// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }*
// length: [1, 350]
// namespace:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
// length: [1, 80]
// model:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// tag:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// digest:
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
// length: [1, 80]
//
// The name returned is not guaranteed to be valid. If it is not valid, the
// field values are left in an undefined state. Use [Name.IsValid] to check
// if the name is valid.
func Parse(s string) Name {
if len(s) > MaxNameLength {
return Name{}
}
var n Name
var tail string
var c byte
for {
s, tail, c = cutLastAny(s, "/:")
switch c {
case ':':
n.t = tail
continue // look for model
case '/':
n.h, n.n, _ = cutLastAny(s, "/")
n.m = tail
return n
case 0:
n.m = tail
return n
}
}
}
// ParseExtended parses and returns any scheme, Name, and digest from from s in
// the the form [scheme://][name][@digest]. All parts are optional.
//
// If the scheme is present, it must be followed by "://". The digest is
// prefixed by "@" and comes after the name. The name is parsed using [Parse].
//
// The scheme and digest are stripped before the name is parsed by [Parse].
//
// For convience, the scheme is never empty. If the scheme is not present, the
// returned scheme is "https".
//
// Examples:
//
// http://ollama.com/bmizerany/smol:latest@digest
// https://ollama.com/bmizerany/smol:latest
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
func ParseExtended(s string) (scheme string, _ Name, digest string) {
i := strings.Index(s, "://")
if i >= 0 {
scheme = s[:i]
s = s[i+3:]
}
i = strings.LastIndex(s, "@")
if i >= 0 {
digest = s[i+1:]
s = s[:i]
}
return scheme, Parse(s), digest
}
func FormatExtended(scheme string, n Name, digest string) string {
var b strings.Builder
if scheme != "" {
b.WriteString(scheme)
b.WriteString("://")
}
b.WriteString(n.String())
if digest != "" {
b.WriteByte('@')
b.WriteString(digest)
}
return b.String()
}
// Merge merges two names into a single name. Non-empty host, namespace, and
// tag parts of a take precedence over fields in b. The model field is left as
// is.
//
// The returned name is not guaranteed to be valid. Use [Name.IsValid] to check
// if the name is valid.
func Merge(a, b Name) Name {
a.h = cmp.Or(a.h, b.h)
a.n = cmp.Or(a.n, b.n)
a.t = cmp.Or(a.t, b.t)
return a
}
// IsValid returns true if the name is valid.
func (n Name) IsValid() bool {
if n.h != "" && !isValidHost(n.h) {
return false
}
if n.n != "" && !isValidNamespace(n.n) {
return false
}
if n.m != "" && !isValidModel(n.m) {
return false
}
if n.t != "" && !isValidTag(n.t) {
return false
}
return true
}
func (n Name) IsFullyQualified() bool {
return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != ""
}
func isValidHost(_ string) bool {
return true // TODO: implement
}
func isValidNamespace(_ string) bool {
return true // TODO: implement
}
func isValidModel(_ string) bool {
return true // TODO: implement
}
func isValidTag(_ string) bool {
return true // TODO: implement
}
func (n Name) Host() string { return n.h }
func (n Name) Namespace() string { return n.n }
func (n Name) Model() string { return n.m }
func (n Name) Tag() string { return n.t }
// Compare compares n and o case-insensitively. It returns 0 if n and o are
// equal, -1 if n sorts before o, and 1 if n sorts after o.
func (n Name) Compare(o Name) int {
return cmp.Or(
stringsx.CompareFold(n.h, o.h),
stringsx.CompareFold(n.n, o.n),
stringsx.CompareFold(n.m, o.m),
stringsx.CompareFold(n.t, o.t),
)
}
// String returns the fully qualified name in the format
// <namespace>/<model>:<tag>.
func (n Name) String() string {
var b strings.Builder
if n.h != "" {
b.WriteString(n.h)
b.WriteByte('/')
}
if n.n != "" {
b.WriteString(n.n)
b.WriteByte('/')
}
b.WriteString(n.m)
if n.t != "" {
b.WriteByte(':')
b.WriteString(n.t)
}
return b.String()
}
func (n Name) GoString() string {
return fmt.Sprintf("<Name %q %q %q %q>", n.h, n.n, n.m, n.t)
}
// cutLastAny is like strings.Cut but scans in reverse for the last character
// in chars. If no character is found, before is the empty string and after is
// s. The returned sep is the byte value of the character in chars if one was
// found; otherwise it is 0.
func cutLastAny(s, chars string) (before, after string, sep byte) {
i := strings.LastIndexAny(s, chars)
if i >= 0 {
return s[:i], s[i+1:], s[i]
}
return "", s, 0
}

View File

@@ -0,0 +1,152 @@
package names
import (
"strings"
"testing"
)
func TestParseName(t *testing.T) {
cases := []struct {
in string
want Name
}{
{"", Name{}},
{"m:t", Name{m: "m", t: "t"}},
{"m", Name{m: "m"}},
{"/m", Name{m: "m"}},
{"/n/m:t", Name{n: "n", m: "m", t: "t"}},
{"n/m", Name{n: "n", m: "m"}},
{"n/m:t", Name{n: "n", m: "m", t: "t"}},
{"n/m", Name{n: "n", m: "m"}},
{"n/m", Name{n: "n", m: "m"}},
{strings.Repeat("m", MaxNameLength+1), Name{}},
{"h/n/m:t", Name{h: "h", n: "n", m: "m", t: "t"}},
{"ollama.com/library/_:latest", Name{h: "ollama.com", n: "library", m: "_", t: "latest"}},
// Invalids
// TODO: {"n:t/m:t", Name{}},
// TODO: {"/h/n/m:t", Name{}},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
got := Parse(tt.in)
if got.Compare(tt.want) != 0 {
t.Errorf("parseName(%q) = %#v, want %q", tt.in, got, tt.want)
}
})
}
}
func TestString(t *testing.T) {
cases := []string{
"",
"m:t",
"m:t",
"m",
"n/m",
"n/m:t",
"n/m",
"n/m",
"h/n/m:t",
"ollama.com/library/_:latest",
// Special cased to "round trip" without the leading slash.
"/m",
"/n/m:t",
}
for _, s := range cases {
t.Run(s, func(t *testing.T) {
s = strings.TrimPrefix(s, "/")
if g := Parse(s).String(); g != s {
t.Errorf("parse(%q).String() = %q", s, g)
}
})
}
}
func TestParseExtended(t *testing.T) {
cases := []struct {
in string
wantScheme string
wantName Name
wantDigest string
}{
{"", "", Name{}, ""},
{"m", "", Name{m: "m"}, ""},
{"http://m", "http", Name{m: "m"}, ""},
{"http+insecure://m", "http+insecure", Name{m: "m"}, ""},
{"http://m@sha256:deadbeef", "http", Name{m: "m"}, "sha256:deadbeef"},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
scheme, name, digest := ParseExtended(tt.in)
if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest)
}
// Round trip
if got := FormatExtended(scheme, name, digest); got != tt.in {
t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got)
}
})
}
}
func TestMerge(t *testing.T) {
cases := []struct {
a, b string
want string
}{
{"", "", ""},
{"m", "", "m"},
{"", "m", ""},
{"x", "y", "x"},
{"o.com/n/m:t", "o.com/n/m:t", "o.com/n/m:t"},
{"o.com/n/m:t", "o.com/n/_:t", "o.com/n/m:t"},
{"bmizerany/smol", "ollama.com/library/_:latest", "ollama.com/bmizerany/smol:latest"},
{"localhost:8080/bmizerany/smol", "ollama.com/library/_:latest", "localhost:8080/bmizerany/smol:latest"},
}
for _, tt := range cases {
t.Run("", func(t *testing.T) {
a, b := Parse(tt.a), Parse(tt.b)
got := Merge(a, b)
if got.Compare(Parse(tt.want)) != 0 {
t.Errorf("merge(%q, %q) = %#v, want %q", tt.a, tt.b, got, tt.want)
}
})
}
}
func TestParseStringRoundTrip(t *testing.T) {
cases := []string{
"",
"m",
"m:t",
"n/m",
"n/m:t",
"n/m:t",
"n/m",
"n/m",
"h/n/m:t",
"ollama.com/library/_:latest",
}
for _, s := range cases {
t.Run(s, func(t *testing.T) {
if got := Parse(s).String(); got != s {
t.Errorf("parse(%q).String() = %q", s, got)
}
})
}
}
var junkName Name
func BenchmarkParseName(b *testing.B) {
b.ReportAllocs()
for range b.N {
junkName = Parse("h/n/m:t")
}
}

View File

@@ -0,0 +1,52 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package stringsx provides additional string manipulation functions
// that aren't in the standard library's strings package or go4.org/mem.
package stringsx
import (
"unicode"
"unicode/utf8"
)
// CompareFold returns -1, 0, or 1 depending on whether a < b, a == b, or a > b,
// like cmp.Compare, but case insensitively.
func CompareFold(a, b string) int {
// Track our position in both strings
ia, ib := 0, 0
for ia < len(a) && ib < len(b) {
ra, wa := nextRuneLower(a[ia:])
rb, wb := nextRuneLower(b[ib:])
if ra < rb {
return -1
}
if ra > rb {
return 1
}
ia += wa
ib += wb
if wa == 0 || wb == 0 {
break
}
}
// If we've reached here, one or both strings are exhausted
// The shorter string is "less than" if they match up to this point
switch {
case ia == len(a) && ib == len(b):
return 0
case ia == len(a):
return -1
default:
return 1
}
}
// nextRuneLower returns the next rune in the string, lowercased, along with its
// original (consumed) width in bytes. If the string is empty, it returns
// (utf8.RuneError, 0)
func nextRuneLower(s string) (r rune, width int) {
r, width = utf8.DecodeRuneInString(s)
return unicode.ToLower(r), width
}

View File

@@ -0,0 +1,78 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package stringsx
import (
"cmp"
"strings"
"testing"
)
func TestCompareFold(t *testing.T) {
tests := []struct {
a, b string
}{
// Basic ASCII cases
{"", ""},
{"a", "a"},
{"a", "A"},
{"A", "a"},
{"a", "b"},
{"b", "a"},
{"abc", "ABC"},
{"ABC", "abc"},
{"abc", "abd"},
{"abd", "abc"},
// Length differences
{"abc", "ab"},
{"ab", "abc"},
// Unicode cases
{"世界", "世界"},
{"Hello世界", "hello世界"},
{"世界Hello", "世界hello"},
{"世界", "世界x"},
{"世界x", "世界"},
// Special case folding examples
{"ß", "ss"}, // German sharp s
{"fi", "fi"}, // fi ligature
{"Σ", "σ"}, // Greek sigma
{"İ", "i\u0307"}, // Turkish dotted I
// Mixed cases
{"HelloWorld", "helloworld"},
{"HELLOWORLD", "helloworld"},
{"helloworld", "HELLOWORLD"},
{"HelloWorld", "helloworld"},
{"helloworld", "HelloWorld"},
// Edge cases
{" ", " "},
{"1", "1"},
{"123", "123"},
{"!@#", "!@#"},
}
wants := []int{}
for _, tt := range tests {
got := CompareFold(tt.a, tt.b)
want := cmp.Compare(strings.ToLower(tt.a), strings.ToLower(tt.b))
if got != want {
t.Errorf("CompareFold(%q, %q) = %v, want %v", tt.a, tt.b, got, want)
}
wants = append(wants, want)
}
if n := testing.AllocsPerRun(1000, func() {
for i, tt := range tests {
if CompareFold(tt.a, tt.b) != wants[i] {
panic("unexpected")
}
}
}); n > 0 {
t.Errorf("allocs = %v; want 0", int(n))
}
}

View File

@@ -0,0 +1,201 @@
// Package syncs provides synchronization primitives.
package syncs
import (
"cmp"
"io"
"sync"
)
var closedChan = func() chan struct{} {
ch := make(chan struct{})
close(ch)
return ch
}()
// Ticket represents a ticket in a sequence of tickets. The zero value is
// invalid. Use [Line.Take] to get a valid ticket.
//
// A Ticket is not safe for concurrent use.
type Ticket struct {
ahead chan struct{} // ticket ahead of this one
ch chan struct{}
}
// Ready returns a channel that is closed when the ticket before this one is
// done.
//
// It is incorrect to wait on Ready after the ticket is done.
func (t *Ticket) Ready() chan struct{} {
return cmp.Or(t.ahead, closedChan)
}
// Done signals that this ticket is done and that the next ticket in line can
// proceed.
//
// The first call to [Done] unblocks the ticket after it, if any. Subsequent
// calls are no-ops.
func (t *Ticket) Done() {
if t.ch != nil {
close(t.ch)
}
t.ch = nil
}
// Line is an ordered sequence of tickets waiting for their turn to proceed.
//
// To get a ticket use [Line.Take].
// To signal that a ticket is done use [Ticket.Done].
// To wait your turn use [Ticket.Ready].
//
// A Line is not safe for concurrent use.
type Line struct {
last chan struct{} // last ticket in line
}
func (q *Line) Take() *Ticket {
t := &Ticket{
ahead: q.last,
ch: make(chan struct{}),
}
q.last = t.ch
return t
}
// RelayReader implements an [io.WriterTo] that yields the passed
// writer to its [WriteTo] method each [io.WriteCloser] taken from [Take], in
// the order they are taken. Each [io.WriteCloser] blocks until the previous
// one is closed, or a call to [RelayReader.CloseWithError] is made.
//
// The zero value is invalid. Use [NewWriteToLine] to get a valid RelayReader.
//
// It is not safe for concurrent use.
type RelayReader struct {
line Line
t *Ticket
w io.Writer
n int64
mu sync.Mutex
err error // set by CloseWithError
closedCh chan struct{} // closed if err is set
}
var (
_ io.Closer = (*RelayReader)(nil)
_ io.WriterTo = (*RelayReader)(nil)
_ io.Reader = (*RelayReader)(nil)
)
func NewRelayReader() *RelayReader {
var q RelayReader
q.closedCh = make(chan struct{})
q.t = q.line.Take()
return &q
}
// CloseWithError terminates the line, unblocking any writer waiting for its
// turn with the error, or [io.EOF] if err is nil. It is safe to call
// [CloseWithError] multiple times and across multiple goroutines.
//
// If the line is already closed, [CloseWithError] is a no-op.
//
// It never returns an error.
func (q *RelayReader) CloseWithError(err error) error {
q.mu.Lock()
defer q.mu.Unlock()
if q.err == nil {
q.err = cmp.Or(q.err, err, io.EOF)
close(q.closedCh)
}
return nil
}
// Close closes the line. Any writer waiting for its turn will be unblocked
// with an [io.ErrClosedPipe] error.
//
// It never returns an error.
func (q *RelayReader) Close() error {
return q.CloseWithError(nil)
}
func (q *RelayReader) closed() <-chan struct{} {
q.mu.Lock()
defer q.mu.Unlock()
return q.closedCh
}
func (q *RelayReader) Read(p []byte) (int, error) {
panic("RelayReader.Read is for show only; use WriteTo")
}
// WriteTo yields the writer w to the first writer in line and blocks until the
// first call to [Close].
//
// It is safe to call [Take] concurrently with [WriteTo].
func (q *RelayReader) WriteTo(dst io.Writer) (int64, error) {
select {
case <-q.closed():
return 0, io.ErrClosedPipe
default:
}
// We have a destination writer; let the relay begin.
q.w = dst
q.t.Done()
<-q.closed()
return q.n, nil
}
// Take returns a writer that will be passed to the next writer in line.
//
// It is not safe for use across multiple goroutines.
func (q *RelayReader) Take() io.WriteCloser {
return &relayWriter{q: q, t: q.line.Take()}
}
type relayWriter struct {
q *RelayReader
t *Ticket
ready bool
}
var _ io.StringWriter = (*relayWriter)(nil)
// Write writes to the writer passed to [RelayReader.WriteTo] as soon as the
// writer is ready. It returns io.ErrClosedPipe if the line is closed before
// the writer is ready.
func (w *relayWriter) Write(p []byte) (int, error) {
if !w.awaitTurn() {
return 0, w.q.err
}
n, err := w.q.w.Write(p)
w.q.n += int64(n)
return n, err
}
func (w *relayWriter) WriteString(s string) (int, error) {
if !w.awaitTurn() {
return 0, w.q.err
}
return io.WriteString(w.q.w, s)
}
// Close signals that the writer is done, unblocking the next writer in line.
func (w *relayWriter) Close() error {
w.t.Done()
return nil
}
func (t *relayWriter) awaitTurn() (ok bool) {
if t.ready {
return true
}
select {
case <-t.t.Ready():
t.ready = true
return true
case <-t.q.closed():
return false
}
}

View File

@@ -0,0 +1,65 @@
package syncs
import (
"bytes"
"io"
"math/rand/v2"
"testing"
"testing/synctest"
)
func TestPipelineReadWriterTo(t *testing.T) {
for range 10 {
synctest.Run(func() {
q := NewRelayReader()
tickets := []struct {
io.WriteCloser
s string
}{
{q.Take(), "you"},
{q.Take(), " say hi,"},
{q.Take(), " and "},
{q.Take(), "I say "},
{q.Take(), "hello"},
}
rand.Shuffle(len(tickets), func(i, j int) {
tickets[i], tickets[j] = tickets[j], tickets[i]
})
var g Group
for i, t := range tickets {
g.Go(func() {
defer t.Close()
if i%2 == 0 {
// Use [relayWriter.WriteString]
io.WriteString(t.WriteCloser, t.s)
} else {
t.Write([]byte(t.s))
}
})
}
var got bytes.Buffer
var copyErr error // checked at end
g.Go(func() {
_, copyErr = io.Copy(&got, q)
})
synctest.Wait()
q.Close()
g.Wait()
if copyErr != nil {
t.Fatal(copyErr)
}
want := "you say hi, and I say hello"
if got.String() != want {
t.Fatalf("got %q, want %q", got.String(), want)
}
})
}
}

View File

@@ -0,0 +1,41 @@
package syncs
import (
"sync"
"sync/atomic"
)
// Group is a [sync.WaitGroup] with a Go method.
type Group struct {
wg sync.WaitGroup
n atomic.Int64
}
func (g *Group) Go(f func()) {
g.wg.Add(1)
go func() {
g.n.Add(1) // Now we are running
defer func() {
g.wg.Done()
g.n.Add(-1) // Now we are done
}()
f()
}()
}
// Running returns the number of goroutines that are currently running.
//
// If a call to [Running] returns zero, and a call to [Wait] is made without
// any calls to [Go], then [Wait] will return immediately. This is true even if
// a goroutine is started and finishes between the two calls.
//
// It is possible for [Running] to return non-zero and for [Wait] to return
// immediately. This can happen if the all running goroutines finish between
// the two calls.
func (g *Group) Running() int64 {
return g.n.Load()
}
func (g *Group) Wait() {
g.wg.Wait()
}

View File

@@ -0,0 +1,74 @@
package testutil
import (
"os"
"path/filepath"
"testing"
"time"
)
// Check calls t.Fatal(err) if err is not nil.
func Check(t *testing.T, err error) {
if err != nil {
t.Helper()
t.Fatal(err)
}
}
// CheckFunc exists so other packages do not need to invent their own type for
// taking a Check function.
type CheckFunc func(err error)
// Checker returns a check function that
// calls t.Fatal if err is not nil.
func Checker(t *testing.T) (check func(err error)) {
return func(err error) {
if err != nil {
t.Helper()
t.Fatal(err)
}
}
}
// StopPanic runs f but silently recovers from any panic f causes.
// The normal usage is:
//
// testutil.StopPanic(func() {
// callThatShouldPanic()
// t.Errorf("callThatShouldPanic did not panic")
// })
func StopPanic(f func()) {
defer func() { recover() }()
f()
}
// CheckTime calls t.Fatalf if got != want. Included in the error message is
// want.Sub(got) to help diagnose the difference, along with their values in
// UTC.
func CheckTime(t *testing.T, got, want time.Time) {
t.Helper()
if !got.Equal(want) {
t.Fatalf("got %v, want %v (%v)", got.UTC(), want.UTC(), want.Sub(got))
}
}
// WriteFile writes data to a file named name. It makes the directory if it
// doesn't exist and sets the file mode to perm.
//
// The name must be a relative path and must not contain .. or start with a /;
// otherwise WriteFile will panic.
func WriteFile[S []byte | string](t testing.TB, name string, data S) {
t.Helper()
if filepath.IsAbs(name) {
t.Fatalf("WriteFile: name must be a relative path, got %q", name)
}
name = filepath.Clean(name)
dir := filepath.Dir(name)
if err := os.MkdirAll(dir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(name, []byte(data), 0o644); err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,118 @@
// Package manifest provides documentation for the Ollama manifest format.
// This package contains no code.
//
// # Manifests
//
// A manifest is a JSON object that describes a model. The JSON object has a
// single field "layers" which is a list of layers that make up the model. Each
// layer has the following fields:
//
// A layer is a single, logical unit of a model. Layers are stored in the cache
// as files with the name of the digest of the layer. Layers are pushed and
// pulled from the registry as blobs.
//
// A layer is represented as a JSON object with the following fields:
//
// - "digest": The digest of the layer.
// - "mediaType": The media type of the layer.
// - "size": The size of the layer in bytes.
//
// Layers are typically stored in a blob store, such as a registry, and are
// referenced by their digest. This package does not define how layers are
// stored or retrieved.
//
// # Configuration Layer
//
// The configuration of a model is represented as a layer with the media type:
//
// application/vnd.ollama.image.config; type=<type>
//
// The "type" parameter in the media type specifies the format of the
// configuration (e.g., "safetensor" or "gguf").
//
// There may be only one configuration layer in a model.
//
// # Template Layer
//
// The model template is a layer with the media type:
//
// application/vnd.ollama.image.template; [name=<name>]
//
// The "name" parameter in the media type specifies the name of the template as
// for lookup at runtime. The name is optional and may be omitted. If omitted,
// the template is the default template for the model.
//
// # Tensor Layers
//
// The tensors of a model are represented as layers with the media type:
//
// application/vnd.ollama.image.tensor; name=<name>; dtype=<dtype>; shape=<shape>
//
// The "name" parameter in the media type specifies the name of the tensor as
// defined in the model's configuration and are bound only by the rules for
// names as defined in the configuration format, as represented by the
// configuration's "type".
//
// The "dtype" parameter in the media type specifies the data type of the tensor
// as a string.
//
// TODO: Define more specifically how to represent data types as strings.
//
// The "shape" parameter in the media type specifies the shape of the tensor as
// a comma-separated list of integers; one per dimension.
//
// # Tokenization Layers
//
// The tokenization of a model is represented as a layer with the media type:
//
// application/vnd.ollama.image.tokenizer
//
// The configuration of the tokenizer is represented as a layer with the media type:
//
// application/vnd.ollama.image.tokenizer.config
//
// # Miscellaneous Layers
//
// These extra layer mime types are reserved:
//
// application/vnd.ollama.image.license
//
// This layer contains one of the many licenses for the model in plain text.
//
// # Example Manifest
//
// The following is an example manifest containing a configuration, a model
// template, and two tensors (digests shortened for brevity):
//
// {
// "layers": [{
// "digest": "sha256:a...",
// "mediaType": "application/vnd.ollama.image.config; type=safetensors",
// "size": 1234
// },{
// "digest": "sha256:b...",
// "mediaType": "application/vnd.ollama.image.template",
// "size": 5678
// },{
// "digest": "sha256:c...",
// "mediaType": "application/vnd.ollama.image.tensor; name=input; dtype=F32; shape=1,2,3",
// "size": 9012
// },{
// "digest": "sha256:d...",
// "mediaType": "application/vnd.ollama.image.tensor; name=output; dtype=I32; shape=4,5,6",
// "size": 3456
// }]
// }
//
// # Legacy Media Types
//
// The appliaction/vnd.ollama.image.model media type is deprecated, but will
// remain supported for backwards compatibility, for some undefined amount of
// time. New models should use the media types defined above.
//
// # Reserved media types
//
// The media type prefix "application/vnd.ollama.image." is reserved for
// defining new media types for layers known to Ollama. Currently, all other
// prefixes are ignored by official Ollama registry clients.
package manifest

View File

@@ -1127,54 +1127,72 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
}
func (s *Server) GenerateRoutes() http.Handler {
config := cors.DefaultConfig()
config.AllowWildcard = true
config.AllowBrowserExtensions = true
config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval"}
for _, prop := range openAIProperties {
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
corsConfig := cors.DefaultConfig()
corsConfig.AllowWildcard = true
corsConfig.AllowBrowserExtensions = true
corsConfig.AllowHeaders = []string{
"Authorization",
"Content-Type",
"User-Agent",
"Accept",
"X-Requested-With",
// OpenAI compatibility headers
"x-stainless-lang",
"x-stainless-package-version",
"x-stainless-os",
"x-stainless-arch",
"x-stainless-retry-count",
"x-stainless-runtime",
"x-stainless-runtime-version",
"x-stainless-async",
"x-stainless-helper-method",
"x-stainless-poll-helper",
"x-stainless-custom-poll-interval",
"x-stainless-timeout",
}
config.AllowOrigins = envconfig.Origins()
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
r := gin.Default()
r.Use(
cors.New(config),
cors.New(corsConfig),
allowedHostsMiddleware(s.addr),
)
// General
r.HEAD("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
r.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
// Local model cache management
r.POST("/api/pull", s.PullHandler)
r.POST("/api/push", s.PushHandler)
r.DELETE("/api/delete", s.DeleteHandler)
r.HEAD("/api/tags", s.ListHandler)
r.GET("/api/tags", s.ListHandler)
r.POST("/api/show", s.ShowHandler)
// Create
r.POST("/api/create", s.CreateHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.POST("/api/copy", s.CopyHandler)
// Inference
r.GET("/api/ps", s.PsHandler)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", s.CreateHandler)
r.POST("/api/push", s.PushHandler)
r.POST("/api/copy", s.CopyHandler)
r.DELETE("/api/delete", s.DeleteHandler)
r.POST("/api/show", s.ShowHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.GET("/api/ps", s.PsHandler)
// Compatibility endpoints
// Inference (OpenAI compatibility)
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
})
r.Handle(method, "/api/tags", s.ListHandler)
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})
})
}
return r
}

View File

@@ -179,7 +179,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
if allReliable {
// HACK
os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(defaultModelsPerGPU*len(gpus)))
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus))
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners(), "gpu_count", len(gpus))
} else {
// HACK
os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(len(gpus)))