diff --git a/Dockerfile b/Dockerfile index 46d4713e..4136fca7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,7 +12,7 @@ FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base RUN yum install -y yum-utils \ && yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \ && rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \ - && dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \ + && dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \ && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH @@ -86,10 +86,11 @@ RUN --mount=type=cache,target=/root/.ccache \ && cmake --install build --component CUDA --strip --parallel 8 FROM base AS build -ARG GOVERSION=1.23.4 -RUN curl -fsSL https://golang.org/dl/go${GOVERSION}.linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local -ENV PATH=/usr/local/go/bin:$PATH WORKDIR /go/src/github.com/ollama/ollama +COPY go.mod go.sum . +RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local +ENV PATH=/usr/local/go/bin:$PATH +RUN go mod download COPY . . ARG GOFLAGS="'-ldflags=-w -s'" ENV CGO_ENABLED=1 diff --git a/README.md b/README.md index 09c720ad..7f7994ca 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@
ollama
@@ -86,7 +86,7 @@ Here are some example models that can be downloaded: | Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` | | Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` | | Phi 4 | 14B | 9.1GB | `ollama run phi4` | -| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` | +| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` | | Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` | | Gemma 2 | 9B | 5.5GB | `ollama run gemma2` | | Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` | @@ -97,7 +97,7 @@ Here are some example models that can be downloaded: | Code Llama | 7B | 3.8GB | `ollama run codellama` | | Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` | | LLaVA | 7B | 4.5GB | `ollama run llava` | -| Solar | 10.7B | 6.1GB | `ollama run solar` | +| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` | > [!NOTE] > You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models. @@ -409,6 +409,8 @@ See the [API documentation](./docs/api.md) for all endpoints. - [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models) - [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms) - [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool) +- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration) +- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.) ### Cloud @@ -533,6 +535,8 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Maid](https://github.com/Mobile-Artificial-Intelligence/maid) - [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama) - [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption) +- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device) +- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.) ### Extensions & Plugins diff --git a/cmd/cmd.go b/cmd/cmd.go index 80ece4c6..c22a08f4 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -34,7 +34,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" - "github.com/ollama/ollama/llama" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/runner" @@ -256,6 +255,7 @@ func StopHandler(cmd *cobra.Command, args []string) error { if strings.Contains(err.Error(), "not found") { return fmt.Errorf("couldn't find model \"%s\" to stop", args[0]) } + return err } return nil } @@ -338,10 +338,16 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } - // TODO(jessegross): We should either find another way to know if this is - // a vision model or remove the logic. Also consider that other modalities will - // need different behavior anyways. - opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine() + if len(info.ProjectorInfo) != 0 { + opts.MultiModal = true + } + for k := range info.ModelInfo { + if strings.Contains(k, ".vision.") { + opts.MultiModal = true + break + } + } + opts.ParentModel = info.Details.ParentModel if interactive { @@ -1274,7 +1280,6 @@ func NewCLI() *cobra.Command { runnerCmd := &cobra.Command{ Use: "runner", - Short: llama.PrintSystemInfo(), Hidden: true, RunE: func(cmd *cobra.Command, args []string) error { return runner.Execute(os.Args[1:]) diff --git a/docs/development.md b/docs/development.md index eb67dbfa..cf6d91e2 100644 --- a/docs/development.md +++ b/docs/development.md @@ -118,6 +118,35 @@ To run tests, use `go test`: go test ./... ``` +> NOTE: In rare cirumstances, you may nedd to change a package using the new +> "synctest" package in go1.24. +> +> If you do not have the "synctest" package enabled, you will not see build or +> test failures resulting from your change(s), if any, locally, but CI will +> break. +> +> If you see failures in CI, you can either keep pushing changes to see if the +> CI build passes, or you can enable the "synctest" package locally to see the +> failures before pushing. +> +> To enable the "synctest" package for testing, run the following command: +> +> ```shell +> GOEXPERIMENT=synctest go test ./... +> ``` +> +> If you wish to enable synctest for all go commands, you can set the +> `GOEXPERIMENT` environment variable in your shell profile or by using: +> +> ```shell +> go env -w GOEXPERIMENT=synctest +> ``` +> +> Which will enable the "synctest" package for all go commands without needing +> to set it for all shell sessions. +> +> The synctest package is not required for production builds. + ## Library detection Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable: diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index b9f9cc17..8662c3b0 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -565,6 +565,43 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO return } +func (llm GGML) VisionGraphSize() (weights, graphSize uint64) { + switch llm.KV().Architecture() { + case "mllama": + for _, layer := range llm.Tensors().GroupLayers()["v"] { + weights += layer.Size() + } + + kv := func(n string) uint64 { + if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok { + return uint64(v) + } + + return 0 + } + + imageSize := kv("image_size") + + maxNumTiles := kv("max_num_tiles") + embeddingLength := kv("embedding_length") + headCount := kv("attention.head_count") + + numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size")) + if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok { + numPatches++ + } + + numPaddedPatches := numPatches + 8 - (numPatches%8)%8 + + graphSize = 4 * (8 + + imageSize*imageSize*kv("num_channels")*maxNumTiles + + embeddingLength*numPatches*maxNumTiles + + 9*embeddingLength*numPaddedPatches*maxNumTiles + + numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount) + } + return weights, graphSize +} + // SupportsKVCacheType checks if the requested cache type is supported func (f GGML) SupportsKVCacheType(cacheType string) bool { return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType) diff --git a/llama/llama.go b/llama/llama.go index 0c4fca43..bb5028bd 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -21,18 +21,6 @@ package llama extern bool llamaProgressCallback(float progress, void *user_data); extern void llamaLog(int level, char* text, void* user_data); - -typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER; -COMPILER inline get_compiler() { -#if defined(__clang__) - return COMP_CLANG; -#elif defined(__GNUC__) - return COMP_GCC; -#else - return UNKNOWN_COMPILER; -#endif -} - */ import "C" @@ -72,19 +60,6 @@ func BackendInit() { C.llama_backend_init() } -func PrintSystemInfo() string { - var compiler string - switch C.get_compiler() { - case C.COMP_UNKNOWN: - compiler = "cgo(unknown_compiler)" - case C.COMP_GCC: - compiler = "cgo(gcc)" - case C.COMP_CLANG: - compiler = "cgo(clang)" - } - return C.GoString(C.llama_print_system_info()) + compiler -} - func GetModelArch(modelPath string) (string, error) { mp := C.CString(modelPath) defer C.free(unsafe.Pointer(mp)) diff --git a/llama/patches/0015-try-catch-backend-load.patch b/llama/patches/0015-try-catch-backend-load.patch deleted file mode 100644 index 9aea6183..00000000 --- a/llama/patches/0015-try-catch-backend-load.patch +++ /dev/null @@ -1,69 +0,0 @@ -From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 -From: Michael Yang -Date: Tue, 11 Feb 2025 14:06:36 -0800 -Subject: [PATCH] try/catch backend load - ---- - ggml/src/ggml-backend-reg.cpp | 45 ++++++++++++++++++----------------- - 1 file changed, 23 insertions(+), 22 deletions(-) - -diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp -index 98d5e14d..1c19129a 100644 ---- a/ggml/src/ggml-backend-reg.cpp -+++ b/ggml/src/ggml-backend-reg.cpp -@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, - } - fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); - for (const auto & entry : dir_it) { -- if (entry.is_regular_file()) { -- std::wstring filename = entry.path().filename().wstring(); -- std::wstring ext = entry.path().extension().wstring(); -- if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { -- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; -- if (!handle && !silent) { -- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); -- } -- if (handle) { -+ try { -+ if (entry.is_regular_file()) { -+ std::wstring filename = entry.path().filename().wstring(); -+ std::wstring ext = entry.path().extension().wstring(); -+ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { -+ dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; -+ if (!handle) { -+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); -+ continue; -+ } -+ - auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); -- if (score_fn) { -- int s = score_fn(); --#ifndef NDEBUG -- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); --#endif -- if (s > best_score) { -- best_score = s; -- best_path = entry.path().wstring(); -- } -- } else { -- if (!silent) { -- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); -- } -+ if (!score_fn) { -+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); -+ continue; -+ } -+ -+ int s = score_fn(); -+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); -+ if (s > best_score) { -+ best_score = s; -+ best_path = entry.path().wstring(); - } - } - } -+ } catch (const std::exception & e) { -+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what()); - } - } - } diff --git a/llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch b/llama/patches/0015-use-std-filesystem-path-instead-of-wstring.patch similarity index 67% rename from llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch rename to llama/patches/0015-use-std-filesystem-path-instead-of-wstring.patch index d60066c1..e72d78ac 100644 --- a/llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch +++ b/llama/patches/0015-use-std-filesystem-path-instead-of-wstring.patch @@ -4,11 +4,11 @@ Date: Sun, 16 Feb 2025 20:00:22 -0500 Subject: [PATCH] use std::filesystem::path instead of wstring --- - ggml/src/ggml-backend-reg.cpp | 144 ++++++++++++++-------------------- - 1 file changed, 58 insertions(+), 86 deletions(-) + ggml/src/ggml-backend-reg.cpp | 199 +++++++++++++++------------------- + 1 file changed, 88 insertions(+), 111 deletions(-) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp -index 1c19129a..c854e6bb 100644 +index 98d5e14d..799af5f3 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -66,26 +66,6 @@ @@ -264,47 +264,55 @@ index 1c19129a..c854e6bb 100644 for (const auto & search_path : search_paths) { if (!fs::exists(search_path)) { continue; -@@ -514,31 +486,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, +@@ -513,29 +485,26 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, + fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); for (const auto & entry : dir_it) { - try { - if (entry.is_regular_file()) { -- std::wstring filename = entry.path().filename().wstring(); -- std::wstring ext = entry.path().extension().wstring(); -+ std::string filename = entry.path().filename().string(); -+ std::string ext = entry.path().extension().string(); - if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { -- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; -+ dl_handle_ptr handle { dl_load_library(entry.path()) }; - if (!handle) { -- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); -+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str()); - continue; - } - - auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); - if (!score_fn) { -- GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); -+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str()); - continue; - } - - int s = score_fn(); -- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); -+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s); - if (s > best_score) { - best_score = s; -- best_path = entry.path().wstring(); -+ best_path = entry.path(); - } + if (entry.is_regular_file()) { +- std::wstring filename = entry.path().filename().wstring(); +- std::wstring ext = entry.path().extension().wstring(); ++ std::string filename = entry.path().filename().string(); ++ std::string ext = entry.path().extension().string(); + if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { +- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; +- if (!handle && !silent) { +- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); ++ dl_handle_ptr handle { dl_load_library(entry.path()) }; ++ if (!handle) { ++ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str()); ++ continue; + } +- if (handle) { +- auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); +- if (score_fn) { +- int s = score_fn(); +-#ifndef NDEBUG +- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); +-#endif +- if (s > best_score) { +- best_score = s; +- best_path = entry.path().wstring(); +- } +- } else { +- if (!silent) { +- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); +- } +- } ++ ++ auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); ++ if (!score_fn) { ++ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str()); ++ continue; ++ } ++ ++ int s = score_fn(); ++ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s); ++ if (s > best_score) { ++ best_score = s; ++ best_path = entry.path(); } } - } catch (const std::exception & e) { -- GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what()); -+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what()); } - } - } -@@ -546,7 +518,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, +@@ -545,7 +514,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, if (best_score == 0) { // try to load the base backend for (const auto & search_path : search_paths) { @@ -313,3 +321,49 @@ index 1c19129a..c854e6bb 100644 if (fs::exists(path)) { return get_reg().load_backend(path, silent); } +@@ -560,6 +529,14 @@ void ggml_backend_load_all() { + ggml_backend_load_all_from_path(nullptr); + } + ++static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) { ++ try { ++ ggml_backend_load_best(name, silent, user_search_path); ++ } catch (const std::exception & e) { ++ GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what()); ++ } ++} ++ + void ggml_backend_load_all_from_path(const char * dir_path) { + #ifdef NDEBUG + bool silent = true; +@@ -567,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) { + bool silent = false; + #endif + +- ggml_backend_load_best("blas", silent, dir_path); +- ggml_backend_load_best("cann", silent, dir_path); +- ggml_backend_load_best("cuda", silent, dir_path); +- ggml_backend_load_best("hip", silent, dir_path); +- ggml_backend_load_best("kompute", silent, dir_path); +- ggml_backend_load_best("metal", silent, dir_path); +- ggml_backend_load_best("rpc", silent, dir_path); +- ggml_backend_load_best("sycl", silent, dir_path); +- ggml_backend_load_best("vulkan", silent, dir_path); +- ggml_backend_load_best("opencl", silent, dir_path); +- ggml_backend_load_best("musa", silent, dir_path); +- ggml_backend_load_best("cpu", silent, dir_path); ++ ggml_backend_try_load_best("blas", silent, dir_path); ++ ggml_backend_try_load_best("cann", silent, dir_path); ++ ggml_backend_try_load_best("cuda", silent, dir_path); ++ ggml_backend_try_load_best("hip", silent, dir_path); ++ ggml_backend_try_load_best("kompute", silent, dir_path); ++ ggml_backend_try_load_best("metal", silent, dir_path); ++ ggml_backend_try_load_best("rpc", silent, dir_path); ++ ggml_backend_try_load_best("sycl", silent, dir_path); ++ ggml_backend_try_load_best("vulkan", silent, dir_path); ++ ggml_backend_try_load_best("opencl", silent, dir_path); ++ ggml_backend_try_load_best("musa", silent, dir_path); ++ ggml_backend_try_load_best("cpu", silent, dir_path); + // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend + const char * backend_path = std::getenv("GGML_BACKEND_PATH"); + if (backend_path) { diff --git a/llama/patches/0017-remove-amx.patch b/llama/patches/0016-remove-amx.patch similarity index 100% rename from llama/patches/0017-remove-amx.patch rename to llama/patches/0016-remove-amx.patch diff --git a/llama/patches/0018-fix-clip-compiler-error.patch b/llama/patches/0017-fix-clip-compiler-error.patch similarity index 100% rename from llama/patches/0018-fix-clip-compiler-error.patch rename to llama/patches/0017-fix-clip-compiler-error.patch diff --git a/llama/patches/0019-add-phi4-support.patch b/llama/patches/0018-add-phi4-support.patch similarity index 100% rename from llama/patches/0019-add-phi4-support.patch rename to llama/patches/0018-add-phi4-support.patch diff --git a/llm/memory.go b/llm/memory.go index 1da4d2c0..40104eca 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -115,6 +115,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin // multimodal models require at least 2048 context opts.NumCtx = max(opts.NumCtx, 2048) } + if projectorWeights == 0 && projectorGraph == 0 { + projectorWeights, projectorGraph = f.VisionGraphSize() + } layers := f.Tensors().GroupLayers() // add one layer worth of memory as a buffer diff --git a/llm/server.go b/llm/server.go index fd027a53..09690a5f 100644 --- a/llm/server.go +++ b/llm/server.go @@ -30,6 +30,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/model" ) type LlamaServer interface { @@ -54,8 +55,15 @@ type llmServer struct { options api.Options numParallel int modelPath string - modelLock sync.Mutex // Temporary until we switch fully to Go server - model *llama.Model // If non-nil, the runner is a new Go server + + // llamaModel is an instance of the cgo llama.cpp model definition + // nil if this server is running the new engine + llamaModel *llama.Model + llamaModelLock sync.Mutex + + // textProcessor handles text encoding/decoding for the model in the Ollama engine + // nil if this server is running the llama.cpp based engine + textProcessor model.TextProcessor estimate MemoryEstimate totalLayers uint64 @@ -89,7 +97,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) { // NewLlamaServer will run a server for the given GPUs // The gpu list must be a single family. -func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) { +func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) { systemInfo := discover.GetSystemInfo() systemTotalMemory := systemInfo.System.TotalMemory systemFreeMemory := systemInfo.System.FreeMemory @@ -130,7 +138,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt slog.Info("offload", "", estimate) params := []string{ - "--model", model, + "--model", modelPath, "--ctx-size", strconv.Itoa(opts.NumCtx), "--batch-size", strconv.Itoa(opts.NumBatch), } @@ -153,11 +161,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt } } - if len(projectors) > 0 { - // TODO: applying multiple projectors is not supported by the llama.cpp server yet - params = append(params, "--mmproj", projectors[0]) - } - defaultThreads := systemInfo.GetOptimalThreadCount() if opts.NumThread > 0 { params = append(params, "--threads", strconv.Itoa(opts.NumThread)) @@ -257,6 +260,34 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt } } slog.Debug("compatible gpu libraries", "compatible", compatible) + exe, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("unable to lookup executable path: %w", err) + } + + if eval, err := filepath.EvalSymlinks(exe); err == nil { + exe = eval + } + + var llamaModel *llama.Model + var textProcessor model.TextProcessor + if envconfig.NewEngine() { + textProcessor, err = model.NewTextProcessor(modelPath) + if err != nil { + // To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner + slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err) + } + } + if textProcessor == nil { + llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true}) + if err != nil { + return nil, err + } + } + + if len(projectors) > 0 && llamaModel != nil { + params = append(params, "--mmproj", projectors[0]) + } // iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc. // adding each library's respective path to the LD_LIBRARY_PATH, until finally running @@ -275,7 +306,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range } finalParams := []string{"runner"} - if envconfig.NewEngine() { + if textProcessor != nil { + // New engine + // TODO - if we have failure to load scenarios, add logic to retry with the old runner finalParams = append(finalParams, "--ollama-engine") } finalParams = append(finalParams, params...) @@ -315,28 +348,20 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt // finally, add the root library path libraryPaths = append(libraryPaths, discover.LibOllamaPath) - exe, err := os.Executable() - if err != nil { - return nil, fmt.Errorf("unable to lookup executable path: %w", err) - } - - if eval, err := filepath.EvalSymlinks(exe); err == nil { - exe = eval - } - - // TODO - once fully switched to the Go runner, load the model here for tokenize/detokenize cgo access s := &llmServer{ - port: port, - cmd: exec.Command(exe, finalParams...), - status: NewStatusWriter(os.Stderr), - options: opts, - modelPath: model, - estimate: estimate, - numParallel: numParallel, - sem: semaphore.NewWeighted(int64(numParallel)), - totalLayers: f.KV().BlockCount() + 1, - gpus: gpus, - done: make(chan error, 1), + port: port, + cmd: exec.Command(exe, finalParams...), + status: NewStatusWriter(os.Stderr), + options: opts, + modelPath: modelPath, + llamaModel: llamaModel, + textProcessor: textProcessor, + estimate: estimate, + numParallel: numParallel, + sem: semaphore.NewWeighted(int64(numParallel)), + totalLayers: f.KV().BlockCount() + 1, + gpus: gpus, + done: make(chan error, 1), } s.cmd.Env = os.Environ() @@ -405,6 +430,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt } err := fmt.Errorf("error starting runner: %v %s", err, msg) if len(compatible) == 0 { + if llamaModel != nil { + llama.FreeModel(llamaModel) + } return nil, err } @@ -701,24 +729,29 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu } if len(req.Format) > 0 { - switch string(req.Format) { - case `null`, `""`: - // Field was set, but "missing" a value. We accept - // these as "not set". - break - case `"json"`: - request["grammar"] = grammarJSON - default: - if req.Format[0] != '{' { - return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format) - } + format := string(req.Format) + if format != `null` && format != `""` { + if s.textProcessor != nil { + // New engine handles this on the backend + request["format"] = req.Format + } else { + // old engine + switch format { + case `"json"`: + request["grammar"] = grammarJSON + default: + if req.Format[0] != '{' { + return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format) + } - // User provided a JSON schema - g := llama.SchemaToGrammar(req.Format) - if g == nil { - return fmt.Errorf("invalid JSON schema in format") + // User provided a JSON schema + g := llama.SchemaToGrammar(req.Format) + if g == nil { + return fmt.Errorf("invalid JSON schema in format") + } + request["grammar"] = string(g) + } } - request["grammar"] = string(g) } } @@ -933,64 +966,25 @@ type TokenizeResponse struct { } func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) { - s.modelLock.Lock() - defer s.modelLock.Unlock() - if s.model != nil { - return s.model.Tokenize(content, false, true) - } + s.llamaModelLock.Lock() + defer s.llamaModelLock.Unlock() - // Make sure the server is ready - status, err := s.getServerStatus(ctx) - if err != nil { - return nil, err - } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable { - return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) + if s.llamaModel != nil { + return s.llamaModel.Tokenize(content, false, true) } - - data, err := json.Marshal(TokenizeRequest{Content: content}) - if err != nil { - return nil, fmt.Errorf("marshaling encode data: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data)) - if err != nil { - return nil, fmt.Errorf("encode request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("do encode request: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode == http.StatusNotFound { - if s.model == nil { - slog.Debug("new runner detected, loading model for cgo tokenization") - m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true}) - if err != nil { - return nil, err - } - s.model = m + if s.textProcessor != nil { + tokens, err := s.textProcessor.Encode(content) + if err != nil { + return nil, err } - return s.model.Tokenize(content, false, true) + toks := make([]int, len(tokens)) + for i, t := range tokens { + toks[i] = int(t) + } + return toks, nil } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read encode request: %w", err) - } - - if resp.StatusCode >= 400 { - log.Printf("llm encode error: %s", body) - return nil, fmt.Errorf("%s", body) - } - - var encoded TokenizeResponse - if err := json.Unmarshal(body, &encoded); err != nil { - return nil, fmt.Errorf("unmarshal encode response: %w", err) - } - - return encoded.Tokens, nil + // not reached + return nil, fmt.Errorf("no tokenizer configured") } type DetokenizeRequest struct { @@ -1002,80 +996,38 @@ type DetokenizeResponse struct { } func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) { - s.modelLock.Lock() - defer s.modelLock.Unlock() - if s.model != nil { + s.llamaModelLock.Lock() + defer s.llamaModelLock.Unlock() + + if s.llamaModel != nil { var resp string for _, token := range tokens { - resp += s.model.TokenToPiece(token) + resp += s.llamaModel.TokenToPiece(token) } return resp, nil } - // Make sure the server is ready - status, err := s.getServerStatus(ctx) - if err != nil { - return "", err - } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable { - return "", fmt.Errorf("unexpected server status: %s", status.ToString()) - } - - data, err := json.Marshal(DetokenizeRequest{Tokens: tokens}) - if err != nil { - return "", fmt.Errorf("marshaling decode data: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data)) - if err != nil { - return "", fmt.Errorf("decode request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return "", fmt.Errorf("do decode request: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode == http.StatusNotFound { - if s.model == nil { - slog.Debug("new runner detected, loading model for cgo tokenization") - m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true}) - if err != nil { - return "", err - } - s.model = m + if s.textProcessor != nil { + toks := make([]int32, len(tokens)) + for i, t := range tokens { + toks[i] = int32(t) } - var resp string - for _, token := range tokens { - resp += s.model.TokenToPiece(token) + content, err := s.textProcessor.Decode(toks) + if err != nil { + return "", err } - return resp, nil + return content, nil } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("read decode request: %w", err) - } - - if resp.StatusCode >= 400 { - log.Printf("llm decode error: %s", body) - return "", fmt.Errorf("%s", body) - } - - var decoded DetokenizeResponse - if err := json.Unmarshal(body, &decoded); err != nil { - return "", fmt.Errorf("unmarshal encode response: %w", err) - } - - return decoded.Content, nil + // not reached + return "", fmt.Errorf("no tokenizer configured") } func (s *llmServer) Close() error { - s.modelLock.Lock() - if s.model != nil { - llama.FreeModel(s.model) - s.model = nil + s.llamaModelLock.Lock() + if s.llamaModel != nil { + llama.FreeModel(s.llamaModel) + s.llamaModel = nil } - s.modelLock.Unlock() + s.llamaModelLock.Unlock() if s.cmd != nil { slog.Debug("stopping llama server") diff --git a/ml/backend.go b/ml/backend.go index 83b7a8c9..3ef8a1ac 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -24,7 +24,6 @@ type Backend interface { Config() Config Get(name string) Tensor NewContext() Context - SystemInfo() string } // BackendCacheConfig should be implemented by backends that need special output diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index f4948fca..2d8ddf99 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1,27 +1,11 @@ package ggml -/* -#cgo CPPFLAGS: -I${SRCDIR}/ggml/include -#include -#include -#include "ggml.h" -#include "ggml-cpu.h" -#include "ggml-backend.h" -static struct ggml_backend_feature * getBackendFeatures(void *fp, ggml_backend_reg_t reg) {return ((ggml_backend_get_features_t)(fp))(reg);} -static struct ggml_backend_feature * getNextBackendFeatures(struct ggml_backend_feature * feature) { return &feature[1];} - -typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER; -COMPILER inline get_compiler() { -#if defined(__clang__) - return COMP_CLANG; -#elif defined(__GNUC__) - return COMP_GCC; -#else - return UNKNOWN_COMPILER; -#endif -} - -*/ +// #cgo CPPFLAGS: -I${SRCDIR}/ggml/include +// #include +// #include +// #include "ggml.h" +// #include "ggml-cpu.h" +// #include "ggml-backend.h" import "C" import ( @@ -729,34 +713,3 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) } } - -func (b *Backend) SystemInfo() string { - var compiler string - switch C.get_compiler() { - case C.COMP_UNKNOWN: - compiler = "cgo(unknown_compiler)" - case C.COMP_GCC: - compiler = "cgo(gcc)" - case C.COMP_CLANG: - compiler = "cgo(clang)" - } - - var s string - for i := range C.ggml_backend_reg_count() { - reg := C.ggml_backend_reg_get(i) - fName := C.CString("ggml_backend_get_features") - defer C.free(unsafe.Pointer(fName)) - get_features_fn := C.ggml_backend_reg_get_proc_address(reg, fName) - if get_features_fn != nil { - s += C.GoString(C.ggml_backend_reg_name(reg)) - s += " : " - for features := C.getBackendFeatures(get_features_fn, reg); features.name != nil; features = C.getNextBackendFeatures(features) { - s += C.GoString(features.name) - s += " = " - s += C.GoString(features.value) - s += " | " - } - } - } - return s + compiler -} diff --git a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp index c854e6bb..799af5f3 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp @@ -484,33 +484,29 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, } fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); for (const auto & entry : dir_it) { - try { - if (entry.is_regular_file()) { - std::string filename = entry.path().filename().string(); - std::string ext = entry.path().extension().string(); - if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { - dl_handle_ptr handle { dl_load_library(entry.path()) }; - if (!handle) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str()); - continue; - } + if (entry.is_regular_file()) { + std::string filename = entry.path().filename().string(); + std::string ext = entry.path().extension().string(); + if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { + dl_handle_ptr handle { dl_load_library(entry.path()) }; + if (!handle) { + GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str()); + continue; + } - auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); - if (!score_fn) { - GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str()); - continue; - } + auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); + if (!score_fn) { + GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str()); + continue; + } - int s = score_fn(); - GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s); - if (s > best_score) { - best_score = s; - best_path = entry.path(); - } + int s = score_fn(); + GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s); + if (s > best_score) { + best_score = s; + best_path = entry.path(); } } - } catch (const std::exception & e) { - GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what()); } } } @@ -533,6 +529,14 @@ void ggml_backend_load_all() { ggml_backend_load_all_from_path(nullptr); } +static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) { + try { + ggml_backend_load_best(name, silent, user_search_path); + } catch (const std::exception & e) { + GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what()); + } +} + void ggml_backend_load_all_from_path(const char * dir_path) { #ifdef NDEBUG bool silent = true; @@ -540,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) { bool silent = false; #endif - ggml_backend_load_best("blas", silent, dir_path); - ggml_backend_load_best("cann", silent, dir_path); - ggml_backend_load_best("cuda", silent, dir_path); - ggml_backend_load_best("hip", silent, dir_path); - ggml_backend_load_best("kompute", silent, dir_path); - ggml_backend_load_best("metal", silent, dir_path); - ggml_backend_load_best("rpc", silent, dir_path); - ggml_backend_load_best("sycl", silent, dir_path); - ggml_backend_load_best("vulkan", silent, dir_path); - ggml_backend_load_best("opencl", silent, dir_path); - ggml_backend_load_best("musa", silent, dir_path); - ggml_backend_load_best("cpu", silent, dir_path); + ggml_backend_try_load_best("blas", silent, dir_path); + ggml_backend_try_load_best("cann", silent, dir_path); + ggml_backend_try_load_best("cuda", silent, dir_path); + ggml_backend_try_load_best("hip", silent, dir_path); + ggml_backend_try_load_best("kompute", silent, dir_path); + ggml_backend_try_load_best("metal", silent, dir_path); + ggml_backend_try_load_best("rpc", silent, dir_path); + ggml_backend_try_load_best("sycl", silent, dir_path); + ggml_backend_try_load_best("vulkan", silent, dir_path); + ggml_backend_try_load_best("opencl", silent, dir_path); + ggml_backend_try_load_best("musa", silent, dir_path); + ggml_backend_try_load_best("cpu", silent, dir_path); // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend const char * backend_path = std::getenv("GGML_BACKEND_PATH"); if (backend_path) { diff --git a/ml/backend/ggml/ggml/src/ggml.go b/ml/backend/ggml/ggml/src/ggml.go index 85c693eb..afc1e1ed 100644 --- a/ml/backend/ggml/ggml/src/ggml.go +++ b/ml/backend/ggml/ggml/src/ggml.go @@ -7,6 +7,20 @@ package ggml // #include // #include "ggml-backend.h" // extern void sink(int level, char *text, void *user_data); +// static struct ggml_backend_feature * first_feature(ggml_backend_get_features_t fp, ggml_backend_reg_t reg) { return fp(reg); } +// static struct ggml_backend_feature * next_feature(struct ggml_backend_feature * feature) { return &feature[1]; } +/* +typedef enum { COMPILER_CLANG, COMPILER_GNUC, COMPILER_UNKNOWN } COMPILER; +static COMPILER compiler_name(void) { +#if defined(__clang__) + return COMPILER_CLANG; +#elif defined(__GNUC__) + return COMPILER_GNUC; +#else + return COMPILER_UNKNOWN; +#endif +} +*/ import "C" import ( @@ -16,6 +30,7 @@ import ( "os" "path/filepath" "runtime" + "strconv" "strings" "sync" "unsafe" @@ -90,4 +105,43 @@ var OnceLoad = sync.OnceFunc(func() { visited[abspath] = struct{}{} } } + + slog.Info("system", "", system{}) }) + +type system struct{} + +func (system) LogValue() slog.Value { + var attrs []slog.Attr + names := make(map[string]int) + for i := range C.ggml_backend_dev_count() { + r := C.ggml_backend_dev_backend_reg(C.ggml_backend_dev_get(i)) + + func() { + fName := C.CString("ggml_backend_get_features") + defer C.free(unsafe.Pointer(fName)) + + if fn := C.ggml_backend_reg_get_proc_address(r, fName); fn != nil { + var features []any + for f := C.first_feature(C.ggml_backend_get_features_t(fn), r); f.name != nil; f = C.next_feature(f) { + features = append(features, C.GoString(f.name), C.GoString(f.value)) + } + + name := C.GoString(C.ggml_backend_reg_name(r)) + attrs = append(attrs, slog.Group(name+"."+strconv.Itoa(names[name]), features...)) + names[name] += 1 + } + }() + } + + switch C.compiler_name() { + case C.COMPILER_CLANG: + attrs = append(attrs, slog.String("compiler", "cgo(clang)")) + case C.COMPILER_GNUC: + attrs = append(attrs, slog.String("compiler", "cgo(gcc)")) + default: + attrs = append(attrs, slog.String("compiler", "cgo(unknown)")) + } + + return slog.GroupValue(attrs...) +} diff --git a/model/model.go b/model/model.go index 16020b35..f8ed8741 100644 --- a/model/model.go +++ b/model/model.go @@ -16,6 +16,7 @@ import ( _ "golang.org/x/image/tiff" _ "golang.org/x/image/webp" + fs "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" _ "github.com/ollama/ollama/ml/backend" @@ -100,6 +101,36 @@ func New(modelPath string, params ml.BackendParams) (Model, error) { return m, nil } +func NewTextProcessor(s string) (TextProcessor, error) { + r, err := os.Open(s) + if err != nil { + return nil, err + } + defer r.Close() + meta, _, err := fs.Decode(r, -1) + if err != nil { + return nil, err + } + return getTextProcessor(meta.KV()) +} + +func getTextProcessor(kv fs.KV) (TextProcessor, error) { + arch := kv.Architecture() + f, ok := models[arch] + if !ok { + return nil, fmt.Errorf("unsupported model architecture %q", arch) + } + m, err := f(kv) + if err != nil { + return nil, err + } + tp, ok := m.(TextProcessor) + if !ok { + return nil, fmt.Errorf("%v is not a TextProcessor", m) + } + return tp, nil +} + func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { t := v.Type() diff --git a/model/model_test.go b/model/model_test.go index 02b8aa3c..8761817e 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -3,9 +3,11 @@ package model import ( "reflect" "slices" + "strings" "testing" "github.com/google/go-cmp/cmp" + fs "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/nn" @@ -134,3 +136,40 @@ func TestPopulateFieldsAlternateName(t *testing.T) { t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff) } } + +func TestGetTextProcessor(t *testing.T) { + tp, err := getTextProcessor(fs.KV{}) + if err == nil { + t.Error("expected error") + } else if !strings.Contains(err.Error(), "unsupported model architecture") { + t.Errorf("unexpected error: %v", err) + } else if tp != nil { + t.Error("expected nil tp") + } + + models["dummy"] = func(ml.Config) (Model, error) { + return notTextProcessorModel{}, nil + } + tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"}) + if err == nil { + t.Error("expected error") + } else if !strings.Contains(err.Error(), "not a TextProcessor") { + t.Errorf("unexpected error: %v", err) + } else if tp != nil { + t.Error("expected nil tp") + } +} + +type notTextProcessorModel struct{} + +func (notTextProcessorModel) Forward(ml.Context, Options) (ml.Tensor, error) { + panic("unimplemented") +} + +func (notTextProcessorModel) Backend() ml.Backend { + panic("unimplemented") +} + +func (notTextProcessorModel) Config() config { + panic("unimplemented") +} diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 9bf6f497..2f254a28 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -1,7 +1,9 @@ package llama import ( + "fmt" "math" + "strings" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" @@ -29,6 +31,10 @@ type Model struct { } func New(c ml.Config) (model.Model, error) { + if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") { + return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model")) + } + m := Model{ BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 743f4c32..8fee0cdb 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -1,6 +1,8 @@ package mllama import ( + "fmt" + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" @@ -25,6 +27,10 @@ const ( ) func New(c ml.Config) (model.Model, error) { + // Verify unified config + if c.Uint("vision.block_count") == 0 { + return nil, fmt.Errorf("non-unified vision model not supported") + } m := Model{ BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 82880c98..8662afc1 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -931,7 +931,6 @@ func Execute(args []string) error { slog.Info("starting go runner") llama.BackendInit() - slog.Info("system", "info", llama.PrintSystemInfo(), "threads", *threads) server := &Server{ batchSize: *batchSize, diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 5705931a..1a4bbf19 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -786,8 +786,6 @@ func (s *Server) loadModel( panic(err) } - slog.Info("system", "info", s.model.Backend().SystemInfo(), "threads", params.NumThreads) - // TODO(jessegross): LoRA loading if lpath.String() != "" { panic("loras are not yet implemented") diff --git a/scripts/install.sh b/scripts/install.sh index 9e146e50..9c232400 100644 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -77,11 +77,12 @@ if [ -d "$OLLAMA_INSTALL_DIR/lib/ollama" ] ; then fi status "Installing ollama to $OLLAMA_INSTALL_DIR" $SUDO install -o0 -g0 -m755 -d $BINDIR -$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR" +$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama" status "Downloading Linux ${ARCH} bundle" curl --fail --show-error --location --progress-bar \ "https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \ $SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR" + if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then status "Making ollama accessible in the PATH in $BINDIR" $SUDO ln -sf "$OLLAMA_INSTALL_DIR/ollama" "$BINDIR/ollama" diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index 82a8bbca..007de5e8 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -27,6 +27,7 @@ import ( "slices" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -73,19 +74,22 @@ const ( DefaultMaxChunkSize = 8 << 20 ) -// DefaultCache returns a new disk cache for storing models. If the -// OLLAMA_MODELS environment variable is set, it uses that directory; -// otherwise, it uses $HOME/.ollama/models. -func DefaultCache() (*blob.DiskCache, error) { +var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) { dir := os.Getenv("OLLAMA_MODELS") if dir == "" { - home, err := os.UserHomeDir() - if err != nil { - return nil, err - } + home, _ := os.UserHomeDir() + home = cmp.Or(home, ".") dir = filepath.Join(home, ".ollama", "models") } return blob.Open(dir) +}) + +// DefaultCache returns the default cache used by the registry. It is +// configured from the OLLAMA_MODELS environment variable, or defaults to +// $HOME/.ollama/models, or, if an error occurs obtaining the home directory, +// it uses the current working directory. +func DefaultCache() (*blob.DiskCache, error) { + return defaultCache() } // Error is the standard error returned by Ollama APIs. It can represent a @@ -168,6 +172,10 @@ func CompleteName(name string) string { // Registry is a client for performing push and pull operations against an // Ollama registry. type Registry struct { + // Cache is the cache used to store models. If nil, [DefaultCache] is + // used. + Cache *blob.DiskCache + // UserAgent is the User-Agent header to send with requests to the // registry. If empty, the User-Agent is determined by HTTPClient. UserAgent string @@ -206,18 +214,28 @@ type Registry struct { // It is only used when a layer is larger than [MaxChunkingThreshold]. MaxChunkSize int64 - // Mask, if set, is the name used to convert non-fully qualified - // names to fully qualified names. If empty, the default mask - // ("registry.ollama.ai/library/_:latest") is used. + // Mask, if set, is the name used to convert non-fully qualified names + // to fully qualified names. If empty, [DefaultMask] is used. Mask string } -func (r *Registry) completeName(name string) names.Name { +func (r *Registry) cache() (*blob.DiskCache, error) { + if r.Cache != nil { + return r.Cache, nil + } + return defaultCache() +} + +func (r *Registry) parseName(name string) (names.Name, error) { mask := defaultMask if r.Mask != "" { mask = names.Parse(r.Mask) } - return names.Merge(names.Parse(name), mask) + n := names.Merge(names.Parse(name), mask) + if !n.IsFullyQualified() { + return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name) + } + return n, nil } // DefaultRegistry returns a new Registry configured from the environment. The @@ -278,12 +296,17 @@ type PushParams struct { } // Push pushes the model with the name in the cache to the remote registry. -func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error { +func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { if p == nil { p = &PushParams{} } - m, err := r.ResolveLocal(c, cmp.Or(p.From, name)) + c, err := r.cache() + if err != nil { + return err + } + + m, err := r.ResolveLocal(cmp.Or(p.From, name)) if err != nil { return err } @@ -306,7 +329,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p * t := traceFromContext(ctx) - scheme, n, _, err := parseName(name, r.Mask) + scheme, n, _, err := r.parseNameExtended(name) if err != nil { // This should never happen since ResolveLocal should have // already validated the name. @@ -399,8 +422,8 @@ func canRetry(err error) bool { // chunks of the specified size, and then reassembled and verified. This is // typically slower than splitting the model up across layers, and is mostly // utilized for layers of type equal to "application/vnd.ollama.image". -func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error { - scheme, n, _, err := parseName(name, r.Mask) +func (r *Registry) Pull(ctx context.Context, name string) error { + scheme, n, _, err := r.parseNameExtended(name) if err != nil { return err } @@ -413,6 +436,11 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err return fmt.Errorf("%w: no layers", ErrManifestInvalid) } + c, err := r.cache() + if err != nil { + return err + } + exists := func(l *Layer) bool { info, err := c.Get(l.Digest) return err == nil && info.Size == l.Size @@ -550,10 +578,14 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified // before attempting to unlink the model. -func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) { - n := r.completeName(name) - if !n.IsFullyQualified() { - return false, fmt.Errorf("%w: %q", ErrNameInvalid, name) +func (r *Registry) Unlink(name string) (ok bool, _ error) { + n, err := r.parseName(name) + if err != nil { + return false, err + } + c, err := r.cache() + if err != nil { + return false, err } return c.Unlink(n.String()) } @@ -626,14 +658,18 @@ type Layer struct { Size int64 `json:"size"` } -// ResolveLocal resolves a name to a Manifest in the local cache. The name is -// parsed using [names.Split] but the scheme is ignored. -func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) { - _, n, d, err := parseName(name, r.Mask) +// ResolveLocal resolves a name to a Manifest in the local cache. +func (r *Registry) ResolveLocal(name string) (*Manifest, error) { + _, n, d, err := r.parseNameExtended(name) + if err != nil { + return nil, err + } + c, err := r.cache() if err != nil { return nil, err } if !d.IsValid() { + // No digest, so resolve the manifest by name. d, err = c.Resolve(n.String()) if err != nil { return nil, err @@ -655,7 +691,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro // Resolve resolves a name to a Manifest in the remote registry. func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) { - scheme, n, d, err := parseName(name, r.Mask) + scheme, n, d, err := r.parseNameExtended(name) if err != nil { return nil, err } @@ -859,7 +895,7 @@ var supportedSchemes = []string{ var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", ")) -// parseName parses and validates an extended name, returning the scheme, name, +// parseNameExtended parses and validates an extended name, returning the scheme, name, // and digest. // // If the scheme is empty, scheme will be "https". If an unsupported scheme is @@ -870,8 +906,8 @@ var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Jo // // If the name is not, once merged with the mask, fully qualified, // [ErrNameInvalid] wrapped with a display friendly message is returned. -func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) { - scheme, name, digest := names.Split(s) +func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) { + scheme, name, digest := splitExtended(s) scheme = cmp.Or(scheme, "https") if !slices.Contains(supportedSchemes, scheme) { err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage) @@ -894,13 +930,33 @@ func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Diges } } - maskName := defaultMask - if mask != "" { - maskName = names.Parse(mask) - } - n := names.Merge(names.Parse(name), maskName) - if !n.IsFullyQualified() { - return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s) + n, err := r.parseName(name) + if err != nil { + return "", names.Name{}, blob.Digest{}, err } return scheme, n, d, nil } + +// splitExtended splits an extended name string into its scheme, name, and digest +// parts. +// +// Examples: +// +// http://ollama.com/bmizerany/smol:latest@digest +// https://ollama.com/bmizerany/smol:latest +// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme. +// model@digest +// @digest +func splitExtended(s string) (scheme, name, digest string) { + i := strings.Index(s, "://") + if i >= 0 { + scheme = s[:i] + s = s[i+3:] + } + i = strings.LastIndex(s, "@") + if i >= 0 { + digest = s[i+1:] + s = s[:i] + } + return scheme, s, digest +} diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index 20a1f159..b9b4271b 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -2,6 +2,7 @@ package ollama import ( "bytes" + "cmp" "context" "encoding/json" "errors" @@ -72,6 +73,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error // To simulate a network error, pass a handler that returns a 499 status code. func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { t.Helper() + c, err := blob.Open(t.TempDir()) if err != nil { t.Fatal(err) @@ -85,13 +87,14 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { } r := &Registry{ + Cache: c, HTTPClient: &http.Client{ Transport: recordRoundTripper(h), }, } link := func(name string, manifest string) { - _, n, _, err := parseName(name, r.Mask) + n, err := r.parseName(name) if err != nil { panic(err) } @@ -151,55 +154,55 @@ func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) { } func TestPushZero(t *testing.T) { - rc, c := newClient(t, okHandler) - err := rc.Push(t.Context(), c, "empty", nil) + rc, _ := newClient(t, okHandler) + err := rc.Push(t.Context(), "empty", nil) if !errors.Is(err, ErrManifestInvalid) { t.Errorf("err = %v; want %v", err, ErrManifestInvalid) } } func TestPushSingle(t *testing.T) { - rc, c := newClient(t, okHandler) - err := rc.Push(t.Context(), c, "single", nil) + rc, _ := newClient(t, okHandler) + err := rc.Push(t.Context(), "single", nil) testutil.Check(t, err) } func TestPushMultiple(t *testing.T) { - rc, c := newClient(t, okHandler) - err := rc.Push(t.Context(), c, "multiple", nil) + rc, _ := newClient(t, okHandler) + err := rc.Push(t.Context(), "multiple", nil) testutil.Check(t, err) } func TestPushNotFound(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { t.Errorf("unexpected request: %v", r) }) - err := rc.Push(t.Context(), c, "notfound", nil) + err := rc.Push(t.Context(), "notfound", nil) if !errors.Is(err, fs.ErrNotExist) { t.Errorf("err = %v; want %v", err, fs.ErrNotExist) } } func TestPushNullLayer(t *testing.T) { - rc, c := newClient(t, nil) - err := rc.Push(t.Context(), c, "null", nil) + rc, _ := newClient(t, nil) + err := rc.Push(t.Context(), "null", nil) if err == nil || !strings.Contains(err.Error(), "invalid manifest") { t.Errorf("err = %v; want invalid manifest", err) } } func TestPushSizeMismatch(t *testing.T) { - rc, c := newClient(t, nil) + rc, _ := newClient(t, nil) ctx, _ := withTraceUnexpected(t.Context()) - got := rc.Push(ctx, c, "sizemismatch", nil) + got := rc.Push(ctx, "sizemismatch", nil) if got == nil || !strings.Contains(got.Error(), "size mismatch") { t.Errorf("err = %v; want size mismatch", got) } } func TestPushInvalid(t *testing.T) { - rc, c := newClient(t, nil) - err := rc.Push(t.Context(), c, "invalid", nil) + rc, _ := newClient(t, nil) + err := rc.Push(t.Context(), "invalid", nil) if err == nil || !strings.Contains(err.Error(), "invalid manifest") { t.Errorf("err = %v; want invalid manifest", err) } @@ -207,7 +210,7 @@ func TestPushInvalid(t *testing.T) { func TestPushExistsAtRemote(t *testing.T) { var pushed bool - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/uploads/") { if !pushed { // First push. Return an uploadURL. @@ -235,35 +238,35 @@ func TestPushExistsAtRemote(t *testing.T) { check := testutil.Checker(t) - err := rc.Push(ctx, c, "single", nil) + err := rc.Push(ctx, "single", nil) check(err) if !errors.Is(errors.Join(errs...), nil) { t.Errorf("errs = %v; want %v", errs, []error{ErrCached}) } - err = rc.Push(ctx, c, "single", nil) + err = rc.Push(ctx, "single", nil) check(err) } func TestPushRemoteError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/blobs/") { w.WriteHeader(500) io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`) return } }) - got := rc.Push(t.Context(), c, "single", nil) + got := rc.Push(t.Context(), "single", nil) checkErrCode(t, got, 500, "blob_error") } func TestPushLocationError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Location", ":///x") w.WriteHeader(http.StatusAccepted) }) - got := rc.Push(t.Context(), c, "single", nil) + got := rc.Push(t.Context(), "single", nil) wantContains := "invalid upload URL" if got == nil || !strings.Contains(got.Error(), wantContains) { t.Errorf("err = %v; want to contain %v", got, wantContains) @@ -271,14 +274,14 @@ func TestPushLocationError(t *testing.T) { } func TestPushUploadRoundtripError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if r.Host == "blob.store" { w.WriteHeader(499) // force RoundTrip error on upload return } w.Header().Set("Location", "http://blob.store/blobs/123") }) - got := rc.Push(t.Context(), c, "single", nil) + got := rc.Push(t.Context(), "single", nil) if !errors.Is(got, errRoundTrip) { t.Errorf("got = %v; want %v", got, errRoundTrip) } @@ -294,20 +297,20 @@ func TestPushUploadFileOpenError(t *testing.T) { os.Remove(c.GetFile(l.Digest)) }, }) - got := rc.Push(ctx, c, "single", nil) + got := rc.Push(ctx, "single", nil) if !errors.Is(got, fs.ErrNotExist) { t.Errorf("got = %v; want fs.ErrNotExist", got) } } func TestPushCommitRoundtripError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/blobs/") { panic("unexpected") } w.WriteHeader(499) // force RoundTrip error }) - err := rc.Push(t.Context(), c, "zero", nil) + err := rc.Push(t.Context(), "zero", nil) if !errors.Is(err, errRoundTrip) { t.Errorf("err = %v; want %v", err, errRoundTrip) } @@ -321,8 +324,8 @@ func checkNotExist(t *testing.T, err error) { } func TestRegistryPullInvalidName(t *testing.T) { - rc, c := newClient(t, nil) - err := rc.Pull(t.Context(), c, "://") + rc, _ := newClient(t, nil) + err := rc.Pull(t.Context(), "://") if !errors.Is(err, ErrNameInvalid) { t.Errorf("err = %v; want %v", err, ErrNameInvalid) } @@ -337,10 +340,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) { } for _, resp := range cases { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, resp) }) - err := rc.Pull(t.Context(), c, "x") + err := rc.Pull(t.Context(), "x") if !errors.Is(err, ErrManifestInvalid) { t.Errorf("err = %v; want invalid manifest", err) } @@ -363,18 +366,18 @@ func TestRegistryPullNotCached(t *testing.T) { }) // Confirm that the layer does not exist locally - _, err := rc.ResolveLocal(c, "model") + _, err := rc.ResolveLocal("model") checkNotExist(t, err) _, err = c.Get(d) checkNotExist(t, err) - err = rc.Pull(t.Context(), c, "model") + err = rc.Pull(t.Context(), "model") check(err) mw, err := rc.Resolve(t.Context(), "model") check(err) - mg, err := rc.ResolveLocal(c, "model") + mg, err := rc.ResolveLocal("model") check(err) if !reflect.DeepEqual(mw, mg) { t.Errorf("mw = %v; mg = %v", mw, mg) @@ -399,7 +402,7 @@ func TestRegistryPullNotCached(t *testing.T) { func TestRegistryPullCached(t *testing.T) { cached := blob.DigestFromBytes("exists") - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/blobs/") { w.WriteHeader(499) // should not be called return @@ -422,7 +425,7 @@ func TestRegistryPullCached(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() - err := rc.Pull(ctx, c, "single") + err := rc.Pull(ctx, "single") testutil.Check(t, err) want := []int64{6} @@ -435,30 +438,30 @@ func TestRegistryPullCached(t *testing.T) { } func TestRegistryPullManifestNotFound(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) }) - err := rc.Pull(t.Context(), c, "notfound") + err := rc.Pull(t.Context(), "notfound") checkErrCode(t, err, 404, "") } func TestRegistryPullResolveRemoteError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) io.WriteString(w, `{"errors":[{"code":"an_error"}]}`) }) - err := rc.Pull(t.Context(), c, "single") + err := rc.Pull(t.Context(), "single") checkErrCode(t, err, 500, "an_error") } func TestRegistryPullResolveRoundtripError(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/manifests/") { w.WriteHeader(499) // force RoundTrip error return } }) - err := rc.Pull(t.Context(), c, "single") + err := rc.Pull(t.Context(), "single") if !errors.Is(err, errRoundTrip) { t.Errorf("err = %v; want %v", err, errRoundTrip) } @@ -511,7 +514,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) { // Check that we pull all layers that we can. - err := rc.Pull(ctx, c, "mixed") + err := rc.Pull(ctx, "mixed") if err != nil { t.Fatal(err) } @@ -529,7 +532,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) { } func TestRegistryPullChunking(t *testing.T) { - rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { + rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range")) if r.URL.Host != "blob.store" { // The production registry redirects to the blob store. @@ -567,7 +570,7 @@ func TestRegistryPullChunking(t *testing.T) { }, }) - err := rc.Pull(ctx, c, "remote") + err := rc.Pull(ctx, "remote") testutil.Check(t, err) want := []int64{0, 3, 6} @@ -709,25 +712,16 @@ func TestErrorUnmarshal(t *testing.T) { // // It is only for testing error messages, not that all invalids and valids are // covered. Those are in other tests for names.Name and blob.Digest. -func TestParseNameErrors(t *testing.T) { +func TestParseNameExtendedErrors(t *testing.T) { cases := []struct { name string err error want string - }{ - {"x", nil, ""}, - {"x@", nil, ""}, - - {"", ErrNameInvalid, `invalid or missing name: ""`}, - {"://", ErrNameInvalid, `invalid or missing name: "://"`}, - {"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`}, - - {"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`}, - {"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`}, - } + }{} + var r Registry for _, tt := range cases { - _, _, _, err := parseName(tt.name, DefaultMask) + _, _, _, err := r.parseNameExtended(tt.name) if !errors.Is(err, tt.err) { t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err) } @@ -736,3 +730,89 @@ func TestParseNameErrors(t *testing.T) { } } } + +func TestParseNameExtended(t *testing.T) { + cases := []struct { + in string + scheme string + name string + digest string + err string + }{ + {in: "http://m", scheme: "http", name: "m"}, + {in: "https+insecure://m", scheme: "https+insecure", name: "m"}, + {in: "http+insecure://m", err: "unsupported scheme"}, + + {in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"}, + + {in: "", err: "invalid or missing name"}, + {in: "m", scheme: "https", name: "m"}, + {in: "://", err: "invalid or missing name"}, + {in: "@sha256:deadbeef", err: "invalid digest"}, + {in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"}, + } + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + var r Registry + scheme, n, digest, err := r.parseNameExtended(tt.in) + if err != nil { + if tt.err == "" { + t.Errorf("err = %v; want nil", err) + } else if !strings.Contains(err.Error(), tt.err) { + t.Errorf("err = %v; want %q", err, tt.err) + } + } else if tt.err != "" { + t.Errorf("err = nil; want %q", tt.err) + } + if err == nil && !n.IsFullyQualified() { + t.Errorf("name = %q; want fully qualified", n) + } + + if scheme != tt.scheme { + t.Errorf("scheme = %q; want %q", scheme, tt.scheme) + } + + // smoke-test name is superset of tt.name + if !strings.Contains(n.String(), tt.name) { + t.Errorf("name = %q; want %q", n, tt.name) + } + + tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String()) + if digest.String() != tt.digest { + t.Errorf("digest = %q; want %q", digest, tt.digest) + } + }) + } +} + +func TestUnlink(t *testing.T) { + t.Run("found by name", func(t *testing.T) { + rc, _ := newClient(t, nil) + + // confirm linked + _, err := rc.ResolveLocal("single") + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + // unlink + _, err = rc.Unlink("single") + testutil.Check(t, err) + + // confirm unlinked + _, err = rc.ResolveLocal("single") + if !errors.Is(err, fs.ErrNotExist) { + t.Errorf("err = %v; want fs.ErrNotExist", err) + } + }) + t.Run("not found by name", func(t *testing.T) { + rc, _ := newClient(t, nil) + ok, err := rc.Unlink("manifestNotFound") + if err != nil { + t.Fatal(err) + } + if ok { + t.Error("expected not found") + } + }) +} diff --git a/server/internal/client/ollama/trace.go b/server/internal/client/ollama/trace.go index 8e53040a..e300870b 100644 --- a/server/internal/client/ollama/trace.go +++ b/server/internal/client/ollama/trace.go @@ -6,6 +6,9 @@ import ( // Trace is a set of functions that are called to report progress during blob // downloads and uploads. +// +// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push] +// and [Registry.Pull]. type Trace struct { // Update is called during [Registry.Push] and [Registry.Pull] to // report the progress of blob uploads and downloads. diff --git a/server/internal/cmd/opp/opp.go b/server/internal/cmd/opp/opp.go index c21e71d5..6976927c 100644 --- a/server/internal/cmd/opp/opp.go +++ b/server/internal/cmd/opp/opp.go @@ -63,25 +63,28 @@ func main() { } flag.Parse() - c, err := ollama.DefaultCache() - if err != nil { - log.Fatal(err) - } - - rc, err := ollama.DefaultRegistry() - if err != nil { - log.Fatal(err) - } - ctx := context.Background() - err = func() error { + err := func() error { switch cmd := flag.Arg(0); cmd { case "pull": - return cmdPull(ctx, rc, c) + rc, err := ollama.DefaultRegistry() + if err != nil { + log.Fatal(err) + } + + return cmdPull(ctx, rc) case "push": - return cmdPush(ctx, rc, c) + rc, err := ollama.DefaultRegistry() + if err != nil { + log.Fatal(err) + } + return cmdPush(ctx, rc) case "import": + c, err := ollama.DefaultCache() + if err != nil { + log.Fatal(err) + } return cmdImport(ctx, c) default: if cmd == "" { @@ -99,7 +102,7 @@ func main() { } } -func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error { +func cmdPull(ctx context.Context, rc *ollama.Registry) error { model := flag.Arg(1) if model == "" { flag.Usage() @@ -145,7 +148,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error errc := make(chan error) go func() { - errc <- rc.Pull(ctx, c, model) + errc <- rc.Pull(ctx, model) }() t := time.NewTicker(time.Second) @@ -161,7 +164,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error } } -func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error { +func cmdPush(ctx context.Context, rc *ollama.Registry) error { args := flag.Args()[1:] flag := flag.NewFlagSet("push", flag.ExitOnError) flagFrom := flag.String("from", "", "Use the manifest from a model by another name.") @@ -177,7 +180,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error } from := cmp.Or(*flagFrom, model) - m, err := rc.ResolveLocal(c, from) + m, err := rc.ResolveLocal(from) if err != nil { return err } @@ -203,7 +206,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error }, }) - return rc.Push(ctx, c, model, &ollama.PushParams{ + return rc.Push(ctx, model, &ollama.PushParams{ From: from, }) } diff --git a/server/internal/internal/backoff/backoff_test.go b/server/internal/internal/backoff/backoff_test.go index bb8438a7..11ace22a 100644 --- a/server/internal/internal/backoff/backoff_test.go +++ b/server/internal/internal/backoff/backoff_test.go @@ -1,3 +1,5 @@ +//go:build goexperiment.synctest + package backoff import ( diff --git a/server/internal/internal/syncs/line_test.go b/server/internal/internal/syncs/line_test.go index d5216026..94114a56 100644 --- a/server/internal/internal/syncs/line_test.go +++ b/server/internal/internal/syncs/line_test.go @@ -1,3 +1,5 @@ +//go:build goexperiment.synctest + package syncs import ( diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go index 6ea590a7..4d44aa8d 100644 --- a/server/internal/registry/server.go +++ b/server/internal/registry/server.go @@ -11,7 +11,6 @@ import ( "log/slog" "net/http" - "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" ) @@ -27,12 +26,15 @@ import ( // directly to the blob disk cache. type Local struct { Client *ollama.Registry // required - Cache *blob.DiskCache // required Logger *slog.Logger // required // Fallback, if set, is used to handle requests that are not handled by // this handler. Fallback http.Handler + + // Prune, if set, is called to prune the local disk cache after a model + // is deleted. + Prune func() error // optional } // serverError is like ollama.Error, but with a Status field for the HTTP @@ -199,14 +201,17 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error { if err != nil { return err } - ok, err := s.Client.Unlink(s.Cache, p.model()) + ok, err := s.Client.Unlink(p.model()) if err != nil { return err } if !ok { return &serverError{404, "not_found", "model not found"} } - return nil + if s.Prune == nil { + return nil + } + return s.Prune() } func decodeUserJSON[T any](r io.Reader) (T, error) { diff --git a/server/internal/registry/server_test.go b/server/internal/registry/server_test.go index 7ba13d50..e44d88c0 100644 --- a/server/internal/registry/server_test.go +++ b/server/internal/registry/server_test.go @@ -42,10 +42,10 @@ func newTestServer(t *testing.T) *Local { t.Fatal(err) } rc := &ollama.Registry{ + Cache: c, HTTPClient: panicOnRoundTrip, } l := &Local{ - Cache: c, Client: rc, Logger: testutil.Slogger(t), } @@ -87,7 +87,7 @@ func TestServerDelete(t *testing.T) { s := newTestServer(t) - _, err := s.Client.ResolveLocal(s.Cache, "smol") + _, err := s.Client.ResolveLocal("smol") check(err) got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`) @@ -95,7 +95,7 @@ func TestServerDelete(t *testing.T) { t.Fatalf("Code = %d; want 200", got.Code) } - _, err = s.Client.ResolveLocal(s.Cache, "smol") + _, err = s.Client.ResolveLocal("smol") if err == nil { t.Fatal("expected smol to have been deleted") } diff --git a/server/prompt.go b/server/prompt.go index 233dffd6..5b5b958f 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -10,7 +10,6 @@ import ( "strings" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/template" @@ -93,7 +92,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. var imgData llm.ImageData if isMllama { - if envconfig.NewEngine() { + if len(m.ProjectorPaths) == 0 { imgData = llm.ImageData{ ID: len(images), Data: i, diff --git a/server/routes.go b/server/routes.go index ff42000f..73e94dc6 100644 --- a/server/routes.go +++ b/server/routes.go @@ -34,7 +34,6 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/openai" - "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" @@ -206,7 +205,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { images := make([]llm.ImageData, len(req.Images)) for i := range req.Images { - if isMllama && !envconfig.NewEngine() { + if isMllama && len(model.ProjectorPaths) > 0 { data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i])) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"}) @@ -1129,7 +1128,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc { } } -func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) { +func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { corsConfig := cors.DefaultConfig() corsConfig.AllowWildcard = true corsConfig.AllowBrowserExtensions = true @@ -1197,10 +1196,11 @@ func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Ha // wrap old with new rs := ®istry.Local{ - Cache: c, Client: rc, Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default() Fallback: r, + + Prune: PruneLayers, } return rs, nil @@ -1258,16 +1258,12 @@ func Serve(ln net.Listener) error { s := &Server{addr: ln.Addr()} - c, err := ollama.DefaultCache() - if err != nil { - return err - } rc, err := ollama.DefaultRegistry() if err != nil { return err } - h, err := s.GenerateRoutes(c, rc) + h, err := s.GenerateRoutes(rc) if err != nil { return err } diff --git a/server/routes_test.go b/server/routes_test.go index 0dd782f4..e13c4b59 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -23,7 +23,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/openai" - "github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -490,11 +489,6 @@ func TestRoutes(t *testing.T) { modelsDir := t.TempDir() t.Setenv("OLLAMA_MODELS", modelsDir) - c, err := blob.Open(modelsDir) - if err != nil { - t.Fatalf("failed to open models dir: %v", err) - } - rc := &ollama.Registry{ // This is a temporary measure to allow us to move forward, // surfacing any code contacting ollama.com we do not intended @@ -511,7 +505,7 @@ func TestRoutes(t *testing.T) { } s := &Server{} - router, err := s.GenerateRoutes(c, rc) + router, err := s.GenerateRoutes(rc) if err != nil { t.Fatalf("failed to generate routes: %v", err) }