mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-23 23:18:26 +00:00
Merge branch 'ollama:main' into main
This commit is contained in:
@@ -12,7 +12,7 @@ FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base
|
||||
RUN yum install -y yum-utils \
|
||||
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
||||
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
||||
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \
|
||||
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||
|
||||
@@ -86,10 +86,11 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
|
||||
FROM base AS build
|
||||
ARG GOVERSION=1.23.4
|
||||
RUN curl -fsSL https://golang.org/dl/go${GOVERSION}.linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY go.mod go.sum .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
|
||||
10
README.md
10
README.md
@@ -1,5 +1,5 @@
|
||||
<div align="center">
|
||||
<a href="https://ollama.com" />
|
||||
<a href="https://ollama.com">
|
||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
</a>
|
||||
</div>
|
||||
@@ -86,7 +86,7 @@ Here are some example models that can be downloaded:
|
||||
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
|
||||
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
|
||||
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
|
||||
| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
|
||||
| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` |
|
||||
| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` |
|
||||
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
|
||||
| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
|
||||
@@ -97,7 +97,7 @@ Here are some example models that can be downloaded:
|
||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Solar | 10.7B | 6.1GB | `ollama run solar` |
|
||||
| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` |
|
||||
|
||||
> [!NOTE]
|
||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
@@ -409,6 +409,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
|
||||
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
|
||||
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
||||
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -533,6 +535,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
||||
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
|
||||
### Extensions & Plugins
|
||||
|
||||
|
||||
17
cmd/cmd.go
17
cmd/cmd.go
@@ -34,7 +34,6 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/runner"
|
||||
@@ -256,6 +255,7 @@ func StopHandler(cmd *cobra.Command, args []string) error {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -338,10 +338,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(jessegross): We should either find another way to know if this is
|
||||
// a vision model or remove the logic. Also consider that other modalities will
|
||||
// need different behavior anyways.
|
||||
opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine()
|
||||
if len(info.ProjectorInfo) != 0 {
|
||||
opts.MultiModal = true
|
||||
}
|
||||
for k := range info.ModelInfo {
|
||||
if strings.Contains(k, ".vision.") {
|
||||
opts.MultiModal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
opts.ParentModel = info.Details.ParentModel
|
||||
|
||||
if interactive {
|
||||
@@ -1274,7 +1280,6 @@ func NewCLI() *cobra.Command {
|
||||
|
||||
runnerCmd := &cobra.Command{
|
||||
Use: "runner",
|
||||
Short: llama.PrintSystemInfo(),
|
||||
Hidden: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return runner.Execute(os.Args[1:])
|
||||
|
||||
@@ -118,6 +118,35 @@ To run tests, use `go test`:
|
||||
go test ./...
|
||||
```
|
||||
|
||||
> NOTE: In rare cirumstances, you may nedd to change a package using the new
|
||||
> "synctest" package in go1.24.
|
||||
>
|
||||
> If you do not have the "synctest" package enabled, you will not see build or
|
||||
> test failures resulting from your change(s), if any, locally, but CI will
|
||||
> break.
|
||||
>
|
||||
> If you see failures in CI, you can either keep pushing changes to see if the
|
||||
> CI build passes, or you can enable the "synctest" package locally to see the
|
||||
> failures before pushing.
|
||||
>
|
||||
> To enable the "synctest" package for testing, run the following command:
|
||||
>
|
||||
> ```shell
|
||||
> GOEXPERIMENT=synctest go test ./...
|
||||
> ```
|
||||
>
|
||||
> If you wish to enable synctest for all go commands, you can set the
|
||||
> `GOEXPERIMENT` environment variable in your shell profile or by using:
|
||||
>
|
||||
> ```shell
|
||||
> go env -w GOEXPERIMENT=synctest
|
||||
> ```
|
||||
>
|
||||
> Which will enable the "synctest" package for all go commands without needing
|
||||
> to set it for all shell sessions.
|
||||
>
|
||||
> The synctest package is not required for production builds.
|
||||
|
||||
## Library detection
|
||||
|
||||
Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable:
|
||||
|
||||
@@ -565,6 +565,43 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
return
|
||||
}
|
||||
|
||||
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||
switch llm.KV().Architecture() {
|
||||
case "mllama":
|
||||
for _, layer := range llm.Tensors().GroupLayers()["v"] {
|
||||
weights += layer.Size()
|
||||
}
|
||||
|
||||
kv := func(n string) uint64 {
|
||||
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
|
||||
return uint64(v)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
imageSize := kv("image_size")
|
||||
|
||||
maxNumTiles := kv("max_num_tiles")
|
||||
embeddingLength := kv("embedding_length")
|
||||
headCount := kv("attention.head_count")
|
||||
|
||||
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
|
||||
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||
numPatches++
|
||||
}
|
||||
|
||||
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
||||
|
||||
graphSize = 4 * (8 +
|
||||
imageSize*imageSize*kv("num_channels")*maxNumTiles +
|
||||
embeddingLength*numPatches*maxNumTiles +
|
||||
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
||||
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
||||
}
|
||||
return weights, graphSize
|
||||
}
|
||||
|
||||
// SupportsKVCacheType checks if the requested cache type is supported
|
||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
||||
|
||||
@@ -21,18 +21,6 @@ package llama
|
||||
|
||||
extern bool llamaProgressCallback(float progress, void *user_data);
|
||||
extern void llamaLog(int level, char* text, void* user_data);
|
||||
|
||||
typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
|
||||
COMPILER inline get_compiler() {
|
||||
#if defined(__clang__)
|
||||
return COMP_CLANG;
|
||||
#elif defined(__GNUC__)
|
||||
return COMP_GCC;
|
||||
#else
|
||||
return UNKNOWN_COMPILER;
|
||||
#endif
|
||||
}
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
@@ -72,19 +60,6 @@ func BackendInit() {
|
||||
C.llama_backend_init()
|
||||
}
|
||||
|
||||
func PrintSystemInfo() string {
|
||||
var compiler string
|
||||
switch C.get_compiler() {
|
||||
case C.COMP_UNKNOWN:
|
||||
compiler = "cgo(unknown_compiler)"
|
||||
case C.COMP_GCC:
|
||||
compiler = "cgo(gcc)"
|
||||
case C.COMP_CLANG:
|
||||
compiler = "cgo(clang)"
|
||||
}
|
||||
return C.GoString(C.llama_print_system_info()) + compiler
|
||||
}
|
||||
|
||||
func GetModelArch(modelPath string) (string, error) {
|
||||
mp := C.CString(modelPath)
|
||||
defer C.free(unsafe.Pointer(mp))
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Yang <mxyng@pm.me>
|
||||
Date: Tue, 11 Feb 2025 14:06:36 -0800
|
||||
Subject: [PATCH] try/catch backend load
|
||||
|
||||
---
|
||||
ggml/src/ggml-backend-reg.cpp | 45 ++++++++++++++++++-----------------
|
||||
1 file changed, 23 insertions(+), 22 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||
index 98d5e14d..1c19129a 100644
|
||||
--- a/ggml/src/ggml-backend-reg.cpp
|
||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||
@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
}
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
for (const auto & entry : dir_it) {
|
||||
- if (entry.is_regular_file()) {
|
||||
- std::wstring filename = entry.path().filename().wstring();
|
||||
- std::wstring ext = entry.path().extension().wstring();
|
||||
- if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
- if (!handle && !silent) {
|
||||
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
- }
|
||||
- if (handle) {
|
||||
+ try {
|
||||
+ if (entry.is_regular_file()) {
|
||||
+ std::wstring filename = entry.path().filename().wstring();
|
||||
+ std::wstring ext = entry.path().extension().wstring();
|
||||
+ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
+ dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
+ if (!handle) {
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
+ continue;
|
||||
+ }
|
||||
+
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
- if (score_fn) {
|
||||
- int s = score_fn();
|
||||
-#ifndef NDEBUG
|
||||
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
-#endif
|
||||
- if (s > best_score) {
|
||||
- best_score = s;
|
||||
- best_path = entry.path().wstring();
|
||||
- }
|
||||
- } else {
|
||||
- if (!silent) {
|
||||
- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
- }
|
||||
+ if (!score_fn) {
|
||||
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
+ continue;
|
||||
+ }
|
||||
+
|
||||
+ int s = score_fn();
|
||||
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
+ if (s > best_score) {
|
||||
+ best_score = s;
|
||||
+ best_path = entry.path().wstring();
|
||||
}
|
||||
}
|
||||
}
|
||||
+ } catch (const std::exception & e) {
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,11 +4,11 @@ Date: Sun, 16 Feb 2025 20:00:22 -0500
|
||||
Subject: [PATCH] use std::filesystem::path instead of wstring
|
||||
|
||||
---
|
||||
ggml/src/ggml-backend-reg.cpp | 144 ++++++++++++++--------------------
|
||||
1 file changed, 58 insertions(+), 86 deletions(-)
|
||||
ggml/src/ggml-backend-reg.cpp | 199 +++++++++++++++-------------------
|
||||
1 file changed, 88 insertions(+), 111 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||
index 1c19129a..c854e6bb 100644
|
||||
index 98d5e14d..799af5f3 100644
|
||||
--- a/ggml/src/ggml-backend-reg.cpp
|
||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||
@@ -66,26 +66,6 @@
|
||||
@@ -264,47 +264,55 @@ index 1c19129a..c854e6bb 100644
|
||||
for (const auto & search_path : search_paths) {
|
||||
if (!fs::exists(search_path)) {
|
||||
continue;
|
||||
@@ -514,31 +486,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
@@ -513,29 +485,26 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
for (const auto & entry : dir_it) {
|
||||
try {
|
||||
if (entry.is_regular_file()) {
|
||||
- std::wstring filename = entry.path().filename().wstring();
|
||||
- std::wstring ext = entry.path().extension().wstring();
|
||||
+ std::string filename = entry.path().filename().string();
|
||||
+ std::string ext = entry.path().extension().string();
|
||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
+ dl_handle_ptr handle { dl_load_library(entry.path()) };
|
||||
if (!handle) {
|
||||
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
if (!score_fn) {
|
||||
- GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
int s = score_fn();
|
||||
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
|
||||
if (s > best_score) {
|
||||
best_score = s;
|
||||
- best_path = entry.path().wstring();
|
||||
+ best_path = entry.path();
|
||||
}
|
||||
if (entry.is_regular_file()) {
|
||||
- std::wstring filename = entry.path().filename().wstring();
|
||||
- std::wstring ext = entry.path().extension().wstring();
|
||||
+ std::string filename = entry.path().filename().string();
|
||||
+ std::string ext = entry.path().extension().string();
|
||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
- if (!handle && !silent) {
|
||||
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
+ dl_handle_ptr handle { dl_load_library(entry.path()) };
|
||||
+ if (!handle) {
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
+ continue;
|
||||
}
|
||||
- if (handle) {
|
||||
- auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
- if (score_fn) {
|
||||
- int s = score_fn();
|
||||
-#ifndef NDEBUG
|
||||
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
-#endif
|
||||
- if (s > best_score) {
|
||||
- best_score = s;
|
||||
- best_path = entry.path().wstring();
|
||||
- }
|
||||
- } else {
|
||||
- if (!silent) {
|
||||
- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
- }
|
||||
- }
|
||||
+
|
||||
+ auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
+ if (!score_fn) {
|
||||
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
+ continue;
|
||||
+ }
|
||||
+
|
||||
+ int s = score_fn();
|
||||
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
|
||||
+ if (s > best_score) {
|
||||
+ best_score = s;
|
||||
+ best_path = entry.path();
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
- GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -546,7 +518,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
@@ -545,7 +514,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
if (best_score == 0) {
|
||||
// try to load the base backend
|
||||
for (const auto & search_path : search_paths) {
|
||||
@@ -313,3 +321,49 @@ index 1c19129a..c854e6bb 100644
|
||||
if (fs::exists(path)) {
|
||||
return get_reg().load_backend(path, silent);
|
||||
}
|
||||
@@ -560,6 +529,14 @@ void ggml_backend_load_all() {
|
||||
ggml_backend_load_all_from_path(nullptr);
|
||||
}
|
||||
|
||||
+static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) {
|
||||
+ try {
|
||||
+ ggml_backend_load_best(name, silent, user_search_path);
|
||||
+ } catch (const std::exception & e) {
|
||||
+ GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what());
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
#ifdef NDEBUG
|
||||
bool silent = true;
|
||||
@@ -567,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
bool silent = false;
|
||||
#endif
|
||||
|
||||
- ggml_backend_load_best("blas", silent, dir_path);
|
||||
- ggml_backend_load_best("cann", silent, dir_path);
|
||||
- ggml_backend_load_best("cuda", silent, dir_path);
|
||||
- ggml_backend_load_best("hip", silent, dir_path);
|
||||
- ggml_backend_load_best("kompute", silent, dir_path);
|
||||
- ggml_backend_load_best("metal", silent, dir_path);
|
||||
- ggml_backend_load_best("rpc", silent, dir_path);
|
||||
- ggml_backend_load_best("sycl", silent, dir_path);
|
||||
- ggml_backend_load_best("vulkan", silent, dir_path);
|
||||
- ggml_backend_load_best("opencl", silent, dir_path);
|
||||
- ggml_backend_load_best("musa", silent, dir_path);
|
||||
- ggml_backend_load_best("cpu", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("blas", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("cann", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("cuda", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("hip", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("kompute", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("metal", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("rpc", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("sycl", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("vulkan", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("opencl", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("musa", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("cpu", silent, dir_path);
|
||||
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
|
||||
const char * backend_path = std::getenv("GGML_BACKEND_PATH");
|
||||
if (backend_path) {
|
||||
@@ -115,6 +115,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
// multimodal models require at least 2048 context
|
||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||
}
|
||||
if projectorWeights == 0 && projectorGraph == 0 {
|
||||
projectorWeights, projectorGraph = f.VisionGraphSize()
|
||||
}
|
||||
|
||||
layers := f.Tensors().GroupLayers()
|
||||
// add one layer worth of memory as a buffer
|
||||
|
||||
278
llm/server.go
278
llm/server.go
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type LlamaServer interface {
|
||||
@@ -54,8 +55,15 @@ type llmServer struct {
|
||||
options api.Options
|
||||
numParallel int
|
||||
modelPath string
|
||||
modelLock sync.Mutex // Temporary until we switch fully to Go server
|
||||
model *llama.Model // If non-nil, the runner is a new Go server
|
||||
|
||||
// llamaModel is an instance of the cgo llama.cpp model definition
|
||||
// nil if this server is running the new engine
|
||||
llamaModel *llama.Model
|
||||
llamaModelLock sync.Mutex
|
||||
|
||||
// textProcessor handles text encoding/decoding for the model in the Ollama engine
|
||||
// nil if this server is running the llama.cpp based engine
|
||||
textProcessor model.TextProcessor
|
||||
|
||||
estimate MemoryEstimate
|
||||
totalLayers uint64
|
||||
@@ -89,7 +97,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||
|
||||
// NewLlamaServer will run a server for the given GPUs
|
||||
// The gpu list must be a single family.
|
||||
func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
systemInfo := discover.GetSystemInfo()
|
||||
systemTotalMemory := systemInfo.System.TotalMemory
|
||||
systemFreeMemory := systemInfo.System.FreeMemory
|
||||
@@ -130,7 +138,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
slog.Info("offload", "", estimate)
|
||||
|
||||
params := []string{
|
||||
"--model", model,
|
||||
"--model", modelPath,
|
||||
"--ctx-size", strconv.Itoa(opts.NumCtx),
|
||||
"--batch-size", strconv.Itoa(opts.NumBatch),
|
||||
}
|
||||
@@ -153,11 +161,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
}
|
||||
}
|
||||
|
||||
if len(projectors) > 0 {
|
||||
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
|
||||
params = append(params, "--mmproj", projectors[0])
|
||||
}
|
||||
|
||||
defaultThreads := systemInfo.GetOptimalThreadCount()
|
||||
if opts.NumThread > 0 {
|
||||
params = append(params, "--threads", strconv.Itoa(opts.NumThread))
|
||||
@@ -257,6 +260,34 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
}
|
||||
}
|
||||
slog.Debug("compatible gpu libraries", "compatible", compatible)
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
var llamaModel *llama.Model
|
||||
var textProcessor model.TextProcessor
|
||||
if envconfig.NewEngine() {
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
if err != nil {
|
||||
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
|
||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
||||
}
|
||||
}
|
||||
if textProcessor == nil {
|
||||
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(projectors) > 0 && llamaModel != nil {
|
||||
params = append(params, "--mmproj", projectors[0])
|
||||
}
|
||||
|
||||
// iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
|
||||
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running
|
||||
@@ -275,7 +306,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||
}
|
||||
finalParams := []string{"runner"}
|
||||
if envconfig.NewEngine() {
|
||||
if textProcessor != nil {
|
||||
// New engine
|
||||
// TODO - if we have failure to load scenarios, add logic to retry with the old runner
|
||||
finalParams = append(finalParams, "--ollama-engine")
|
||||
}
|
||||
finalParams = append(finalParams, params...)
|
||||
@@ -315,28 +348,20 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
// finally, add the root library path
|
||||
libraryPaths = append(libraryPaths, discover.LibOllamaPath)
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
// TODO - once fully switched to the Go runner, load the model here for tokenize/detokenize cgo access
|
||||
s := &llmServer{
|
||||
port: port,
|
||||
cmd: exec.Command(exe, finalParams...),
|
||||
status: NewStatusWriter(os.Stderr),
|
||||
options: opts,
|
||||
modelPath: model,
|
||||
estimate: estimate,
|
||||
numParallel: numParallel,
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
totalLayers: f.KV().BlockCount() + 1,
|
||||
gpus: gpus,
|
||||
done: make(chan error, 1),
|
||||
port: port,
|
||||
cmd: exec.Command(exe, finalParams...),
|
||||
status: NewStatusWriter(os.Stderr),
|
||||
options: opts,
|
||||
modelPath: modelPath,
|
||||
llamaModel: llamaModel,
|
||||
textProcessor: textProcessor,
|
||||
estimate: estimate,
|
||||
numParallel: numParallel,
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
totalLayers: f.KV().BlockCount() + 1,
|
||||
gpus: gpus,
|
||||
done: make(chan error, 1),
|
||||
}
|
||||
|
||||
s.cmd.Env = os.Environ()
|
||||
@@ -405,6 +430,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
}
|
||||
err := fmt.Errorf("error starting runner: %v %s", err, msg)
|
||||
if len(compatible) == 0 {
|
||||
if llamaModel != nil {
|
||||
llama.FreeModel(llamaModel)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -701,24 +729,29 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
|
||||
if len(req.Format) > 0 {
|
||||
switch string(req.Format) {
|
||||
case `null`, `""`:
|
||||
// Field was set, but "missing" a value. We accept
|
||||
// these as "not set".
|
||||
break
|
||||
case `"json"`:
|
||||
request["grammar"] = grammarJSON
|
||||
default:
|
||||
if req.Format[0] != '{' {
|
||||
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
|
||||
}
|
||||
format := string(req.Format)
|
||||
if format != `null` && format != `""` {
|
||||
if s.textProcessor != nil {
|
||||
// New engine handles this on the backend
|
||||
request["format"] = req.Format
|
||||
} else {
|
||||
// old engine
|
||||
switch format {
|
||||
case `"json"`:
|
||||
request["grammar"] = grammarJSON
|
||||
default:
|
||||
if req.Format[0] != '{' {
|
||||
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
|
||||
}
|
||||
|
||||
// User provided a JSON schema
|
||||
g := llama.SchemaToGrammar(req.Format)
|
||||
if g == nil {
|
||||
return fmt.Errorf("invalid JSON schema in format")
|
||||
// User provided a JSON schema
|
||||
g := llama.SchemaToGrammar(req.Format)
|
||||
if g == nil {
|
||||
return fmt.Errorf("invalid JSON schema in format")
|
||||
}
|
||||
request["grammar"] = string(g)
|
||||
}
|
||||
}
|
||||
request["grammar"] = string(g)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -933,64 +966,25 @@ type TokenizeResponse struct {
|
||||
}
|
||||
|
||||
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
s.modelLock.Lock()
|
||||
defer s.modelLock.Unlock()
|
||||
if s.model != nil {
|
||||
return s.model.Tokenize(content, false, true)
|
||||
}
|
||||
s.llamaModelLock.Lock()
|
||||
defer s.llamaModelLock.Unlock()
|
||||
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
if s.llamaModel != nil {
|
||||
return s.llamaModel.Tokenize(content, false, true)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(TokenizeRequest{Content: content})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling encode data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do encode request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
if s.model == nil {
|
||||
slog.Debug("new runner detected, loading model for cgo tokenization")
|
||||
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.model = m
|
||||
if s.textProcessor != nil {
|
||||
tokens, err := s.textProcessor.Encode(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.model.Tokenize(content, false, true)
|
||||
toks := make([]int, len(tokens))
|
||||
for i, t := range tokens {
|
||||
toks[i] = int(t)
|
||||
}
|
||||
return toks, nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read encode request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm encode error: %s", body)
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var encoded TokenizeResponse
|
||||
if err := json.Unmarshal(body, &encoded); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
return encoded.Tokens, nil
|
||||
// not reached
|
||||
return nil, fmt.Errorf("no tokenizer configured")
|
||||
}
|
||||
|
||||
type DetokenizeRequest struct {
|
||||
@@ -1002,80 +996,38 @@ type DetokenizeResponse struct {
|
||||
}
|
||||
|
||||
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
s.modelLock.Lock()
|
||||
defer s.modelLock.Unlock()
|
||||
if s.model != nil {
|
||||
s.llamaModelLock.Lock()
|
||||
defer s.llamaModelLock.Unlock()
|
||||
|
||||
if s.llamaModel != nil {
|
||||
var resp string
|
||||
for _, token := range tokens {
|
||||
resp += s.model.TokenToPiece(token)
|
||||
resp += s.llamaModel.TokenToPiece(token)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
||||
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling decode data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("do decode request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
if s.model == nil {
|
||||
slog.Debug("new runner detected, loading model for cgo tokenization")
|
||||
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
s.model = m
|
||||
if s.textProcessor != nil {
|
||||
toks := make([]int32, len(tokens))
|
||||
for i, t := range tokens {
|
||||
toks[i] = int32(t)
|
||||
}
|
||||
var resp string
|
||||
for _, token := range tokens {
|
||||
resp += s.model.TokenToPiece(token)
|
||||
content, err := s.textProcessor.Decode(toks)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp, nil
|
||||
return content, nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read decode request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm decode error: %s", body)
|
||||
return "", fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var decoded DetokenizeResponse
|
||||
if err := json.Unmarshal(body, &decoded); err != nil {
|
||||
return "", fmt.Errorf("unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
return decoded.Content, nil
|
||||
// not reached
|
||||
return "", fmt.Errorf("no tokenizer configured")
|
||||
}
|
||||
|
||||
func (s *llmServer) Close() error {
|
||||
s.modelLock.Lock()
|
||||
if s.model != nil {
|
||||
llama.FreeModel(s.model)
|
||||
s.model = nil
|
||||
s.llamaModelLock.Lock()
|
||||
if s.llamaModel != nil {
|
||||
llama.FreeModel(s.llamaModel)
|
||||
s.llamaModel = nil
|
||||
}
|
||||
s.modelLock.Unlock()
|
||||
s.llamaModelLock.Unlock()
|
||||
|
||||
if s.cmd != nil {
|
||||
slog.Debug("stopping llama server")
|
||||
|
||||
@@ -24,7 +24,6 @@ type Backend interface {
|
||||
Config() Config
|
||||
Get(name string) Tensor
|
||||
NewContext() Context
|
||||
SystemInfo() string
|
||||
}
|
||||
|
||||
// BackendCacheConfig should be implemented by backends that need special output
|
||||
|
||||
@@ -1,27 +1,11 @@
|
||||
package ggml
|
||||
|
||||
/*
|
||||
#cgo CPPFLAGS: -I${SRCDIR}/ggml/include
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-backend.h"
|
||||
static struct ggml_backend_feature * getBackendFeatures(void *fp, ggml_backend_reg_t reg) {return ((ggml_backend_get_features_t)(fp))(reg);}
|
||||
static struct ggml_backend_feature * getNextBackendFeatures(struct ggml_backend_feature * feature) { return &feature[1];}
|
||||
|
||||
typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
|
||||
COMPILER inline get_compiler() {
|
||||
#if defined(__clang__)
|
||||
return COMP_CLANG;
|
||||
#elif defined(__GNUC__)
|
||||
return COMP_GCC;
|
||||
#else
|
||||
return UNKNOWN_COMPILER;
|
||||
#endif
|
||||
}
|
||||
|
||||
*/
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
|
||||
// #include <stdlib.h>
|
||||
// #include <stdint.h>
|
||||
// #include "ggml.h"
|
||||
// #include "ggml-cpu.h"
|
||||
// #include "ggml-backend.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
@@ -729,34 +713,3 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Backend) SystemInfo() string {
|
||||
var compiler string
|
||||
switch C.get_compiler() {
|
||||
case C.COMP_UNKNOWN:
|
||||
compiler = "cgo(unknown_compiler)"
|
||||
case C.COMP_GCC:
|
||||
compiler = "cgo(gcc)"
|
||||
case C.COMP_CLANG:
|
||||
compiler = "cgo(clang)"
|
||||
}
|
||||
|
||||
var s string
|
||||
for i := range C.ggml_backend_reg_count() {
|
||||
reg := C.ggml_backend_reg_get(i)
|
||||
fName := C.CString("ggml_backend_get_features")
|
||||
defer C.free(unsafe.Pointer(fName))
|
||||
get_features_fn := C.ggml_backend_reg_get_proc_address(reg, fName)
|
||||
if get_features_fn != nil {
|
||||
s += C.GoString(C.ggml_backend_reg_name(reg))
|
||||
s += " : "
|
||||
for features := C.getBackendFeatures(get_features_fn, reg); features.name != nil; features = C.getNextBackendFeatures(features) {
|
||||
s += C.GoString(features.name)
|
||||
s += " = "
|
||||
s += C.GoString(features.value)
|
||||
s += " | "
|
||||
}
|
||||
}
|
||||
}
|
||||
return s + compiler
|
||||
}
|
||||
|
||||
74
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
74
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
@@ -484,33 +484,29 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
}
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
for (const auto & entry : dir_it) {
|
||||
try {
|
||||
if (entry.is_regular_file()) {
|
||||
std::string filename = entry.path().filename().string();
|
||||
std::string ext = entry.path().extension().string();
|
||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
dl_handle_ptr handle { dl_load_library(entry.path()) };
|
||||
if (!handle) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
continue;
|
||||
}
|
||||
if (entry.is_regular_file()) {
|
||||
std::string filename = entry.path().filename().string();
|
||||
std::string ext = entry.path().extension().string();
|
||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
dl_handle_ptr handle { dl_load_library(entry.path()) };
|
||||
if (!handle) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
if (!score_fn) {
|
||||
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
continue;
|
||||
}
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
if (!score_fn) {
|
||||
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
int s = score_fn();
|
||||
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
|
||||
if (s > best_score) {
|
||||
best_score = s;
|
||||
best_path = entry.path();
|
||||
}
|
||||
int s = score_fn();
|
||||
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
|
||||
if (s > best_score) {
|
||||
best_score = s;
|
||||
best_path = entry.path();
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -533,6 +529,14 @@ void ggml_backend_load_all() {
|
||||
ggml_backend_load_all_from_path(nullptr);
|
||||
}
|
||||
|
||||
static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) {
|
||||
try {
|
||||
ggml_backend_load_best(name, silent, user_search_path);
|
||||
} catch (const std::exception & e) {
|
||||
GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
#ifdef NDEBUG
|
||||
bool silent = true;
|
||||
@@ -540,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
bool silent = false;
|
||||
#endif
|
||||
|
||||
ggml_backend_load_best("blas", silent, dir_path);
|
||||
ggml_backend_load_best("cann", silent, dir_path);
|
||||
ggml_backend_load_best("cuda", silent, dir_path);
|
||||
ggml_backend_load_best("hip", silent, dir_path);
|
||||
ggml_backend_load_best("kompute", silent, dir_path);
|
||||
ggml_backend_load_best("metal", silent, dir_path);
|
||||
ggml_backend_load_best("rpc", silent, dir_path);
|
||||
ggml_backend_load_best("sycl", silent, dir_path);
|
||||
ggml_backend_load_best("vulkan", silent, dir_path);
|
||||
ggml_backend_load_best("opencl", silent, dir_path);
|
||||
ggml_backend_load_best("musa", silent, dir_path);
|
||||
ggml_backend_load_best("cpu", silent, dir_path);
|
||||
ggml_backend_try_load_best("blas", silent, dir_path);
|
||||
ggml_backend_try_load_best("cann", silent, dir_path);
|
||||
ggml_backend_try_load_best("cuda", silent, dir_path);
|
||||
ggml_backend_try_load_best("hip", silent, dir_path);
|
||||
ggml_backend_try_load_best("kompute", silent, dir_path);
|
||||
ggml_backend_try_load_best("metal", silent, dir_path);
|
||||
ggml_backend_try_load_best("rpc", silent, dir_path);
|
||||
ggml_backend_try_load_best("sycl", silent, dir_path);
|
||||
ggml_backend_try_load_best("vulkan", silent, dir_path);
|
||||
ggml_backend_try_load_best("opencl", silent, dir_path);
|
||||
ggml_backend_try_load_best("musa", silent, dir_path);
|
||||
ggml_backend_try_load_best("cpu", silent, dir_path);
|
||||
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
|
||||
const char * backend_path = std::getenv("GGML_BACKEND_PATH");
|
||||
if (backend_path) {
|
||||
|
||||
@@ -7,6 +7,20 @@ package ggml
|
||||
// #include <stdlib.h>
|
||||
// #include "ggml-backend.h"
|
||||
// extern void sink(int level, char *text, void *user_data);
|
||||
// static struct ggml_backend_feature * first_feature(ggml_backend_get_features_t fp, ggml_backend_reg_t reg) { return fp(reg); }
|
||||
// static struct ggml_backend_feature * next_feature(struct ggml_backend_feature * feature) { return &feature[1]; }
|
||||
/*
|
||||
typedef enum { COMPILER_CLANG, COMPILER_GNUC, COMPILER_UNKNOWN } COMPILER;
|
||||
static COMPILER compiler_name(void) {
|
||||
#if defined(__clang__)
|
||||
return COMPILER_CLANG;
|
||||
#elif defined(__GNUC__)
|
||||
return COMPILER_GNUC;
|
||||
#else
|
||||
return COMPILER_UNKNOWN;
|
||||
#endif
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
@@ -16,6 +30,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
@@ -90,4 +105,43 @@ var OnceLoad = sync.OnceFunc(func() {
|
||||
visited[abspath] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("system", "", system{})
|
||||
})
|
||||
|
||||
type system struct{}
|
||||
|
||||
func (system) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
names := make(map[string]int)
|
||||
for i := range C.ggml_backend_dev_count() {
|
||||
r := C.ggml_backend_dev_backend_reg(C.ggml_backend_dev_get(i))
|
||||
|
||||
func() {
|
||||
fName := C.CString("ggml_backend_get_features")
|
||||
defer C.free(unsafe.Pointer(fName))
|
||||
|
||||
if fn := C.ggml_backend_reg_get_proc_address(r, fName); fn != nil {
|
||||
var features []any
|
||||
for f := C.first_feature(C.ggml_backend_get_features_t(fn), r); f.name != nil; f = C.next_feature(f) {
|
||||
features = append(features, C.GoString(f.name), C.GoString(f.value))
|
||||
}
|
||||
|
||||
name := C.GoString(C.ggml_backend_reg_name(r))
|
||||
attrs = append(attrs, slog.Group(name+"."+strconv.Itoa(names[name]), features...))
|
||||
names[name] += 1
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
switch C.compiler_name() {
|
||||
case C.COMPILER_CLANG:
|
||||
attrs = append(attrs, slog.String("compiler", "cgo(clang)"))
|
||||
case C.COMPILER_GNUC:
|
||||
attrs = append(attrs, slog.String("compiler", "cgo(gcc)"))
|
||||
default:
|
||||
attrs = append(attrs, slog.String("compiler", "cgo(unknown)"))
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
_ "golang.org/x/image/tiff"
|
||||
_ "golang.org/x/image/webp"
|
||||
|
||||
fs "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
@@ -100,6 +101,36 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
r, err := os.Open(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
meta, _, err := fs.Decode(r, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return getTextProcessor(meta.KV())
|
||||
}
|
||||
|
||||
func getTextProcessor(kv fs.KV) (TextProcessor, error) {
|
||||
arch := kv.Architecture()
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
||||
}
|
||||
m, err := f(kv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tp, ok := m.(TextProcessor)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%v is not a TextProcessor", m)
|
||||
}
|
||||
return tp, nil
|
||||
}
|
||||
|
||||
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
t := v.Type()
|
||||
|
||||
|
||||
@@ -3,9 +3,11 @@ package model
|
||||
import (
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
fs "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@@ -134,3 +136,40 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
||||
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTextProcessor(t *testing.T) {
|
||||
tp, err := getTextProcessor(fs.KV{})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
} else if tp != nil {
|
||||
t.Error("expected nil tp")
|
||||
}
|
||||
|
||||
models["dummy"] = func(ml.Config) (Model, error) {
|
||||
return notTextProcessorModel{}, nil
|
||||
}
|
||||
tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
} else if tp != nil {
|
||||
t.Error("expected nil tp")
|
||||
}
|
||||
}
|
||||
|
||||
type notTextProcessorModel struct{}
|
||||
|
||||
func (notTextProcessorModel) Forward(ml.Context, Options) (ml.Tensor, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (notTextProcessorModel) Backend() ml.Backend {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (notTextProcessorModel) Config() config {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@@ -29,6 +31,10 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
|
||||
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package mllama
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@@ -25,6 +27,10 @@ const (
|
||||
)
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
// Verify unified config
|
||||
if c.Uint("vision.block_count") == 0 {
|
||||
return nil, fmt.Errorf("non-unified vision model not supported")
|
||||
}
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
|
||||
@@ -931,7 +931,6 @@ func Execute(args []string) error {
|
||||
slog.Info("starting go runner")
|
||||
|
||||
llama.BackendInit()
|
||||
slog.Info("system", "info", llama.PrintSystemInfo(), "threads", *threads)
|
||||
|
||||
server := &Server{
|
||||
batchSize: *batchSize,
|
||||
|
||||
@@ -786,8 +786,6 @@ func (s *Server) loadModel(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
slog.Info("system", "info", s.model.Backend().SystemInfo(), "threads", params.NumThreads)
|
||||
|
||||
// TODO(jessegross): LoRA loading
|
||||
if lpath.String() != "" {
|
||||
panic("loras are not yet implemented")
|
||||
|
||||
@@ -77,11 +77,12 @@ if [ -d "$OLLAMA_INSTALL_DIR/lib/ollama" ] ; then
|
||||
fi
|
||||
status "Installing ollama to $OLLAMA_INSTALL_DIR"
|
||||
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
||||
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR"
|
||||
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
|
||||
status "Downloading Linux ${ARCH} bundle"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||
|
||||
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
|
||||
status "Making ollama accessible in the PATH in $BINDIR"
|
||||
$SUDO ln -sf "$OLLAMA_INSTALL_DIR/ollama" "$BINDIR/ollama"
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -73,19 +74,22 @@ const (
|
||||
DefaultMaxChunkSize = 8 << 20
|
||||
)
|
||||
|
||||
// DefaultCache returns a new disk cache for storing models. If the
|
||||
// OLLAMA_MODELS environment variable is set, it uses that directory;
|
||||
// otherwise, it uses $HOME/.ollama/models.
|
||||
func DefaultCache() (*blob.DiskCache, error) {
|
||||
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
||||
dir := os.Getenv("OLLAMA_MODELS")
|
||||
if dir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
home = cmp.Or(home, ".")
|
||||
dir = filepath.Join(home, ".ollama", "models")
|
||||
}
|
||||
return blob.Open(dir)
|
||||
})
|
||||
|
||||
// DefaultCache returns the default cache used by the registry. It is
|
||||
// configured from the OLLAMA_MODELS environment variable, or defaults to
|
||||
// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
|
||||
// it uses the current working directory.
|
||||
func DefaultCache() (*blob.DiskCache, error) {
|
||||
return defaultCache()
|
||||
}
|
||||
|
||||
// Error is the standard error returned by Ollama APIs. It can represent a
|
||||
@@ -168,6 +172,10 @@ func CompleteName(name string) string {
|
||||
// Registry is a client for performing push and pull operations against an
|
||||
// Ollama registry.
|
||||
type Registry struct {
|
||||
// Cache is the cache used to store models. If nil, [DefaultCache] is
|
||||
// used.
|
||||
Cache *blob.DiskCache
|
||||
|
||||
// UserAgent is the User-Agent header to send with requests to the
|
||||
// registry. If empty, the User-Agent is determined by HTTPClient.
|
||||
UserAgent string
|
||||
@@ -206,18 +214,28 @@ type Registry struct {
|
||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
||||
MaxChunkSize int64
|
||||
|
||||
// Mask, if set, is the name used to convert non-fully qualified
|
||||
// names to fully qualified names. If empty, the default mask
|
||||
// ("registry.ollama.ai/library/_:latest") is used.
|
||||
// Mask, if set, is the name used to convert non-fully qualified names
|
||||
// to fully qualified names. If empty, [DefaultMask] is used.
|
||||
Mask string
|
||||
}
|
||||
|
||||
func (r *Registry) completeName(name string) names.Name {
|
||||
func (r *Registry) cache() (*blob.DiskCache, error) {
|
||||
if r.Cache != nil {
|
||||
return r.Cache, nil
|
||||
}
|
||||
return defaultCache()
|
||||
}
|
||||
|
||||
func (r *Registry) parseName(name string) (names.Name, error) {
|
||||
mask := defaultMask
|
||||
if r.Mask != "" {
|
||||
mask = names.Parse(r.Mask)
|
||||
}
|
||||
return names.Merge(names.Parse(name), mask)
|
||||
n := names.Merge(names.Parse(name), mask)
|
||||
if !n.IsFullyQualified() {
|
||||
return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// DefaultRegistry returns a new Registry configured from the environment. The
|
||||
@@ -278,12 +296,17 @@ type PushParams struct {
|
||||
}
|
||||
|
||||
// Push pushes the model with the name in the cache to the remote registry.
|
||||
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
|
||||
func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
||||
if p == nil {
|
||||
p = &PushParams{}
|
||||
}
|
||||
|
||||
m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m, err := r.ResolveLocal(cmp.Or(p.From, name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -306,7 +329,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
|
||||
|
||||
t := traceFromContext(ctx)
|
||||
|
||||
scheme, n, _, err := parseName(name, r.Mask)
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
// This should never happen since ResolveLocal should have
|
||||
// already validated the name.
|
||||
@@ -399,8 +422,8 @@ func canRetry(err error) bool {
|
||||
// chunks of the specified size, and then reassembled and verified. This is
|
||||
// typically slower than splitting the model up across layers, and is mostly
|
||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
|
||||
scheme, n, _, err := parseName(name, r.Mask)
|
||||
func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -413,6 +436,11 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
||||
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
||||
}
|
||||
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
exists := func(l *Layer) bool {
|
||||
info, err := c.Get(l.Digest)
|
||||
return err == nil && info.Size == l.Size
|
||||
@@ -550,10 +578,14 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
||||
|
||||
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
||||
// before attempting to unlink the model.
|
||||
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
|
||||
n := r.completeName(name)
|
||||
if !n.IsFullyQualified() {
|
||||
return false, fmt.Errorf("%w: %q", ErrNameInvalid, name)
|
||||
func (r *Registry) Unlink(name string) (ok bool, _ error) {
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return c.Unlink(n.String())
|
||||
}
|
||||
@@ -626,14 +658,18 @@ type Layer struct {
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
|
||||
// parsed using [names.Split] but the scheme is ignored.
|
||||
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
||||
_, n, d, err := parseName(name, r.Mask)
|
||||
// ResolveLocal resolves a name to a Manifest in the local cache.
|
||||
func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
|
||||
_, n, d, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !d.IsValid() {
|
||||
// No digest, so resolve the manifest by name.
|
||||
d, err = c.Resolve(n.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -655,7 +691,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro
|
||||
|
||||
// Resolve resolves a name to a Manifest in the remote registry.
|
||||
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
|
||||
scheme, n, d, err := parseName(name, r.Mask)
|
||||
scheme, n, d, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -859,7 +895,7 @@ var supportedSchemes = []string{
|
||||
|
||||
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
|
||||
|
||||
// parseName parses and validates an extended name, returning the scheme, name,
|
||||
// parseNameExtended parses and validates an extended name, returning the scheme, name,
|
||||
// and digest.
|
||||
//
|
||||
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
|
||||
@@ -870,8 +906,8 @@ var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Jo
|
||||
//
|
||||
// If the name is not, once merged with the mask, fully qualified,
|
||||
// [ErrNameInvalid] wrapped with a display friendly message is returned.
|
||||
func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
|
||||
scheme, name, digest := names.Split(s)
|
||||
func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
|
||||
scheme, name, digest := splitExtended(s)
|
||||
scheme = cmp.Or(scheme, "https")
|
||||
if !slices.Contains(supportedSchemes, scheme) {
|
||||
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
|
||||
@@ -894,13 +930,33 @@ func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Diges
|
||||
}
|
||||
}
|
||||
|
||||
maskName := defaultMask
|
||||
if mask != "" {
|
||||
maskName = names.Parse(mask)
|
||||
}
|
||||
n := names.Merge(names.Parse(name), maskName)
|
||||
if !n.IsFullyQualified() {
|
||||
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
return "", names.Name{}, blob.Digest{}, err
|
||||
}
|
||||
return scheme, n, d, nil
|
||||
}
|
||||
|
||||
// splitExtended splits an extended name string into its scheme, name, and digest
|
||||
// parts.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// http://ollama.com/bmizerany/smol:latest@digest
|
||||
// https://ollama.com/bmizerany/smol:latest
|
||||
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
|
||||
// model@digest
|
||||
// @digest
|
||||
func splitExtended(s string) (scheme, name, digest string) {
|
||||
i := strings.Index(s, "://")
|
||||
if i >= 0 {
|
||||
scheme = s[:i]
|
||||
s = s[i+3:]
|
||||
}
|
||||
i = strings.LastIndex(s, "@")
|
||||
if i >= 0 {
|
||||
digest = s[i+1:]
|
||||
s = s[:i]
|
||||
}
|
||||
return scheme, s, digest
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -72,6 +73,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
t.Helper()
|
||||
|
||||
c, err := blob.Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -85,13 +87,14 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
}
|
||||
|
||||
r := &Registry{
|
||||
Cache: c,
|
||||
HTTPClient: &http.Client{
|
||||
Transport: recordRoundTripper(h),
|
||||
},
|
||||
}
|
||||
|
||||
link := func(name string, manifest string) {
|
||||
_, n, _, err := parseName(name, r.Mask)
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -151,55 +154,55 @@ func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
||||
}
|
||||
|
||||
func TestPushZero(t *testing.T) {
|
||||
rc, c := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), c, "empty", nil)
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "empty", nil)
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushSingle(t *testing.T) {
|
||||
rc, c := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), c, "single", nil)
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "single", nil)
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestPushMultiple(t *testing.T) {
|
||||
rc, c := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), c, "multiple", nil)
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "multiple", nil)
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestPushNotFound(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Errorf("unexpected request: %v", r)
|
||||
})
|
||||
err := rc.Push(t.Context(), c, "notfound", nil)
|
||||
err := rc.Push(t.Context(), "notfound", nil)
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushNullLayer(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), c, "null", nil)
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), "null", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushSizeMismatch(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
rc, _ := newClient(t, nil)
|
||||
ctx, _ := withTraceUnexpected(t.Context())
|
||||
got := rc.Push(ctx, c, "sizemismatch", nil)
|
||||
got := rc.Push(ctx, "sizemismatch", nil)
|
||||
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
|
||||
t.Errorf("err = %v; want size mismatch", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushInvalid(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), c, "invalid", nil)
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), "invalid", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
@@ -207,7 +210,7 @@ func TestPushInvalid(t *testing.T) {
|
||||
|
||||
func TestPushExistsAtRemote(t *testing.T) {
|
||||
var pushed bool
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/uploads/") {
|
||||
if !pushed {
|
||||
// First push. Return an uploadURL.
|
||||
@@ -235,35 +238,35 @@ func TestPushExistsAtRemote(t *testing.T) {
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
err := rc.Push(ctx, c, "single", nil)
|
||||
err := rc.Push(ctx, "single", nil)
|
||||
check(err)
|
||||
|
||||
if !errors.Is(errors.Join(errs...), nil) {
|
||||
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
|
||||
}
|
||||
|
||||
err = rc.Push(ctx, c, "single", nil)
|
||||
err = rc.Push(ctx, "single", nil)
|
||||
check(err)
|
||||
}
|
||||
|
||||
func TestPushRemoteError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
w.WriteHeader(500)
|
||||
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
|
||||
return
|
||||
}
|
||||
})
|
||||
got := rc.Push(t.Context(), c, "single", nil)
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
checkErrCode(t, got, 500, "blob_error")
|
||||
}
|
||||
|
||||
func TestPushLocationError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Location", ":///x")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
got := rc.Push(t.Context(), c, "single", nil)
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
wantContains := "invalid upload URL"
|
||||
if got == nil || !strings.Contains(got.Error(), wantContains) {
|
||||
t.Errorf("err = %v; want to contain %v", got, wantContains)
|
||||
@@ -271,14 +274,14 @@ func TestPushLocationError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPushUploadRoundtripError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Host == "blob.store" {
|
||||
w.WriteHeader(499) // force RoundTrip error on upload
|
||||
return
|
||||
}
|
||||
w.Header().Set("Location", "http://blob.store/blobs/123")
|
||||
})
|
||||
got := rc.Push(t.Context(), c, "single", nil)
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
if !errors.Is(got, errRoundTrip) {
|
||||
t.Errorf("got = %v; want %v", got, errRoundTrip)
|
||||
}
|
||||
@@ -294,20 +297,20 @@ func TestPushUploadFileOpenError(t *testing.T) {
|
||||
os.Remove(c.GetFile(l.Digest))
|
||||
},
|
||||
})
|
||||
got := rc.Push(ctx, c, "single", nil)
|
||||
got := rc.Push(ctx, "single", nil)
|
||||
if !errors.Is(got, fs.ErrNotExist) {
|
||||
t.Errorf("got = %v; want fs.ErrNotExist", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushCommitRoundtripError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
panic("unexpected")
|
||||
}
|
||||
w.WriteHeader(499) // force RoundTrip error
|
||||
})
|
||||
err := rc.Push(t.Context(), c, "zero", nil)
|
||||
err := rc.Push(t.Context(), "zero", nil)
|
||||
if !errors.Is(err, errRoundTrip) {
|
||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||
}
|
||||
@@ -321,8 +324,8 @@ func checkNotExist(t *testing.T, err error) {
|
||||
}
|
||||
|
||||
func TestRegistryPullInvalidName(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
err := rc.Pull(t.Context(), c, "://")
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Pull(t.Context(), "://")
|
||||
if !errors.Is(err, ErrNameInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
||||
}
|
||||
@@ -337,10 +340,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, resp := range cases {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, resp)
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "x")
|
||||
err := rc.Pull(t.Context(), "x")
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
@@ -363,18 +366,18 @@ func TestRegistryPullNotCached(t *testing.T) {
|
||||
})
|
||||
|
||||
// Confirm that the layer does not exist locally
|
||||
_, err := rc.ResolveLocal(c, "model")
|
||||
_, err := rc.ResolveLocal("model")
|
||||
checkNotExist(t, err)
|
||||
|
||||
_, err = c.Get(d)
|
||||
checkNotExist(t, err)
|
||||
|
||||
err = rc.Pull(t.Context(), c, "model")
|
||||
err = rc.Pull(t.Context(), "model")
|
||||
check(err)
|
||||
|
||||
mw, err := rc.Resolve(t.Context(), "model")
|
||||
check(err)
|
||||
mg, err := rc.ResolveLocal(c, "model")
|
||||
mg, err := rc.ResolveLocal("model")
|
||||
check(err)
|
||||
if !reflect.DeepEqual(mw, mg) {
|
||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||
@@ -399,7 +402,7 @@ func TestRegistryPullNotCached(t *testing.T) {
|
||||
|
||||
func TestRegistryPullCached(t *testing.T) {
|
||||
cached := blob.DigestFromBytes("exists")
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
w.WriteHeader(499) // should not be called
|
||||
return
|
||||
@@ -422,7 +425,7 @@ func TestRegistryPullCached(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := rc.Pull(ctx, c, "single")
|
||||
err := rc.Pull(ctx, "single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{6}
|
||||
@@ -435,30 +438,30 @@ func TestRegistryPullCached(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegistryPullManifestNotFound(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "notfound")
|
||||
err := rc.Pull(t.Context(), "notfound")
|
||||
checkErrCode(t, err, 404, "")
|
||||
}
|
||||
|
||||
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "single")
|
||||
err := rc.Pull(t.Context(), "single")
|
||||
checkErrCode(t, err, 500, "an_error")
|
||||
}
|
||||
|
||||
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
||||
w.WriteHeader(499) // force RoundTrip error
|
||||
return
|
||||
}
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "single")
|
||||
err := rc.Pull(t.Context(), "single")
|
||||
if !errors.Is(err, errRoundTrip) {
|
||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||
}
|
||||
@@ -511,7 +514,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
||||
|
||||
// Check that we pull all layers that we can.
|
||||
|
||||
err := rc.Pull(ctx, c, "mixed")
|
||||
err := rc.Pull(ctx, "mixed")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -529,7 +532,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegistryPullChunking(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
||||
if r.URL.Host != "blob.store" {
|
||||
// The production registry redirects to the blob store.
|
||||
@@ -567,7 +570,7 @@ func TestRegistryPullChunking(t *testing.T) {
|
||||
},
|
||||
})
|
||||
|
||||
err := rc.Pull(ctx, c, "remote")
|
||||
err := rc.Pull(ctx, "remote")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{0, 3, 6}
|
||||
@@ -709,25 +712,16 @@ func TestErrorUnmarshal(t *testing.T) {
|
||||
//
|
||||
// It is only for testing error messages, not that all invalids and valids are
|
||||
// covered. Those are in other tests for names.Name and blob.Digest.
|
||||
func TestParseNameErrors(t *testing.T) {
|
||||
func TestParseNameExtendedErrors(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want string
|
||||
}{
|
||||
{"x", nil, ""},
|
||||
{"x@", nil, ""},
|
||||
|
||||
{"", ErrNameInvalid, `invalid or missing name: ""`},
|
||||
{"://", ErrNameInvalid, `invalid or missing name: "://"`},
|
||||
{"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`},
|
||||
|
||||
{"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
|
||||
{"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
|
||||
}
|
||||
}{}
|
||||
|
||||
var r Registry
|
||||
for _, tt := range cases {
|
||||
_, _, _, err := parseName(tt.name, DefaultMask)
|
||||
_, _, _, err := r.parseNameExtended(tt.name)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
|
||||
}
|
||||
@@ -736,3 +730,89 @@ func TestParseNameErrors(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameExtended(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
scheme string
|
||||
name string
|
||||
digest string
|
||||
err string
|
||||
}{
|
||||
{in: "http://m", scheme: "http", name: "m"},
|
||||
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
|
||||
{in: "http+insecure://m", err: "unsupported scheme"},
|
||||
|
||||
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
|
||||
|
||||
{in: "", err: "invalid or missing name"},
|
||||
{in: "m", scheme: "https", name: "m"},
|
||||
{in: "://", err: "invalid or missing name"},
|
||||
{in: "@sha256:deadbeef", err: "invalid digest"},
|
||||
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
var r Registry
|
||||
scheme, n, digest, err := r.parseNameExtended(tt.in)
|
||||
if err != nil {
|
||||
if tt.err == "" {
|
||||
t.Errorf("err = %v; want nil", err)
|
||||
} else if !strings.Contains(err.Error(), tt.err) {
|
||||
t.Errorf("err = %v; want %q", err, tt.err)
|
||||
}
|
||||
} else if tt.err != "" {
|
||||
t.Errorf("err = nil; want %q", tt.err)
|
||||
}
|
||||
if err == nil && !n.IsFullyQualified() {
|
||||
t.Errorf("name = %q; want fully qualified", n)
|
||||
}
|
||||
|
||||
if scheme != tt.scheme {
|
||||
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
|
||||
}
|
||||
|
||||
// smoke-test name is superset of tt.name
|
||||
if !strings.Contains(n.String(), tt.name) {
|
||||
t.Errorf("name = %q; want %q", n, tt.name)
|
||||
}
|
||||
|
||||
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
|
||||
if digest.String() != tt.digest {
|
||||
t.Errorf("digest = %q; want %q", digest, tt.digest)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnlink(t *testing.T) {
|
||||
t.Run("found by name", func(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
|
||||
// confirm linked
|
||||
_, err := rc.ResolveLocal("single")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// unlink
|
||||
_, err = rc.Unlink("single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
// confirm unlinked
|
||||
_, err = rc.ResolveLocal("single")
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
||||
}
|
||||
})
|
||||
t.Run("not found by name", func(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
ok, err := rc.Unlink("manifestNotFound")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ok {
|
||||
t.Error("expected not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,9 @@ import (
|
||||
|
||||
// Trace is a set of functions that are called to report progress during blob
|
||||
// downloads and uploads.
|
||||
//
|
||||
// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
|
||||
// and [Registry.Pull].
|
||||
type Trace struct {
|
||||
// Update is called during [Registry.Push] and [Registry.Pull] to
|
||||
// report the progress of blob uploads and downloads.
|
||||
|
||||
@@ -63,25 +63,28 @@ func main() {
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
c, err := ollama.DefaultCache()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = func() error {
|
||||
err := func() error {
|
||||
switch cmd := flag.Arg(0); cmd {
|
||||
case "pull":
|
||||
return cmdPull(ctx, rc, c)
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return cmdPull(ctx, rc)
|
||||
case "push":
|
||||
return cmdPush(ctx, rc, c)
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return cmdPush(ctx, rc)
|
||||
case "import":
|
||||
c, err := ollama.DefaultCache()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return cmdImport(ctx, c)
|
||||
default:
|
||||
if cmd == "" {
|
||||
@@ -99,7 +102,7 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
|
||||
func cmdPull(ctx context.Context, rc *ollama.Registry) error {
|
||||
model := flag.Arg(1)
|
||||
if model == "" {
|
||||
flag.Usage()
|
||||
@@ -145,7 +148,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
|
||||
errc := make(chan error)
|
||||
go func() {
|
||||
errc <- rc.Pull(ctx, c, model)
|
||||
errc <- rc.Pull(ctx, model)
|
||||
}()
|
||||
|
||||
t := time.NewTicker(time.Second)
|
||||
@@ -161,7 +164,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
}
|
||||
}
|
||||
|
||||
func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
|
||||
func cmdPush(ctx context.Context, rc *ollama.Registry) error {
|
||||
args := flag.Args()[1:]
|
||||
flag := flag.NewFlagSet("push", flag.ExitOnError)
|
||||
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
|
||||
@@ -177,7 +180,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
}
|
||||
|
||||
from := cmp.Or(*flagFrom, model)
|
||||
m, err := rc.ResolveLocal(c, from)
|
||||
m, err := rc.ResolveLocal(from)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -203,7 +206,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
},
|
||||
})
|
||||
|
||||
return rc.Push(ctx, c, model, &ollama.PushParams{
|
||||
return rc.Push(ctx, model, &ollama.PushParams{
|
||||
From: from,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build goexperiment.synctest
|
||||
|
||||
package backoff
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build goexperiment.synctest
|
||||
|
||||
package syncs
|
||||
|
||||
import (
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
)
|
||||
|
||||
@@ -27,12 +26,15 @@ import (
|
||||
// directly to the blob disk cache.
|
||||
type Local struct {
|
||||
Client *ollama.Registry // required
|
||||
Cache *blob.DiskCache // required
|
||||
Logger *slog.Logger // required
|
||||
|
||||
// Fallback, if set, is used to handle requests that are not handled by
|
||||
// this handler.
|
||||
Fallback http.Handler
|
||||
|
||||
// Prune, if set, is called to prune the local disk cache after a model
|
||||
// is deleted.
|
||||
Prune func() error // optional
|
||||
}
|
||||
|
||||
// serverError is like ollama.Error, but with a Status field for the HTTP
|
||||
@@ -199,14 +201,17 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ok, err := s.Client.Unlink(s.Cache, p.model())
|
||||
ok, err := s.Client.Unlink(p.model())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return &serverError{404, "not_found", "model not found"}
|
||||
}
|
||||
return nil
|
||||
if s.Prune == nil {
|
||||
return nil
|
||||
}
|
||||
return s.Prune()
|
||||
}
|
||||
|
||||
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||
|
||||
@@ -42,10 +42,10 @@ func newTestServer(t *testing.T) *Local {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rc := &ollama.Registry{
|
||||
Cache: c,
|
||||
HTTPClient: panicOnRoundTrip,
|
||||
}
|
||||
l := &Local{
|
||||
Cache: c,
|
||||
Client: rc,
|
||||
Logger: testutil.Slogger(t),
|
||||
}
|
||||
@@ -87,7 +87,7 @@ func TestServerDelete(t *testing.T) {
|
||||
|
||||
s := newTestServer(t)
|
||||
|
||||
_, err := s.Client.ResolveLocal(s.Cache, "smol")
|
||||
_, err := s.Client.ResolveLocal("smol")
|
||||
check(err)
|
||||
|
||||
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
|
||||
@@ -95,7 +95,7 @@ func TestServerDelete(t *testing.T) {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
|
||||
_, err = s.Client.ResolveLocal(s.Cache, "smol")
|
||||
_, err = s.Client.ResolveLocal("smol")
|
||||
if err == nil {
|
||||
t.Fatal("expected smol to have been deleted")
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/model/models/mllama"
|
||||
"github.com/ollama/ollama/template"
|
||||
@@ -93,7 +92,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
var imgData llm.ImageData
|
||||
|
||||
if isMllama {
|
||||
if envconfig.NewEngine() {
|
||||
if len(m.ProjectorPaths) == 0 {
|
||||
imgData = llm.ImageData{
|
||||
ID: len(images),
|
||||
Data: i,
|
||||
|
||||
@@ -34,7 +34,6 @@ import (
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/model/models/mllama"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/template"
|
||||
@@ -206,7 +205,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
images := make([]llm.ImageData, len(req.Images))
|
||||
for i := range req.Images {
|
||||
if isMllama && !envconfig.NewEngine() {
|
||||
if isMllama && len(model.ProjectorPaths) > 0 {
|
||||
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
|
||||
@@ -1129,7 +1128,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) {
|
||||
func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
corsConfig := cors.DefaultConfig()
|
||||
corsConfig.AllowWildcard = true
|
||||
corsConfig.AllowBrowserExtensions = true
|
||||
@@ -1197,10 +1196,11 @@ func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Ha
|
||||
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
Cache: c,
|
||||
Client: rc,
|
||||
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
||||
Fallback: r,
|
||||
|
||||
Prune: PruneLayers,
|
||||
}
|
||||
|
||||
return rs, nil
|
||||
@@ -1258,16 +1258,12 @@ func Serve(ln net.Listener) error {
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
|
||||
c, err := ollama.DefaultCache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h, err := s.GenerateRoutes(c, rc)
|
||||
h, err := s.GenerateRoutes(rc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
@@ -490,11 +489,6 @@ func TestRoutes(t *testing.T) {
|
||||
modelsDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||
|
||||
c, err := blob.Open(modelsDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open models dir: %v", err)
|
||||
}
|
||||
|
||||
rc := &ollama.Registry{
|
||||
// This is a temporary measure to allow us to move forward,
|
||||
// surfacing any code contacting ollama.com we do not intended
|
||||
@@ -511,7 +505,7 @@ func TestRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(c, rc)
|
||||
router, err := s.GenerateRoutes(rc)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate routes: %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user