From 5f8c03189e73758339c3652d7eb0e6c3aed9761f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 18 Feb 2025 14:47:49 -0800 Subject: [PATCH 01/31] build: remove backend build for sapphirerapids sapphire rapids has amx support but it ends up having a negative performance impact. emerald rapids also has amx support with a positive performance impact however there's no reasonable way in ggml to differentiate between the two. the impact is small (~6%) so disable amx entirely for simplicity --- llama/patches/0018-remove-amx.patch | 24 ++++++++++++++++++++++++ ml/backend/ggml/ggml/src/CMakeLists.txt | 4 ---- 2 files changed, 24 insertions(+), 4 deletions(-) create mode 100644 llama/patches/0018-remove-amx.patch diff --git a/llama/patches/0018-remove-amx.patch b/llama/patches/0018-remove-amx.patch new file mode 100644 index 00000000..5428ee64 --- /dev/null +++ b/llama/patches/0018-remove-amx.patch @@ -0,0 +1,24 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Michael Yang +Date: Tue, 18 Feb 2025 14:47:21 -0800 +Subject: [PATCH] remove amx + +--- + ggml/src/CMakeLists.txt | 4 ---- + 1 file changed, 4 deletions(-) + +diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt +index 72b488dd..50828717 100644 +--- a/ggml/src/CMakeLists.txt ++++ b/ggml/src/CMakeLists.txt +@@ -293,10 +293,6 @@ if (GGML_CPU_ALL_VARIANTS) + ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512) + ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI) + ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI) +- if (NOT MSVC) +- # MSVC doesn't support AMX +- ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) +- endif() + else () + ggml_add_cpu_backend_variant_impl("") + endif() diff --git a/ml/backend/ggml/ggml/src/CMakeLists.txt b/ml/backend/ggml/ggml/src/CMakeLists.txt index 72b488dd..50828717 100644 --- a/ml/backend/ggml/ggml/src/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/CMakeLists.txt @@ -293,10 +293,6 @@ if (GGML_CPU_ALL_VARIANTS) ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512) ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI) ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI) - if (NOT MSVC) - # MSVC doesn't support AMX - ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) - endif() else () ggml_add_cpu_backend_variant_impl("") endif() From 3c874df46e261039c60b6b47e4386621a0a06777 Mon Sep 17 00:00:00 2001 From: maninhill <41712985+maninhill@users.noreply.github.com> Date: Thu, 20 Feb 2025 05:20:09 +0800 Subject: [PATCH 02/31] docs: Add MaxKB to Community Integrations (#9212) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 9d186725..14b0190b 100644 --- a/README.md +++ b/README.md @@ -382,6 +382,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI) - [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models) - [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally) +- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot) ### Cloud From 778603a8182df94cfea961e4dbaff89780b725fa Mon Sep 17 00:00:00 2001 From: zyxucp <286513187@qq.com> Date: Thu, 20 Feb 2025 05:22:48 +0800 Subject: [PATCH 03/31] docs: Add AntSK to Community Integrations (#9214) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 14b0190b..bae51848 100644 --- a/README.md +++ b/README.md @@ -382,6 +382,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI) - [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models) - [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally) +- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot) - [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot) ### Cloud From d721a02e7daba5a4c25b75e67c6413adb387e606 Mon Sep 17 00:00:00 2001 From: yuiseki Date: Thu, 20 Feb 2025 06:24:27 +0900 Subject: [PATCH 04/31] test: add test cases for ListHandler (#9146) --- cmd/cmd_test.go | 91 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index c8963280..e70ffbea 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -10,6 +10,7 @@ import ( "os" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/spf13/cobra" @@ -490,6 +491,96 @@ func TestPushHandler(t *testing.T) { } } +func TestListHandler(t *testing.T) { + tests := []struct { + name string + args []string + serverResponse []api.ListModelResponse + expectedError string + expectedOutput string + }{ + { + name: "list all models", + args: []string{}, + serverResponse: []api.ListModelResponse{ + {Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)}, + {Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-48 * time.Hour)}, + }, + expectedOutput: "NAME ID SIZE MODIFIED \n" + + "model1 sha256:abc12 1.0 KB 24 hours ago \n" + + "model2 sha256:def45 2.0 KB 2 days ago \n", + }, + { + name: "filter models by prefix", + args: []string{"model1"}, + serverResponse: []api.ListModelResponse{ + {Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)}, + {Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-24 * time.Hour)}, + }, + expectedOutput: "NAME ID SIZE MODIFIED \n" + + "model1 sha256:abc12 1.0 KB 24 hours ago \n", + }, + { + name: "server error", + args: []string{}, + expectedError: "server error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/tags" || r.Method != http.MethodGet { + t.Errorf("unexpected request to %s %s", r.Method, r.URL.Path) + http.Error(w, "not found", http.StatusNotFound) + return + } + + if tt.expectedError != "" { + http.Error(w, tt.expectedError, http.StatusInternalServerError) + return + } + + response := api.ListResponse{Models: tt.serverResponse} + if err := json.NewEncoder(w).Encode(response); err != nil { + t.Fatal(err) + } + })) + defer mockServer.Close() + + t.Setenv("OLLAMA_HOST", mockServer.URL) + + cmd := &cobra.Command{} + cmd.SetContext(context.TODO()) + + // Capture stdout + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + err := ListHandler(cmd, tt.args) + + // Restore stdout and get output + w.Close() + os.Stdout = oldStdout + output, _ := io.ReadAll(r) + + if tt.expectedError == "" { + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if got := string(output); got != tt.expectedOutput { + t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got) + } + } else { + if err == nil || !strings.Contains(err.Error(), tt.expectedError) { + t.Errorf("expected error containing %q, got %v", tt.expectedError, err) + } + } + }) + } +} + func TestCreateHandler(t *testing.T) { tests := []struct { name string From bda4ef6c568732c57bcc31b4b8d87520f21aaa1a Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 19 Feb 2025 15:03:03 -0800 Subject: [PATCH 05/31] reorder patches --- llama/patches/{0018-remove-amx.patch => 0019-remove-amx.patch} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename llama/patches/{0018-remove-amx.patch => 0019-remove-amx.patch} (100%) diff --git a/llama/patches/0018-remove-amx.patch b/llama/patches/0019-remove-amx.patch similarity index 100% rename from llama/patches/0018-remove-amx.patch rename to llama/patches/0019-remove-amx.patch From 351a85d9ea0db108ca29bba48d0a04e37c6e3607 Mon Sep 17 00:00:00 2001 From: Lucas Hahn <50808857+lucasthahn@users.noreply.github.com> Date: Thu, 20 Feb 2025 00:56:18 -0500 Subject: [PATCH 06/31] openai: add 'timeout' to allowable x-stainless headers (#9237) --- server/routes.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/routes.go b/server/routes.go index 95485cb8..9cefb607 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1131,7 +1131,7 @@ func (s *Server) GenerateRoutes() http.Handler { config.AllowWildcard = true config.AllowBrowserExtensions = true config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"} - openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval"} + openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout"} for _, prop := range openAIProperties { config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop) } From 3d4cc7833c21bf9189cb1d3fa8365997e46fad33 Mon Sep 17 00:00:00 2001 From: danielekp <61015367+danielekp@users.noreply.github.com> Date: Thu, 20 Feb 2025 20:34:24 +0100 Subject: [PATCH 07/31] docs: Add yla to community integrations --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index bae51848..5e439898 100644 --- a/README.md +++ b/README.md @@ -384,6 +384,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally) - [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot) - [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot) +- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models) ### Cloud From 7c168b08c9522f56290478fba0267118d20f7ec4 Mon Sep 17 00:00:00 2001 From: frob Date: Thu, 20 Feb 2025 21:10:15 +0100 Subject: [PATCH 08/31] server: add missing function parens to debug log (#9255) --- server/sched.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/sched.go b/server/sched.go index 563f2aad..b4600dbf 100644 --- a/server/sched.go +++ b/server/sched.go @@ -179,7 +179,7 @@ func (s *Scheduler) processPending(ctx context.Context) { if allReliable { // HACK os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(defaultModelsPerGPU*len(gpus))) - slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus)) + slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners(), "gpu_count", len(gpus)) } else { // HACK os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(len(gpus))) From ba9ec3d05ed47b243389f8273d160329e8964362 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 20 Feb 2025 11:52:25 -0800 Subject: [PATCH 09/31] ci: use clang for windows cpu builds clang outputs are faster. we were previously building with clang via gcc wrapper in cgo but this was missed during the build updates so there was a drop in performance --- .github/workflows/release.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 37ac7e45..37d525e9 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -160,6 +160,10 @@ jobs: echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append + - if: matrix.preset == 'CPU' + run: | + echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} uses: actions/cache/save@v4 with: From 14b5a9a150598d724e4ef17616cdb25257ddc155 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Thu, 20 Feb 2025 13:19:58 -0800 Subject: [PATCH 10/31] api: document client stream behavior with a test (#8996) Added unit tests to verify error handling behavior in the Client.stream and Client.do methods. Tests cover various error scenarios including: - Error responses with status codes >= 400 - Error messages with successful status codes - Empty error messages - Successful responses --- api/client.go | 2 +- api/client_test.go | 210 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 211 insertions(+), 1 deletion(-) diff --git a/api/client.go b/api/client.go index 4688d4d1..f87ea0fd 100644 --- a/api/client.go +++ b/api/client.go @@ -132,7 +132,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData const maxBufferSize = 512 * format.KiloByte func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { - var buf *bytes.Buffer + var buf io.Reader if data != nil { bts, err := json.Marshal(data) if err != nil { diff --git a/api/client_test.go b/api/client_test.go index 23fe9334..fe9a1589 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -1,6 +1,13 @@ package api import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" "testing" ) @@ -43,3 +50,206 @@ func TestClientFromEnvironment(t *testing.T) { }) } } + +// testError represents an internal error type with status code and message +// this is used since the error response from the server is not a standard error struct +type testError struct { + message string + statusCode int +} + +func (e testError) Error() string { + return e.message +} + +func TestClientStream(t *testing.T) { + testCases := []struct { + name string + responses []any + wantErr string + }{ + { + name: "immediate error response", + responses: []any{ + testError{ + message: "test error message", + statusCode: http.StatusBadRequest, + }, + }, + wantErr: "test error message", + }, + { + name: "error after successful chunks, ok response", + responses: []any{ + ChatResponse{Message: Message{Content: "partial response 1"}}, + ChatResponse{Message: Message{Content: "partial response 2"}}, + testError{ + message: "mid-stream error", + statusCode: http.StatusOK, + }, + }, + wantErr: "mid-stream error", + }, + { + name: "successful stream completion", + responses: []any{ + ChatResponse{Message: Message{Content: "chunk 1"}}, + ChatResponse{Message: Message{Content: "chunk 2"}}, + ChatResponse{ + Message: Message{Content: "final chunk"}, + Done: true, + DoneReason: "stop", + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + flusher, ok := w.(http.Flusher) + if !ok { + t.Fatal("expected http.Flusher") + } + + w.Header().Set("Content-Type", "application/x-ndjson") + + for _, resp := range tc.responses { + if errResp, ok := resp.(testError); ok { + w.WriteHeader(errResp.statusCode) + err := json.NewEncoder(w).Encode(map[string]string{ + "error": errResp.message, + }) + if err != nil { + t.Fatal("failed to encode error response:", err) + } + return + } + + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Fatalf("failed to encode response: %v", err) + } + flusher.Flush() + } + })) + defer ts.Close() + + client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient) + + var receivedChunks []ChatResponse + err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error { + var resp ChatResponse + if err := json.Unmarshal(chunk, &resp); err != nil { + return fmt.Errorf("failed to unmarshal chunk: %w", err) + } + receivedChunks = append(receivedChunks, resp) + return nil + }) + + if tc.wantErr != "" { + if err == nil { + t.Fatal("expected error but got nil") + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Errorf("expected error containing %q, got %v", tc.wantErr, err) + } + return + } + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +func TestClientDo(t *testing.T) { + testCases := []struct { + name string + response any + wantErr string + }{ + { + name: "immediate error response", + response: testError{ + message: "test error message", + statusCode: http.StatusBadRequest, + }, + wantErr: "test error message", + }, + { + name: "server error response", + response: testError{ + message: "internal error", + statusCode: http.StatusInternalServerError, + }, + wantErr: "internal error", + }, + { + name: "successful response", + response: struct { + ID string `json:"id"` + Success bool `json:"success"` + }{ + ID: "msg_123", + Success: true, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if errResp, ok := tc.response.(testError); ok { + w.WriteHeader(errResp.statusCode) + err := json.NewEncoder(w).Encode(map[string]string{ + "error": errResp.message, + }) + if err != nil { + t.Fatal("failed to encode error response:", err) + } + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(tc.response); err != nil { + t.Fatalf("failed to encode response: %v", err) + } + })) + defer ts.Close() + + client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient) + + var resp struct { + ID string `json:"id"` + Success bool `json:"success"` + } + err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp) + + if tc.wantErr != "" { + if err == nil { + t.Fatalf("got nil, want error %q", tc.wantErr) + } + if err.Error() != tc.wantErr { + t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr) + } + return + } + + if err != nil { + t.Fatalf("got error %q, want nil", err) + } + + if expectedResp, ok := tc.response.(struct { + ID string `json:"id"` + Success bool `json:"success"` + }); ok { + if resp.ID != expectedResp.ID { + t.Errorf("response ID mismatch: got %q, want %q", resp.ID, expectedResp.ID) + } + if resp.Success != expectedResp.Success { + t.Errorf("response Success mismatch: got %v, want %v", resp.Success, expectedResp.Success) + } + } + }) + } +} From bd6a7d5e6416c4c2aeba07233303385254395b6c Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 20 Feb 2025 11:18:01 -0800 Subject: [PATCH 11/31] ollamarunner: Pass runner performance parameters to backends Currently the following parameters are in the runner but not used: - numGPULayers - mainGPU - threads - tensorSplit This passes them through to the backend, which is where they would actually get used. However, the GGML backend does not yet do anything with them. --- ml/backend.go | 23 +++++++++++++++++++---- ml/backend/ggml/ggml.go | 2 +- model/model.go | 4 ++-- runner/ollamarunner/runner.go | 31 ++++++++++++++++++------------- 4 files changed, 40 insertions(+), 20 deletions(-) diff --git a/ml/backend.go b/ml/backend.go index aebf86f7..3cc18f2b 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -26,9 +26,24 @@ type Backend interface { SystemInfo() string } -var backends = make(map[string]func(*os.File) (Backend, error)) +// BackendParams controls how the backend loads and executes models +type BackendParams struct { + // NumThreads sets the number of threads to use if running on the CPU + NumThreads int -func RegisterBackend(name string, f func(*os.File) (Backend, error)) { + // MainGPU is the index of the primary GPU to use + MainGPU int + + // NumGPULayers is the number of layers to offload to GPUs + NumGPULayers int + + // TensorSplit is the fraction of the model to offload to each GPU + TensorSplit []float32 +} + +var backends = make(map[string]func(*os.File, BackendParams) (Backend, error)) + +func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) { if _, ok := backends[name]; ok { panic("backend: backend already registered") } @@ -36,9 +51,9 @@ func RegisterBackend(name string, f func(*os.File) (Backend, error)) { backends[name] = f } -func NewBackend(f *os.File) (Backend, error) { +func NewBackend(f *os.File, params BackendParams) (Backend, error) { if backend, ok := backends["ggml"]; ok { - return backend(f) + return backend(f, params) } return nil, fmt.Errorf("unsupported backend") diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 5ba36361..492f2d0a 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -84,7 +84,7 @@ type Backend struct { tensors map[string]*Context } -func New(r *os.File) (ml.Backend, error) { +func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { meta, n, err := fs.Decode(r, -1) if err != nil { return nil, err diff --git a/model/model.go b/model/model.go index 5eedc9bd..0b5996d9 100644 --- a/model/model.go +++ b/model/model.go @@ -70,14 +70,14 @@ func Register(name string, f func(ml.Config) (Model, error)) { } // New initializes a new model instance with the provided configuration based on the metadata in the model file -func New(modelPath string) (Model, error) { +func New(modelPath string, params ml.BackendParams) (Model, error) { r, err := os.Open(modelPath) if err != nil { return nil, err } defer r.Close() - b, err := ml.NewBackend(r) + b, err := ml.NewBackend(r, params) if err != nil { return nil, err } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 6d45050c..d11eba82 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -25,6 +25,7 @@ import ( "golang.org/x/sync/semaphore" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" "github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/sample" @@ -801,6 +802,7 @@ func (m *multiLPath) String() string { func (s *Server) loadModel( mpath string, + params ml.BackendParams, lpath multiLPath, parallel int, kvCacheType string, @@ -808,12 +810,12 @@ func (s *Server) loadModel( multiUserCache bool, ) { var err error - s.model, err = model.New(mpath) + s.model, err = model.New(mpath, params) if err != nil { panic(err) } - slog.Info("system", "info", s.model.Backend().SystemInfo() /* "threads", *threads */) + slog.Info("system", "info", s.model.Backend().SystemInfo(), "threads", params.NumThreads) // TODO(jessegross): LoRA loading if lpath.String() != "" { @@ -843,17 +845,17 @@ func Execute(args []string) error { mpath := fs.String("model", "", "Path to model binary file") parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously") batchSize := fs.Int("batch-size", 512, "Batch size") - _ = fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") - _ = fs.Int("main-gpu", 0, "Main GPU") + numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") + mainGPU := fs.Int("main-gpu", 0, "Main GPU") _ = fs.Bool("flash-attn", false, "Enable flash attention") kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size") kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)") port := fs.Int("port", 8080, "Port to expose the server on") - _ = fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") + threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") verbose := fs.Bool("verbose", false, "verbose output (default: disabled)") _ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)") _ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing") - _ = fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions") + tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions") multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") var lpaths multiLPath @@ -890,15 +892,11 @@ func Execute(args []string) error { } // TODO(jessegross): Parameters that need to be implemented: - // n-gpu-layers - // main-gpu // flash-attn - // threads // no-mmap // mlock - // tensor-split - /*var tensorSplitFloats []float32 + var tensorSplitFloats []float32 if *tensorSplit != "" { stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1) @@ -907,10 +905,17 @@ func Execute(args []string) error { f, _ := strconv.ParseFloat(s, 32) tensorSplitFloats = append(tensorSplitFloats, float32(f)) } - }*/ + } + + params := ml.BackendParams{ + NumThreads: *threads, + NumGPULayers: *numGPULayers, + MainGPU: *mainGPU, + TensorSplit: tensorSplitFloats, + } server.ready.Add(1) - go server.loadModel(*mpath, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache) + go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache) server.cond = sync.NewCond(&server.mu) From e5bcc51ae199116a635d74be3a510c5aeeb2894a Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 18 Feb 2025 16:52:29 -0800 Subject: [PATCH 12/31] ggml-backend: Don't recreate the scheduler for each context We don't need to create and destroy the GGML scheduler for every context. This introduces extra CPU overhead for every forward pass and extra memory for contexts that don't actually get scheduled (for example, KV caches). We can instead just have one scheduler for the backend and reset it each time we call Compute. This improves token generation performance by 1-2% and removes scheduler create/destroy from profile traces. --- ml/backend/ggml/ggml.go | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 492f2d0a..0e30c36f 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -82,6 +82,8 @@ type Backend struct { meta *fs.GGML cpus, gpus []Context tensors map[string]*Context + + sched *C.struct_ggml_backend_sched } func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { @@ -182,10 +184,24 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { return nil, err } + backends := make([]*C.struct_ggml_backend, len(gpus)+len(cpus)) + bufts := make([]*C.struct_ggml_backend_buffer_type, len(gpus)+len(cpus)) + for i, c := range append(gpus, cpus...) { + backends[i] = c.backend + bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend) + } + return &Backend{ meta: meta, cpus: cpus, gpus: gpus, + sched: C.ggml_backend_sched_new( + (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])), + (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])), + C.int(len(backends)), + C.size_t(max(8192, len(meta.Tensors().Items())*5)), + true, + ), }, nil } @@ -219,31 +235,23 @@ func (b *Backend) NewContext() ml.Context { }) backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus)) - bufts := make([]*C.struct_ggml_backend_buffer_type, len(b.gpus)+len(b.cpus)) for i, c := range append(b.gpus, b.cpus...) { backends[i] = c.backend - bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend) } return &Context{ + b: b, ctx: c, backend: backends[0], nodes: nodes, - sched: C.ggml_backend_sched_new( - (*C.ggml_backend_t)(unsafe.Pointer(&backends[0])), - (*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])), - C.int(len(backends)), - C.size_t(nodes), - true, - ), } } type Context struct { + b *Backend ctx *C.struct_ggml_context backend *C.struct_ggml_backend - sched *C.struct_ggml_backend_sched graph *C.struct_ggml_cgraph nodes int } @@ -257,12 +265,13 @@ func (c *Context) Forward(t ml.Tensor) { } func (c *Context) Compute(tensors ...ml.Tensor) { - C.ggml_backend_sched_graph_compute_async(c.sched, c.graph) + C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph) + C.ggml_backend_sched_reset(c.b.sched) needSync := true sync := func() { if needSync { - C.ggml_backend_sched_synchronize(c.sched) + C.ggml_backend_sched_synchronize(c.b.sched) needSync = false } } @@ -350,7 +359,6 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { func (c *Context) Close() { if c != nil { - C.ggml_backend_sched_free(c.sched) C.ggml_free(c.ctx) } } From 5c5535c0648fb12b174246eb2524e862ae2d2d5b Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 18 Feb 2025 17:16:57 -0800 Subject: [PATCH 13/31] models: Prune unused outputs earlier in the forward pass Currently Rows is called as the last step in a model computation to get the values for the output tokens. However, if we move it earlier in the process then we can trim out computations that never get used. This is similar to how models are defined in llama.cpp. Changing the model definition in this way improves token generation performance by approximately 8%. --- model/models/llama/model.go | 36 ++++++++++++++++++++----------- model/models/mllama/model.go | 6 ++---- model/models/mllama/model_text.go | 27 +++++++++++++++++------ 3 files changed, 46 insertions(+), 23 deletions(-) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index b2c5c2c7..e90631fb 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -120,11 +120,19 @@ type Layer struct { MLP *MLP } -func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { +func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) + + // In the final layer (outputs != nil), optimize by pruning to just the token positions + // we need logits for. + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState @@ -144,22 +152,26 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { return nil, err } - hiddenState := m.TokenEmbedding.Forward(ctx, inputs) - - for i, layer := range m.Layers { - m.Cache.SetLayer(i) - hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options) - } - - hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) - hiddenState = m.Output.Forward(ctx, hiddenState) - outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } - return hiddenState.Rows(ctx, outputs), nil + hiddenState := m.TokenEmbedding.Forward(ctx, inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var lastLayerOutputs ml.Tensor + if i == len(m.Layers)-1 { + lastLayerOutputs = outputs + } + + hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) + } + + hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + return m.Output.Forward(ctx, hiddenState), nil } func init() { diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index a1460d94..f5521ce5 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -93,15 +93,13 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { return nil, err } - // TODO: attention mask, cross attention mask - hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)) - outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } - return hiddenState.Rows(ctx, outputs), nil + // TODO: attention mask, cross attention mask + return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil } func init() { diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 1e48086a..8ad804cf 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -74,11 +74,19 @@ type TextSelfAttentionDecoderLayer struct { MLP *TextMLP } -func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { +func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { residual := hiddenState hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts) + + // In the final layer (outputs != nil), optimize by pruning to just the token positions + // we need logits for. + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState @@ -145,7 +153,7 @@ type TextCrossAttentionDecoderLayer struct { MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"` } -func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { +func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { residual := hiddenState hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) @@ -161,14 +169,14 @@ func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, } type TextDecoderLayer interface { - Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor + Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor } type TextDecoder struct { Layers []TextDecoderLayer } -func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { +func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { for i, layer := range d.Layers { layerType := selfAttentionLayer if slices.Contains(opts.crossAttentionLayers, uint32(i)) { @@ -179,7 +187,12 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr cache.SetLayerType(layerType) if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() { - hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts) + var lastLayerOutputs ml.Tensor + if i == len(d.Layers)-1 { + lastLayerOutputs = outputs + } + + hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts) } } @@ -205,9 +218,9 @@ type TextModel struct { *TextModelOptions } -func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor { +func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs) - hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions) + hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions) hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) return m.Output.Forward(ctx, hiddenState) } From 5d81c1a1842712e218d0311546037d152502b2c0 Mon Sep 17 00:00:00 2001 From: "Junyan Qin (Chin)" Date: Sat, 22 Feb 2025 01:36:55 +0800 Subject: [PATCH 14/31] docs: add `RockChinQ/LangBot` to integrations list (#9272) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 5e439898..548eb244 100644 --- a/README.md +++ b/README.md @@ -385,6 +385,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot) - [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot) - [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models) +- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms) ### Cloud From 2192a28eedc24398c7f274a15617341389b6c143 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 20 Feb 2025 16:45:05 -0800 Subject: [PATCH 15/31] ml/backend/ggml: fix rms norm --- ml/backend/ggml/ggml.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 0e30c36f..2b7b9189 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -485,7 +485,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso } func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor { - return (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) + return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) } func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { From f53f4198c36d0a943de598ad91a20baa9481c5c5 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Fri, 14 Feb 2025 20:51:44 -0800 Subject: [PATCH 16/31] ml: Abstract attention out of model definitions There are two benefits to doing this: - Provide a library function that models can use, reducing code for each model implementation - Enables a single place to drop in optimized implementations of attention based on the backend or other factors. One is provided for GGML. On CUDA this improves token generation rate by about 3%. It does not have a significant effect on Metal. Co-authored-by: Daniel Hiltgen --- ml/backend.go | 20 +++++++++++ ml/backend/ggml/ggml.go | 15 ++++++++ ml/nn/attention.go | 59 +++++++++++++++++++++++++++++++ model/models/llama/model.go | 9 ++--- model/models/mllama/model_text.go | 21 ++++------- 5 files changed, 102 insertions(+), 22 deletions(-) create mode 100644 ml/nn/attention.go diff --git a/ml/backend.go b/ml/backend.go index 3cc18f2b..6e3f0516 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -111,6 +111,26 @@ type Tensor interface { Copy(ctx Context, t2 Tensor) Tensor } +// ScaledDotProductAttention implements a fused attention +// operation equivalent to following code on a tensor named +// query: +// +// kq := key.MulmatFullPrec(ctx, query) +// +// kq = kq.Scale(ctx, scale) +// +// if mask != nil { +// kq = kq.Add(ctx, mask) +// } +// +// kq = kq.Softmax(ctx) +// +// kqv := value.Mulmat(ctx, kq) +// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) +type ScaledDotProductAttention interface { + ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor +} + type number interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 2b7b9189..2d7cf340 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -651,6 +651,21 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int } } +func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor { + var kqMask *C.struct_ggml_tensor + if mask != nil { + kqMask = mask.(*Tensor).t + } + + kq := key.MulmatFullPrec(ctx, t) + kq = &Tensor{ + t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), + } + + kqv := value.Mulmat(ctx, kq) + return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) +} + func (b *Backend) SystemInfo() string { var compiler string switch C.get_compiler() { diff --git a/ml/nn/attention.go b/ml/nn/attention.go new file mode 100644 index 00000000..4f0c9fa1 --- /dev/null +++ b/ml/nn/attention.go @@ -0,0 +1,59 @@ +package nn + +import ( + "fmt" + + "github.com/ollama/ollama/ml" +) + +// Attention implements scaled dot-product attention for transformer models: +// Attention(Q, K, V) = softmax(QK^T/√d_k)V +// +// Parameters: +// - ctx: Context for tensor operations +// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads] +// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads] +// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads] +// - mask: Optional attention mask that is added to the attention score. If +// provided, should broadcast to [seq_len_k, seq_len_q, heads] +// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension +// +// Returns: +// +// Attention output with shape [d_v, heads, seq_len_q] +func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor { + if query.Dim(0) != key.Dim(0) { + panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) + } + + if mask != nil && query.Dim(1) != mask.Dim(1) { + panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1))) + } + + if key.Dim(1) != value.Dim(0) { + panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0))) + } + + if mask != nil && key.Dim(1) != mask.Dim(0) { + panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0))) + } + + if key.Dim(2) != value.Dim(2) { + panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) + } + + if sdpa, ok := query.(ml.ScaledDotProductAttention); ok { + return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale) + } else { + kq := key.MulmatFullPrec(ctx, query) + + kq = kq.Scale(ctx, scale) + if mask != nil { + kq = kq.Add(ctx, mask) + } + kq = kq.Softmax(ctx) + + kqv := value.Mulmat(ctx, kq) + return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + } +} diff --git a/model/models/llama/model.go b/model/models/llama/model.go index e90631fb..4fe02999 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -86,13 +86,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - kq := k.MulmatFullPrec(ctx, q) - kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) - kq = kq.Add(ctx, mask) - kq = kq.Softmax(ctx) - - kqv := v.Mulmat(ctx, kq) - kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + scaleFactor := 1.0 / math.Sqrt(float64(headDim)) + kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor) kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, kqv) diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 8ad804cf..003bf9cb 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -38,13 +38,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scores := key.MulmatFullPrec(ctx, query) - scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) - scores = scores.Add(ctx, mask) - scores = scores.Softmax(ctx) - - attention := value.Mulmat(ctx, scores) - attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + scaleFactor := 1.0 / math.Sqrt(float64(headDim)) + attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, attention) @@ -112,7 +107,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = ca.QueryNorm.Forward(ctx, query, opts.eps) - var key, value ml.Tensor + var key, value, mask ml.Tensor if crossAttentionStates != nil { numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) @@ -125,19 +120,15 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio cache.Put(ctx, key, value) } else { - key, value, _ = cache.Get(ctx) + key, value, mask = cache.Get(ctx) } query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scores := key.Mulmat(ctx, query) - scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) - scores = scores.Softmax(ctx) - - attention := value.Mulmat(ctx, scores) - attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + scaleFactor := 1.0 / math.Sqrt(float64(headDim)) + attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return ca.Output.Forward(ctx, attention) From 68bac1e0a646e00a215b6bffb6f294f895c32238 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Fri, 21 Feb 2025 21:02:26 -0800 Subject: [PATCH 17/31] server: group routes by category and purpose (#9270) The route assembly in Handler lacked clear organization making it difficult scan for routes and their relationships to each other. This commit aims to fix that by reordering the assembly of routes to group them by category and purpose. Also, be more specific about what "config" refers to (it is about CORS if you were wondering... I was.) --- envconfig/config.go | 6 ++-- envconfig/config_test.go | 2 +- server/routes.go | 76 +++++++++++++++++++++++++--------------- 3 files changed, 51 insertions(+), 33 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index fbd881ba..d867bdac 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -53,8 +53,8 @@ func Host() *url.URL { } } -// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. -func Origins() (origins []string) { +// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable. +func AllowedOrigins() (origins []string) { if s := Var("OLLAMA_ORIGINS"); s != "" { origins = strings.Split(s, ",") } @@ -249,7 +249,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"}, "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"}, - "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"}, + "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 735b4540..993ddd9c 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -134,7 +134,7 @@ func TestOrigins(t *testing.T) { t.Run(tt.value, func(t *testing.T) { t.Setenv("OLLAMA_ORIGINS", tt.value) - if diff := cmp.Diff(Origins(), tt.expect); diff != "" { + if diff := cmp.Diff(AllowedOrigins(), tt.expect); diff != "" { t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff) } }) diff --git a/server/routes.go b/server/routes.go index 9cefb607..de72f847 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1127,54 +1127,72 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc { } func (s *Server) GenerateRoutes() http.Handler { - config := cors.DefaultConfig() - config.AllowWildcard = true - config.AllowBrowserExtensions = true - config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"} - openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout"} - for _, prop := range openAIProperties { - config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop) + corsConfig := cors.DefaultConfig() + corsConfig.AllowWildcard = true + corsConfig.AllowBrowserExtensions = true + corsConfig.AllowHeaders = []string{ + "Authorization", + "Content-Type", + "User-Agent", + "Accept", + "X-Requested-With", + + // OpenAI compatibility headers + "x-stainless-lang", + "x-stainless-package-version", + "x-stainless-os", + "x-stainless-arch", + "x-stainless-retry-count", + "x-stainless-runtime", + "x-stainless-runtime-version", + "x-stainless-async", + "x-stainless-helper-method", + "x-stainless-poll-helper", + "x-stainless-custom-poll-interval", + "x-stainless-timeout", } - config.AllowOrigins = envconfig.Origins() + corsConfig.AllowOrigins = envconfig.AllowedOrigins() r := gin.Default() r.Use( - cors.New(config), + cors.New(corsConfig), allowedHostsMiddleware(s.addr), ) + // General + r.HEAD("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") }) + r.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") }) + r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) }) + r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) }) + + // Local model cache management r.POST("/api/pull", s.PullHandler) + r.POST("/api/push", s.PushHandler) + r.DELETE("/api/delete", s.DeleteHandler) + r.HEAD("/api/tags", s.ListHandler) + r.GET("/api/tags", s.ListHandler) + r.POST("/api/show", s.ShowHandler) + + // Create + r.POST("/api/create", s.CreateHandler) + r.POST("/api/blobs/:digest", s.CreateBlobHandler) + r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) + r.POST("/api/copy", s.CopyHandler) + + // Inference + r.GET("/api/ps", s.PsHandler) r.POST("/api/generate", s.GenerateHandler) r.POST("/api/chat", s.ChatHandler) r.POST("/api/embed", s.EmbedHandler) r.POST("/api/embeddings", s.EmbeddingsHandler) - r.POST("/api/create", s.CreateHandler) - r.POST("/api/push", s.PushHandler) - r.POST("/api/copy", s.CopyHandler) - r.DELETE("/api/delete", s.DeleteHandler) - r.POST("/api/show", s.ShowHandler) - r.POST("/api/blobs/:digest", s.CreateBlobHandler) - r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) - r.GET("/api/ps", s.PsHandler) - // Compatibility endpoints + // Inference (OpenAI compatibility) r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler) r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler) - for _, method := range []string{http.MethodGet, http.MethodHead} { - r.Handle(method, "/", func(c *gin.Context) { - c.String(http.StatusOK, "Ollama is running") - }) - - r.Handle(method, "/api/tags", s.ListHandler) - r.Handle(method, "/api/version", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"version": version.Version}) - }) - } - return r } From 7cfd4aee4d9956b89dbbb103ee4877194abfe670 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sat, 22 Feb 2025 11:22:59 -0800 Subject: [PATCH 18/31] docs: add additional ROCm docs for building (#9066) --- docs/development.md | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/docs/development.md b/docs/development.md index 88fec3db..522d106b 100644 --- a/docs/development.md +++ b/docs/development.md @@ -46,15 +46,6 @@ Install prerequisites: - (Optional) NVIDIA GPU support - [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network) -> [!IMPORTANT] -> Ensure prerequisites are in `PATH` before running CMake. - -> [!IMPORTANT] -> ROCm is not compatible with Visual Studio CMake generators. Use `-GNinja` when configuring the project. - -> [!IMPORTANT] -> CUDA is only compatible with Visual Studio CMake generators. - Then, configure and build the project: ```shell @@ -62,6 +53,14 @@ cmake -B build cmake --build build --config Release ``` +> [!IMPORTANT] +> Building for ROCm requires additional flags: +> ``` +> cmake -B build -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ +> cmake --build build --config Release +> ``` + + Lastly, run Ollama: ```shell From 8c13cfa4dd35a79c983eb19b5ec2be7ffa220b69 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sun, 23 Feb 2025 19:13:53 -0800 Subject: [PATCH 19/31] ml/backend/ggml: fix crash on windows paths with wide characters (#9305) --- ...d-filesystem-path-instead-of-wstring.patch | 108 +++++++++++------- ml/backend/ggml/ggml/src/ggml-backend-reg.cpp | 56 +++++---- 2 files changed, 101 insertions(+), 63 deletions(-) diff --git a/llama/patches/0018-use-std-filesystem-path-instead-of-wstring.patch b/llama/patches/0018-use-std-filesystem-path-instead-of-wstring.patch index 749cfbba..95144fb4 100644 --- a/llama/patches/0018-use-std-filesystem-path-instead-of-wstring.patch +++ b/llama/patches/0018-use-std-filesystem-path-instead-of-wstring.patch @@ -4,17 +4,23 @@ Date: Sun, 16 Feb 2025 20:00:22 -0500 Subject: [PATCH] use std::filesystem::path instead of wstring --- - ggml/src/ggml-backend-reg.cpp | 116 ++++++++++++---------------------- - 1 file changed, 40 insertions(+), 76 deletions(-) + ggml/src/ggml-backend-reg.cpp | 144 ++++++++++++++-------------------- + 1 file changed, 58 insertions(+), 86 deletions(-) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp -index 84b21dd8..de78feae 100644 +index 84b21dd8..e35a6936 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp -@@ -72,16 +72,6 @@ - # pragma clang diagnostic ignored "-Wdeprecated-declarations" +@@ -66,26 +66,6 @@ + #include "ggml-kompute.h" #endif +-// disable C++17 deprecation warning for std::codecvt_utf8 +-#if defined(__clang__) +-# pragma clang diagnostic push +-# pragma clang diagnostic ignored "-Wdeprecated-declarations" +-#endif +- -static std::wstring utf8_to_utf16(const std::string & str) { - std::wstring_convert> converter; - return converter.from_bytes(str); @@ -25,10 +31,14 @@ index 84b21dd8..de78feae 100644 - return converter.to_bytes(str); -} - - #if defined(__clang__) - # pragma clang diagnostic pop - #endif -@@ -96,12 +86,12 @@ struct dl_handle_deleter { +-#if defined(__clang__) +-# pragma clang diagnostic pop +-#endif +- + #ifdef _WIN32 + + using dl_handle = std::remove_pointer_t; +@@ -96,7 +76,7 @@ struct dl_handle_deleter { } }; @@ -37,24 +47,44 @@ index 84b21dd8..de78feae 100644 // suppress error dialogs for missing DLLs DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); - -- HMODULE handle = LoadLibraryW(path.c_str()); -+ HMODULE handle = LoadLibraryW(path.wstring().c_str()); - - SetErrorMode(old_mode); - -@@ -129,8 +119,8 @@ struct dl_handle_deleter { +@@ -129,8 +109,8 @@ struct dl_handle_deleter { } }; -static void * dl_load_library(const std::wstring & path) { - dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL); +static void * dl_load_library(const std::filesystem::path & path) { -+ dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL); ++ dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); return handle; } -@@ -222,11 +212,11 @@ struct ggml_backend_registry { +@@ -141,6 +121,25 @@ static void * dl_get_sym(dl_handle * handle, const char * name) { + + #endif + ++static std::string path_to_string(const std::filesystem::path & path) ++{ ++#ifdef _WIN32 ++ const std::wstring wstr = path.wstring(); ++ const int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, nullptr, 0, nullptr, nullptr); ++ if (size_needed <= 0) { ++ return std::string(); ++ } ++ ++ // size_needed includes the null terminator ++ std::string str(size_needed - 1, '\0'); ++ WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, str.data(), size_needed, nullptr, nullptr); ++ return str; ++#else ++ return path.string(); ++#endif ++} ++ ++ + using dl_handle_ptr = std::unique_ptr; + + struct ggml_backend_reg_entry { +@@ -222,11 +221,11 @@ struct ggml_backend_registry { ); } @@ -64,49 +94,49 @@ index 84b21dd8..de78feae 100644 if (!handle) { if (!silent) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str()); -+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path.string().c_str()); ++ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(path).c_str()); } return nullptr; } -@@ -234,7 +224,7 @@ struct ggml_backend_registry { +@@ -234,7 +233,7 @@ struct ggml_backend_registry { auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); if (score_fn && score_fn() == 0) { if (!silent) { - GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str()); -+ GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path.string().c_str()); ++ GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_to_string(path).c_str()); } return nullptr; } -@@ -242,7 +232,7 @@ struct ggml_backend_registry { +@@ -242,7 +241,7 @@ struct ggml_backend_registry { auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init"); if (!backend_init_fn) { if (!silent) { - GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str()); -+ GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path.string().c_str()); ++ GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_to_string(path).c_str()); } return nullptr; } -@@ -251,16 +241,16 @@ struct ggml_backend_registry { +@@ -251,16 +250,16 @@ struct ggml_backend_registry { if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) { if (!silent) { if (!reg) { - GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str()); -+ GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path.string().c_str()); ++ GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path_to_string(path).c_str()); } else { GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n", - __func__, utf16_to_utf8(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION); -+ __func__, path.string().c_str(), reg->api_version, GGML_BACKEND_API_VERSION); ++ __func__, path_to_string(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION); } } return nullptr; } - GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str()); -+ GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path.string().c_str()); ++ GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_to_string(path).c_str()); register_backend(reg, score_fn ? score_fn() : -1, std::move(handle)); -@@ -396,14 +386,14 @@ ggml_backend_t ggml_backend_init_best(void) { +@@ -396,14 +395,14 @@ ggml_backend_t ggml_backend_init_best(void) { // Dynamic loading ggml_backend_reg_t ggml_backend_load(const char * path) { @@ -123,7 +153,7 @@ index 84b21dd8..de78feae 100644 #if defined(__APPLE__) // get executable path std::vector path; -@@ -415,15 +405,9 @@ static std::wstring get_executable_path() { +@@ -415,15 +414,9 @@ static std::wstring get_executable_path() { } path.resize(size); } @@ -141,7 +171,7 @@ index 84b21dd8..de78feae 100644 std::vector path(1024); while (true) { // get executable path -@@ -436,76 +420,56 @@ static std::wstring get_executable_path() { +@@ -436,76 +429,55 @@ static std::wstring get_executable_path() { break; } if (len < (ssize_t) path.size()) { @@ -179,11 +209,11 @@ index 84b21dd8..de78feae 100644 -static std::wstring backend_filename_prefix() { -#ifdef _WIN32 - return L"ggml-"; -+ return std::filesystem::path(path.data()).parent_path(); - #else +-#else - return L"libggml-"; -+ return {}; ++ return std::filesystem::path(path.data()).parent_path(); #endif ++ return {}; } -static std::wstring backend_filename_suffix() { @@ -234,7 +264,7 @@ index 84b21dd8..de78feae 100644 for (const auto & search_path : search_paths) { if (!fs::exists(search_path)) { continue; -@@ -514,31 +478,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, +@@ -514,31 +486,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, for (const auto & entry : dir_it) { try { if (entry.is_regular_file()) { @@ -247,20 +277,20 @@ index 84b21dd8..de78feae 100644 + dl_handle_ptr handle { dl_load_library(entry.path()) }; if (!handle) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); -+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, entry.path().string().c_str()); ++ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str()); continue; } auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); if (!score_fn) { - GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); -+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, entry.path().string().c_str()); ++ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str()); continue; } int s = score_fn(); - GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); -+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, entry.path().string().c_str(), s); ++ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s); if (s > best_score) { best_score = s; - best_path = entry.path().wstring(); @@ -270,11 +300,11 @@ index 84b21dd8..de78feae 100644 } } catch (const std::exception & e) { - GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what()); -+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, entry.path().string().c_str(), e.what()); ++ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what()); } } } -@@ -546,7 +510,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, +@@ -546,7 +518,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, if (best_score == 0) { // try to load the base backend for (const auto & search_path : search_paths) { diff --git a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp index de78feae..e35a6936 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp @@ -66,16 +66,6 @@ #include "ggml-kompute.h" #endif -// disable C++17 deprecation warning for std::codecvt_utf8 -#if defined(__clang__) -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#endif - -#if defined(__clang__) -# pragma clang diagnostic pop -#endif - #ifdef _WIN32 using dl_handle = std::remove_pointer_t; @@ -91,7 +81,7 @@ static dl_handle * dl_load_library(const std::filesystem::path & path) { DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); - HMODULE handle = LoadLibraryW(path.wstring().c_str()); + HMODULE handle = LoadLibraryW(path.c_str()); SetErrorMode(old_mode); @@ -120,7 +110,7 @@ struct dl_handle_deleter { }; static void * dl_load_library(const std::filesystem::path & path) { - dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL); + dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); return handle; } @@ -131,6 +121,25 @@ static void * dl_get_sym(dl_handle * handle, const char * name) { #endif +static std::string path_to_string(const std::filesystem::path & path) +{ +#ifdef _WIN32 + const std::wstring wstr = path.wstring(); + const int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, nullptr, 0, nullptr, nullptr); + if (size_needed <= 0) { + return std::string(); + } + + // size_needed includes the null terminator + std::string str(size_needed - 1, '\0'); + WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, str.data(), size_needed, nullptr, nullptr); + return str; +#else + return path.string(); +#endif +} + + using dl_handle_ptr = std::unique_ptr; struct ggml_backend_reg_entry { @@ -216,7 +225,7 @@ struct ggml_backend_registry { dl_handle_ptr handle { dl_load_library(path) }; if (!handle) { if (!silent) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path.string().c_str()); + GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(path).c_str()); } return nullptr; } @@ -224,7 +233,7 @@ struct ggml_backend_registry { auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); if (score_fn && score_fn() == 0) { if (!silent) { - GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path.string().c_str()); + GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_to_string(path).c_str()); } return nullptr; } @@ -232,7 +241,7 @@ struct ggml_backend_registry { auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init"); if (!backend_init_fn) { if (!silent) { - GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path.string().c_str()); + GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_to_string(path).c_str()); } return nullptr; } @@ -241,16 +250,16 @@ struct ggml_backend_registry { if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) { if (!silent) { if (!reg) { - GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path.string().c_str()); + GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path_to_string(path).c_str()); } else { GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n", - __func__, path.string().c_str(), reg->api_version, GGML_BACKEND_API_VERSION); + __func__, path_to_string(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION); } } return nullptr; } - GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path.string().c_str()); + GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_to_string(path).c_str()); register_backend(reg, score_fn ? score_fn() : -1, std::move(handle)); @@ -432,9 +441,8 @@ static std::filesystem::path get_executable_path() { } return std::filesystem::path(path.data()).parent_path(); -#else - return {}; #endif + return {}; } static std::string backend_filename_prefix() { @@ -483,18 +491,18 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { dl_handle_ptr handle { dl_load_library(entry.path()) }; if (!handle) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, entry.path().string().c_str()); + GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str()); continue; } auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); if (!score_fn) { - GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, entry.path().string().c_str()); + GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str()); continue; } int s = score_fn(); - GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, entry.path().string().c_str(), s); + GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s); if (s > best_score) { best_score = s; best_path = entry.path(); @@ -502,7 +510,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, } } } catch (const std::exception & e) { - GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, entry.path().string().c_str(), e.what()); + GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what()); } } } From 4604b103060a14f36aeb05fbf47e37a404ae98aa Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Mon, 24 Feb 2025 13:11:46 -0800 Subject: [PATCH 20/31] go.mod: bump to go1.24 (#9242) --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 1c99c094..a6107a62 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/ollama/ollama -go 1.23.4 +go 1.24 require ( github.com/containerd/console v1.0.3 From 314573bfe8afd6e93389ec519699da20285a38dc Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Mon, 24 Feb 2025 13:26:35 -0800 Subject: [PATCH 21/31] config: allow setting context length through env var (#8938) * envconfig: allow setting context length through env var --- api/types.go | 4 +++- envconfig/config.go | 3 +++ envconfig/config_test.go | 16 ++++++++++++++++ llm/memory_test.go | 1 + 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/api/types.go b/api/types.go index f4c5b105..637ca204 100644 --- a/api/types.go +++ b/api/types.go @@ -10,6 +10,8 @@ import ( "strconv" "strings" "time" + + "github.com/ollama/ollama/envconfig" ) // StatusError is an error with an HTTP status code and message. @@ -609,7 +611,7 @@ func DefaultOptions() Options { Runner: Runner{ // options set when the model is loaded - NumCtx: 2048, + NumCtx: int(envconfig.ContextLength()), NumBatch: 512, NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically NumThread: 0, // let the runtime decide diff --git a/envconfig/config.go b/envconfig/config.go index d867bdac..6117aa26 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -167,6 +167,8 @@ var ( MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE") // Enable the new Ollama engine NewEngine = Bool("OLLAMA_NEW_ENGINE") + // ContextLength sets the default context length + ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 2048) ) func String(s string) func() string { @@ -252,6 +254,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, + "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 2048)"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, // Informational diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 993ddd9c..385dab5f 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -272,3 +272,19 @@ func TestVar(t *testing.T) { }) } } + +func TestContextLength(t *testing.T) { + cases := map[string]uint{ + "": 2048, + "4096": 4096, + } + + for k, v := range cases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_CONTEXT_LENGTH", k) + if i := ContextLength(); i != v { + t.Errorf("%s: expected %d, got %d", k, v, i) + } + }) + } +} diff --git a/llm/memory_test.go b/llm/memory_test.go index e49d2541..40cc01df 100644 --- a/llm/memory_test.go +++ b/llm/memory_test.go @@ -17,6 +17,7 @@ import ( func TestEstimateGPULayers(t *testing.T) { t.Setenv("OLLAMA_DEBUG", "1") t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16 + t.Setenv("OLLAMA_CONTEXT_LENGTH", "2048") modelName := "dummy" f, err := os.CreateTemp(t.TempDir(), modelName) From 0b7e1676eb2d3214c9a0d3ea4de932071956cf43 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Mon, 24 Feb 2025 17:19:01 -0800 Subject: [PATCH 22/31] sample: add sampling package for new engine (#8410) --- runner/ollamarunner/runner.go | 63 ++++----- sample/greedy.go | 13 -- sample/sample.go | 74 ----------- sample/samplers.go | 139 ++++++++++++++++++++ sample/samplers_test.go | 238 ++++++++++++++++++++++++++++++++++ sample/transforms.go | 120 +++++++++++++++++ sample/transforms_test.go | 80 ++++++++++++ 7 files changed, 600 insertions(+), 127 deletions(-) delete mode 100644 sample/greedy.go delete mode 100644 sample/sample.go create mode 100644 sample/samplers.go create mode 100644 sample/samplers_test.go create mode 100644 sample/transforms.go create mode 100644 sample/transforms_test.go diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index d11eba82..d3998120 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -65,8 +65,8 @@ type Sequence struct { // number of tokens to predict numPredict int - // set of samplers to run on generated logits - samplers []sample.Sampler + // sampler with transforms to run on generated logits + sampler sample.Sampler // channel to send back the embedding if embedding only embedding chan []float32 @@ -93,7 +93,7 @@ type NewSequenceParams struct { numPredict int stop []string numKeep int32 - samplers []sample.Sampler + sampler sample.Sampler embedding bool } @@ -136,7 +136,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen responses: make(chan string, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), - samplers: params.samplers, + sampler: params.sampler, embeddingOnly: params.embedding, stop: params.stop, numKeep: params.numKeep, @@ -393,13 +393,7 @@ func (s *Server) processBatch() error { return fmt.Errorf("failed to decode batch: %w", err) } - f32s := modelOutput.Floats() - - // TODO(jessegross): This will no longer be necessary once the sampling interface takes f32s - logits := make([]float64, len(f32s)) - for i, f32 := range f32s { - logits[i] = float64(f32) - } + logits := modelOutput.Floats() for i, seq := range s.seqs { if seq == nil { @@ -433,14 +427,12 @@ func (s *Server) processBatch() error { } // sample a token - vocabSize := len(f32s) / len(options.Outputs) - tokens, err := sample.Sample(logits[seq.iBatch*vocabSize:(seq.iBatch+1)*vocabSize], seq.samplers...) - if err != nil { - return err - } + vocabSize := len(logits) / len(options.Outputs) - // TODO(jessegross): Sampler will output a single int32 in the future - token := int32(tokens[0]) + token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize]) + if err != nil { + return fmt.Errorf("failed to sample token: %w", err) + } // if it's an end of sequence token, break if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) { @@ -565,27 +557,6 @@ type CompletionResponse struct { Timings Timings `json:"timings"` } -func getSamplers(_ CompletionRequest) []sample.Sampler { - // TODO(jessegross): Waiting for sampling code - - /*samplingParams.TopK = req.TopK - samplingParams.TopP = req.TopP - samplingParams.MinP = req.MinP - samplingParams.TypicalP = req.TypicalP - samplingParams.Temp = req.Temperature - samplingParams.RepeatLastN = req.RepeatLastN - samplingParams.PenaltyRepeat = req.RepeatPenalty - samplingParams.PenaltyFreq = req.FrequencyPenalty - samplingParams.PenaltyPresent = req.PresencePenalty - samplingParams.Mirostat = req.Mirostat - samplingParams.MirostatTau = req.MirostatTau - samplingParams.MirostatEta = req.MirostatEta - samplingParams.Seed = uint32(req.Seed) - samplingParams.Grammar = req.Grammar*/ - - return []sample.Sampler{sample.Greedy()} -} - func (s *Server) completion(w http.ResponseWriter, r *http.Request) { var req CompletionRequest req.Options = Options(api.DefaultOptions()) @@ -604,11 +575,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + sampler, err := sample.NewSampler( + req.Temperature, + req.TopK, + req.TopP, + req.MinP, + req.Seed, + ) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create sampler: %v", err), http.StatusInternalServerError) + return + } + seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ numPredict: req.NumPredict, stop: req.Stop, numKeep: int32(req.NumKeep), - samplers: getSamplers(req), + sampler: sampler, embedding: false, }) if err != nil { diff --git a/sample/greedy.go b/sample/greedy.go deleted file mode 100644 index 206f5544..00000000 --- a/sample/greedy.go +++ /dev/null @@ -1,13 +0,0 @@ -package sample - -import "gonum.org/v1/gonum/floats" - -type greedy struct{} - -func Greedy() Sampler { - return greedy{} -} - -func (s greedy) Sample(t []float64) ([]float64, error) { - return []float64{float64(floats.MaxIdx(t))}, nil -} diff --git a/sample/sample.go b/sample/sample.go deleted file mode 100644 index 44c08cae..00000000 --- a/sample/sample.go +++ /dev/null @@ -1,74 +0,0 @@ -package sample - -import ( - "slices" - - "gonum.org/v1/gonum/floats" - "gonum.org/v1/gonum/stat/sampleuv" -) - -type Sampler interface { - Sample([]float64) ([]float64, error) -} - -type Temperature float64 - -func (s Temperature) Sample(t []float64) ([]float64, error) { - floats.Div(t, slices.Repeat([]float64{float64(s)}, len(t))) - return t, nil -} - -type softmax struct{} - -func Softmax() Sampler { - return softmax{} -} - -func (softmax) Sample(t []float64) ([]float64, error) { - return t, nil -} - -type TopK int - -func (s TopK) Sample(t []float64) ([]float64, error) { - return t, nil -} - -type TopP float32 - -func (s TopP) Sample(t []float64) ([]float64, error) { - return t, nil -} - -type MinP float32 - -func (s MinP) Sample(t []float64) ([]float64, error) { - return t, nil -} - -type weighed struct{} - -func Weighed() Sampler { - return weighed{} -} - -func (s weighed) Sample(t []float64) ([]float64, error) { - w := sampleuv.NewWeighted(t, nil) - if v, ok := w.Take(); ok { - return []float64{float64(v)}, nil - } - - return t, nil -} - -func Sample(floats []float64, samplers ...Sampler) ([]float64, error) { - var err error - for _, sampler := range samplers { - floats, err = sampler.Sample(floats) - if err != nil { - return nil, err - } - } - - return floats, nil -} diff --git a/sample/samplers.go b/sample/samplers.go new file mode 100644 index 00000000..836c6e4d --- /dev/null +++ b/sample/samplers.go @@ -0,0 +1,139 @@ +package sample + +import ( + "errors" + "math" + + "golang.org/x/exp/rand" + "gonum.org/v1/gonum/stat/sampleuv" +) + +type Sampler interface { + Sample([]float32) (int32, error) +} + +type weighted struct { + src rand.Source + transforms []Transform +} + +// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279 +func Weighted(seed *uint64, transforms ...Transform) Sampler { + var src rand.Source + if seed != nil { + src = rand.NewSource(*seed) + } + return weighted{src: src, transforms: transforms} +} + +func (s weighted) Sample(logits []float32) (int32, error) { + logits64 := make([]float64, len(logits)) + for i, v := range logits { + logits64[i] = float64(v) + } + + for _, t := range s.transforms { + logits64 = t.Apply(logits64) + } + + logitsCopy := make([]float64, 0, len(logits)) + indices := make([]int, 0, len(logits)) + for i, logit := range logits64 { + if !math.IsInf(logit, -1) { + logitsCopy = append(logitsCopy, logit) + indices = append(indices, i) + } + } + + if len(logitsCopy) == 0 { + return -1, errors.New("no valid logits found for weighed sampling") + } + + probs := softmax(logitsCopy) + w := sampleuv.NewWeighted(probs, s.src) + if idx, ok := w.Take(); ok { + return int32(indices[idx]), nil + } + return -1, errors.New("weighed sampler failed, no valid token found") +} + +type greedy struct { + transforms []Transform +} + +func Greedy(transforms ...Transform) Sampler { + return greedy{transforms: transforms} +} + +func (s greedy) Sample(logits []float32) (int32, error) { + logits64 := make([]float64, len(logits)) + for i, v := range logits { + logits64[i] = float64(v) + } + + for _, t := range s.transforms { + logits64 = t.Apply(logits64) + } + + var maxIdx int + var maxLogit float64 + for i, logit := range logits64 { + if logit > maxLogit { + maxLogit = logit + maxIdx = i + } + } + + if maxLogit == math.Inf(-1) { + return -1, errors.New("no valid logits found for greedy sampling") + } + + return int32(maxIdx), nil +} + +// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 +func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) { + transforms := []Transform{} + if temperature < 0 || temperature > 2 { + return nil, errors.New("temperature must be between 0 and 2") + } + + if temperature != 0 { + transforms = append(transforms, Temperature(temperature)) + } + + if topK != 0 { + if topK <= 0 { + return nil, errors.New("topK must be greater than 0") + } + transforms = append(transforms, TopK(topK)) + } + + if topP != 0 { + if topP < 0 || topP >= 1 { + return nil, errors.New("topP must be between 0 and 1") + } + transforms = append(transforms, TopP(topP)) + } + + if minP != 0 { + if minP < 0 || minP >= 1 { + return nil, errors.New("minP must be between 0 and 1") + } + transforms = append(transforms, MinP(minP)) + } + + if len(transforms) == 0 { + return nil, errors.New("at least one transform is required") + } + + if temperature == 0 { + return Greedy(transforms...), nil + } + + if seed != 0 { + seed64 := uint64(seed) + return Weighted(&seed64, transforms...), nil + } + return Weighted(nil, transforms...), nil +} diff --git a/sample/samplers_test.go b/sample/samplers_test.go new file mode 100644 index 00000000..aaa8d99c --- /dev/null +++ b/sample/samplers_test.go @@ -0,0 +1,238 @@ +package sample + +import ( + "math" + "math/rand/v2" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestWeighted(t *testing.T) { + got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))}) + if err != nil { + t.Error(err) + return + } + want := int32(1) + if want != got { + t.Errorf("index mismatch: want %d, got %d", want, got) + } + + got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))}) + if err == nil { + t.Error("expected error for no valid tokens, got index", got) + } + + seed := uint64(42) + got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4}) + if err != nil { + t.Error(err) + return + } + // With seed 42, we expect a consistent sample + want = int32(3) // This will be deterministic due to the seed + if want != got { + t.Errorf("index mismatch: want %d, got %d", want, got) + } +} + +type testTransform struct { + id int + callOrder *[]int +} + +func (ts *testTransform) Apply(logits []float64) []float64 { + if ts.callOrder != nil { + *ts.callOrder = append(*ts.callOrder, ts.id) + } + return logits +} + +func TestSample(t *testing.T) { + input := []float32{1, 2, 3, 4} + + var callOrder []int + mock1 := &testTransform{ + id: 1, + callOrder: &callOrder, + } + mock2 := &testTransform{ + id: 2, + callOrder: &callOrder, + } + mock3 := &testTransform{ + id: 3, + callOrder: &callOrder, + } + + got, err := Greedy(mock1, mock2, mock3).Sample(input) + if err != nil { + t.Error(err) + return + } + + want := int32(3) // Greedy sampler should pick highest logit + if want != got { + t.Errorf("index mismatch: want %d, got %d", want, got) + } + wantOrder := []int{1, 2, 3} + if diff := cmp.Diff(wantOrder, callOrder); diff != "" { + t.Errorf("call order mismatch (-want +got):\n%s", diff) + } + + callOrder = nil + + _, err = Weighted(nil, mock1, mock2, mock3).Sample(input) + if err != nil { + t.Error(err) + return + } + wantOrder = []int{1, 2, 3} + if diff := cmp.Diff(wantOrder, callOrder); diff != "" { + t.Errorf("call order mismatch (-want +got):\n%s", diff) + } +} + +func TestNewSampler(t *testing.T) { + tests := []struct { + name string + temperature float32 + topK int + topP float32 + minP float32 + seed int + wantErr bool + }{ + { + name: "no transforms", + wantErr: true, + }, + { + name: "temperature", + temperature: 0.5, + wantErr: false, + }, + { + name: "invalid temperature negative", + temperature: -1, + wantErr: true, + }, + { + name: "invalid temperature too high", + temperature: 2.1, + wantErr: true, + }, + { + name: "top k", + topK: 10, + wantErr: false, + }, + { + name: "invalid top k negative", + topK: -1, + wantErr: true, + }, + { + name: "top p", + topP: 0.9, + wantErr: false, + }, + { + name: "invalid top p negative", + topP: -0.1, + wantErr: true, + }, + { + name: "invalid top p one", + topP: 1.0, + wantErr: true, + }, + { + name: "min p", + minP: 0.2, + wantErr: false, + }, + { + name: "invalid min p negative", + minP: -0.1, + wantErr: true, + }, + { + name: "invalid min p one", + minP: 1.0, + wantErr: true, + }, + { + name: "seed", + seed: 42, + wantErr: true, // seed alone is not valid without other transforms + }, + { + name: "default values", + temperature: 0.8, + topK: 40, + topP: 0.9, + minP: 0.0, + seed: 0, + wantErr: false, + }, + { + name: "all zeroes", + temperature: 0.0, + topK: 0, + topP: 0.0, + minP: 0.0, + seed: 0, + wantErr: true, // all zeroes means no transforms + }, + { + name: "all transforms", + temperature: 0.8, + topK: 50, + topP: 0.95, + minP: 0.1, + seed: 42, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed) + if (err != nil) != tt.wantErr { + t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func BenchmarkSample(b *testing.B) { + transforms := []Transform{ + Temperature(0.5), + TopK(10), + TopP(0.9), + MinP(0.2), + } + + samplers := map[string]Sampler{ + "Greedy": Greedy(transforms...), + "Weighted": Weighted(nil, transforms...), + } + + logits := make([]float32, 1<<16) + for i := range logits { + logits[i] = rand.Float32() + } + + for name, s := range samplers { + b.Run(name, func(b *testing.B) { + b.ResetTimer() + for range b.N { + if _, err := s.Sample(logits); err != nil { + b.Error(err) + } + } + }) + } +} diff --git a/sample/transforms.go b/sample/transforms.go new file mode 100644 index 00000000..2dc6ebae --- /dev/null +++ b/sample/transforms.go @@ -0,0 +1,120 @@ +package sample + +import ( + "cmp" + "math" + "slices" + + pq "github.com/emirpasic/gods/v2/queues/priorityqueue" +) + +type Transform interface { + Apply([]float64) []float64 +} + +// TODO(parthsareen): potentially cache softmax values +func softmax(logits []float64) []float64 { + var sum float64 + probs := make([]float64, len(logits)) + for i, v := range logits { + probs[i] = math.Exp(v) + sum += probs[i] + } + + for i := range probs { + probs[i] /= sum + } + + return probs +} + +type Temperature float64 + +func (t Temperature) Apply(logits []float64) []float64 { + temp := math.Max(float64(t), 1e-7) + + // subtracting max logit to avoid under/overflow + maxLogit := slices.Max(logits) + for i := range logits { + logits[i] = (logits[i] - maxLogit) / temp + } + + return logits +} + +type logitMap struct { + index int + logit float64 +} + +type TopK int + +// TODO(parthsareen): avoid having to check all logits after this transform +func (k TopK) Apply(logits []float64) []float64 { + if int(k) >= len(logits) { + return logits + } + q := pq.NewWith(func(a, b logitMap) int { + return -cmp.Compare(a.logit, b.logit) + }) + + for i, logit := range logits { + q.Enqueue(logitMap{index: i, logit: logit}) + } + + validLogits := make(map[int]float64) + for range k { + logitMap, _ := q.Dequeue() + validLogits[logitMap.index] = logitMap.logit + } + + for i := range logits { + if _, ok := validLogits[i]; !ok { + logits[i] = math.Inf(-1) + } + } + + return logits +} + +type TopP float64 + +func (p TopP) Apply(logits []float64) []float64 { + probs := softmax(logits) + indices := make([]int, len(probs)) + for i := range indices { + indices[i] = i + } + + // sort in descending order + slices.SortFunc(indices, func(i, j int) int { + return cmp.Compare(probs[j], probs[i]) + }) + + var sum float64 + for i, idx := range indices { + sum += probs[idx] + if sum > float64(p) { + for _, idx := range indices[i+1:] { + logits[idx] = math.Inf(-1) + } + break + } + } + return logits +} + +type MinP float64 + +func (p MinP) Apply(logits []float64) []float64 { + probs := softmax(logits) + threshold := slices.Max(probs) * float64(p) + + for i, prob := range probs { + if prob < threshold { + logits[i] = math.Inf(-1) + } + } + + return logits +} diff --git a/sample/transforms_test.go b/sample/transforms_test.go new file mode 100644 index 00000000..05f76a27 --- /dev/null +++ b/sample/transforms_test.go @@ -0,0 +1,80 @@ +package sample + +import ( + "math" + "math/rand/v2" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestTemperature(t *testing.T) { + got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0}) + want := []float64{-4, -10, 0, -14, -6, -12, -8} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) + } +} + +func TestSoftmax(t *testing.T) { + got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4}) + + want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("probs mismatch (-want +got):\n%s", diff) + } +} + +func TestTopK(t *testing.T) { + got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) + want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) + } + + got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) + + want = []float64{-3, -2, -1, 0, 1, 2, 4} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) + } +} + +func TestTopP(t *testing.T) { + got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) + want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) + } +} + +func TestMinP(t *testing.T) { + got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3}) + want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) + } +} + +func BenchmarkTransform(b *testing.B) { + transforms := map[string]Transform{ + "Temperature": Temperature(0.5), + "TopK": TopK(10), + "TopP": TopP(0.9), + "MinP": MinP(0.2), + } + + logits := make([]float64, 1<<16) + for i := range logits { + logits[i] = rand.Float64() + } + + for name, transform := range transforms { + b.Run(name, func(b *testing.B) { + b.ResetTimer() + for range b.N { + transform.Apply(logits) + } + }) + } +} From 348b3e0983c76263008a8dfbbc23e2449107f6d1 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Mon, 24 Feb 2025 22:39:44 -0800 Subject: [PATCH 23/31] server/internal: copy bmizerany/ollama-go to internal package (#9294) This commit copies (without history) the bmizerany/ollama-go repository with the intention of integrating it into the ollama as a replacement for the pushing, and pulling of models, and management of the cache they are pushed and pulled from. New homes for these packages will be determined as they are integrated and we have a better understanding of proper package boundaries. --- .github/workflows/test.yaml | 1 + .gitignore | 2 +- .golangci.yaml | 8 +- server/internal/cache/blob/cache.go | 544 ++++++++++++ server/internal/cache/blob/cache_test.go | 685 +++++++++++++++ server/internal/cache/blob/casecheck_test.go | 93 ++ server/internal/cache/blob/digest.go | 95 +++ server/internal/cache/blob/digest_test.go | 63 ++ server/internal/chunks/chunks.go | 78 ++ server/internal/chunks/chunks_test.go | 65 ++ server/internal/client/ollama/registry.go | 802 ++++++++++++++++++ .../internal/client/ollama/registry_test.go | 656 ++++++++++++++ server/internal/client/ollama/trace.go | 48 ++ .../opp/internal/safetensors/safetensors.go | 220 +++++ server/internal/cmd/opp/opp.go | 366 ++++++++ server/internal/cmd/oppbench/oppbench.go | 11 + server/internal/cmd/oppbench/oppbench_test.go | 107 +++ server/internal/internal/backoff/backoff.go | 48 ++ .../internal/backoff/backoff_synctest_test.go | 40 + .../internal/internal/backoff/backoff_test.go | 38 + server/internal/internal/names/name.go | 229 +++++ server/internal/internal/names/name_test.go | 152 ++++ server/internal/internal/stringsx/stringsx.go | 52 ++ .../internal/stringsx/stringsx_test.go | 78 ++ server/internal/internal/syncs/line.go | 201 +++++ server/internal/internal/syncs/line_test.go | 65 ++ server/internal/internal/syncs/syncs.go | 41 + server/internal/internal/testutil/testutil.go | 74 ++ server/internal/manifest/manifest.go | 118 +++ 29 files changed, 4974 insertions(+), 6 deletions(-) create mode 100644 server/internal/cache/blob/cache.go create mode 100644 server/internal/cache/blob/cache_test.go create mode 100644 server/internal/cache/blob/casecheck_test.go create mode 100644 server/internal/cache/blob/digest.go create mode 100644 server/internal/cache/blob/digest_test.go create mode 100644 server/internal/chunks/chunks.go create mode 100644 server/internal/chunks/chunks_test.go create mode 100644 server/internal/client/ollama/registry.go create mode 100644 server/internal/client/ollama/registry_test.go create mode 100644 server/internal/client/ollama/trace.go create mode 100644 server/internal/cmd/opp/internal/safetensors/safetensors.go create mode 100644 server/internal/cmd/opp/opp.go create mode 100644 server/internal/cmd/oppbench/oppbench.go create mode 100644 server/internal/cmd/oppbench/oppbench_test.go create mode 100644 server/internal/internal/backoff/backoff.go create mode 100644 server/internal/internal/backoff/backoff_synctest_test.go create mode 100644 server/internal/internal/backoff/backoff_test.go create mode 100644 server/internal/internal/names/name.go create mode 100644 server/internal/internal/names/name_test.go create mode 100644 server/internal/internal/stringsx/stringsx.go create mode 100644 server/internal/internal/stringsx/stringsx_test.go create mode 100644 server/internal/internal/syncs/line.go create mode 100644 server/internal/internal/syncs/line_test.go create mode 100644 server/internal/internal/syncs/syncs.go create mode 100644 server/internal/internal/testutil/testutil.go create mode 100644 server/internal/manifest/manifest.go diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 8af8812f..56a2cc4f 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -147,6 +147,7 @@ jobs: runs-on: ${{ matrix.os }} env: CGO_ENABLED: '1' + GOEXPERIMENT: 'synctest' steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 diff --git a/.gitignore b/.gitignore index 551abec8..3a2af0bd 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,6 @@ .swp dist build -ollama .cache *.exe .idea @@ -14,3 +13,4 @@ test_data __debug_bin* llama/build llama/vendor +/ollama diff --git a/.golangci.yaml b/.golangci.yaml index 9d59fd6c..9bb9786a 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -6,8 +6,6 @@ linters: - bidichk - bodyclose - containedctx - - contextcheck - - errcheck - gocheckcompilerdirectives - gofmt - gofumpt @@ -23,10 +21,11 @@ linters: - staticcheck - tenv - unconvert - - unused - - usestdlibvars - wastedassign - whitespace + disable: + - usestdlibvars + - errcheck linters-settings: staticcheck: checks: @@ -39,5 +38,4 @@ severity: - gofmt - goimports - intrange - - usestdlibvars severity: info diff --git a/server/internal/cache/blob/cache.go b/server/internal/cache/blob/cache.go new file mode 100644 index 00000000..f0b0760f --- /dev/null +++ b/server/internal/cache/blob/cache.go @@ -0,0 +1,544 @@ +// Package blob implements a content-addressable disk cache for blobs and +// manifests. +package blob + +import ( + "bytes" + "crypto/sha256" + "errors" + "fmt" + "hash" + "io" + "io/fs" + "iter" + "os" + "path/filepath" + "strings" + "time" + + "github.com/ollama/ollama/server/internal/internal/names" +) + +// Entry contains metadata about a blob in the cache. +type Entry struct { + Digest Digest + Size int64 + Time time.Time // when added to the cache +} + +// DiskCache caches blobs and manifests on disk. +// +// The cache is rooted at a directory, which is created if it does not exist. +// +// Blobs are stored in the "blobs" subdirectory, and manifests are stored in the +// "manifests" subdirectory. A example directory structure might look like: +// +// / +// blobs/ +// sha256- - +// manifests/ +// / +// / +// / +// - +// +// The cache is safe for concurrent use. +// +// Name casing is preserved in the cache, but is not significant when resolving +// names. For example, "Foo" and "foo" are considered the same name. +// +// The cache is not safe for concurrent use. It guards concurrent writes, but +// does not prevent duplicated effort. Because blobs are immutable, duplicate +// writes should result in the same file being written to disk. +type DiskCache struct { + // Dir specifies the top-level directory where blobs and manifest + // pointers are stored. + dir string + now func() time.Time + + testHookBeforeFinalWrite func(f *os.File) +} + +// PutString is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))). +func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error { + return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data))) +} + +// Open opens a cache rooted at the given directory. If the directory does not +// exist, it is created. If the directory is not a directory, an error is +// returned. +func Open(dir string) (*DiskCache, error) { + if dir == "" { + return nil, errors.New("blob: empty directory name") + } + + info, err := os.Stat(dir) + if err == nil && !info.IsDir() { + return nil, fmt.Errorf("%q is not a directory", dir) + } + if err := os.MkdirAll(dir, 0o777); err != nil { + return nil, err + } + + subdirs := []string{"blobs", "manifests"} + for _, subdir := range subdirs { + if err := os.MkdirAll(filepath.Join(dir, subdir), 0o777); err != nil { + return nil, err + } + } + + // TODO(bmizerany): support shards + c := &DiskCache{ + dir: dir, + now: time.Now, + } + return c, nil +} + +func readAndSum(filename string, limit int64) (data []byte, _ Digest, err error) { + f, err := os.Open(filename) + if err != nil { + return nil, Digest{}, err + } + defer f.Close() + + h := sha256.New() + r := io.TeeReader(f, h) + data, err = io.ReadAll(io.LimitReader(r, limit)) + if err != nil { + return nil, Digest{}, err + } + var d Digest + h.Sum(d.sum[:0]) + return data, d, nil +} + +//lint:ignore U1000 used for debugging purposes as needed in tests +var debug = false + +// debugger returns a function that can be used to add a step to the error message. +// The error message will be a list of steps that were taken before the error occurred. +// The steps are added in the order they are called. +// +// To set the error message, call the returned function with an empty string. +// +//lint:ignore U1000 used for debugging purposes as needed in tests +func debugger(err *error) func(step string) { + if !debug { + return func(string) {} + } + var steps []string + return func(step string) { + if step == "" && *err != nil { + *err = fmt.Errorf("%q: %w", steps, *err) + return + } + steps = append(steps, step) + if len(steps) > 100 { + // shift hints in case of a bug that causes a lot of hints + copy(steps, steps[1:]) + steps = steps[:100] + } + } +} + +// Resolve resolves a name to a digest. The name is expected to +// be in either of the following forms: +// +// @ +// +// +// +// If a digest is provided, it is returned as is and nothing else happens. +// +// If a name is provided for a manifest that exists in the cache, the digest +// of the manifest is returned. If there is no manifest in the cache, it +// returns [fs.ErrNotExist]. +// +// To cover the case where a manifest may change without the cache knowing +// (e.g. it was reformatted or modified by hand), the manifest data read and +// hashed is passed to a PutBytes call to ensure that the manifest is in the +// blob store. This is done to ensure that future calls to [Get] succeed in +// these cases. +// +// TODO(bmizerany): Move Links/Resolve/etc. out of this package. +func (c *DiskCache) Resolve(name string) (Digest, error) { + name, digest := splitNameDigest(name) + if digest != "" { + return ParseDigest(digest) + } + + // We want to address manifests files by digest using Get. That requires + // them to be blobs. This cannot be directly accomplished by looking in + // the blob store because manifests can change without Ollama knowing + // (e.g. a user modifies a manifests by hand then pushes it to update + // their model). We also need to support the blob caches inherited from + // older versions of Ollama, which do not store manifests in the blob + // store, so for these cases, we need to handle adding the manifests to + // the blob store, just in time. + // + // So now we read the manifests file, hash it, and copy it to the blob + // store if it's not already there. + // + // This should be cheap because manifests are small, and accessed + // infrequently. + file, err := c.manifestPath(name) + if err != nil { + return Digest{}, err + } + + data, d, err := readAndSum(file, 1<<20) + if err != nil { + return Digest{}, err + } + + // Ideally we'd read the "manifest" file as a manifest to the blob file, + // but we are not changing this yet, so copy the manifest to the blob + // store so it can be addressed by digest subsequent calls to Get. + if err := PutBytes(c, d, data); err != nil { + return Digest{}, err + } + return d, nil +} + +// Put writes a new blob to the cache, identified by its digest. The operation +// reads content from r, which must precisely match both the specified size and +// digest. +// +// Concurrent write safety is achieved through file locking. The implementation +// guarantees write integrity by enforcing size limits and content validation +// before allowing the file to reach its final state. +func (c *DiskCache) Put(d Digest, r io.Reader, size int64) error { + return c.copyNamedFile(c.GetFile(d), r, d, size) +} + +// Import imports a blob from the provided reader into the cache. It reads the +// entire content of the reader, calculates its digest, and stores it in the +// cache. +// +// Import should be considered unsafe for use with untrusted data, such as data +// read from a network. The caller is responsible for ensuring the integrity of +// the data being imported. +func (c *DiskCache) Import(r io.Reader, size int64) (Digest, error) { + // users that want to change the temp dir can set TEMPDIR. + f, err := os.CreateTemp("", "blob-") + if err != nil { + return Digest{}, err + } + defer os.Remove(f.Name()) + + // Copy the blob to a temporary file. + h := sha256.New() + r = io.TeeReader(r, h) + n, err := io.Copy(f, r) + if err != nil { + return Digest{}, err + } + if n != size { + return Digest{}, fmt.Errorf("blob: expected %d bytes, got %d", size, n) + } + + // Check the digest. + var d Digest + h.Sum(d.sum[:0]) + if err := f.Close(); err != nil { + return Digest{}, err + } + name := c.GetFile(d) + // Rename the temporary file to the final file. + if err := os.Rename(f.Name(), name); err != nil { + return Digest{}, err + } + os.Chtimes(name, c.now(), c.now()) // mainly for tests + return d, nil +} + +// Get retrieves a blob from the cache using the provided digest. The operation +// fails if the digest is malformed or if any errors occur during blob +// retrieval. +func (c *DiskCache) Get(d Digest) (Entry, error) { + name := c.GetFile(d) + info, err := os.Stat(name) + if err != nil { + return Entry{}, err + } + if info.Size() == 0 { + return Entry{}, fs.ErrNotExist + } + return Entry{ + Digest: d, + Size: info.Size(), + Time: info.ModTime(), + }, nil +} + +// Link creates a symbolic reference in the cache that maps the provided name +// to a blob identified by its digest, making it retrievable by name using +// [Resolve]. +// +// It returns an error if either the name or digest is invalid, or if link +// creation encounters any issues. +func (c *DiskCache) Link(name string, d Digest) error { + manifest, err := c.manifestPath(name) + if err != nil { + return err + } + f, err := os.OpenFile(c.GetFile(d), os.O_RDONLY, 0) + if err != nil { + return err + } + defer f.Close() + + // TODO(bmizerany): test this happens only if the blob was found to + // avoid leaving debris + if err := os.MkdirAll(filepath.Dir(manifest), 0o777); err != nil { + return err + } + + info, err := f.Stat() + if err != nil { + return err + } + + // Copy manifest to cache directory. + return c.copyNamedFile(manifest, f, d, info.Size()) +} + +// Unlink removes the any link for name. If the link does not exist, nothing +// happens, and no error is returned. +// +// It returns an error if the name is invalid or if the link removal encounters +// any issues. +func (c *DiskCache) Unlink(name string) error { + manifest, err := c.manifestPath(name) + if err != nil { + return err + } + err = os.Remove(manifest) + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return err +} + +// GetFile returns the absolute path to the file, in the cache, for the given +// digest. It does not check if the file exists. +// +// The returned path should not be stored, used outside the lifetime of the +// cache, or interpreted in any way. +func (c *DiskCache) GetFile(d Digest) string { + filename := fmt.Sprintf("sha256-%x", d.sum) + return absJoin(c.dir, "blobs", filename) +} + +// Links returns a sequence of links in the cache in lexical order. +func (c *DiskCache) Links() iter.Seq2[string, error] { + return func(yield func(string, error) bool) { + for path, err := range c.links() { + if err != nil { + yield("", err) + return + } + if !yield(pathToName(path), nil) { + return + } + } + } +} + +// pathToName converts a path to a name. It is the inverse of nameToPath. The +// path is assumed to be in filepath.ToSlash format. +func pathToName(s string) string { + s = strings.TrimPrefix(s, "manifests/") + rr := []rune(s) + for i := len(rr) - 1; i > 0; i-- { + if rr[i] == '/' { + rr[i] = ':' + return string(rr) + } + } + return s +} + +// manifestPath finds the first manifest file on disk that matches the given +// name using a case-insensitive comparison. If no manifest file is found, it +// returns the path where the manifest file would be if it existed. +// +// If two manifest files exists on disk that match the given name using a +// case-insensitive comparison, the one that sorts first, lexically, is +// returned. +func (c *DiskCache) manifestPath(name string) (string, error) { + np, err := nameToPath(name) + if err != nil { + return "", err + } + + maybe := filepath.Join("manifests", np) + for l, err := range c.links() { + if err != nil { + return "", err + } + if strings.EqualFold(maybe, l) { + return filepath.Join(c.dir, l), nil + } + } + return filepath.Join(c.dir, maybe), nil +} + +// links returns a sequence of links in the cache in lexical order. +func (c *DiskCache) links() iter.Seq2[string, error] { + // TODO(bmizerany): reuse empty dirnames if exist + return func(yield func(string, error) bool) { + fsys := os.DirFS(c.dir) + manifests, err := fs.Glob(fsys, "manifests/*/*/*/*") + if err != nil { + yield("", err) + return + } + for _, manifest := range manifests { + if !yield(manifest, nil) { + return + } + } + } +} + +type checkWriter struct { + d Digest + size int64 + n int64 + h hash.Hash + f *os.File + err error + + testHookBeforeFinalWrite func(*os.File) +} + +func (w *checkWriter) seterr(err error) error { + if w.err == nil { + w.err = err + } + return err +} + +// Write writes p to the underlying hash and writer. The last write to the +// underlying writer is guaranteed to be the last byte of p as verified by the +// hash. +func (w *checkWriter) Write(p []byte) (int, error) { + _, err := w.h.Write(p) + if err != nil { + return 0, w.seterr(err) + } + nextSize := w.n + int64(len(p)) + if nextSize == w.size { + // last write. check hash. + sum := w.h.Sum(nil) + if !bytes.Equal(sum, w.d.sum[:]) { + return 0, w.seterr(fmt.Errorf("file content changed underfoot")) + } + if w.testHookBeforeFinalWrite != nil { + w.testHookBeforeFinalWrite(w.f) + } + } + if nextSize > w.size { + return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size)) + } + n, err := w.f.Write(p) + w.n += int64(n) + return n, w.seterr(err) +} + +// copyNamedFile copies file into name, expecting it to have the given Digest +// and size, if that file is not present already. +func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size int64) error { + info, err := os.Stat(name) + if err == nil && info.Size() == size { + // File already exists with correct size. This is good enough. + // We can skip expensive hash checks. + // + // TODO: Do the hash check, but give caller a way to skip it. + return nil + } + + // Copy file to cache directory. + mode := os.O_RDWR | os.O_CREATE + if err == nil && info.Size() > size { // shouldn't happen but fix in case + mode |= os.O_TRUNC + } + f, err := os.OpenFile(name, mode, 0o666) + if err != nil { + return err + } + defer f.Close() + if size == 0 { + // File now exists with correct size. + // Only one possible zero-length file, so contents are OK too. + // Early return here makes sure there's a "last byte" for code below. + return nil + } + + // From here on, if any of the I/O writing the file fails, + // we make a best-effort attempt to truncate the file f + // before returning, to avoid leaving bad bytes in the file. + + // Copy file to f, but also into h to double-check hash. + cw := &checkWriter{ + d: out, + size: size, + h: sha256.New(), + f: f, + testHookBeforeFinalWrite: c.testHookBeforeFinalWrite, + } + n, err := io.Copy(cw, file) + if err != nil { + f.Truncate(0) + return err + } + if n < size { + f.Truncate(0) + return io.ErrUnexpectedEOF + } + + if err := f.Close(); err != nil { + // Data might not have been written, + // but file may look like it is the right size. + // To be extra careful, remove cached file. + os.Remove(name) + return err + } + os.Chtimes(name, c.now(), c.now()) // mainly for tests + + return nil +} + +func splitNameDigest(s string) (name, digest string) { + i := strings.LastIndexByte(s, '@') + if i < 0 { + return s, "" + } + return s[:i], s[i+1:] +} + +var errInvalidName = errors.New("invalid name") + +func nameToPath(name string) (_ string, err error) { + if strings.Contains(name, "@") { + // TODO(bmizerany): HACK: Fix names.Parse to validate. + // TODO(bmizerany): merge with default parts (maybe names.Merge(a, b)) + return "", errInvalidName + } + n := names.Parse(name) + if !n.IsFullyQualified() { + return "", errInvalidName + } + return filepath.Join(n.Host(), n.Namespace(), n.Model(), n.Tag()), nil +} + +func absJoin(pp ...string) string { + abs, err := filepath.Abs(filepath.Join(pp...)) + if err != nil { + // Likely a bug bug or a bad OS problem. Just panic. + panic(err) + } + return abs +} diff --git a/server/internal/cache/blob/cache_test.go b/server/internal/cache/blob/cache_test.go new file mode 100644 index 00000000..704542ea --- /dev/null +++ b/server/internal/cache/blob/cache_test.go @@ -0,0 +1,685 @@ +package blob + +import ( + "crypto/sha256" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "slices" + "strings" + "testing" + "time" + + "github.com/ollama/ollama/server/internal/internal/testutil" +) + +func init() { + debug = true +} + +var epoch = func() time.Time { + d := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC) + if d.IsZero() { + panic("time zero") + } + return d +}() + +func TestOpenErrors(t *testing.T) { + exe, err := os.Executable() + if err != nil { + panic(err) + } + + cases := []struct { + dir string + err string + }{ + {t.TempDir(), ""}, + {"", "empty directory name"}, + {exe, "not a directory"}, + } + + for _, tt := range cases { + t.Run(tt.dir, func(t *testing.T) { + _, err := Open(tt.dir) + if tt.err == "" { + if err != nil { + t.Fatal(err) + } + return + } + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), tt.err) { + t.Fatalf("err = %v, want %q", err, tt.err) + } + }) + } +} + +func TestGetFile(t *testing.T) { + t.Chdir(t.TempDir()) + + c, err := Open(".") + if err != nil { + t.Fatal(err) + } + + d := mkdigest("1") + got := c.GetFile(d) + cleaned := filepath.Clean(got) + if cleaned != got { + t.Fatalf("got is unclean: %q", got) + } + if !filepath.IsAbs(got) { + t.Fatal("got is not absolute") + } + abs, _ := filepath.Abs(c.dir) + if !strings.HasPrefix(got, abs) { + t.Fatalf("got is not local to %q", c.dir) + } +} + +func TestBasic(t *testing.T) { + c, err := Open(t.TempDir()) + if err != nil { + t.Fatal(err) + } + now := epoch + c.now = func() time.Time { return now } + + checkEntry := entryChecker(t, c) + checkFailed := func(err error) { + if err == nil { + t.Helper() + t.Fatal("expected error") + } + } + + _, err = c.Resolve("invalid") + checkFailed(err) + + _, err = c.Resolve("h/n/m:t") + checkFailed(err) + + dx := mkdigest("x") + + d, err := c.Resolve(fmt.Sprintf("h/n/m:t@%s", dx)) + if err != nil { + t.Fatal(err) + } + if d != dx { + t.Fatalf("d = %v, want %v", d, dx) + } + + _, err = c.Get(Digest{}) + checkFailed(err) + + // not committed yet + _, err = c.Get(dx) + checkFailed(err) + + err = PutBytes(c, dx, "!") + checkFailed(err) + + err = PutBytes(c, dx, "x") + if err != nil { + t.Fatal(err) + } + checkEntry(dx, 1, now) + + t0 := now + now = now.Add(1*time.Hour + 1*time.Minute) + err = PutBytes(c, dx, "x") + if err != nil { + t.Fatal(err) + } + + // check not updated + checkEntry(dx, 1, t0) +} + +type sleepFunc func(d time.Duration) time.Time + +func openTester(t *testing.T) (*DiskCache, sleepFunc) { + t.Helper() + c, err := Open(t.TempDir()) + if err != nil { + t.Fatal(err) + } + now := epoch + c.now = func() time.Time { return now } + return c, func(d time.Duration) time.Time { + now = now.Add(d) + return now + } +} + +func TestManifestPath(t *testing.T) { + check := testutil.Checker(t) + + c, sleep := openTester(t) + + d1 := mkdigest("1") + err := PutBytes(c, d1, "1") + check(err) + + err = c.Link("h/n/m:t", d1) + check(err) + + t0 := sleep(0) + sleep(1 * time.Hour) + err = c.Link("h/n/m:t", d1) // nop expected + check(err) + + file := must(c.manifestPath("h/n/m:t")) + info, err := os.Stat(file) + check(err) + testutil.CheckTime(t, info.ModTime(), t0) +} + +func TestManifestExistsWithoutBlob(t *testing.T) { + t.Chdir(t.TempDir()) + + check := testutil.Checker(t) + + c, err := Open(".") + check(err) + + checkEntry := entryChecker(t, c) + + man := must(c.manifestPath("h/n/m:t")) + os.MkdirAll(filepath.Dir(man), 0o777) + testutil.WriteFile(t, man, "1") + + got, err := c.Resolve("h/n/m:t") + check(err) + + want := mkdigest("1") + if got != want { + t.Fatalf("got = %v, want %v", got, want) + } + + e, err := c.Get(got) + check(err) + checkEntry(got, 1, e.Time) +} + +func TestPut(t *testing.T) { + c, sleep := openTester(t) + + check := testutil.Checker(t) + checkEntry := entryChecker(t, c) + + d := mkdigest("hello, world") + err := PutBytes(c, d, "hello") + if err == nil { + t.Fatal("expected error") + } + + got, err := c.Get(d) + if !errors.Is(err, fs.ErrNotExist) { + t.Fatalf("expected error, got %v", got) + } + + // Put a valid blob + err = PutBytes(c, d, "hello, world") + check(err) + checkEntry(d, 12, sleep(0)) + + // Put a blob with content that does not hash to the digest + err = PutBytes(c, d, "hello") + if err == nil { + t.Fatal("expected error") + } + checkNotExists(t, c, d) + + // Put the valid blob back and check it + err = PutBytes(c, d, "hello, world") + check(err) + checkEntry(d, 12, sleep(0)) + + // Put a blob that errors during Read + err = c.Put(d, &errOnBangReader{s: "!"}, 1) + if err == nil { + t.Fatal("expected error") + } + checkNotExists(t, c, d) + + // Put valid blob back and check it + err = PutBytes(c, d, "hello, world") + check(err) + checkEntry(d, 12, sleep(0)) + + // Put a blob with mismatched size + err = c.Put(d, strings.NewReader("hello, world"), 11) + if err == nil { + t.Fatal("expected error") + } + checkNotExists(t, c, d) + + // Final byte does not match the digest (testing commit phase) + err = PutBytes(c, d, "hello, world$") + if err == nil { + t.Fatal("expected error") + } + checkNotExists(t, c, d) + + reset := c.setTestHookBeforeFinalWrite(func(f *os.File) { + // change mode to read-only + f.Truncate(0) + f.Chmod(0o400) + f.Close() + f1, err := os.OpenFile(f.Name(), os.O_RDONLY, 0) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { f1.Close() }) + *f = *f1 + }) + defer reset() + + err = PutBytes(c, d, "hello, world") + if err == nil { + t.Fatal("expected error") + } + checkNotExists(t, c, d) + reset() +} + +func TestImport(t *testing.T) { + c, _ := openTester(t) + + checkEntry := entryChecker(t, c) + + want := mkdigest("x") + got, err := c.Import(strings.NewReader("x"), 1) + if err != nil { + t.Fatal(err) + } + if want != got { + t.Fatalf("digest = %v, want %v", got, want) + } + checkEntry(want, 1, epoch) + + got, err = c.Import(strings.NewReader("x"), 1) + if err != nil { + t.Fatal(err) + } + if want != got { + t.Fatalf("digest = %v, want %v", got, want) + } + checkEntry(want, 1, epoch) +} + +func (c *DiskCache) setTestHookBeforeFinalWrite(h func(*os.File)) (reset func()) { + old := c.testHookBeforeFinalWrite + c.testHookBeforeFinalWrite = h + return func() { c.testHookBeforeFinalWrite = old } +} + +func TestPutGetZero(t *testing.T) { + c, sleep := openTester(t) + + check := testutil.Checker(t) + checkEntry := entryChecker(t, c) + + d := mkdigest("x") + err := PutBytes(c, d, "x") + check(err) + checkEntry(d, 1, sleep(0)) + + err = os.Truncate(c.GetFile(d), 0) + check(err) + + _, err = c.Get(d) + if !errors.Is(err, fs.ErrNotExist) { + t.Fatalf("err = %v, want fs.ErrNotExist", err) + } +} + +func TestPutZero(t *testing.T) { + c, _ := openTester(t) + d := mkdigest("x") + err := c.Put(d, strings.NewReader("x"), 0) // size == 0 (not size of content) + testutil.Check(t, err) + checkNotExists(t, c, d) +} + +func TestCommit(t *testing.T) { + check := testutil.Checker(t) + + c, err := Open(t.TempDir()) + if err != nil { + t.Fatal(err) + } + checkEntry := entryChecker(t, c) + + now := epoch + c.now = func() time.Time { return now } + + d1 := mkdigest("1") + err = c.Link("h/n/m:t", d1) + if !errors.Is(err, fs.ErrNotExist) { + t.Fatalf("err = %v, want fs.ErrNotExist", err) + } + + err = PutBytes(c, d1, "1") + check(err) + + err = c.Link("h/n/m:t", d1) + check(err) + + got, err := c.Resolve("h/n/m:t") + check(err) + if got != d1 { + t.Fatalf("d = %v, want %v", got, d1) + } + + // commit again, more than 1 byte + d2 := mkdigest("22") + err = PutBytes(c, d2, "22") + check(err) + err = c.Link("h/n/m:t", d2) + check(err) + checkEntry(d2, 2, now) + + filename := must(c.manifestPath("h/n/m:t")) + data, err := os.ReadFile(filename) + check(err) + if string(data) != "22" { + t.Fatalf("data = %q, want %q", data, "22") + } + + t0 := now + now = now.Add(1 * time.Hour) + err = c.Link("h/n/m:t", d2) // same contents; nop + check(err) + info, err := os.Stat(filename) + check(err) + testutil.CheckTime(t, info.ModTime(), t0) +} + +func TestManifestInvalidBlob(t *testing.T) { + c, _ := openTester(t) + d := mkdigest("1") + err := c.Link("h/n/m:t", d) + if err == nil { + t.Fatal("expected error") + } + checkNotExists(t, c, d) + + err = PutBytes(c, d, "1") + testutil.Check(t, err) + err = os.WriteFile(c.GetFile(d), []byte("invalid"), 0o666) + if err != nil { + t.Fatal(err) + } + + err = c.Link("h/n/m:t", d) + if !strings.Contains(err.Error(), "underfoot") { + t.Fatalf("err = %v, want error to contain %q", err, "underfoot") + } +} + +func TestManifestNameReuse(t *testing.T) { + t.Run("case-insensitive", func(t *testing.T) { + // This should run on all file system types. + testManifestNameReuse(t) + }) + t.Run("case-sensitive", func(t *testing.T) { + useCaseInsensitiveTempDir(t) + testManifestNameReuse(t) + }) +} + +func testManifestNameReuse(t *testing.T) { + check := testutil.Checker(t) + + c, _ := openTester(t) + + d1 := mkdigest("1") + err := PutBytes(c, d1, "1") + check(err) + err = c.Link("h/n/m:t", d1) + check(err) + + d2 := mkdigest("22") + err = PutBytes(c, d2, "22") + check(err) + err = c.Link("H/N/M:T", d2) + check(err) + + var g [2]Digest + g[0], err = c.Resolve("h/n/m:t") + check(err) + g[1], err = c.Resolve("H/N/M:T") + check(err) + + w := [2]Digest{d2, d2} + if g != w { + t.Fatalf("g = %v, want %v", g, w) + } + + var got []string + for l, err := range c.links() { + if err != nil { + t.Fatal(err) + } + got = append(got, l) + } + want := []string{"manifests/h/n/m/t"} + if !slices.Equal(got, want) { + t.Fatalf("got = %v, want %v", got, want) + } + + // relink with different case + err = c.Unlink("h/n/m:t") + check(err) + err = c.Link("h/n/m:T", d1) + check(err) + + got = got[:0] + for l, err := range c.links() { + if err != nil { + t.Fatal(err) + } + got = append(got, l) + } + + // we should have only one link that is same case as the last link + want = []string{"manifests/h/n/m/T"} + if !slices.Equal(got, want) { + t.Fatalf("got = %v, want %v", got, want) + } +} + +func TestManifestFile(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"", ""}, + + // valid names + {"h/n/m:t", "/manifests/h/n/m/t"}, + {"hh/nn/mm:tt", "/manifests/hh/nn/mm/tt"}, + + {"%/%/%/%", ""}, + + // already a path + {"h/n/m/t", ""}, + + // refs are not names + {"h/n/m:t@sha256-1", ""}, + {"m@sha256-1", ""}, + {"n/m:t@sha256-1", ""}, + } + + c, _ := openTester(t) + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + got, err := c.manifestPath(tt.in) + if err != nil && tt.want != "" { + t.Fatalf("unexpected error: %v", err) + } + if err == nil && tt.want == "" { + t.Fatalf("expected error") + } + dir := filepath.ToSlash(c.dir) + got = filepath.ToSlash(got) + got = strings.TrimPrefix(got, dir) + if got != tt.want { + t.Fatalf("got = %q, want %q", got, tt.want) + } + }) + } +} + +func TestNames(t *testing.T) { + c, _ := openTester(t) + check := testutil.Checker(t) + + check(PutBytes(c, mkdigest("1"), "1")) + check(PutBytes(c, mkdigest("2"), "2")) + + check(c.Link("h/n/m:t", mkdigest("1"))) + check(c.Link("h/n/m:u", mkdigest("2"))) + + var got []string + for l, err := range c.Links() { + if err != nil { + t.Fatal(err) + } + got = append(got, l) + } + want := []string{"h/n/m:t", "h/n/m:u"} + if !slices.Equal(got, want) { + t.Fatalf("got = %v, want %v", got, want) + } +} + +func mkdigest(s string) Digest { + return Digest{sha256.Sum256([]byte(s))} +} + +func checkNotExists(t *testing.T, c *DiskCache, d Digest) { + t.Helper() + _, err := c.Get(d) + if !errors.Is(err, fs.ErrNotExist) { + t.Fatalf("err = %v, want fs.ErrNotExist", err) + } +} + +func entryChecker(t *testing.T, c *DiskCache) func(Digest, int64, time.Time) { + t.Helper() + return func(d Digest, size int64, mod time.Time) { + t.Helper() + t.Run("checkEntry:"+d.String(), func(t *testing.T) { + t.Helper() + + defer func() { + if t.Failed() { + dumpCacheContents(t, c) + } + }() + + e, err := c.Get(d) + if size == 0 && errors.Is(err, fs.ErrNotExist) { + err = nil + } + if err != nil { + t.Fatal(err) + } + if e.Digest != d { + t.Errorf("e.Digest = %v, want %v", e.Digest, d) + } + if e.Size != size { + t.Fatalf("e.Size = %v, want %v", e.Size, size) + } + + testutil.CheckTime(t, e.Time, mod) + info, err := os.Stat(c.GetFile(d)) + if err != nil { + t.Fatal(err) + } + if info.Size() != size { + t.Fatalf("info.Size = %v, want %v", info.Size(), size) + } + testutil.CheckTime(t, info.ModTime(), mod) + }) + } +} + +func must[T any](v T, err error) T { + if err != nil { + panic(err) + } + return v +} + +func TestNameToPath(t *testing.T) { + _, err := nameToPath("h/n/m:t") + if err != nil { + t.Fatal(err) + } +} + +type errOnBangReader struct { + s string + n int +} + +func (e *errOnBangReader) Read(p []byte) (int, error) { + if len(p) < 1 { + return 0, io.ErrShortBuffer + } + if e.n >= len(p) { + return 0, io.EOF + } + if e.s[e.n] == '!' { + return 0, errors.New("bang") + } + p[0] = e.s[e.n] + e.n++ + return 1, nil +} + +func dumpCacheContents(t *testing.T, c *DiskCache) { + t.Helper() + + var b strings.Builder + fsys := os.DirFS(c.dir) + fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error { + t.Helper() + + if err != nil { + return err + } + info, err := d.Info() + if err != nil { + return err + } + + // Format like ls: + // + // ; ls -la + // drwxr-xr-x 224 Jan 13 14:22 blob/sha256-123 + // drwxr-xr-x 224 Jan 13 14:22 manifest/h/n/m + + fmt.Fprintf(&b, " %s % 4d %s %s\n", + info.Mode(), + info.Size(), + info.ModTime().Format("Jan 2 15:04"), + path, + ) + return nil + }) + t.Log() + t.Logf("cache contents:\n%s", b.String()) +} diff --git a/server/internal/cache/blob/casecheck_test.go b/server/internal/cache/blob/casecheck_test.go new file mode 100644 index 00000000..f0842ef9 --- /dev/null +++ b/server/internal/cache/blob/casecheck_test.go @@ -0,0 +1,93 @@ +package blob + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func isCaseSensitive(dir string) bool { + defer func() { + os.Remove(filepath.Join(dir, "_casecheck")) + }() + + exists := func(file string) bool { + _, err := os.Stat(file) + return err == nil + } + + file := filepath.Join(dir, "_casecheck") + FILE := filepath.Join(dir, "_CASECHECK") + if exists(file) || exists(FILE) { + panic(fmt.Sprintf("_casecheck already exists in %q; remove and try again.", dir)) + } + + err := os.WriteFile(file, nil, 0o666) + if err != nil { + panic(err) + } + + return !exists(FILE) +} + +func isCI() bool { + return os.Getenv("CI") != "" +} + +const volumeHint = ` + + Unable to locate case-insensitive TMPDIR on darwin. + + To run tests, create the case-insensitive volume /Volumes/data: + + $ sudo diskutil apfs addVolume disk1 APFSX data -mountpoint /Volumes/data + + or run with: + + CI=1 go test ./... + +` + +// useCaseInsensitiveTempDir sets TMPDIR to a case-insensitive directory +// can find one, otherwise it skips the test if the CI environment variable is +// set, or GOOS is not darwin. +func useCaseInsensitiveTempDir(t *testing.T) bool { + if isCaseSensitive(os.TempDir()) { + // Use the default temp dir if it is already case-sensitive. + return true + } + if runtime.GOOS == "darwin" { + // If darwin, check for the special case-sensitive volume and + // use it if available. + const volume = "/Volumes/data" + _, err := os.Stat(volume) + if err == nil { + tmpdir := filepath.Join(volume, "tmp") + os.MkdirAll(tmpdir, 0o700) + t.Setenv("TMPDIR", tmpdir) + return true + } + if isCI() { + // Special case darwin in CI; it is not case-sensitive + // by default, and we will be testing other platforms + // that are case-sensitive, so we'll have the test + // being skipped covered there. + t.Skip("Skipping test in CI for darwin; TMPDIR is not case-insensitive.") + } + } + + if !isCI() { + // Require devs to always tests with a case-insensitive TMPDIR. + + // TODO(bmizerany): Print platform-specific instructions or + // link to docs on that topic. + lines := strings.Split(volumeHint, "\n") + for _, line := range lines { + t.Log(line) + } + } + return false +} diff --git a/server/internal/cache/blob/digest.go b/server/internal/cache/blob/digest.go new file mode 100644 index 00000000..723ba222 --- /dev/null +++ b/server/internal/cache/blob/digest.go @@ -0,0 +1,95 @@ +package blob + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "slices" + "strings" +) + +var ErrInvalidDigest = errors.New("invalid digest") + +// Digest is a blob identifier that is the SHA-256 hash of a blob's content. +// +// It is comparable and can be used as a map key. +type Digest struct { + sum [32]byte +} + +// ParseDigest parses a digest from a string. If the string is not a valid +// digest, a call to the returned digest's IsValid method will return false. +// +// The input string may be in one of two forms: +// +// - ("sha256-"), where is a 64-character hexadecimal string. +// - ("sha256:"), where is a 64-character hexadecimal string. +// +// The [Digest.String] method will return the canonical form of the +// digest, "sha256:". +func ParseDigest[S ~[]byte | ~string](v S) (Digest, error) { + s := string(v) + i := strings.IndexAny(s, ":-") + var zero Digest + if i < 0 { + return zero, ErrInvalidDigest + } + + prefix, sum := s[:i], s[i+1:] + if prefix != "sha256" || len(sum) != 64 { + return zero, ErrInvalidDigest + } + + var d Digest + _, err := hex.Decode(d.sum[:], []byte(sum)) + if err != nil { + return zero, ErrInvalidDigest + } + return d, nil +} + +func DigestFromBytes[S ~[]byte | ~string](v S) Digest { + return Digest{sha256.Sum256([]byte(v))} +} + +// String returns the string representation of the digest in the conventional +// form "sha256:". +func (d Digest) String() string { + return fmt.Sprintf("sha256:%x", d.sum[:]) +} + +func (d Digest) Short() string { + return fmt.Sprintf("%x", d.sum[:4]) +} + +func (d Digest) Compare(other Digest) int { + return slices.Compare(d.sum[:], other.sum[:]) +} + +// IsValid returns true if the digest is valid, i.e. if it is the SHA-256 hash +// of some content. +func (d Digest) IsValid() bool { + return d != (Digest{}) +} + +// MarshalText implements the encoding.TextMarshaler interface. It returns an +// error if [Digest.IsValid] returns false. +func (d Digest) MarshalText() ([]byte, error) { + return []byte(d.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface, and only +// works for a zero digest. If [Digest.IsValid] returns true, it returns an +// error. +func (d *Digest) UnmarshalText(text []byte) error { + if *d != (Digest{}) { + return errors.New("digest: illegal UnmarshalText on valid digest") + } + v, err := ParseDigest(string(text)) + if err != nil { + return err + } + *d = v + return nil +} diff --git a/server/internal/cache/blob/digest_test.go b/server/internal/cache/blob/digest_test.go new file mode 100644 index 00000000..c96ad383 --- /dev/null +++ b/server/internal/cache/blob/digest_test.go @@ -0,0 +1,63 @@ +package blob + +import ( + "encoding/json" + "testing" +) + +func TestParseDigest(t *testing.T) { + cases := []struct { + in string + valid bool + }{ + {"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true}, + {"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true}, + + // too short + {"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false}, + {"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false}, + + // too long + {"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false}, + {"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false}, + + // invalid prefix + {"sha255-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false}, + {"sha255:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false}, + {"sha256!0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false}, + + // invalid hex + {"sha256-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false}, + {"sha256:XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false}, + } + + for _, tt := range cases { + got, err := ParseDigest(tt.in) + if tt.valid && err != nil { + t.Errorf("ParseDigest(%q) = %v, %v; want valid", tt.in, got, err) + } + want := "sha256:" + tt.in[7:] + if tt.valid && got.String() != want { + t.Errorf("ParseDigest(%q).String() = %q, want %q", tt.in, got.String(), want) + } + } +} + +func TestDigestMarshalText(t *testing.T) { + const s = `"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"` + var d Digest + if err := json.Unmarshal([]byte(s), &d); err != nil { + t.Errorf("json.Unmarshal: %v", err) + } + out, err := json.Marshal(d) + if err != nil { + t.Errorf("json.Marshal: %v", err) + } + want := `"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"` + if string(out) != want { + t.Errorf("json.Marshal: got %s, want %s", out, want) + } + if err := json.Unmarshal([]byte(`"invalid"`), &Digest{}); err == nil { + t.Errorf("json.Unmarshal: expected error") + } +} diff --git a/server/internal/chunks/chunks.go b/server/internal/chunks/chunks.go new file mode 100644 index 00000000..7eb7a6c1 --- /dev/null +++ b/server/internal/chunks/chunks.go @@ -0,0 +1,78 @@ +package chunks + +import ( + "fmt" + "iter" + "strconv" + "strings" +) + +type Chunk struct { + Start, End int64 +} + +func New(start, end int64) Chunk { + return Chunk{start, end} +} + +// ParseRange parses a string in the form "unit=range" where unit is a string +// and range is a string in the form "start-end". It returns the unit and the +// range as a Chunk. +func ParseRange(s string) (unit string, _ Chunk, _ error) { + unit, r, _ := strings.Cut(s, "=") + if r == "" { + return unit, Chunk{}, nil + } + c, err := Parse(r) + if err != nil { + return "", Chunk{}, err + } + return unit, c, err +} + +// Parse parses a string in the form "start-end" and returns the Chunk. +func Parse(s string) (Chunk, error) { + startStr, endStr, _ := strings.Cut(s, "-") + start, err := strconv.ParseInt(startStr, 10, 64) + if err != nil { + return Chunk{}, fmt.Errorf("invalid start: %v", err) + } + end, err := strconv.ParseInt(endStr, 10, 64) + if err != nil { + return Chunk{}, fmt.Errorf("invalid end: %v", err) + } + if start > end { + return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end) + } + return Chunk{start, end}, nil +} + +// Of returns a sequence of contiguous Chunks of size chunkSize that cover +// the range [0, size), in order. +func Of(size, chunkSize int64) iter.Seq[Chunk] { + return func(yield func(Chunk) bool) { + for start := int64(0); start < size; start += chunkSize { + end := min(start+chunkSize-1, size-1) + if !yield(Chunk{start, end}) { + break + } + } + } +} + +// Count returns the number of Chunks of size chunkSize needed to cover the +// range [0, size). +func Count(size, chunkSize int64) int64 { + return (size + chunkSize - 1) / chunkSize +} + +// Size returns end minus start plus one. +func (c Chunk) Size() int64 { + return c.End - c.Start + 1 +} + +// String returns the string representation of the Chunk in the form +// "{start}-{end}". +func (c Chunk) String() string { + return fmt.Sprintf("%d-%d", c.Start, c.End) +} diff --git a/server/internal/chunks/chunks_test.go b/server/internal/chunks/chunks_test.go new file mode 100644 index 00000000..c23e0de8 --- /dev/null +++ b/server/internal/chunks/chunks_test.go @@ -0,0 +1,65 @@ +package chunks + +import ( + "slices" + "testing" +) + +func TestOf(t *testing.T) { + cases := []struct { + total int64 + chunkSize int64 + want []Chunk + }{ + {0, 1, nil}, + {1, 1, []Chunk{{0, 0}}}, + {1, 2, []Chunk{{0, 0}}}, + {2, 1, []Chunk{{0, 0}, {1, 1}}}, + {10, 9, []Chunk{{0, 8}, {9, 9}}}, + } + + for _, tt := range cases { + got := slices.Collect(Of(tt.total, tt.chunkSize)) + if !slices.Equal(got, tt.want) { + t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want) + } + } +} + +func TestSize(t *testing.T) { + cases := []struct { + c Chunk + want int64 + }{ + {Chunk{0, 0}, 1}, + {Chunk{0, 1}, 2}, + {Chunk{3, 4}, 2}, + } + + for _, tt := range cases { + got := tt.c.Size() + if got != tt.want { + t.Errorf("%v: got %d; want %d", tt.c, got, tt.want) + } + } +} + +func TestCount(t *testing.T) { + cases := []struct { + total int64 + chunkSize int64 + want int64 + }{ + {0, 1, 0}, + {1, 1, 1}, + {1, 2, 1}, + {2, 1, 2}, + {10, 9, 2}, + } + for _, tt := range cases { + got := Count(tt.total, tt.chunkSize) + if got != tt.want { + t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want) + } + } +} diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go new file mode 100644 index 00000000..13612272 --- /dev/null +++ b/server/internal/client/ollama/registry.go @@ -0,0 +1,802 @@ +// Package ollama provides a client for interacting with an Ollama registry +// which pushes and pulls model manifests and layers as defined by the +// [ollama.com/manifest]. +package ollama + +import ( + "bufio" + "bytes" + "cmp" + "context" + "crypto" + "crypto/ed25519" + "crypto/sha256" + "crypto/tls" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "net/http" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync/atomic" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/sync/errgroup" + + "github.com/ollama/ollama/server/internal/cache/blob" + "github.com/ollama/ollama/server/internal/chunks" + "github.com/ollama/ollama/server/internal/internal/backoff" + "github.com/ollama/ollama/server/internal/internal/names" + "github.com/ollama/ollama/server/internal/internal/syncs" + + _ "embed" +) + +// Errors +var ( + // ErrManifestNotFound is returned when a manifest is not found in the + // cache or registry. + ErrManifestNotFound = errors.New("manifest not found") + + // ErrManifestInvalid is returned when a manifest found in a local or + // remote cache is invalid. + ErrManifestInvalid = errors.New("invalid manifest") + + // ErrMissingModel is returned when the model part of a name is missing + // or invalid. + ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}") + + // ErrCached is passed to [Trace.PushUpdate] when a layer already + // exists. It is a non-fatal error and is never returned by [Registry.Push]. + ErrCached = errors.New("cached") +) + +// Defaults +const ( + // DefaultChunkingThreshold is the threshold at which a layer should be + // split up into chunks when downloading. + DefaultChunkingThreshold = 128 << 20 + + // DefaultMaxChunkSize is the default maximum size of a chunk to + // download. It is configured based on benchmarks and aims to strike a + // balance between download speed and memory usage. + DefaultMaxChunkSize = 8 << 20 +) + +// DefaultCache returns a new disk cache for storing models. If the +// OLLAMA_MODELS environment variable is set, it uses that directory; +// otherwise, it uses $HOME/.ollama/models. +func DefaultCache() (*blob.DiskCache, error) { + dir := os.Getenv("OLLAMA_MODELS") + if dir == "" { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + dir = filepath.Join(home, ".ollama", "models") + } + return blob.Open(dir) +} + +// Error is the standard error returned by Ollama APIs. +type Error struct { + Status int `json:"-"` + Code string `json:"code"` + Message string `json:"message"` +} + +func (e *Error) Error() string { + return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (e *Error) UnmarshalJSON(b []byte) error { + type E Error + var v struct{ Errors []E } + if err := json.Unmarshal(b, &v); err != nil { + return err + } + if len(v.Errors) == 0 { + return fmt.Errorf("no messages in error response: %s", string(b)) + } + *e = Error(v.Errors[0]) // our registry only returns one error. + return nil +} + +// TODO(bmizerany): make configurable on [Registry] +var defaultName = func() names.Name { + n := names.Parse("ollama.com/library/_:latest") + if !n.IsFullyQualified() { + panic("default name is not fully qualified") + } + return n +}() + +// Registry is a client for performing push and pull operations against an +// Ollama registry. +type Registry struct { + // UserAgent is the User-Agent header to send with requests to the + // registry. If empty, the User-Agent is determined by HTTPClient. + UserAgent string + + // Key is the key used to authenticate with the registry. + // + // Currently, only Ed25519 keys are supported. + Key crypto.PrivateKey + + // HTTPClient is the HTTP client used to make requests to the registry. + // + // If nil, [http.DefaultClient] is used. + // + // As a quick note: If a Registry function that makes a call to a URL + // with the "https+insecure" scheme, the client will be cloned and the + // transport will be set to skip TLS verification, unless the client's + // Transport done not have a Clone method with the same signature as + // [http.Transport.Clone], which case, the call will fail. + HTTPClient *http.Client + + // MaxStreams is the maximum number of concurrent streams to use when + // pushing or pulling models. If zero, the number of streams is + // determined by [runtime.GOMAXPROCS]. + // + // Clients that want "unlimited" streams should set this to a large + // number. + MaxStreams int + + // ChunkingThreshold is the maximum size of a layer to download in a single + // request. If zero, [DefaultChunkingThreshold] is used. + ChunkingThreshold int64 + + // MaxChunkSize is the maximum size of a chunk to download. If zero, + // the default is [DefaultMaxChunkSize]. + // + // It is only used when a layer is larger than [MaxChunkingThreshold]. + MaxChunkSize int64 +} + +// RegistryFromEnv returns a new Registry configured from the environment. The +// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the +// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the +// system's temporary directory. +// +// It returns an error if any configuration in the environment is invalid. +func RegistryFromEnv() (*Registry, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519")) + if err != nil { + return nil, err + } + + var rc Registry + rc.Key, err = ssh.ParseRawPrivateKey(keyPEM) + if err != nil { + return nil, err + } + maxStreams := os.Getenv("OLLAMA_REGISTRY_MAXSTREAMS") + if maxStreams != "" { + var err error + rc.MaxStreams, err = strconv.Atoi(maxStreams) + if err != nil { + return nil, fmt.Errorf("invalid OLLAMA_REGISTRY_MAXSTREAMS: %w", err) + } + } + return &rc, nil +} + +type PushParams struct { + // From is an optional destination name for the model. If empty, the + // destination name is the same as the source name. + From string +} + +// parseName parses name using [names.ParseExtended] and then merges the name with the +// default name, and checks that the name is fully qualified. If a digest is +// present, it parse and returns it with the other fields as their zero values. +// +// It returns an error if the name is not fully qualified, or if the digest, if +// any, is invalid. +// +// The scheme is returned as provided by [names.ParseExtended]. +func parseName(s string) (scheme string, n names.Name, d blob.Digest, err error) { + scheme, n, ds := names.ParseExtended(s) + n = names.Merge(n, defaultName) + if ds != "" { + // Digest is present. Validate it. + d, err = blob.ParseDigest(ds) + if err != nil { + return "", names.Name{}, blob.Digest{}, err + } + } + + // The name check is deferred until after the digest check because we + // say that digests take precedence over names, and so should there + // errors when being parsed. + if !n.IsFullyQualified() { + return "", names.Name{}, blob.Digest{}, ErrNameInvalid + } + + scheme = cmp.Or(scheme, "https") + return scheme, n, d, nil +} + +func (r *Registry) maxStreams() int { + n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0)) + + // Large downloads require a writter stream, so ensure we have at least + // two streams to avoid a deadlock. + return max(n, 2) +} + +func (r *Registry) maxChunkingThreshold() int64 { + return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold) +} + +// chunkSizeFor returns the chunk size for a layer of the given size. If the +// size is less than or equal to the max chunking threshold, the size is +// returned; otherwise, the max chunk size is returned. +func (r *Registry) maxChunkSize() int64 { + return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize) +} + +// Push pushes the model with the name in the cache to the remote registry. +func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error { + if p == nil { + p = &PushParams{} + } + + m, err := ResolveLocal(c, cmp.Or(p.From, name)) + if err != nil { + return err + } + + // Before much else happens, check layers at not null, the blobs exist, + // and the sizes match. This prevents long uploads followed by + // disappointment. + for _, l := range m.Layers { + if l == nil { + return fmt.Errorf("%w: null layer", ErrManifestInvalid) + } + info, err := c.Get(l.Digest) + if err != nil { + return fmt.Errorf("error getting %s: %w", l.Digest.Short(), err) + } + if info.Size != l.Size { + return fmt.Errorf("size mismatch for %s: %d != %d", l.Digest.Short(), info.Size, l.Size) + } + } + + t := traceFromContext(ctx) + + scheme, n, _, err := parseName(name) + if err != nil { + // This should never happen since ResolveLocal should have + // already validated the name. + panic(err) + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var g errgroup.Group + g.SetLimit(r.maxStreams()) + for _, l := range m.Layers { + var progress atomic.Int64 + g.Go(func() (err error) { + defer func() { t.update(l, progress.Load(), err) }() + + t.update(l, 0, nil) + + startURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/uploads/?digest=%s", + scheme, + n.Host(), + n.Namespace(), + n.Model(), + l.Digest, + ) + res, err := r.doOK(ctx, "POST", startURL, nil) + if err != nil { + return err + } + res.Body.Close() + + f, err := os.Open(c.GetFile(l.Digest)) + if err != nil { + return err + } + defer f.Close() + + uploadURL := res.Header.Get("Location") + if uploadURL == "" { + t.update(l, l.Size, ErrCached) + return nil + } + + req, err := r.newRequest(ctx, "PUT", uploadURL, f) + if err != nil { + return fmt.Errorf("invalid upload URL returned from registry: %q: %w", uploadURL, err) + } + req.ContentLength = l.Size + + res, err = doOK(r.client(), req) + if err == nil { + res.Body.Close() + } + return err + }) + } + + if err := g.Wait(); err != nil { + return err + } + + // Commit + path := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s", + scheme, + n.Host(), + n.Namespace(), + n.Model(), + n.Tag(), + ) + res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data)) + if err == nil { + res.Body.Close() + } + // TODO(bmizerany): add a "commit" trace event + return err +} + +func canRetry(err error) bool { + var re *Error + if !errors.As(err, &re) { + return false + } + return re.Status >= 500 +} + +// Pull pulls the model with the given name from the remote registry into the +// cache. +// +// For layers larger then [Registry.MaxChunkSize], the layer is downloaded in +// chunks of the specified size, and then reassembled and verified. This is +// typically slower than splitting the model up across layers, and is mostly +// utilized for layers of type equal to "application/vnd.ollama.image". +func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error { + scheme, n, _, err := parseName(name) + if err != nil { + return err + } + + m, err := r.Resolve(ctx, name) + if err != nil { + return err + } + if len(m.Layers) == 0 { + return fmt.Errorf("%w: no layers", ErrManifestInvalid) + } + + exists := func(l *Layer) bool { + info, err := c.Get(l.Digest) + return err == nil && info.Size == l.Size + } + + t := traceFromContext(ctx) + + var g errgroup.Group + g.SetLimit(r.maxStreams()) + + for _, l := range m.Layers { + if exists(l) { + t.update(l, l.Size, ErrCached) + continue + } + + blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest) + req, err := r.newRequest(ctx, "GET", blobURL, nil) + if err != nil { + t.update(l, 0, err) + continue + } + + t.update(l, 0, nil) + + if l.Size <= r.maxChunkingThreshold() { + g.Go(func() error { + res, err := doOK(r.client(), req) + if err != nil { + return err + } + defer res.Body.Close() + err = c.Put(l.Digest, res.Body, l.Size) + if err == nil { + t.update(l, l.Size, nil) + } + return err + }) + } else { + q := syncs.NewRelayReader() + + g.Go(func() (err error) { + defer func() { q.CloseWithError(err) }() + return c.Put(l.Digest, q, l.Size) + }) + + var progress atomic.Int64 + + // We want to avoid extra round trips per chunk due to + // redirects from the registry to the blob store, so + // fire an initial request to get the final URL and + // then use that URL for the chunk requests. + req.Header.Set("Range", "bytes=0-0") + res, err := doOK(r.client(), req) + if err != nil { + return err + } + res.Body.Close() + req = res.Request.WithContext(req.Context()) + + streamNo := 0 + tws := make([]*bufio.Writer, r.maxStreams()-1) + for chunk := range chunks.Of(l.Size, r.maxChunkSize()) { + ticket := q.Take() + bufIdx := streamNo % len(tws) + streamNo++ + g.Go(func() (err error) { + defer func() { + if err != nil { + q.CloseWithError(err) + } + ticket.Close() + t.update(l, progress.Load(), err) + }() + + for _, err := range backoff.Loop(ctx, 3*time.Second) { + if err != nil { + return err + } + + err := func() error { + req := req.Clone(req.Context()) + req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk)) + res, err := doOK(r.client(), req) + if err != nil { + return err + } + defer res.Body.Close() + + tw := tws[bufIdx] + if tw == nil { + tw = bufio.NewWriterSize(nil, int(r.maxChunkSize())) + tws[bufIdx] = tw + } + tw.Reset(ticket) + defer tw.Reset(nil) // release ticket + + _, err = io.CopyN(tw, res.Body, chunk.Size()) + if err != nil { + return maybeUnexpectedEOF(err) + } + if err := tw.Flush(); err != nil { + return err + } + + total := progress.Add(chunk.Size()) + if total >= l.Size { + q.Close() + } + return nil + }() + if !canRetry(err) { + return err + } + } + return nil + }) + } + } + } + + if err := g.Wait(); err != nil { + return err + } + + // store the manifest blob + md := blob.DigestFromBytes(m.Data) + if err := blob.PutBytes(c, md, m.Data); err != nil { + return err + } + + // commit the manifest with a link + return c.Link(m.Name, md) +} + +// Manifest represents a [ollama.com/manifest]. +type Manifest struct { + Name string `json:"-"` // the canonical name of the model + Data []byte `json:"-"` // the raw data of the manifest + Layers []*Layer `json:"layers"` +} + +var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000") + +// Layer returns the layer with the given +// digest, or nil if not found. +func (m *Manifest) Layer(d blob.Digest) *Layer { + for _, l := range m.Layers { + if l.Digest == d { + return l + } + } + return nil +} + +// MarshalJSON implements json.Marshaler. +// +// NOTE: It adds an empty config object to the manifest, which is required by +// the registry, but not used by the client. In the future, the config object +// will not be required by the registry and this will should be removed. +func (m Manifest) MarshalJSON() ([]byte, error) { + type M Manifest + v := struct { + M + + // This is ignored, mostly, by the registry But, if not + // present, it will cause an error to be returned during the + // last phase of the commit which expects it, but does nothing + // with it. This will be fixed in a future release of + // ollama.com. + Config *Layer `json:"config"` + }{ + M: M(m), + Config: &Layer{Digest: emptyDigest}, + } + return json.Marshal(v) +} + +// unmarshalManifest unmarshals the data into a manifest, and sets the name +// field to the string representation of the name. +// +// It panics if the name is not fully qualified. Callers should ensure the name +// is fully qualified before calling this function. +func unmarshalManifest(n names.Name, data []byte) (*Manifest, error) { + if !n.IsFullyQualified() { + panic(fmt.Sprintf("unmarshalManifest: name is not fully qualified: %s", n.String())) + } + var m Manifest + if err := json.Unmarshal(data, &m); err != nil { + return nil, err + } + m.Name = n.String() + m.Data = data + return &m, nil +} + +// Layer is a layer in a model. +type Layer struct { + Digest blob.Digest `json:"digest"` + MediaType string `json:"mediaType"` + Size int64 `json:"size"` +} + +// ResolveLocal resolves a name to a Manifest in the local cache. The name is +// parsed using [names.ParseExtended] but the scheme is ignored. +func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) { + _, n, d, err := parseName(name) + if err != nil { + return nil, err + } + if !d.IsValid() { + d, err = c.Resolve(n.String()) + if err != nil { + return nil, err + } + } + data, err := os.ReadFile(c.GetFile(d)) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name) + } + return nil, err + } + m, err := unmarshalManifest(n, data) + if err != nil { + return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err)) + } + return m, nil +} + +// Resolve resolves a name to a Manifest in the remote registry. +func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) { + scheme, n, d, err := parseName(name) + if err != nil { + return nil, err + } + + manifestURL := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s", scheme, n.Host(), n.Namespace(), n.Model(), n.Tag()) + if d.IsValid() { + manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d) + } + + res, err := r.doOK(ctx, "GET", manifestURL, nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + data, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + // TODO(bmizerany): return digest here + m, err := unmarshalManifest(n, data) + if err != nil { + return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err)) + } + return m, nil +} + +func (r *Registry) client() *http.Client { + if r.HTTPClient != nil { + return r.HTTPClient + } + return http.DefaultClient +} + +// newRequest constructs a new request, ready to use, with the given method, +// url, and body, presigned with client Key and UserAgent. +func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, err + } + if r.UserAgent != "" { + req.Header.Set("User-Agent", r.UserAgent) + } + if r.Key != nil { + token, err := makeAuthToken(r.Key) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + } + return req, nil +} + +// doOK makes a request with the given client and request, and returns the +// response if the status code is 200. If the status code is not 200, an Error +// is parsed from the response body and returned. If any other error occurs, it +// is returned. +func doOK(c *http.Client, r *http.Request) (*http.Response, error) { + if r.URL.Scheme == "https+insecure" { + // TODO(bmizerany): clone client.Transport, set + // InsecureSkipVerify, etc. + + type cloner interface { + Clone() *http.Transport + } + + // Attempt to configure the transport to skip TLS verification + // if we can clone it, otherwise fall through and let the http + // client complain and the scheme being invalid. + x, ok := cmp.Or(c.Transport, http.DefaultTransport).(cloner) + if ok { + tr := x.Clone() + tr.TLSClientConfig = cmp.Or(tr.TLSClientConfig, &tls.Config{}) + tr.TLSClientConfig.InsecureSkipVerify = true + + cc := *c // shallow copy + cc.Transport = tr + c = &cc + + r = r.Clone(r.Context()) + r.URL.Scheme = "https" + + // fall through + } + } + + res, err := c.Do(r) + if err != nil { + return nil, err + } + if res.StatusCode/100 != 2 { + out, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + var re Error + if err := json.Unmarshal(out, &re); err != nil { + // Use the raw body if we can't parse it as an error object. + re.Message = string(out) + } + re.Status = res.StatusCode + return nil, &re + } + return res, nil +} + +// doOK is a convenience method for making a request with newRequest and +// passing it to doOK with r.client(). +func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { + req, err := r.newRequest(ctx, method, path, body) + if err != nil { + return nil, err + } + return doOK(r.client(), req) +} + +// makeAuthToken creates an Ollama auth token for the given private key. +// +// NOTE: This format is OLD, overly complex, and should be replaced. We're +// inheriting it from the original Ollama client and ollama.com +// implementations, so we need to support it for now. +func makeAuthToken(key crypto.PrivateKey) (string, error) { + privKey, _ := key.(*ed25519.PrivateKey) + if privKey == nil { + return "", fmt.Errorf("unsupported private key type: %T", key) + } + + url := fmt.Sprintf("https://ollama.com?ts=%d", time.Now().Unix()) + // Part 1: the checkData (e.g. the URL with a timestamp) + + // Part 2: the public key + pubKeyShort, err := func() ([]byte, error) { + sshPubKey, err := ssh.NewPublicKey(privKey.Public()) + if err != nil { + return nil, err + } + pubKeyParts := bytes.Fields(ssh.MarshalAuthorizedKey(sshPubKey)) + if len(pubKeyParts) < 2 { + return nil, fmt.Errorf("malformed public key: %q", pubKeyParts) + } + pubKeyShort := pubKeyParts[1] + return pubKeyShort, nil + }() + if err != nil { + return "", err + } + + // Part 3: the signature + sig := ed25519.Sign(*privKey, []byte(checkData(url))) + + // Assemble the token: :: + var b strings.Builder + io.WriteString(&b, base64.StdEncoding.EncodeToString([]byte(url))) + b.WriteByte(':') + b.Write(pubKeyShort) + b.WriteByte(':') + io.WriteString(&b, base64.StdEncoding.EncodeToString(sig)) + + return b.String(), nil +} + +// The original spec for Ollama tokens was to use the SHA256 of the zero +// string as part of the signature. I'm not sure why that was, but we still +// need it to verify the signature. +var zeroSum = func() string { + sha256sum := sha256.Sum256(nil) + x := base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:]))) + return x +}() + +// checkData takes a URL and creates the original string format of the +// data signature that is used by the ollama client to sign requests +func checkData(url string) string { + return fmt.Sprintf("GET,%s,%s", url, zeroSum) +} + +func maybeUnexpectedEOF(err error) error { + if errors.Is(err, io.EOF) { + return io.ErrUnexpectedEOF + } + return err +} diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go new file mode 100644 index 00000000..d8f2a407 --- /dev/null +++ b/server/internal/client/ollama/registry_test.go @@ -0,0 +1,656 @@ +package ollama + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "math/rand/v2" + "net/http" + "net/http/httptest" + "os" + "path" + "reflect" + "slices" + "strings" + "testing" + "time" + + "github.com/ollama/ollama/server/internal/cache/blob" + "github.com/ollama/ollama/server/internal/chunks" + "github.com/ollama/ollama/server/internal/internal/testutil" +) + +func TestManifestMarshalJSON(t *testing.T) { + // All manifests should contain an "empty" config object. + var m Manifest + data, err := json.Marshal(m) + if err != nil { + t.Fatal(err) + } + if !bytes.Contains(data, []byte(`"config":{"digest":"sha256:`)) { + t.Error("expected manifest to contain empty config") + t.Fatalf("got:\n%s", string(data)) + } +} + +func link(c *blob.DiskCache, name string, manifest string) { + _, n, _, err := parseName(name) + if err != nil { + panic(err) + } + d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest))) + if err != nil { + panic(err) + } + if err := c.Link(n.String(), d); err != nil { + panic(err) + } +} + +var errRoundTrip = errors.New("forced roundtrip error") + +type recordRoundTripper http.HandlerFunc + +func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + w := httptest.NewRecorder() + rr(w, req) + if w.Code == 499 { + return nil, errRoundTrip + } + resp := w.Result() + // For some reason, Response.Request is not set by httptest.NewRecorder, so we + // set it manually. + resp.Request = req + return w.Result(), nil +} + +// newClient constructs a cache with predefined manifests for testing. The manifests are: +// +// empty: no data +// zero: no layers +// single: one layer with the contents "exists" +// multiple: two layers with the contents "exists" and "here" +// notfound: a layer that does not exist in the cache +// null: one null layer (e.g. [null]) +// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size) +// invalid: a layer with invalid JSON data +// +// Tests that want to ensure the client does not communicate with the upstream +// registry should pass a nil handler, which will cause a panic if +// communication is attempted. +// +// To simulate a network error, pass a handler that returns a 499 status code. +func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { + t.Helper() + c, err := blob.Open(t.TempDir()) + if err != nil { + t.Fatal(err) + } + + mklayer := func(data string) *Layer { + return &Layer{ + Digest: importBytes(t, c, data), + Size: int64(len(data)), + } + } + + commit := func(name string, layers ...*Layer) { + t.Helper() + data, err := json.Marshal(&Manifest{Layers: layers}) + if err != nil { + t.Fatal(err) + } + link(c, name, string(data)) + } + + link(c, "empty", "") + commit("zero") + commit("single", mklayer("exists")) + commit("multiple", mklayer("exists"), mklayer("present")) + commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))}) + commit("null", nil) + commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499}) + link(c, "invalid", "!!!!!") + + rc := &Registry{ + HTTPClient: &http.Client{ + Transport: recordRoundTripper(h), + }, + } + return rc, c +} + +func okHandler(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) +} + +func checkErrCode(t *testing.T, err error, status int, code string) { + t.Helper() + var e *Error + if !errors.As(err, &e) || e.Status != status || e.Code != code { + t.Errorf("err = %v; want %v %v", err, status, code) + } +} + +func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest { + d, err := c.Import(strings.NewReader(data), int64(len(data))) + if err != nil { + t.Fatal(err) + } + return d +} + +func TestRegistryPushInvalidNames(t *testing.T) { + rc, c := newClient(t, nil) + + cases := []struct { + name string + err error + }{ + {"", ErrNameInvalid}, + {"@", ErrNameInvalid}, + {"@x", blob.ErrInvalidDigest}, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + // Create a new registry and push a new image. + err := rc.Push(t.Context(), c, tt.name, nil) + if !errors.Is(err, tt.err) { + t.Errorf("err = %v; want %v", err, tt.err) + } + }) + } +} + +func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) { + t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }} + return WithTrace(ctx, t), t +} + +func TestPushZero(t *testing.T) { + rc, c := newClient(t, okHandler) + err := rc.Push(t.Context(), c, "empty", nil) + if !errors.Is(err, ErrManifestInvalid) { + t.Errorf("err = %v; want %v", err, ErrManifestInvalid) + } +} + +func TestPushSingle(t *testing.T) { + rc, c := newClient(t, okHandler) + err := rc.Push(t.Context(), c, "single", nil) + testutil.Check(t, err) +} + +func TestPushMultiple(t *testing.T) { + rc, c := newClient(t, okHandler) + err := rc.Push(t.Context(), c, "multiple", nil) + testutil.Check(t, err) +} + +func TestPushNotFound(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + t.Errorf("unexpected request: %v", r) + }) + err := rc.Push(t.Context(), c, "notfound", nil) + if !errors.Is(err, fs.ErrNotExist) { + t.Errorf("err = %v; want %v", err, fs.ErrNotExist) + } +} + +func TestPushNullLayer(t *testing.T) { + rc, c := newClient(t, nil) + err := rc.Push(t.Context(), c, "null", nil) + if err == nil || !strings.Contains(err.Error(), "invalid manifest") { + t.Errorf("err = %v; want invalid manifest", err) + } +} + +func TestPushSizeMismatch(t *testing.T) { + rc, c := newClient(t, nil) + ctx, _ := withTraceUnexpected(t.Context()) + got := rc.Push(ctx, c, "sizemismatch", nil) + if got == nil || !strings.Contains(got.Error(), "size mismatch") { + t.Errorf("err = %v; want size mismatch", got) + } +} + +func TestPushInvalid(t *testing.T) { + rc, c := newClient(t, nil) + err := rc.Push(t.Context(), c, "invalid", nil) + if err == nil || !strings.Contains(err.Error(), "invalid manifest") { + t.Errorf("err = %v; want invalid manifest", err) + } +} + +func TestPushExistsAtRemote(t *testing.T) { + var pushed bool + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/uploads/") { + if !pushed { + // First push. Return an uploadURL. + pushed = true + w.Header().Set("Location", "http://blob.store/blobs/123") + return + } + w.WriteHeader(http.StatusAccepted) + return + } + + io.Copy(io.Discard, r.Body) + w.WriteHeader(http.StatusOK) + }) + + rc.MaxStreams = 1 // prevent concurrent uploads + + var errs []error + ctx := WithTrace(t.Context(), &Trace{ + Update: func(_ *Layer, n int64, err error) { + // uploading one at a time so no need to lock + errs = append(errs, err) + }, + }) + + check := testutil.Checker(t) + + err := rc.Push(ctx, c, "single", nil) + check(err) + + if !errors.Is(errors.Join(errs...), nil) { + t.Errorf("errs = %v; want %v", errs, []error{ErrCached}) + } + + err = rc.Push(ctx, c, "single", nil) + check(err) +} + +func TestPushRemoteError(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/blobs/") { + w.WriteHeader(500) + io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`) + return + } + }) + got := rc.Push(t.Context(), c, "single", nil) + checkErrCode(t, got, 500, "blob_error") +} + +func TestPushLocationError(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", ":///x") + w.WriteHeader(http.StatusAccepted) + }) + got := rc.Push(t.Context(), c, "single", nil) + wantContains := "invalid upload URL" + if got == nil || !strings.Contains(got.Error(), wantContains) { + t.Errorf("err = %v; want to contain %v", got, wantContains) + } +} + +func TestPushUploadRoundtripError(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + if r.Host == "blob.store" { + w.WriteHeader(499) // force RoundTrip error on upload + return + } + w.Header().Set("Location", "http://blob.store/blobs/123") + }) + got := rc.Push(t.Context(), c, "single", nil) + if !errors.Is(got, errRoundTrip) { + t.Errorf("got = %v; want %v", got, errRoundTrip) + } +} + +func TestPushUploadFileOpenError(t *testing.T) { + rc, c := newClient(t, okHandler) + ctx := WithTrace(t.Context(), &Trace{ + Update: func(l *Layer, _ int64, err error) { + // Remove the file just before it is opened for upload, + // but after the initial Stat that happens before the + // upload starts + os.Remove(c.GetFile(l.Digest)) + }, + }) + got := rc.Push(ctx, c, "single", nil) + if !errors.Is(got, fs.ErrNotExist) { + t.Errorf("got = %v; want fs.ErrNotExist", got) + } +} + +func TestPushCommitRoundtripError(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/blobs/") { + panic("unexpected") + } + w.WriteHeader(499) // force RoundTrip error + }) + err := rc.Push(t.Context(), c, "zero", nil) + if !errors.Is(err, errRoundTrip) { + t.Errorf("err = %v; want %v", err, errRoundTrip) + } +} + +func checkNotExist(t *testing.T, err error) { + t.Helper() + if !errors.Is(err, fs.ErrNotExist) { + t.Fatalf("err = %v; want fs.ErrNotExist", err) + } +} + +func TestRegistryPullInvalidName(t *testing.T) { + rc, c := newClient(t, nil) + err := rc.Pull(t.Context(), c, "://") + if !errors.Is(err, ErrNameInvalid) { + t.Errorf("err = %v; want %v", err, ErrNameInvalid) + } +} + +func TestRegistryPullInvalidManifest(t *testing.T) { + cases := []string{ + "", + "null", + "!!!", + `{"layers":[]}`, + } + + for _, resp := range cases { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, resp) + }) + err := rc.Pull(t.Context(), c, "x") + if !errors.Is(err, ErrManifestInvalid) { + t.Errorf("err = %v; want invalid manifest", err) + } + } +} + +func TestRegistryPullNotCached(t *testing.T) { + check := testutil.Checker(t) + + var c *blob.DiskCache + var rc *Registry + + d := blob.DigestFromBytes("some data") + rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/blobs/") { + io.WriteString(w, "some data") + return + } + fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d) + }) + + // Confirm that the layer does not exist locally + _, err := ResolveLocal(c, "model") + checkNotExist(t, err) + + _, err = c.Get(d) + checkNotExist(t, err) + + err = rc.Pull(t.Context(), c, "model") + check(err) + + mw, err := rc.Resolve(t.Context(), "model") + check(err) + mg, err := ResolveLocal(c, "model") + check(err) + if !reflect.DeepEqual(mw, mg) { + t.Errorf("mw = %v; mg = %v", mw, mg) + } + + // Confirm successful download + info, err := c.Get(d) + check(err) + if info.Digest != d { + t.Errorf("info.Digest = %v; want %v", info.Digest, d) + } + if info.Size != 9 { + t.Errorf("info.Size = %v; want %v", info.Size, 9) + } + + data, err := os.ReadFile(c.GetFile(d)) + check(err) + if string(data) != "some data" { + t.Errorf("data = %q; want %q", data, "exists") + } +} + +func TestRegistryPullCached(t *testing.T) { + cached := blob.DigestFromBytes("exists") + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/blobs/") { + w.WriteHeader(499) // should not be called + return + } + if strings.Contains(r.URL.Path, "/manifests/") { + fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached) + } + }) + + var errs []error + var reads []int64 + ctx := WithTrace(t.Context(), &Trace{ + Update: func(d *Layer, n int64, err error) { + t.Logf("update %v %d %v", d, n, err) + reads = append(reads, n) + errs = append(errs, err) + }, + }) + + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + err := rc.Pull(ctx, c, "single") + testutil.Check(t, err) + + want := []int64{6} + if !errors.Is(errors.Join(errs...), ErrCached) { + t.Errorf("errs = %v; want %v", errs, ErrCached) + } + if !slices.Equal(reads, want) { + t.Errorf("pairs = %v; want %v", reads, want) + } +} + +func TestRegistryPullManifestNotFound(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }) + err := rc.Pull(t.Context(), c, "notfound") + checkErrCode(t, err, 404, "") +} + +func TestRegistryPullResolveRemoteError(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + io.WriteString(w, `{"errors":[{"code":"an_error"}]}`) + }) + err := rc.Pull(t.Context(), c, "single") + checkErrCode(t, err, 500, "an_error") +} + +func TestRegistryPullResolveRoundtripError(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/manifests/") { + w.WriteHeader(499) // force RoundTrip error + return + } + }) + err := rc.Pull(t.Context(), c, "single") + if !errors.Is(err, errRoundTrip) { + t.Errorf("err = %v; want %v", err, errRoundTrip) + } +} + +// TestRegistryPullMixedCachedNotCached tests that cached layers do not +// interfere with pulling layers that are not cached +func TestRegistryPullMixedCachedNotCached(t *testing.T) { + x := blob.DigestFromBytes("xxxxxx") + e := blob.DigestFromBytes("exists") + y := blob.DigestFromBytes("yyyyyy") + + for i := range 10 { + t.Logf("iteration %d", i) + + digests := []blob.Digest{x, e, y} + + rand.Shuffle(len(digests), func(i, j int) { + digests[i], digests[j] = digests[j], digests[i] + }) + + manifest := fmt.Sprintf(`{ + "layers": [ + {"digest":"%s","size":6}, + {"digest":"%s","size":6}, + {"digest":"%s","size":6} + ] + }`, digests[0], digests[1], digests[2]) + + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + switch path.Base(r.URL.Path) { + case "latest": + io.WriteString(w, manifest) + case x.String(): + io.WriteString(w, "xxxxxx") + case e.String(): + io.WriteString(w, "exists") + case y.String(): + io.WriteString(w, "yyyyyy") + default: + panic(fmt.Sprintf("unexpected request: %v", r)) + } + }) + + ctx := WithTrace(t.Context(), &Trace{ + Update: func(l *Layer, n int64, err error) { + t.Logf("update %v %d %v", l, n, err) + }, + }) + + // Check that we pull all layers that we can. + + err := rc.Pull(ctx, c, "mixed") + if err != nil { + t.Fatal(err) + } + + for _, d := range digests { + info, err := c.Get(d) + if err != nil { + t.Fatalf("Get(%v): %v", d, err) + } + if info.Size != 6 { + t.Errorf("info.Size = %v; want %v", info.Size, 6) + } + } + } +} + +func TestRegistryPullChunking(t *testing.T) { + rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range")) + if r.URL.Host != "blob.store" { + // The production registry redirects to the blob store. + http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound) + return + } + if strings.Contains(r.URL.Path, "/blobs/") { + rng := r.Header.Get("Range") + if rng == "" { + http.Error(w, "missing range", http.StatusBadRequest) + return + } + _, c, err := chunks.ParseRange(r.Header.Get("Range")) + if err != nil { + panic(err) + } + io.WriteString(w, "remote"[c.Start:c.End+1]) + return + } + fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote")) + }) + + // Force chunking by setting the threshold to less than the size of the + // layer. + rc.ChunkingThreshold = 3 + rc.MaxChunkSize = 3 + + var reads []int64 + ctx := WithTrace(t.Context(), &Trace{ + Update: func(d *Layer, n int64, err error) { + if err != nil { + t.Errorf("update %v %d %v", d, n, err) + } + reads = append(reads, n) + }, + }) + + err := rc.Pull(ctx, c, "remote") + testutil.Check(t, err) + + want := []int64{0, 3, 6} + if !slices.Equal(reads, want) { + t.Errorf("reads = %v; want %v", reads, want) + } +} + +func TestRegistryResolveByDigest(t *testing.T) { + check := testutil.Checker(t) + + exists := blob.DigestFromBytes("exists") + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v2/alice/palace/blobs/"+exists.String() { + w.WriteHeader(499) // should not hit manifest endpoint + } + fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists) + }) + + _, err := rc.Resolve(t.Context(), "alice/palace@"+exists.String()) + check(err) +} + +func TestInsecureSkipVerify(t *testing.T) { + exists := blob.DigestFromBytes("exists") + + s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists) + })) + defer s.Close() + + const name = "ollama.com/library/insecure" + + var rc Registry + url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name) + _, err := rc.Resolve(t.Context(), url) + if err == nil || !strings.Contains(err.Error(), "failed to verify") { + t.Errorf("err = %v; want cert verifiction failure", err) + } + + url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name) + _, err = rc.Resolve(t.Context(), url) + testutil.Check(t, err) +} + +func TestCanRetry(t *testing.T) { + cases := []struct { + err error + want bool + }{ + {nil, false}, + {errors.New("x"), false}, + {ErrCached, false}, + {ErrManifestInvalid, false}, + {ErrNameInvalid, false}, + {&Error{Status: 100}, false}, + {&Error{Status: 500}, true}, + } + for _, tt := range cases { + if got := canRetry(tt.err); got != tt.want { + t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want) + } + } +} diff --git a/server/internal/client/ollama/trace.go b/server/internal/client/ollama/trace.go new file mode 100644 index 00000000..8e53040a --- /dev/null +++ b/server/internal/client/ollama/trace.go @@ -0,0 +1,48 @@ +package ollama + +import ( + "context" +) + +// Trace is a set of functions that are called to report progress during blob +// downloads and uploads. +type Trace struct { + // Update is called during [Registry.Push] and [Registry.Pull] to + // report the progress of blob uploads and downloads. + // + // It is called once at the beginning of the download with a zero n and + // then once per read operation with the number of bytes read so far, + // and an error if any. + // + // A function assigned must be safe for concurrent use. The function is + // called synchronously and so should not block or take long to run. + Update func(_ *Layer, n int64, _ error) +} + +func (t *Trace) update(l *Layer, n int64, err error) { + if t.Update != nil { + t.Update(l, n, err) + } +} + +type traceKey struct{} + +// WithTrace returns a context derived from ctx that uses t to report trace +// events. +func WithTrace(ctx context.Context, t *Trace) context.Context { + return context.WithValue(ctx, traceKey{}, t) +} + +var emptyTrace = &Trace{} + +// traceFromContext returns the Trace associated with ctx, or an empty Trace if +// none is found. +// +// It never returns nil. +func traceFromContext(ctx context.Context) *Trace { + t, _ := ctx.Value(traceKey{}).(*Trace) + if t == nil { + return emptyTrace + } + return t +} diff --git a/server/internal/cmd/opp/internal/safetensors/safetensors.go b/server/internal/cmd/opp/internal/safetensors/safetensors.go new file mode 100644 index 00000000..7f3e9979 --- /dev/null +++ b/server/internal/cmd/opp/internal/safetensors/safetensors.go @@ -0,0 +1,220 @@ +// safetensors provides a reader for the safetensor directories and files. +package safetensors + +import ( + "encoding/json" + "fmt" + "io" + "io/fs" + "iter" + "slices" + "strconv" + "strings" +) + +// Tensor represents a single tensor in a safetensors file. +// +// It's zero value is not valid. Use [Model.Tensors] to get valid tensors. +// +// It is not safe for use across multiple goroutines. +type Tensor struct { + name string + dataType string + shape []int64 + + fsys fs.FS + fname string // entry name in fsys + offset int64 + size int64 +} + +type Model struct { + fsys fs.FS +} + +func Read(fsys fs.FS) (*Model, error) { + return &Model{fsys: fsys}, nil +} + +func (m *Model) Tensors() iter.Seq2[*Tensor, error] { + return func(yield func(*Tensor, error) bool) { + entries, err := fs.Glob(m.fsys, "*.safetensors") + if err != nil { + yield(nil, err) + return + } + for _, e := range entries { + tt, err := m.readTensors(e) + if err != nil { + yield(nil, err) + return + } + for _, t := range tt { + if !yield(t, nil) { + return + } + } + } + } +} + +func (m *Model) readTensors(fname string) ([]*Tensor, error) { + f, err := m.fsys.Open(fname) + if err != nil { + return nil, err + } + defer f.Close() + + finfo, err := f.Stat() + if err != nil { + return nil, err + } + + headerSize, err := readInt64(f) + if err != nil { + return nil, err + } + + data := make([]byte, headerSize) + _, err = io.ReadFull(f, data) + if err != nil { + return nil, err + } + + var raws map[string]json.RawMessage + if err := json.Unmarshal(data, &raws); err != nil { + return nil, err + } + + // TODO(bmizerany): do something with metadata? This could be another + // header read if needed. We also need to figure out if the metadata is + // present in only one .safetensors file or if each file may have their + // own and if it needs to follow each tensor. Currently, I (bmizerany) + // am only seeing them show up with one entry for file type which is + // always "pt". + + tt := make([]*Tensor, 0, len(raws)) + for name, raw := range raws { + if !strings.HasPrefix(name, "model.layer") { + continue + } + var v struct { + DataType string `json:"dtype"` + Shape []int64 `json:"shape"` + Offsets []int64 `json:"data_offsets"` + } + if err := json.Unmarshal(raw, &v); err != nil { + return nil, fmt.Errorf("error unmarshalling layer %q: %w", name, err) + } + if len(v.Offsets) != 2 { + return nil, fmt.Errorf("invalid offsets for %q: %v", name, v.Offsets) + } + + // TODO(bmizerany): after collecting, validate all offests make + // tensors contiguous? + begin, end := v.Offsets[0], v.Offsets[1] + if err := checkBeginEnd(finfo.Size(), begin, end); err != nil { + return nil, err + } + + // TODO(bmizerany): just yield.. don't be silly and make a slice :) + tt = append(tt, &Tensor{ + name: name, + dataType: v.DataType, + shape: v.Shape, + fsys: m.fsys, + fname: fname, + offset: begin, + size: end - begin, + }) + } + return tt, nil +} + +func checkBeginEnd(size, begin, end int64) error { + if begin < 0 { + return fmt.Errorf("begin must not be negative: %d", begin) + } + if end < 0 { + return fmt.Errorf("end must not be negative: %d", end) + } + if end < begin { + return fmt.Errorf("end must be >= begin: %d < %d", end, begin) + } + if end > size { + return fmt.Errorf("end must be <= size: %d > %d", end, size) + } + return nil +} + +func readInt64(r io.Reader) (int64, error) { + var v uint64 + var buf [8]byte + if _, err := io.ReadFull(r, buf[:]); err != nil { + return 0, err + } + for i := range buf { + v |= uint64(buf[i]) << (8 * i) + } + return int64(v), nil +} + +type Shape []int64 + +func (s Shape) String() string { + var b strings.Builder + b.WriteByte('[') + for i, v := range s { + if i > 0 { + b.WriteByte(',') + } + b.WriteString(strconv.FormatInt(v, 10)) + } + b.WriteByte(']') + return b.String() +} + +func (t *Tensor) Name() string { return t.name } +func (t *Tensor) DataType() string { return t.dataType } +func (t *Tensor) Size() int64 { return t.size } +func (t *Tensor) Shape() Shape { return slices.Clone(t.shape) } + +func (t *Tensor) Reader() (io.ReadCloser, error) { + f, err := t.fsys.Open(t.fname) + if err != nil { + return nil, err + } + r := newSectionReader(f, t.offset, t.size) + rc := struct { + io.Reader + io.Closer + }{r, f} + return rc, nil +} + +// newSectionReader returns a new io.Reader that reads from r starting at +// offset. It is a convenience function for creating a io.SectionReader when r +// may not be an io.ReaderAt. +// +// If r is already a ReaderAt, it is returned directly, otherwise if r is an +// io.Seeker, a new io.ReaderAt is returned that wraps r after seeking to the +// beginning of the file. +// +// If r is an io.Seeker, +// or slow path. The slow path is used when r does not implement io.ReaderAt, +// in which case it must discard the data it reads. +func newSectionReader(r io.Reader, offset, n int64) io.Reader { + if r, ok := r.(io.ReaderAt); ok { + return io.NewSectionReader(r, offset, n) + } + if r, ok := r.(io.ReadSeeker); ok { + r.Seek(offset, io.SeekStart) + return io.LimitReader(r, n) + } + // Discard to offset and return a limited reader. + _, err := io.CopyN(io.Discard, r, offset) + if err != nil { + return nil + } + return io.LimitReader(r, n) +} diff --git a/server/internal/cmd/opp/opp.go b/server/internal/cmd/opp/opp.go new file mode 100644 index 00000000..12199cf3 --- /dev/null +++ b/server/internal/cmd/opp/opp.go @@ -0,0 +1,366 @@ +package main + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "log" + "mime" + "net/http" + "os" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/ollama/ollama/server/internal/cache/blob" + "github.com/ollama/ollama/server/internal/client/ollama" + "github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors" + "golang.org/x/sync/errgroup" +) + +var stdout io.Writer = os.Stdout + +const usage = `Opp is a tool for pushing and pulling Ollama models. + +Usage: + + opp [flags] + +Commands: + + push Upload a model to the Ollama server. + pull Download a model from the Ollama server. + import Import a model from a local safetensor directory. + +Examples: + + # Pull a model from the Ollama server. + opp pull library/llama3.2:latest + + # Push a model to the Ollama server. + opp push username/my_model:8b + + # Import a model from a local safetensor directory. + opp import /path/to/safetensor + +Envionment Variables: + + OLLAMA_MODELS + The directory where models are pushed and pulled from + (default ~/.ollama/models). +` + +func main() { + flag.Usage = func() { + fmt.Fprint(os.Stderr, usage) + } + flag.Parse() + + c, err := ollama.DefaultCache() + if err != nil { + log.Fatal(err) + } + + rc, err := ollama.RegistryFromEnv() + if err != nil { + log.Fatal(err) + } + + ctx := context.Background() + + err = func() error { + switch cmd := flag.Arg(0); cmd { + case "pull": + return cmdPull(ctx, rc, c) + case "push": + return cmdPush(ctx, rc, c) + case "import": + return cmdImport(ctx, c) + default: + if cmd == "" { + flag.Usage() + } else { + fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd) + } + os.Exit(2) + return errors.New("unreachable") + } + }() + if err != nil { + fmt.Fprintf(os.Stderr, "opp: %v\n", err) + os.Exit(1) + } +} + +func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error { + model := flag.Arg(1) + if model == "" { + flag.Usage() + os.Exit(1) + } + + tr := http.DefaultTransport.(*http.Transport).Clone() + // TODO(bmizerany): configure transport? + rc.HTTPClient = &http.Client{Transport: tr} + + var mu sync.Mutex + p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded] + + var pb bytes.Buffer + printProgress := func() { + pb.Reset() + mu.Lock() + for d, s := range p { + // Write progress to a buffer first to avoid blocking + // on stdout while holding the lock. + stamp := time.Now().Format("2006/01/02 15:04:05") + fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0])) + if s[0] == s[1] { + delete(p, d) + } + } + mu.Unlock() + io.Copy(stdout, &pb) + } + + ctx = ollama.WithTrace(ctx, &ollama.Trace{ + Update: func(l *ollama.Layer, n int64, err error) { + if err != nil && !errors.Is(err, ollama.ErrCached) { + fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err) + return + } + + mu.Lock() + p[l.Digest] = [2]int64{l.Size, n} + mu.Unlock() + }, + }) + + errc := make(chan error) + go func() { + errc <- rc.Pull(ctx, c, model) + }() + + t := time.NewTicker(time.Second) + defer t.Stop() + for { + select { + case <-t.C: + printProgress() + case err := <-errc: + printProgress() + return err + } + } +} + +func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error { + args := flag.Args()[1:] + flag := flag.NewFlagSet("push", flag.ExitOnError) + flagFrom := flag.String("from", "", "Use the manifest from a model by another name.") + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: opp push \n") + flag.PrintDefaults() + } + flag.Parse(args) + + model := flag.Arg(0) + if model == "" { + return fmt.Errorf("missing model argument") + } + + from := cmp.Or(*flagFrom, model) + m, err := ollama.ResolveLocal(c, from) + if err != nil { + return err + } + + ctx = ollama.WithTrace(ctx, &ollama.Trace{ + Update: func(l *ollama.Layer, n int64, err error) { + switch { + case errors.Is(err, ollama.ErrCached): + fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n) + case err != nil: + fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err) + case n == 0: + l := m.Layer(l.Digest) + mt, p, _ := mime.ParseMediaType(l.MediaType) + mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.") + switch mt { + case "tensor": + fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"]) + default: + fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType) + } + } + }, + }) + + return rc.Push(ctx, c, model, &ollama.PushParams{ + From: from, + }) +} + +type trackingReader struct { + io.Reader + n *atomic.Int64 +} + +func (r *trackingReader) Read(p []byte) (n int, err error) { + n, err = r.Reader.Read(p) + r.n.Add(int64(n)) + return n, err +} + +func cmdImport(ctx context.Context, c *blob.DiskCache) error { + args := flag.Args()[1:] + flag := flag.NewFlagSet("import", flag.ExitOnError) + flagAs := flag.String("as", "", "Import using the provided name.") + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: opp import \n") + flag.PrintDefaults() + } + flag.Parse(args) + + dir := cmp.Or(flag.Arg(0), ".") + fmt.Fprintf(os.Stderr, "Reading %s\n", dir) + + m, err := safetensors.Read(os.DirFS(dir)) + if err != nil { + return err + } + + var total int64 + var tt []*safetensors.Tensor + for t, err := range m.Tensors() { + if err != nil { + return err + } + tt = append(tt, t) + total += t.Size() + } + + var n atomic.Int64 + done := make(chan error) + go func() { + layers := make([]*ollama.Layer, len(tt)) + var g errgroup.Group + g.SetLimit(runtime.GOMAXPROCS(0)) + var ctxErr error + for i, t := range tt { + if ctx.Err() != nil { + // The context may cancel AFTER we exit the + // loop, and so if we use ctx.Err() after the + // loop we may report it as the error that + // broke the loop, when it was not. This can + // manifest as a false-negative, leading the + // user to think their import failed when it + // did not, so capture it if and only if we + // exit the loop because of a ctx.Err() and + // report it. + ctxErr = ctx.Err() + break + } + g.Go(func() (err error) { + rc, err := t.Reader() + if err != nil { + return err + } + defer rc.Close() + tr := &trackingReader{rc, &n} + d, err := c.Import(tr, t.Size()) + if err != nil { + return err + } + if err := rc.Close(); err != nil { + return err + } + + layers[i] = &ollama.Layer{ + Digest: d, + Size: t.Size(), + MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{ + "name": t.Name(), + "dtype": t.DataType(), + "shape": t.Shape().String(), + }), + } + + return nil + }) + } + + done <- func() error { + if err := errors.Join(g.Wait(), ctxErr); err != nil { + return err + } + m := &ollama.Manifest{Layers: layers} + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return err + } + d := blob.DigestFromBytes(data) + err = blob.PutBytes(c, d, data) + if err != nil { + return err + } + return c.Link(*flagAs, d) + }() + }() + + fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir) + + csiHideCursor(stdout) + defer csiShowCursor(stdout) + + csiSavePos(stdout) + writeProgress := func() { + csiRestorePos(stdout) + nn := n.Load() + fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n", + formatNatural(nn), + formatNatural(total), + nn*100/total, + ansiClearToEndOfLine, + ) + } + + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + writeProgress() + case err := <-done: + writeProgress() + return err + } + } +} + +func formatNatural(n int64) string { + switch { + case n < 1024: + return fmt.Sprintf("%d B", n) + case n < 1024*1024: + return fmt.Sprintf("%.1f KB", float64(n)/1024) + case n < 1024*1024*1024: + return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024)) + default: + return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024)) + } +} + +const ansiClearToEndOfLine = "\033[K" + +func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") } +func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") } +func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") } +func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") } diff --git a/server/internal/cmd/oppbench/oppbench.go b/server/internal/cmd/oppbench/oppbench.go new file mode 100644 index 00000000..7a530594 --- /dev/null +++ b/server/internal/cmd/oppbench/oppbench.go @@ -0,0 +1,11 @@ +package main + +import ( + "fmt" + "os" +) + +func main() { + fmt.Println("Run as 'go test -bench=.' to run the benchmarks") + os.Exit(1) +} diff --git a/server/internal/cmd/oppbench/oppbench_test.go b/server/internal/cmd/oppbench/oppbench_test.go new file mode 100644 index 00000000..c71d6cde --- /dev/null +++ b/server/internal/cmd/oppbench/oppbench_test.go @@ -0,0 +1,107 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "sync/atomic" + "testing" + "time" + + "github.com/ollama/ollama/server/internal/chunks" + "golang.org/x/sync/errgroup" +) + +func BenchmarkDownload(b *testing.B) { + run := func(fileSize, chunkSize int64) { + name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize) + b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) }) + } + + run(100<<20, 8<<20) + run(100<<20, 16<<20) + run(100<<20, 32<<20) + run(100<<20, 64<<20) + run(100<<20, 128<<20) // 1 chunk +} + +func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error { + const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d" + req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil) + if err != nil { + return err + } + req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk)) + res, err := c.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + + _, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read + return err +} + +var sleepTime atomic.Int64 + +func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) { + client := &http.Client{ + Transport: func() http.RoundTripper { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.DisableKeepAlives = true + return tr + }(), + } + defer client.CloseIdleConnections() + + // warm up the client + run(context.Background(), client, chunks.New(0, 1<<20)) + + b.SetBytes(fileSize) + b.ReportAllocs() + + // Give our CDN a min to breathe between benchmarks. + time.Sleep(time.Duration(sleepTime.Swap(3))) + + for b.Loop() { + g, ctx := errgroup.WithContext(b.Context()) + g.SetLimit(runtime.GOMAXPROCS(0)) + for chunk := range chunks.Of(fileSize, chunkSize) { + g.Go(func() error { return run(ctx, client, chunk) }) + } + if err := g.Wait(); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkWrite(b *testing.B) { + b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) }) +} + +func benchmarkWrite(b *testing.B, chunkSize int) { + b.ReportAllocs() + + dir := b.TempDir() + f, err := os.Create(filepath.Join(dir, "write-single")) + if err != nil { + b.Fatal(err) + } + defer f.Close() + + data := make([]byte, chunkSize) + b.SetBytes(int64(chunkSize)) + r := bytes.NewReader(data) + for b.Loop() { + r.Reset(data) + _, err := io.Copy(f, r) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/server/internal/internal/backoff/backoff.go b/server/internal/internal/backoff/backoff.go new file mode 100644 index 00000000..1f0634f7 --- /dev/null +++ b/server/internal/internal/backoff/backoff.go @@ -0,0 +1,48 @@ +package backoff + +import ( + "context" + "iter" + "math/rand/v2" + "time" +) + +func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] { + var n int + return func(yield func(int, error) bool) { + var t *time.Timer + for { + if ctx.Err() != nil { + yield(n, ctx.Err()) + return + } + + if !yield(n, nil) { + return + } + + n++ + + // n^2 backoff timer is a little smoother than the + // common choice of 2^n. + d := time.Duration(n*n) * 10 * time.Millisecond + if d > maxBackoff { + d = maxBackoff + } + // Randomize the delay between 0.5-1.5 x msec, in order + // to prevent accidental "thundering herd" problems. + d = time.Duration(float64(d) * (rand.Float64() + 0.5)) + + if t == nil { + t = time.NewTimer(d) + } else { + t.Reset(d) + } + select { + case <-ctx.Done(): + t.Stop() + case <-t.C: + } + } + } +} diff --git a/server/internal/internal/backoff/backoff_synctest_test.go b/server/internal/internal/backoff/backoff_synctest_test.go new file mode 100644 index 00000000..cf17ce80 --- /dev/null +++ b/server/internal/internal/backoff/backoff_synctest_test.go @@ -0,0 +1,40 @@ +//go:build goexperiment.synctest + +package backoff + +import ( + "context" + "errors" + "testing" + "testing/synctest" + "time" +) + +func TestLoop(t *testing.T) { + synctest.Run(func() { + last := -1 + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + for n, err := range Loop(ctx, 100*time.Millisecond) { + if !errors.Is(err, ctx.Err()) { + t.Errorf("err = %v, want nil", err) + } + if err != nil { + break + } + if n != last+1 { + t.Errorf("n = %d, want %d", n, last+1) + } + last = n + if n > 5 { + cancel() + } + } + + if last != 6 { + t.Errorf("last = %d, want 6", last) + } + }) +} diff --git a/server/internal/internal/backoff/backoff_test.go b/server/internal/internal/backoff/backoff_test.go new file mode 100644 index 00000000..bb8438a7 --- /dev/null +++ b/server/internal/internal/backoff/backoff_test.go @@ -0,0 +1,38 @@ +package backoff + +import ( + "context" + "testing" + "testing/synctest" + "time" +) + +func TestLoopAllocs(t *testing.T) { + for i := range 3 { + got := testing.AllocsPerRun(1000, func() { + for tick := range Loop(t.Context(), 1) { + if tick >= i { + break + } + } + }) + want := float64(0) + if i > 0 { + want = 3 // due to time.NewTimer + } + if got > want { + t.Errorf("[%d ticks]: allocs = %v, want 0", i, want) + } + } +} + +func BenchmarkLoop(b *testing.B) { + ctx := context.Background() + synctest.Run(func() { + for n := range Loop(ctx, 100*time.Millisecond) { + if n == b.N { + break + } + } + }) +} diff --git a/server/internal/internal/names/name.go b/server/internal/internal/names/name.go new file mode 100644 index 00000000..361cce76 --- /dev/null +++ b/server/internal/internal/names/name.go @@ -0,0 +1,229 @@ +package names + +import ( + "cmp" + "fmt" + "strings" + + "github.com/ollama/ollama/server/internal/internal/stringsx" +) + +const MaxNameLength = 50 + 1 + 50 + 1 + 50 // /: + +type Name struct { + // Make incomparable to enfoce use of Compare / Equal for + // case-insensitive comparisons. + _ [0]func() + + h string + n string + m string + t string +} + +// Parse parses and assembles a Name from a name string. The +// format of a valid name string is: +// +// s: +// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest } +// { host } "/" { namespace } "/" { model } ":" { tag } +// { host } "/" { namespace } "/" { model } "@" { digest } +// { host } "/" { namespace } "/" { model } +// { namespace } "/" { model } ":" { tag } "@" { digest } +// { namespace } "/" { model } ":" { tag } +// { namespace } "/" { model } "@" { digest } +// { namespace } "/" { model } +// { model } ":" { tag } "@" { digest } +// { model } ":" { tag } +// { model } "@" { digest } +// { model } +// "@" { digest } +// host: +// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }* +// length: [1, 350] +// namespace: +// pattern: { alphanum | "_" } { alphanum | "-" | "_" }* +// length: [1, 80] +// model: +// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }* +// length: [1, 80] +// tag: +// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }* +// length: [1, 80] +// digest: +// pattern: { alphanum | "_" } { alphanum | "-" | ":" }* +// length: [1, 80] +// +// The name returned is not guaranteed to be valid. If it is not valid, the +// field values are left in an undefined state. Use [Name.IsValid] to check +// if the name is valid. +func Parse(s string) Name { + if len(s) > MaxNameLength { + return Name{} + } + + var n Name + var tail string + var c byte + for { + s, tail, c = cutLastAny(s, "/:") + switch c { + case ':': + n.t = tail + continue // look for model + case '/': + n.h, n.n, _ = cutLastAny(s, "/") + n.m = tail + return n + case 0: + n.m = tail + return n + } + } +} + +// ParseExtended parses and returns any scheme, Name, and digest from from s in +// the the form [scheme://][name][@digest]. All parts are optional. +// +// If the scheme is present, it must be followed by "://". The digest is +// prefixed by "@" and comes after the name. The name is parsed using [Parse]. +// +// The scheme and digest are stripped before the name is parsed by [Parse]. +// +// For convience, the scheme is never empty. If the scheme is not present, the +// returned scheme is "https". +// +// Examples: +// +// http://ollama.com/bmizerany/smol:latest@digest +// https://ollama.com/bmizerany/smol:latest +// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme. +func ParseExtended(s string) (scheme string, _ Name, digest string) { + i := strings.Index(s, "://") + if i >= 0 { + scheme = s[:i] + s = s[i+3:] + } + i = strings.LastIndex(s, "@") + if i >= 0 { + digest = s[i+1:] + s = s[:i] + } + return scheme, Parse(s), digest +} + +func FormatExtended(scheme string, n Name, digest string) string { + var b strings.Builder + if scheme != "" { + b.WriteString(scheme) + b.WriteString("://") + } + b.WriteString(n.String()) + if digest != "" { + b.WriteByte('@') + b.WriteString(digest) + } + return b.String() +} + +// Merge merges two names into a single name. Non-empty host, namespace, and +// tag parts of a take precedence over fields in b. The model field is left as +// is. +// +// The returned name is not guaranteed to be valid. Use [Name.IsValid] to check +// if the name is valid. +func Merge(a, b Name) Name { + a.h = cmp.Or(a.h, b.h) + a.n = cmp.Or(a.n, b.n) + a.t = cmp.Or(a.t, b.t) + return a +} + +// IsValid returns true if the name is valid. +func (n Name) IsValid() bool { + if n.h != "" && !isValidHost(n.h) { + return false + } + if n.n != "" && !isValidNamespace(n.n) { + return false + } + if n.m != "" && !isValidModel(n.m) { + return false + } + if n.t != "" && !isValidTag(n.t) { + return false + } + return true +} + +func (n Name) IsFullyQualified() bool { + return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != "" +} + +func isValidHost(_ string) bool { + return true // TODO: implement +} + +func isValidNamespace(_ string) bool { + return true // TODO: implement +} + +func isValidModel(_ string) bool { + return true // TODO: implement +} + +func isValidTag(_ string) bool { + return true // TODO: implement +} + +func (n Name) Host() string { return n.h } +func (n Name) Namespace() string { return n.n } +func (n Name) Model() string { return n.m } +func (n Name) Tag() string { return n.t } + +// Compare compares n and o case-insensitively. It returns 0 if n and o are +// equal, -1 if n sorts before o, and 1 if n sorts after o. +func (n Name) Compare(o Name) int { + return cmp.Or( + stringsx.CompareFold(n.h, o.h), + stringsx.CompareFold(n.n, o.n), + stringsx.CompareFold(n.m, o.m), + stringsx.CompareFold(n.t, o.t), + ) +} + +// String returns the fully qualified name in the format +// /:. +func (n Name) String() string { + var b strings.Builder + if n.h != "" { + b.WriteString(n.h) + b.WriteByte('/') + } + if n.n != "" { + b.WriteString(n.n) + b.WriteByte('/') + } + b.WriteString(n.m) + if n.t != "" { + b.WriteByte(':') + b.WriteString(n.t) + } + return b.String() +} + +func (n Name) GoString() string { + return fmt.Sprintf("", n.h, n.n, n.m, n.t) +} + +// cutLastAny is like strings.Cut but scans in reverse for the last character +// in chars. If no character is found, before is the empty string and after is +// s. The returned sep is the byte value of the character in chars if one was +// found; otherwise it is 0. +func cutLastAny(s, chars string) (before, after string, sep byte) { + i := strings.LastIndexAny(s, chars) + if i >= 0 { + return s[:i], s[i+1:], s[i] + } + return "", s, 0 +} diff --git a/server/internal/internal/names/name_test.go b/server/internal/internal/names/name_test.go new file mode 100644 index 00000000..760fec5f --- /dev/null +++ b/server/internal/internal/names/name_test.go @@ -0,0 +1,152 @@ +package names + +import ( + "strings" + "testing" +) + +func TestParseName(t *testing.T) { + cases := []struct { + in string + want Name + }{ + {"", Name{}}, + {"m:t", Name{m: "m", t: "t"}}, + {"m", Name{m: "m"}}, + {"/m", Name{m: "m"}}, + {"/n/m:t", Name{n: "n", m: "m", t: "t"}}, + {"n/m", Name{n: "n", m: "m"}}, + {"n/m:t", Name{n: "n", m: "m", t: "t"}}, + {"n/m", Name{n: "n", m: "m"}}, + {"n/m", Name{n: "n", m: "m"}}, + {strings.Repeat("m", MaxNameLength+1), Name{}}, + {"h/n/m:t", Name{h: "h", n: "n", m: "m", t: "t"}}, + {"ollama.com/library/_:latest", Name{h: "ollama.com", n: "library", m: "_", t: "latest"}}, + + // Invalids + // TODO: {"n:t/m:t", Name{}}, + // TODO: {"/h/n/m:t", Name{}}, + } + + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + got := Parse(tt.in) + if got.Compare(tt.want) != 0 { + t.Errorf("parseName(%q) = %#v, want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestString(t *testing.T) { + cases := []string{ + "", + "m:t", + "m:t", + "m", + "n/m", + "n/m:t", + "n/m", + "n/m", + "h/n/m:t", + "ollama.com/library/_:latest", + + // Special cased to "round trip" without the leading slash. + "/m", + "/n/m:t", + } + for _, s := range cases { + t.Run(s, func(t *testing.T) { + s = strings.TrimPrefix(s, "/") + if g := Parse(s).String(); g != s { + t.Errorf("parse(%q).String() = %q", s, g) + } + }) + } +} + +func TestParseExtended(t *testing.T) { + cases := []struct { + in string + + wantScheme string + wantName Name + wantDigest string + }{ + {"", "", Name{}, ""}, + {"m", "", Name{m: "m"}, ""}, + {"http://m", "http", Name{m: "m"}, ""}, + {"http+insecure://m", "http+insecure", Name{m: "m"}, ""}, + {"http://m@sha256:deadbeef", "http", Name{m: "m"}, "sha256:deadbeef"}, + } + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + scheme, name, digest := ParseExtended(tt.in) + if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest { + t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest) + } + + // Round trip + if got := FormatExtended(scheme, name, digest); got != tt.in { + t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got) + } + }) + } +} + +func TestMerge(t *testing.T) { + cases := []struct { + a, b string + want string + }{ + {"", "", ""}, + {"m", "", "m"}, + {"", "m", ""}, + {"x", "y", "x"}, + {"o.com/n/m:t", "o.com/n/m:t", "o.com/n/m:t"}, + {"o.com/n/m:t", "o.com/n/_:t", "o.com/n/m:t"}, + + {"bmizerany/smol", "ollama.com/library/_:latest", "ollama.com/bmizerany/smol:latest"}, + {"localhost:8080/bmizerany/smol", "ollama.com/library/_:latest", "localhost:8080/bmizerany/smol:latest"}, + } + for _, tt := range cases { + t.Run("", func(t *testing.T) { + a, b := Parse(tt.a), Parse(tt.b) + got := Merge(a, b) + if got.Compare(Parse(tt.want)) != 0 { + t.Errorf("merge(%q, %q) = %#v, want %q", tt.a, tt.b, got, tt.want) + } + }) + } +} + +func TestParseStringRoundTrip(t *testing.T) { + cases := []string{ + "", + "m", + "m:t", + "n/m", + "n/m:t", + "n/m:t", + "n/m", + "n/m", + "h/n/m:t", + "ollama.com/library/_:latest", + } + for _, s := range cases { + t.Run(s, func(t *testing.T) { + if got := Parse(s).String(); got != s { + t.Errorf("parse(%q).String() = %q", s, got) + } + }) + } +} + +var junkName Name + +func BenchmarkParseName(b *testing.B) { + b.ReportAllocs() + for range b.N { + junkName = Parse("h/n/m:t") + } +} diff --git a/server/internal/internal/stringsx/stringsx.go b/server/internal/internal/stringsx/stringsx.go new file mode 100644 index 00000000..6c7a8d20 --- /dev/null +++ b/server/internal/internal/stringsx/stringsx.go @@ -0,0 +1,52 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package stringsx provides additional string manipulation functions +// that aren't in the standard library's strings package or go4.org/mem. +package stringsx + +import ( + "unicode" + "unicode/utf8" +) + +// CompareFold returns -1, 0, or 1 depending on whether a < b, a == b, or a > b, +// like cmp.Compare, but case insensitively. +func CompareFold(a, b string) int { + // Track our position in both strings + ia, ib := 0, 0 + for ia < len(a) && ib < len(b) { + ra, wa := nextRuneLower(a[ia:]) + rb, wb := nextRuneLower(b[ib:]) + if ra < rb { + return -1 + } + if ra > rb { + return 1 + } + ia += wa + ib += wb + if wa == 0 || wb == 0 { + break + } + } + + // If we've reached here, one or both strings are exhausted + // The shorter string is "less than" if they match up to this point + switch { + case ia == len(a) && ib == len(b): + return 0 + case ia == len(a): + return -1 + default: + return 1 + } +} + +// nextRuneLower returns the next rune in the string, lowercased, along with its +// original (consumed) width in bytes. If the string is empty, it returns +// (utf8.RuneError, 0) +func nextRuneLower(s string) (r rune, width int) { + r, width = utf8.DecodeRuneInString(s) + return unicode.ToLower(r), width +} diff --git a/server/internal/internal/stringsx/stringsx_test.go b/server/internal/internal/stringsx/stringsx_test.go new file mode 100644 index 00000000..8575c0b2 --- /dev/null +++ b/server/internal/internal/stringsx/stringsx_test.go @@ -0,0 +1,78 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package stringsx + +import ( + "cmp" + "strings" + "testing" +) + +func TestCompareFold(t *testing.T) { + tests := []struct { + a, b string + }{ + // Basic ASCII cases + {"", ""}, + {"a", "a"}, + {"a", "A"}, + {"A", "a"}, + {"a", "b"}, + {"b", "a"}, + {"abc", "ABC"}, + {"ABC", "abc"}, + {"abc", "abd"}, + {"abd", "abc"}, + + // Length differences + {"abc", "ab"}, + {"ab", "abc"}, + + // Unicode cases + {"世界", "世界"}, + {"Hello世界", "hello世界"}, + {"世界Hello", "世界hello"}, + {"世界", "世界x"}, + {"世界x", "世界"}, + + // Special case folding examples + {"ß", "ss"}, // German sharp s + {"fi", "fi"}, // fi ligature + {"Σ", "σ"}, // Greek sigma + {"İ", "i\u0307"}, // Turkish dotted I + + // Mixed cases + {"HelloWorld", "helloworld"}, + {"HELLOWORLD", "helloworld"}, + {"helloworld", "HELLOWORLD"}, + {"HelloWorld", "helloworld"}, + {"helloworld", "HelloWorld"}, + + // Edge cases + {" ", " "}, + {"1", "1"}, + {"123", "123"}, + {"!@#", "!@#"}, + } + + wants := []int{} + for _, tt := range tests { + got := CompareFold(tt.a, tt.b) + want := cmp.Compare(strings.ToLower(tt.a), strings.ToLower(tt.b)) + if got != want { + t.Errorf("CompareFold(%q, %q) = %v, want %v", tt.a, tt.b, got, want) + } + wants = append(wants, want) + } + + if n := testing.AllocsPerRun(1000, func() { + for i, tt := range tests { + if CompareFold(tt.a, tt.b) != wants[i] { + panic("unexpected") + } + } + }); n > 0 { + t.Errorf("allocs = %v; want 0", int(n)) + } +} diff --git a/server/internal/internal/syncs/line.go b/server/internal/internal/syncs/line.go new file mode 100644 index 00000000..021cd4c0 --- /dev/null +++ b/server/internal/internal/syncs/line.go @@ -0,0 +1,201 @@ +// Package syncs provides synchronization primitives. +package syncs + +import ( + "cmp" + "io" + "sync" +) + +var closedChan = func() chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch +}() + +// Ticket represents a ticket in a sequence of tickets. The zero value is +// invalid. Use [Line.Take] to get a valid ticket. +// +// A Ticket is not safe for concurrent use. +type Ticket struct { + ahead chan struct{} // ticket ahead of this one + ch chan struct{} +} + +// Ready returns a channel that is closed when the ticket before this one is +// done. +// +// It is incorrect to wait on Ready after the ticket is done. +func (t *Ticket) Ready() chan struct{} { + return cmp.Or(t.ahead, closedChan) +} + +// Done signals that this ticket is done and that the next ticket in line can +// proceed. +// +// The first call to [Done] unblocks the ticket after it, if any. Subsequent +// calls are no-ops. +func (t *Ticket) Done() { + if t.ch != nil { + close(t.ch) + } + t.ch = nil +} + +// Line is an ordered sequence of tickets waiting for their turn to proceed. +// +// To get a ticket use [Line.Take]. +// To signal that a ticket is done use [Ticket.Done]. +// To wait your turn use [Ticket.Ready]. +// +// A Line is not safe for concurrent use. +type Line struct { + last chan struct{} // last ticket in line +} + +func (q *Line) Take() *Ticket { + t := &Ticket{ + ahead: q.last, + ch: make(chan struct{}), + } + q.last = t.ch + return t +} + +// RelayReader implements an [io.WriterTo] that yields the passed +// writer to its [WriteTo] method each [io.WriteCloser] taken from [Take], in +// the order they are taken. Each [io.WriteCloser] blocks until the previous +// one is closed, or a call to [RelayReader.CloseWithError] is made. +// +// The zero value is invalid. Use [NewWriteToLine] to get a valid RelayReader. +// +// It is not safe for concurrent use. +type RelayReader struct { + line Line + t *Ticket + w io.Writer + n int64 + + mu sync.Mutex + err error // set by CloseWithError + closedCh chan struct{} // closed if err is set +} + +var ( + _ io.Closer = (*RelayReader)(nil) + _ io.WriterTo = (*RelayReader)(nil) + _ io.Reader = (*RelayReader)(nil) +) + +func NewRelayReader() *RelayReader { + var q RelayReader + q.closedCh = make(chan struct{}) + q.t = q.line.Take() + return &q +} + +// CloseWithError terminates the line, unblocking any writer waiting for its +// turn with the error, or [io.EOF] if err is nil. It is safe to call +// [CloseWithError] multiple times and across multiple goroutines. +// +// If the line is already closed, [CloseWithError] is a no-op. +// +// It never returns an error. +func (q *RelayReader) CloseWithError(err error) error { + q.mu.Lock() + defer q.mu.Unlock() + if q.err == nil { + q.err = cmp.Or(q.err, err, io.EOF) + close(q.closedCh) + } + return nil +} + +// Close closes the line. Any writer waiting for its turn will be unblocked +// with an [io.ErrClosedPipe] error. +// +// It never returns an error. +func (q *RelayReader) Close() error { + return q.CloseWithError(nil) +} + +func (q *RelayReader) closed() <-chan struct{} { + q.mu.Lock() + defer q.mu.Unlock() + return q.closedCh +} + +func (q *RelayReader) Read(p []byte) (int, error) { + panic("RelayReader.Read is for show only; use WriteTo") +} + +// WriteTo yields the writer w to the first writer in line and blocks until the +// first call to [Close]. +// +// It is safe to call [Take] concurrently with [WriteTo]. +func (q *RelayReader) WriteTo(dst io.Writer) (int64, error) { + select { + case <-q.closed(): + return 0, io.ErrClosedPipe + default: + } + + // We have a destination writer; let the relay begin. + q.w = dst + q.t.Done() + <-q.closed() + return q.n, nil +} + +// Take returns a writer that will be passed to the next writer in line. +// +// It is not safe for use across multiple goroutines. +func (q *RelayReader) Take() io.WriteCloser { + return &relayWriter{q: q, t: q.line.Take()} +} + +type relayWriter struct { + q *RelayReader + t *Ticket + ready bool +} + +var _ io.StringWriter = (*relayWriter)(nil) + +// Write writes to the writer passed to [RelayReader.WriteTo] as soon as the +// writer is ready. It returns io.ErrClosedPipe if the line is closed before +// the writer is ready. +func (w *relayWriter) Write(p []byte) (int, error) { + if !w.awaitTurn() { + return 0, w.q.err + } + n, err := w.q.w.Write(p) + w.q.n += int64(n) + return n, err +} + +func (w *relayWriter) WriteString(s string) (int, error) { + if !w.awaitTurn() { + return 0, w.q.err + } + return io.WriteString(w.q.w, s) +} + +// Close signals that the writer is done, unblocking the next writer in line. +func (w *relayWriter) Close() error { + w.t.Done() + return nil +} + +func (t *relayWriter) awaitTurn() (ok bool) { + if t.ready { + return true + } + select { + case <-t.t.Ready(): + t.ready = true + return true + case <-t.q.closed(): + return false + } +} diff --git a/server/internal/internal/syncs/line_test.go b/server/internal/internal/syncs/line_test.go new file mode 100644 index 00000000..d5216026 --- /dev/null +++ b/server/internal/internal/syncs/line_test.go @@ -0,0 +1,65 @@ +package syncs + +import ( + "bytes" + "io" + "math/rand/v2" + "testing" + "testing/synctest" +) + +func TestPipelineReadWriterTo(t *testing.T) { + for range 10 { + synctest.Run(func() { + q := NewRelayReader() + + tickets := []struct { + io.WriteCloser + s string + }{ + {q.Take(), "you"}, + {q.Take(), " say hi,"}, + {q.Take(), " and "}, + {q.Take(), "I say "}, + {q.Take(), "hello"}, + } + + rand.Shuffle(len(tickets), func(i, j int) { + tickets[i], tickets[j] = tickets[j], tickets[i] + }) + + var g Group + for i, t := range tickets { + g.Go(func() { + defer t.Close() + if i%2 == 0 { + // Use [relayWriter.WriteString] + io.WriteString(t.WriteCloser, t.s) + } else { + t.Write([]byte(t.s)) + } + }) + } + + var got bytes.Buffer + var copyErr error // checked at end + g.Go(func() { + _, copyErr = io.Copy(&got, q) + }) + + synctest.Wait() + + q.Close() + g.Wait() + + if copyErr != nil { + t.Fatal(copyErr) + } + + want := "you say hi, and I say hello" + if got.String() != want { + t.Fatalf("got %q, want %q", got.String(), want) + } + }) + } +} diff --git a/server/internal/internal/syncs/syncs.go b/server/internal/internal/syncs/syncs.go new file mode 100644 index 00000000..8f1b1e07 --- /dev/null +++ b/server/internal/internal/syncs/syncs.go @@ -0,0 +1,41 @@ +package syncs + +import ( + "sync" + "sync/atomic" +) + +// Group is a [sync.WaitGroup] with a Go method. +type Group struct { + wg sync.WaitGroup + n atomic.Int64 +} + +func (g *Group) Go(f func()) { + g.wg.Add(1) + go func() { + g.n.Add(1) // Now we are running + defer func() { + g.wg.Done() + g.n.Add(-1) // Now we are done + }() + f() + }() +} + +// Running returns the number of goroutines that are currently running. +// +// If a call to [Running] returns zero, and a call to [Wait] is made without +// any calls to [Go], then [Wait] will return immediately. This is true even if +// a goroutine is started and finishes between the two calls. +// +// It is possible for [Running] to return non-zero and for [Wait] to return +// immediately. This can happen if the all running goroutines finish between +// the two calls. +func (g *Group) Running() int64 { + return g.n.Load() +} + +func (g *Group) Wait() { + g.wg.Wait() +} diff --git a/server/internal/internal/testutil/testutil.go b/server/internal/internal/testutil/testutil.go new file mode 100644 index 00000000..354c2608 --- /dev/null +++ b/server/internal/internal/testutil/testutil.go @@ -0,0 +1,74 @@ +package testutil + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +// Check calls t.Fatal(err) if err is not nil. +func Check(t *testing.T, err error) { + if err != nil { + t.Helper() + t.Fatal(err) + } +} + +// CheckFunc exists so other packages do not need to invent their own type for +// taking a Check function. +type CheckFunc func(err error) + +// Checker returns a check function that +// calls t.Fatal if err is not nil. +func Checker(t *testing.T) (check func(err error)) { + return func(err error) { + if err != nil { + t.Helper() + t.Fatal(err) + } + } +} + +// StopPanic runs f but silently recovers from any panic f causes. +// The normal usage is: +// +// testutil.StopPanic(func() { +// callThatShouldPanic() +// t.Errorf("callThatShouldPanic did not panic") +// }) +func StopPanic(f func()) { + defer func() { recover() }() + f() +} + +// CheckTime calls t.Fatalf if got != want. Included in the error message is +// want.Sub(got) to help diagnose the difference, along with their values in +// UTC. +func CheckTime(t *testing.T, got, want time.Time) { + t.Helper() + if !got.Equal(want) { + t.Fatalf("got %v, want %v (%v)", got.UTC(), want.UTC(), want.Sub(got)) + } +} + +// WriteFile writes data to a file named name. It makes the directory if it +// doesn't exist and sets the file mode to perm. +// +// The name must be a relative path and must not contain .. or start with a /; +// otherwise WriteFile will panic. +func WriteFile[S []byte | string](t testing.TB, name string, data S) { + t.Helper() + + if filepath.IsAbs(name) { + t.Fatalf("WriteFile: name must be a relative path, got %q", name) + } + name = filepath.Clean(name) + dir := filepath.Dir(name) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(name, []byte(data), 0o644); err != nil { + t.Fatal(err) + } +} diff --git a/server/internal/manifest/manifest.go b/server/internal/manifest/manifest.go new file mode 100644 index 00000000..e020d2c0 --- /dev/null +++ b/server/internal/manifest/manifest.go @@ -0,0 +1,118 @@ +// Package manifest provides documentation for the Ollama manifest format. +// This package contains no code. +// +// # Manifests +// +// A manifest is a JSON object that describes a model. The JSON object has a +// single field "layers" which is a list of layers that make up the model. Each +// layer has the following fields: +// +// A layer is a single, logical unit of a model. Layers are stored in the cache +// as files with the name of the digest of the layer. Layers are pushed and +// pulled from the registry as blobs. +// +// A layer is represented as a JSON object with the following fields: +// +// - "digest": The digest of the layer. +// - "mediaType": The media type of the layer. +// - "size": The size of the layer in bytes. +// +// Layers are typically stored in a blob store, such as a registry, and are +// referenced by their digest. This package does not define how layers are +// stored or retrieved. +// +// # Configuration Layer +// +// The configuration of a model is represented as a layer with the media type: +// +// application/vnd.ollama.image.config; type= +// +// The "type" parameter in the media type specifies the format of the +// configuration (e.g., "safetensor" or "gguf"). +// +// There may be only one configuration layer in a model. +// +// # Template Layer +// +// The model template is a layer with the media type: +// +// application/vnd.ollama.image.template; [name=] +// +// The "name" parameter in the media type specifies the name of the template as +// for lookup at runtime. The name is optional and may be omitted. If omitted, +// the template is the default template for the model. +// +// # Tensor Layers +// +// The tensors of a model are represented as layers with the media type: +// +// application/vnd.ollama.image.tensor; name=; dtype=; shape= +// +// The "name" parameter in the media type specifies the name of the tensor as +// defined in the model's configuration and are bound only by the rules for +// names as defined in the configuration format, as represented by the +// configuration's "type". +// +// The "dtype" parameter in the media type specifies the data type of the tensor +// as a string. +// +// TODO: Define more specifically how to represent data types as strings. +// +// The "shape" parameter in the media type specifies the shape of the tensor as +// a comma-separated list of integers; one per dimension. +// +// # Tokenization Layers +// +// The tokenization of a model is represented as a layer with the media type: +// +// application/vnd.ollama.image.tokenizer +// +// The configuration of the tokenizer is represented as a layer with the media type: +// +// application/vnd.ollama.image.tokenizer.config +// +// # Miscellaneous Layers +// +// These extra layer mime types are reserved: +// +// application/vnd.ollama.image.license +// +// This layer contains one of the many licenses for the model in plain text. +// +// # Example Manifest +// +// The following is an example manifest containing a configuration, a model +// template, and two tensors (digests shortened for brevity): +// +// { +// "layers": [{ +// "digest": "sha256:a...", +// "mediaType": "application/vnd.ollama.image.config; type=safetensors", +// "size": 1234 +// },{ +// "digest": "sha256:b...", +// "mediaType": "application/vnd.ollama.image.template", +// "size": 5678 +// },{ +// "digest": "sha256:c...", +// "mediaType": "application/vnd.ollama.image.tensor; name=input; dtype=F32; shape=1,2,3", +// "size": 9012 +// },{ +// "digest": "sha256:d...", +// "mediaType": "application/vnd.ollama.image.tensor; name=output; dtype=I32; shape=4,5,6", +// "size": 3456 +// }] +// } +// +// # Legacy Media Types +// +// The appliaction/vnd.ollama.image.model media type is deprecated, but will +// remain supported for backwards compatibility, for some undefined amount of +// time. New models should use the media types defined above. +// +// # Reserved media types +// +// The media type prefix "application/vnd.ollama.image." is reserved for +// defining new media types for layers known to Ollama. Currently, all other +// prefixes are ignored by official Ollama registry clients. +package manifest From 4df98f3eb58dafaa2a069fbc6173faa124a952c1 Mon Sep 17 00:00:00 2001 From: frob Date: Tue, 25 Feb 2025 17:52:50 +0100 Subject: [PATCH 24/31] Move cgroups fix out of AMD section. (#9072) Co-authored-by: Richard Lyons --- docs/troubleshooting.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 7ef1618e..4275cdf3 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -73,6 +73,10 @@ curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/ +## Linux docker + +If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration. + ## NVIDIA GPU Discovery When Ollama starts up, it takes inventory of the GPUs present in the system to determine compatibility and how much VRAM is available. Sometimes this discovery can fail to find your GPUs. In general, running the latest driver will yield the best results. @@ -100,8 +104,6 @@ On linux, AMD GPU access typically requires `video` and/or `render` group member When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44` -If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration. - If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure. - `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems - `OLLAMA_DEBUG=1` During GPU discovery additional information will be reported From a499390648e9184211a1e9d196cdb20b48355591 Mon Sep 17 00:00:00 2001 From: Pavol Rusnak Date: Tue, 25 Feb 2025 18:54:19 +0100 Subject: [PATCH 25/31] build: support Compute Capability 5.0, 5.2 and 5.3 for CUDA 12.x (#8567) CUDA 12.x still supports Compute Capability 5.0, 5.2 and 5.3, so let's build for these architectures as well --- CMakePresets.json | 2 +- discover/cuda_common.go | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index c789ad7f..68546bde 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -28,7 +28,7 @@ "name": "CUDA 12", "inherits": [ "CUDA" ], "cacheVariables": { - "CMAKE_CUDA_ARCHITECTURES": "60;61;62;70;72;75;80;86;87;89;90;90a" + "CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;62;70;72;75;80;86;87;89;90;90a" } }, { diff --git a/discover/cuda_common.go b/discover/cuda_common.go index 878cee8c..04829529 100644 --- a/discover/cuda_common.go +++ b/discover/cuda_common.go @@ -57,7 +57,8 @@ func cudaVariant(gpuInfo CudaGPUInfo) string { } } - if gpuInfo.computeMajor < 6 || gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) { + // driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers + if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) { return "v11" } return "v12" From b16367b4b2ee22b8bd8ad1ef8c2abf5f8171be8c Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 25 Feb 2025 09:18:44 -0800 Subject: [PATCH 26/31] fix: add back bf16 support this was accidentally removed when moving fs/ggml from its previous location --- fs/ggml/ggml.go | 27 +++++++++++++++++----- fs/ggml/ggml_test.go | 53 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 5 deletions(-) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 90d1d440..57313859 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -207,11 +207,26 @@ func (t Tensor) block() (n int) { func (t Tensor) blockSize() uint64 { switch t.Kind { - case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16 + case + 0, // F32 + 1, // F16 + 24, // I8 + 25, // I16 + 26, // I32 + 27, // I64 + 28, // F64 + 30: // BF16 return 1 - case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL + case + 2, // Q4_0 + 3, // Q4_1 + 6, // Q5_0 + 7, // Q5_1 + 8, // Q8_0 + 9, // Q8_1 + 20: // IQ4_NL return 32 - default: // All others + default: return 256 } } @@ -235,7 +250,7 @@ func (t Tensor) typeSize() uint64 { case 8: // Q8_0 return 2 + blockSize case 9: // Q8_1 - return 4 + 4 + blockSize + return 2 + 2 + blockSize case 10: // Q2_K return blockSize/16 + blockSize/4 + 2 + 2 case 11: // Q3_K @@ -247,7 +262,7 @@ func (t Tensor) typeSize() uint64 { case 14: // Q6_K return blockSize/2 + blockSize/4 + blockSize/16 + 2 case 15: // Q8_K - return 2 + blockSize + 2*blockSize/16 + return 4 + blockSize + 2*blockSize/16 case 16: // IQ2_XXS return 2 + 2*blockSize/8 case 17: // IQ2_XS @@ -276,6 +291,8 @@ func (t Tensor) typeSize() uint64 { return 8 case 29: // IQ1_M return blockSize/8 + blockSize/16 + blockSize/32 + case 30: // BF16 + return 2 default: return 0 } diff --git a/fs/ggml/ggml_test.go b/fs/ggml/ggml_test.go index 4fcdf085..324e40fa 100644 --- a/fs/ggml/ggml_test.go +++ b/fs/ggml/ggml_test.go @@ -3,6 +3,7 @@ package ggml import ( "maps" "slices" + "strconv" "strings" "testing" @@ -157,3 +158,55 @@ func TestTensorLayers(t *testing.T) { }) } } + +// ref: https://github.com/ggml-org/llama.cpp/blob/a82c9e7c23ef6db48cebfa194dc9cebbc4ac3552/ggml/src/ggml.c#L572 +func TestTensorTypes(t *testing.T) { + cases := []struct { + kind uint32 + blockSize uint64 + typeSize uint64 + }{ + {0, 1, 4}, + {1, 1, 2}, + {2, 32, 18}, + {3, 32, 20}, + {6, 32, 22}, + {7, 32, 24}, + {8, 32, 34}, + {9, 32, 36}, + {10, 256, 84}, + {11, 256, 110}, + {12, 256, 144}, + {13, 256, 176}, + {14, 256, 210}, + {15, 256, 292}, + {16, 256, 66}, + {17, 256, 74}, + {18, 256, 98}, + {19, 256, 50}, + {20, 32, 18}, + {21, 256, 110}, + {22, 256, 82}, + {23, 256, 136}, + {24, 1, 1}, + {25, 1, 2}, + {26, 1, 4}, + {27, 1, 8}, + {28, 1, 8}, + {29, 256, 56}, + {30, 1, 2}, + } + + for _, tt := range cases { + t.Run(strconv.Itoa(int(tt.kind)), func(t *testing.T) { + tensor := Tensor{Kind: tt.kind} + if tensor.blockSize() != tt.blockSize { + t.Errorf("unexpected block size: got=%d want=%d", tensor.blockSize(), tt.blockSize) + } + + if tensor.typeSize() != tt.typeSize { + t.Errorf("unexpected type size: got=%d want=%d", tensor.typeSize(), tt.typeSize) + } + }) + } +} From 888855675e7a29bbf29882804276db4368da1ba9 Mon Sep 17 00:00:00 2001 From: Chuanhui Liu <1601648586@qq.com> Date: Tue, 25 Feb 2025 16:15:47 -0500 Subject: [PATCH 27/31] docs: rocm install link (#9346) --- docs/development.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/development.md b/docs/development.md index 522d106b..6e68c9eb 100644 --- a/docs/development.md +++ b/docs/development.md @@ -41,7 +41,7 @@ Install prerequisites: - [CMake](https://cmake.org/download/) - [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) including the Native Desktop Workload - (Optional) AMD GPU support - - [ROCm](https://rocm.github.io/install.html) + - [ROCm](https://rocm.docs.amd.com/en/latest/) - [Ninja](https://github.com/ninja-build/ninja/releases) - (Optional) NVIDIA GPU support - [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network) @@ -127,4 +127,4 @@ Ollama looks for acceleration libraries in the following paths relative to the ` * `.` (macOS) * `build/lib/ollama` (for development) -If the libraries are not found, Ollama will not run with any acceleration libraries. \ No newline at end of file +If the libraries are not found, Ollama will not run with any acceleration libraries. From 6ecd7f64ba36b1d24ea4bb1b73a6dc4234e7d567 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Pekkarinen?= Date: Tue, 25 Feb 2025 23:38:08 +0200 Subject: [PATCH 28/31] docker: upgrade rocm to 6.3.3 (#8211) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit centos-7 images have been deprecated upstream and replaced with almalinux-8 images instead, requiring some small extra work. Signed-off-by: José Pekkarinen --- Dockerfile | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/Dockerfile b/Dockerfile index 0a8cb99f..09612824 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,17 +2,17 @@ ARG FLAVOR=${TARGETARCH} -ARG ROCMVERSION=6.1.2 +ARG ROCMVERSION=6.3.3 ARG JETPACK5VERSION=r35.4.1 ARG JETPACK6VERSION=r36.2.0 ARG CMAKEVERSION=3.31.2 -FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCMVERSION}-complete AS base-amd64 +FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 RUN sed -i -e 's/mirror.centos.org/vault.centos.org/g' -e 's/^#.*baseurl=http/baseurl=http/g' -e 's/^mirrorlist=http/#mirrorlist=http/g' /etc/yum.repos.d/*.repo \ - && yum install -y yum-utils devtoolset-10-gcc devtoolset-10-gcc-c++ \ - && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo \ + && yum install -y yum-utils gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ \ + && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo \ && curl -s -L https://github.com/ccache/ccache/releases/download/v4.10.2/ccache-4.10.2-linux-x86_64.tar.xz | tar -Jx -C /usr/local/bin --strip-components 1 -ENV PATH=/opt/rh/devtoolset-10/root/usr/bin:/opt/rh/devtoolset-11/root/usr/bin:$PATH +ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH FROM --platform=linux/arm64 rockylinux:8 AS base-arm64 # install epel-release for ccache @@ -29,9 +29,7 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml ENV LDFLAGS=-s FROM base AS cpu -# amd64 uses gcc which requires devtoolset-11 for AVX extensions while arm64 uses clang -RUN if [ "$(uname -m)" = "x86_64" ]; then yum install -y devtoolset-11-gcc devtoolset-11-gcc-c++; fi -ENV PATH=/opt/rh/devtoolset-11/root/usr/bin:$PATH +# amd64 uses gcc which requires gcc-toolset-11 for AVX extensions while arm64 uses clang RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'CPU' \ && cmake --build --parallel --preset 'CPU' \ @@ -104,7 +102,7 @@ COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12 COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 lib/ollama/cuda_jetpack5 COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 lib/ollama/cuda_jetpack6 -FROM --platform=linux/arm64 scratch AS rocm +FROM scratch AS rocm COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm FROM ${FLAVOR} AS archive From e91ae3d47d8153c4b7c10dba031b77d7ae408ef0 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Tue, 25 Feb 2025 13:47:36 -0800 Subject: [PATCH 29/31] Update ROCm (6.3 linux, 6.2 windows) and CUDA v12.8 (#9304) * Bump cuda and rocm versions Update ROCm to linux:6.3 win:6.2 and CUDA v12 to 12.8. Yum has some silent failure modes, so largely switch to dnf. * Fix windows build script --- .github/workflows/release.yaml | 8 +-- .github/workflows/test.yaml | 2 +- Dockerfile | 28 +++++----- scripts/build_docker.sh | 2 +- scripts/build_linux.sh | 34 ++++++++++-- scripts/build_windows.ps1 | 94 +++++++++++++++++++--------------- 6 files changed, 104 insertions(+), 64 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 37d525e9..12f36140 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -111,13 +111,13 @@ jobs: - os: windows arch: amd64 preset: 'CUDA 12' - install: https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_551.61_windows.exe - cuda-version: '12.4' + install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe + cuda-version: '12.8' - os: windows arch: amd64 preset: 'ROCm 6' - install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe - rocm-version: '6.1' + install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe + rocm-version: '6.2' runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} environment: release env: diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 56a2cc4f..431bc328 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -81,7 +81,7 @@ jobs: install: https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe flags: '-DCMAKE_CUDA_ARCHITECTURES=87' - preset: ROCm - install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe + install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe flags: '-DAMDGPU_TARGETS=gfx1010' runs-on: windows steps: diff --git a/Dockerfile b/Dockerfile index 09612824..46d4713e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,20 +4,22 @@ ARG FLAVOR=${TARGETARCH} ARG ROCMVERSION=6.3.3 ARG JETPACK5VERSION=r35.4.1 -ARG JETPACK6VERSION=r36.2.0 +ARG JETPACK6VERSION=r36.4.0 ARG CMAKEVERSION=3.31.2 +# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 -RUN sed -i -e 's/mirror.centos.org/vault.centos.org/g' -e 's/^#.*baseurl=http/baseurl=http/g' -e 's/^mirrorlist=http/#mirrorlist=http/g' /etc/yum.repos.d/*.repo \ - && yum install -y yum-utils gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ \ - && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo \ - && curl -s -L https://github.com/ccache/ccache/releases/download/v4.10.2/ccache-4.10.2-linux-x86_64.tar.xz | tar -Jx -C /usr/local/bin --strip-components 1 -ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH +RUN yum install -y yum-utils \ + && yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \ + && rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \ + && dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \ + && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo +ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH -FROM --platform=linux/arm64 rockylinux:8 AS base-arm64 +FROM --platform=linux/arm64 almalinux:8 AS base-arm64 # install epel-release for ccache RUN yum install -y yum-utils epel-release \ - && yum install -y clang ccache \ + && dnf install -y clang ccache \ && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo ENV CC=clang CXX=clang++ @@ -29,7 +31,8 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml ENV LDFLAGS=-s FROM base AS cpu -# amd64 uses gcc which requires gcc-toolset-11 for AVX extensions while arm64 uses clang +RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ +ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'CPU' \ && cmake --build --parallel --preset 'CPU' \ @@ -37,7 +40,7 @@ RUN --mount=type=cache,target=/root/.ccache \ FROM base AS cuda-11 ARG CUDA11VERSION=11.3 -RUN yum install -y cuda-toolkit-${CUDA11VERSION//./-} +RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-} ENV PATH=/usr/local/cuda-11/bin:$PATH RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'CUDA 11' \ @@ -45,8 +48,8 @@ RUN --mount=type=cache,target=/root/.ccache \ && cmake --install build --component CUDA --strip --parallel 8 FROM base AS cuda-12 -ARG CUDA12VERSION=12.4 -RUN yum install -y cuda-toolkit-${CUDA12VERSION//./-} +ARG CUDA12VERSION=12.8 +RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} ENV PATH=/usr/local/cuda-12/bin:$PATH RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'CUDA 12' \ @@ -54,6 +57,7 @@ RUN --mount=type=cache,target=/root/.ccache \ && cmake --install build --component CUDA --strip --parallel 8 FROM base AS rocm-6 +ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'ROCm 6' \ && cmake --build --parallel --preset 'ROCm 6' \ diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh index 567eb7c7..1dd8d1f6 100755 --- a/scripts/build_docker.sh +++ b/scripts/build_docker.sh @@ -28,7 +28,7 @@ if echo $PLATFORM | grep "amd64" > /dev/null; then ${LOAD_OR_PUSH} \ --platform=linux/amd64 \ ${OLLAMA_COMMON_BUILD_ARGS} \ - --target runtime-rocm \ + --build-arg FLAVOR=rocm \ -f Dockerfile \ -t ${FINAL_IMAGE_REPO}:$VERSION-rocm \ . diff --git a/scripts/build_linux.sh b/scripts/build_linux.sh index a0c3d2f0..618722d1 100755 --- a/scripts/build_linux.sh +++ b/scripts/build_linux.sh @@ -22,8 +22,34 @@ docker buildx build \ -f Dockerfile \ . -# buildx behavior changes for single vs. multiplatform -if echo $PLATFORM | grep "," > /dev/null ; then - mv -f ./dist/linux_*64/ollama* ./dist/ - rmdir ./dist/linux_*64 +if echo $PLATFORM | grep "amd64" > /dev/null; then + outDir="./dist" + if echo $PLATFORM | grep "," > /dev/null ; then + outDir="./dist/linux_amd64" + fi + docker buildx build \ + --output type=local,dest=${outDir} \ + --platform=linux/amd64 \ + ${OLLAMA_COMMON_BUILD_ARGS} \ + --build-arg FLAVOR=rocm \ + --target archive \ + -f Dockerfile \ + . +fi + +# buildx behavior changes for single vs. multiplatform +echo "Compressing linux tar bundles..." +if echo $PLATFORM | grep "," > /dev/null ; then + tar c -C ./dist/linux_arm64 --exclude cuda_jetpack5 --exclude cuda_jetpack6 . | pigz -9vc >./dist/ollama-linux-arm64.tgz + tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz + tar c -C ./dist/linux_arm64 ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz + tar c -C ./dist/linux_amd64 --exclude rocm . | pigz -9vc >./dist/ollama-linux-amd64.tgz + tar c -C ./dist/linux_amd64 ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz +elif echo $PLATFORM | grep "arm64" > /dev/null ; then + tar c -C ./dist/ --exclude cuda_jetpack5 --exclude cuda_jetpack6 bin lib | pigz -9vc >./dist/ollama-linux-arm64.tgz + tar c -C ./dist/ ./lib/ollama/cuda_jetpack5 | pigz -9vc >./dist/ollama-linux-arm64-jetpack5.tgz + tar c -C ./dist/ ./lib/ollama/cuda_jetpack6 | pigz -9vc >./dist/ollama-linux-arm64-jetpack6.tgz +elif echo $PLATFORM | grep "amd64" > /dev/null ; then + tar c -C ./dist/ --exclude rocm bin lib | pigz -9vc >./dist/ollama-linux-amd64.tgz + tar c -C ./dist/ ./lib/ollama/rocm | pigz -9vc >./dist/ollama-linux-amd64-rocm.tgz fi diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 68f3b11d..465cc551 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -26,6 +26,9 @@ function checkEnv() { $MSVC_INSTALL=(Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation $env:VCToolsRedistDir=(get-item "${MSVC_INSTALL}\VC\Redist\MSVC\*")[0] } + if (-Not (get-command -ErrorAction silent ninja)) { + $script:NINJA_DIR=(gci -path (Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation -r -fi ninja.exe) | split-path -parent + } # Locate CUDA versions # Note: this assumes every version found will be built $cudaList=(get-item "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*\bin\" -ea 'silentlycontinue') @@ -75,6 +78,7 @@ function checkEnv() { } else { write-host "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree" } + $script:JOBS=((Get-CimInstance Win32_ComputerSystem).NumberOfLogicalProcessors) } @@ -83,51 +87,57 @@ function buildOllama() { Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}" New-Item "${script:SRC_DIR}\dist\windows-${script:ARCH}\lib\ollama\" -ItemType Directory -ea 0 - - # Default first, then conditionall ROCm and cuda v11 - write-host "Building Default native backend libraries" - $env:CMAKE_GENERATOR="ninja" - & cmake --preset Default + & cmake --fresh --preset CPU --install-prefix $script:DIST_DIR if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} - & cmake --build --preset Default -j 12 + & cmake --build --preset CPU --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build --component CPU --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} - & cmake --install build -j 12 - - # TODO - add steps for v11 and ROCm - # - # if ("$script:CUDA_DIRS".Contains("v11") -and "$script:CUDA_DIRS".Contains("v12")) { - # # We assume the default is v12, so override for v11 - # $origCUDA_PATH=$env:CUDA_PATH - # $hashEnv = @{} - # Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value } - # $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $v11="$_" }} - # write-host "$v11" - # # $env:CUDA_PATH=$hashEnv[$v11] - # # $env:CUDACXX=$hashEnv[$v11]+"\bin\nvcc.exe" - # $env:CUDAToolkit_ROOT=$hashEnv[$v11] - # # ls env: - # write-host "Building CUDA v11 backend libraries" - # & cmake --preset "CUDA 11" - # $env:CUDA_PATH=$origCUDA_PATH - # exit(1) - # if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} - # # & cmake --build --preset "CUDA 11" -j 12 - # # if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} - # } - # if ($env:HIP_PATH) { - # write-host "Building ROCm backend libraries" - # $env:HIPCXX="${env:HIP_PATH}\bin\clang++.exe" - # $env:HIP_PLATFORM="amd" - # $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" - # & cmake --preset "ROCm" - # $env:HIPCXX="" - # $env:HIP_PLATFORM="" - # $env:CMAKE_PREFIX_PATH="" - # if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} - # & cmake --build --preset "ROCm" -j 12 - # if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} - # } + $hashEnv = @{} + Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value } + if ("$script:CUDA_DIRS".Contains("v11")) { + $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $v11="$_" }} + $env:CUDAToolkit_ROOT=$hashEnv[$v11] + write-host "Building CUDA v11 backend libraries" + # Note: cuda v11 requires msvc 2019 so force the older generator + # to avoid 2022 (or newer) from being used as the default + & cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build --preset "CUDA 11" --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build --component "CUDA" --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } + if ("$script:CUDA_DIRS".Contains("v12")) { + $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }} + $env:CUDAToolkit_ROOT=$hashEnv[$v12] + write-host "Building CUDA v12 backend libraries" + & cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build --preset "CUDA 12" --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build --component "CUDA" --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } + if ($env:HIP_PATH) { + write-host "Building ROCm backend libraries" + if ($null -ne $script:NINJA_DIR) { + $env:PATH="$script:NINJA_DIR;$env:PATH" + } + $env:HIPCXX="${env:HIP_PATH}\bin\clang++.exe" + $env:HIP_PLATFORM="amd" + $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" + & cmake --fresh --preset "ROCm 6" -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ --install-prefix $script:DIST_DIR + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + $env:HIPCXX="" + $env:HIP_PLATFORM="" + $env:CMAKE_PREFIX_PATH="" + & cmake --build --preset "ROCm" --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build --component "HIP" --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } } else { write-host "Skipping generate step with OLLAMA_SKIP_GENERATE set" } From 0d694793f25a274a58680ba244e5febc31b96743 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Tue, 25 Feb 2025 14:28:07 -0800 Subject: [PATCH 30/31] .github: always run tests, and other helpful fixes (#9348) During work on our new registry client, I ran into frustrations with CI where a misspelling in a comment caused the linter to fail, which caused the tests to not run, which caused the build to not be cached, which caused the next run to be slow, which caused me to be sad. This commit address these issues, and pulls in some helpful changes we've had in CI on ollama.com for some time now. They are: * Always run tests, even if the other checks fail. Tests are the most important part of CI, and should always run. Failures in tests can be correlated with failures in other checks, and can help surface the root cause of the failure sooner. This is especially important when the failure is platform specific, and the tests are not platform independent. * Check that `go generate` is clean. This prevents 'go generate' abuse regressions. This codebase used to use it to generate platform specific binary build artifacts. Let's make sure that does not happen again and this powerful tool is used correctly, and the generated code is checked in. Also, while adding `go generate` the check, it was revealed that the generated metal code was putting dates in the comments, resulting in non-deterministic builds. This is a bad practice, and this commit fixes that. Git tells us the most important date: the commit date along with other associated changes. * Check that `go mod tidy` is clean. A new job to check that `go mod tidy` is clean was added, to prevent easily preventable merge conflicts or go.mod changes being deferred to a future PR that is unrelated to the change that caused the go.mod to change. * More robust caching. We now cache the go build cache, and the go mod download cache independently. This is because the download cache contains zips that can be unpacked in parallel faster than they can be fetched and extracted by tar. This speeds up the build significantly. The linter is hostile enough. It does not need to also punish us with longer build times due to small failures like misspellings. --- .github/workflows/test.yaml | 80 ++++++++++++++++++- .../src/ggml-metal/ggml-metal-embed.metal | 2 +- ml/backend/ggml/ggml/src/ggml-metal/metal.go | 2 +- 3 files changed, 79 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 431bc328..bb0e8d90 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -140,6 +140,13 @@ jobs: env: CMAKE_GENERATOR: Ninja + go_mod_tidy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: check that 'go mod tidy' is clean + run: go mod tidy --diff || (echo "Please run 'go mod tidy'." && exit 1) + test: strategy: matrix: @@ -149,14 +156,81 @@ jobs: CGO_ENABLED: '1' GOEXPERIMENT: 'synctest' steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 + - name: checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 + + - name: cache restore + uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: + # Note: unlike the other setups, this is only grabbing the mod download + # cache, rather than the whole mod directory, as the download cache + # contains zips that can be unpacked in parallel faster than they can be + # fetched and extracted by tar + path: | + ~/.cache/go-build + ~/go/pkg/mod/cache + ~\AppData\Local\go-build + # NOTE: The -3- here should be incremented when the scheme of data to be + # cached changes (e.g. path above changes). + key: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}-${{ github.run_id }} + restore-keys: | + ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }} + ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3- + + - name: Setup Go + uses: actions/setup-go@v5 + with: + # The caching strategy of setup-go is less than ideal, and wastes + # time by not saving artifacts due to small failures like the linter + # complaining, etc. This means subsequent have to rebuild their world + # again until all checks pass. For instance, if you mispell a word, + # you're punished until you fix it. This is more hostile than + # helpful. + cache: false + go-version-file: go.mod + + # TODO(bmizerany): replace this heavy tool with just the + # tools/checks/binaries we want and then make them all run in parallel + # across jobs, not on a single tiny vm on Github Actions. - uses: golangci/golangci-lint-action@v6 with: args: --timeout 10m0s -v - - run: go test ./... + + - name: go test + # Do not skip tests in the face of linter errors, or 'go mod tidy' + # checks, which are secondary to the tests. Tests trump linters. + if: always() + run: go test -count=1 -benchtime=1x ./... + + # It is tempting to run this in a platform independent way, but the past + # shows this codebase will see introductions of platform specific code + # generation, and so we need to check this per platform to ensure we + # don't abuse go generate on specific platforms. + - name: check that 'go generate' is clean + run: | + go generate ./... + git diff --name-only --exit-code || (echo "Please run 'go generate ./...'." && exit 1) + + - name: cache save + # Always save the cache, even if the job fails. The artifacts produced + # during the building of test binaries are not all for naught. They can + # be used to speed up subsequent runs. + if: always() + + uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + with: + # Note: unlike the other setups, this is only grabbing the mod download + # cache, rather than the whole mod directory, as the download cache + # contains zips that can be unpacked in parallel faster than they can be + # fetched and extracted by tar + path: | + ~/.cache/go-build + ~/go/pkg/mod/cache + ~\AppData\Local\go-build + # NOTE: The -3- here should be incremented when the scheme of data to be + # cached changes (e.g. path above changes). + key: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}-${{ github.run_id }} patches: runs-on: ubuntu-latest diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index 2e51b87a..f10966df 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -1,4 +1,4 @@ -// Code generated Fri Jan 10 13:05:45 PST 2025. DO NOT EDIT. +// Code generated by go generate. DO NOT EDIT. #define GGML_COMMON_DECL_METAL #define GGML_COMMON_IMPL_METAL #if defined(GGML_METAL_EMBED_LIBRARY) diff --git a/ml/backend/ggml/ggml/src/ggml-metal/metal.go b/ml/backend/ggml/ggml/src/ggml-metal/metal.go index 1025e205..eb65dfde 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/metal.go +++ b/ml/backend/ggml/ggml/src/ggml-metal/metal.go @@ -2,7 +2,7 @@ package metal -//go:generate sh -c "{ echo // Code generated $(date). DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal" +//go:generate sh -c "{ echo // Code generated by 'go generate'. DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal" // #cgo CPPFLAGS: -DGGML_METAL_EMBED_LIBRARY -I.. -I../../include // #cgo LDFLAGS: -framework Metal -framework MetalKit From 3ad4bc8afe34bd32b37f56678927ba31fbcd98d4 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Tue, 25 Feb 2025 14:33:03 -0800 Subject: [PATCH 31/31] llama: removed unused 'vendoring' file (#9351) --- llama/vendoring | 1 - 1 file changed, 1 deletion(-) delete mode 100644 llama/vendoring diff --git a/llama/vendoring b/llama/vendoring deleted file mode 100644 index 5fdb7cdc..00000000 --- a/llama/vendoring +++ /dev/null @@ -1 +0,0 @@ -LLAMACPP_BASE_COMMIT=46e3556e01b824e52395fb050b29804b6cff2a7c