Compare commits

...

128 Commits

Author SHA1 Message Date
likelovewant
ff50cfb582 Merge branch 'ollama:main' into main 2024-06-14 00:59:47 +08:00
Patrick Devine
c69bc19e46 move OLLAMA_HOST to envconfig (#5009) 2024-06-12 18:48:16 -04:00
Michael Yang
bba5d177aa Merge pull request #5004 from ollama/mxyng/fix-templates
fix: multiple templates when creating from model
2024-06-12 14:39:29 -07:00
Michael Yang
c16f8af911 fix: multiple templates when creating from model
multiple templates may appear in a model if a model is created from
another model that 1) has an autodetected template and 2) defines a
custom template
2024-06-12 13:35:49 -07:00
likelovewant
edaec3183a Merge branch 'ollama:main' into main 2024-06-12 15:35:25 +08:00
Michael Yang
217f60c3d9 Merge pull request #4987 from ollama/mxyng/revert-byte-order
Revert "Merge pull request #4938 from ollama/mxyng/fix-byte-order"
2024-06-11 16:04:20 -07:00
Michael Yang
7bdcd1da94 Revert "Merge pull request #4938 from ollama/mxyng/fix-byte-order"
This reverts commit f5f245cc15, reversing
changes made to 94d37fdcae.

this change broke gguf v2 which is incorrectly detected as big endian
2024-06-11 15:56:17 -07:00
Jeffrey Morgan
ead259d877 llm: fix seed value not being applied to requests (#4986) 2024-06-11 14:24:41 -07:00
James Montgomery
2ff45d571d Add Ollama-hpp to Community Libraries in README. (#4983) 2024-06-11 11:15:05 -07:00
Michael Yang
0f3cf1d42e Merge pull request #4715 from ollama/mxyng/utf16-parser
proper utf16 support
2024-06-10 11:41:29 -07:00
Michael Yang
5bc029c529 Merge pull request #4921 from ollama/mxyng/import-md
update import.md
2024-06-10 11:41:09 -07:00
Michael Yang
e9a9c6a8e8 Merge pull request #4965 from ollama/mxyng/skip-layer-remove
fix: skip removing layers that no longer exist
2024-06-10 11:40:03 -07:00
Michael Yang
515f497e6d fix: skip removing layers that no longer exist 2024-06-10 11:32:19 -07:00
Michael Yang
b27268aaef add test 2024-06-10 11:32:15 -07:00
Michael Yang
f5f245cc15 Merge pull request #4938 from ollama/mxyng/fix-byte-order
fix parsing big endian gguf
2024-06-10 09:38:12 -07:00
Jim Scardelis
94d37fdcae fix: examples/langchain-python-rag-privategpt/requirements.txt (#3382) 2024-06-09 10:58:09 -07:00
Craig Hughes
b84aea1685 Critical fix from llama.cpp JSON grammar to forbid un-escaped escape characters inside strings, which breaks parsing. (#3782) 2024-06-09 10:57:09 -07:00
Napuh
896495de7b Add instructions to easily install specific versions on faq.md (#4084)
* Added instructions to easily install specific versions on faq.md

* Small typo

* Moved instructions on how to install specific version to linux.md

* Update docs/linux.md

* Update docs/linux.md

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-06-09 10:49:03 -07:00
dcasota
5528dd9d11 Error handling load_single_document() in ingest.py (#4852)
load_single_document() handles
- corrupt files
- empty (zero byte) files
- unsupported file extensions
2024-06-09 10:41:07 -07:00
Jeffrey Morgan
943172cbf4 Update api.md 2024-06-08 23:04:32 -07:00
likelovewant
1b5848cbf2 remove gfx906 has conflicts with gfx906:xnack- 2024-06-09 11:46:22 +08:00
likelovewant
76026b4a35 Merge branch 'ollama:main' into main 2024-06-09 10:10:23 +08:00
Nischal Jain
85169e8d6f Added headless-ollama (#4612) 2024-06-08 18:51:16 -07:00
Jeffrey Morgan
34f142797a llm: always add bos token to prompt (#4941)
* fix embedding by adding fixes from llama.cpp upstream

* remove assert

---------

Co-authored-by: Jesper Ek <deadbeef84@gmail.com>
2024-06-08 18:47:10 -07:00
Erhan
46a7f1e74a Update README.md with LangChainRust (#4854) 2024-06-08 17:29:36 -07:00
Michael Yang
620d5c569e fix parsing big endian gguf 2024-06-08 12:35:26 -07:00
Michael Yang
b9ce7bf75e update import.md 2024-06-07 16:45:15 -07:00
Daniel Hiltgen
cddc63381c Merge pull request #4909 from dhiltgen/oneapi_disable
Add ability to skip oneapi generate
2024-06-07 14:07:15 -07:00
Michael Yang
385a32ecb5 Merge pull request #4910 from ollama/mxyng/detect-chat-template
fix create model when template detection errors
2024-06-07 11:07:39 -07:00
Michael Yang
030e765e76 fix create model when template detection errors 2024-06-07 10:51:35 -07:00
Daniel Hiltgen
ab8c929e20 Add ability to skip oneapi generate
This follows the same pattern for cuda and rocm to allow
disabling the build even when we detect the dependent libraries
2024-06-07 08:32:49 -07:00
likelovewant
27e7397b11 Update gen_windows.ps1 2024-06-07 17:35:15 +08:00
likelovewant
a6390a8992 Merge branch 'ollama:main' into main 2024-06-07 17:25:53 +08:00
Jeffrey Morgan
ce0dc33cb8 llm: patch to fix qwen 2 temporarily on nvidia (#4897) 2024-06-06 23:14:33 -07:00
Michael Yang
78f81fc0e5 Merge pull request #4800 from ollama/mxyng/detect-chat-template
detect chat template from KV
2024-06-06 16:17:18 -07:00
Michael Yang
9b6c2e6eb6 detect chat template from KV 2024-06-06 16:03:47 -07:00
royjhan
1a29e9a879 API app/browser access (#4879)
* API app/browser access

* Add tauri (resolves #2291, #4791, #3799, #4388)
2024-06-06 15:19:03 -07:00
royjhan
4bf1da4944 Separate ListResponse and ModelResponse for api/tags vs api/ps (#4842)
* Remove false time fields

* Struct Separation for List and Process

* Remove Marshaler
2024-06-06 10:11:45 -07:00
Blake Mizerany
de5beb06b3 server: skip blob verification for already verified blobs 2024-06-05 16:39:11 -07:00
Sam
98e65929dc docs(tools): add gollama (#4829) 2024-06-05 14:13:39 -07:00
Michael Yang
66ab48772f proper utf16 support 2024-06-05 13:11:50 -07:00
Michael Yang
22fcf8f7de Merge pull request #3737 from ollama/mxyng/modelname-4
update create handler to use model.Name
2024-06-05 12:05:05 -07:00
royjhan
28c7813ac4 API PS Documentation (#4822)
* API PS Documentation
2024-06-05 11:06:53 -07:00
Kartikeya Mishra
1d8616d30f docs: update to add LLocal.in to web & desktop integrations (#4719) 2024-06-04 14:43:59 -07:00
Michael Yang
d61ef8b954 update create handler to use model.Name 2024-06-04 13:28:25 -07:00
Michael Yang
89d9900152 Merge pull request #4570 from ollama/mxyng/slices
lint some of the things
2024-06-04 13:27:05 -07:00
Michael
4a048715b6 local wording was confusing people
local wording was confusing people -- Ollama runs on cloud providers
2024-06-04 13:25:25 -07:00
Michael Yang
6297f85606 gofmt, goimports 2024-06-04 13:20:24 -07:00
Michael Yang
ed56428dd7 warn on intrange, usestdlibvars 2024-06-04 11:52:48 -07:00
Michael Yang
ad40b92b6a disable intrange 2024-06-04 11:35:30 -07:00
Michael Yang
8ce4032e72 more lint 2024-06-04 11:13:30 -07:00
Michael Yang
42660466f8 no usestdlibvars 2024-06-04 11:13:30 -07:00
Michael Yang
e919f6811f lint windows 2024-06-04 11:13:30 -07:00
Michael Yang
bf7edb0d5d lint linux 2024-06-04 11:13:30 -07:00
Michael Yang
f38353d6b9 stdin.fd 2024-06-04 11:13:30 -07:00
Michael Yang
201d853fdf nolintlint 2024-06-04 11:13:30 -07:00
Michael Yang
e40145a39d lint 2024-06-04 11:13:30 -07:00
Michael Yang
c895a7d13f some gocritic 2024-06-04 11:13:30 -07:00
Michael Yang
dad7a987ae nosprintfhostport 2024-06-04 11:13:30 -07:00
Michael Yang
8ffb51749f nolintlint 2024-06-04 11:13:30 -07:00
Michael Yang
55f6eba049 gofmt 2024-06-04 11:13:30 -07:00
Michael Yang
04f3c12bb7 replace x/exp/slices with slices 2024-06-04 11:13:30 -07:00
Shubham
60323e0805 add embed model command and fix question invoke (#4766)
* add embed model command and fix question invoke

* Update docs/tutorials/langchainpy.md

Co-authored-by: Kim Hallberg <hallberg.kim@gmail.com>

* Update docs/tutorials/langchainpy.md

---------

Co-authored-by: Kim Hallberg <hallberg.kim@gmail.com>
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-06-03 22:20:48 -07:00
likelovewant
71ae05239e Update README.md 2024-06-03 15:54:48 +08:00
likelovewant
a4a435bf8f Update amd_windows.go 2024-06-03 14:55:48 +08:00
likelovewant
2490a69f7b Merge branch 'ollama:main' into main 2024-06-03 14:15:05 +08:00
Jeffrey Morgan
d4a86102fd update welcome prompt in windows to llama3 (#4779) 2024-06-01 21:05:51 -07:00
Jeffrey Morgan
476fb8e892 Limit GPU lib search for now (#4777)
* fix oneapi errors on windows 10
2024-06-01 19:24:33 -07:00
Michael Yang
829ff87bd1 revert tokenize ffi (#4761)
* Revert "use `int32_t` for call to tokenize (#4738)"

This reverts commit 763bb65dbb.

* Revert "vocab only"

This reverts commit bf54c845e9.

* Revert "use ffi for tokenizing/detokenizing"

This reverts commit 26a00a0410.
2024-05-31 18:54:21 -07:00
Josh
f6b622c4b3 Merge pull request #4733 from ollama/jyan/isvalidname
added IsValidNamespace function
2024-05-31 14:08:45 -07:00
Josh Yan
2e4da8eec2 added tests for IsValidNamespace 2024-05-31 11:48:07 -07:00
likelovewant
16ce79eb3b Merge branch 'ollama:main' into main 2024-05-31 18:43:24 +08:00
Jeffrey Morgan
763bb65dbb use int32_t for call to tokenize (#4738)
* use `int32_t` for call to tokenize

* variable naming

* cleanup

* fix crash
2024-05-30 21:43:30 -07:00
Jeffrey Morgan
7ca9605f54 speed up tests by only building static lib (#4740) 2024-05-30 21:43:15 -07:00
Michael Yang
eb2c443a79 Merge pull request #4736 from ollama/mxyng/vocab-only
vocab only for tokenize
2024-05-30 17:21:00 -07:00
Michael Yang
278e25ea44 Merge pull request #4737 from ollama/mxyng/less-generate
only generate on relevant changes
2024-05-30 17:17:50 -07:00
Jeffrey Morgan
a50a87a7b8 partial offloading: allow flash attention and disable mmap (#4734)
* partial offloading: allow flash attention and disable mmap

* allow mmap with num_gpu=0
2024-05-30 16:58:01 -07:00
Michael Yang
98085015d5 only generate on relevant changes 2024-05-30 16:54:11 -07:00
Michael Yang
bf54c845e9 vocab only 2024-05-30 16:49:28 -07:00
Josh Yan
c365f195a8 directly use isvalidpart 2024-05-30 16:40:04 -07:00
Josh
e91d0ef737 Merge pull request #4728 from ollama/jyan/japanese
fixed japanese characters deleted at end of line
2024-05-30 16:25:12 -07:00
Jeffrey Morgan
22f5c12ced Update llama.cpp submodule to 5921b8f0 (#4731)
* update llama.cpp submodule to `5921b8f089d3b7bda86aac5a66825df6a6c10603`

* add patch
2024-05-30 16:20:22 -07:00
Josh Yan
298c996e54 added IsValidNamespace function 2024-05-30 16:02:07 -07:00
Daniel Hiltgen
0fc0cfc6d2 Merge pull request #4594 from dhiltgen/doc_container_workarounds
Add isolated gpu test to troubleshooting
2024-05-30 13:10:54 -07:00
Josh Yan
914f68f021 replaced duplicate call with variable 2024-05-30 10:38:07 -07:00
Josh Yan
bd1d119ba9 fixed japanese characters deleted at end of line 2024-05-30 10:24:21 -07:00
Lei Jitang
a03be18189 Fix OLLAMA_LLM_LIBRARY with wrong map name and add more env vars to help message (#4663)
* envconfig/config.go: Fix wrong description of OLLAMA_LLM_LIBRARY

Signed-off-by: Lei Jitang <leijitang@outlook.com>

* serve: Add more env to help message of ollama serve

Add more enviroment variables to `ollama serve --help`
to let users know what can be configurated.

Signed-off-by: Lei Jitang <leijitang@outlook.com>

---------

Signed-off-by: Lei Jitang <leijitang@outlook.com>
2024-05-30 09:36:51 -07:00
Michael Yang
96bc232b43 Merge pull request #4413 from ollama/mxyng/name-check
check if name exists before create/pull/copy
2024-05-29 12:06:58 -07:00
Michael Yang
bca7b12284 Merge pull request #3718 from ollama/mxyng/modelname-3
update delete handler to use model.Name
2024-05-29 12:02:07 -07:00
Michael Yang
32cb1960c1 Merge pull request #4380 from ollama/mxyng/tokenize
use tokenize/detokenize
2024-05-29 12:00:59 -07:00
Michael Yang
de781b37c8 rm unused infill 2024-05-29 11:26:47 -07:00
Michael Yang
3e21799377 rm unused system prompt 2024-05-29 11:26:47 -07:00
Michael Yang
26a00a0410 use ffi for tokenizing/detokenizing 2024-05-29 11:26:47 -07:00
likelovewant
cafde1f8ce Merge branch 'ollama:main' into main 2024-05-29 19:33:39 +08:00
Daniel Hiltgen
646371f56d Merge pull request #3278 from zhewang1-intc/rebase_ollama_main
Enabling ollama to run on Intel GPUs with SYCL backend
2024-05-28 16:30:50 -07:00
Jeffrey Morgan
1f5008544b Update install.sh 2024-05-28 15:01:22 -07:00
Jeffrey Morgan
45cbfc5aee fix wsl2 status check for nvidia cards (#4689) 2024-05-28 14:49:46 -07:00
Jeffrey Morgan
6d423b383b Improve install experience on WSL2 and Linux (#4653) 2024-05-28 14:41:50 -07:00
Josh
ad897080a2 working on integration of multi-byte and multi-width runes (#4549)
* integrated runewidth for display management - fixed cursor movement for mutli-width char

* updated input and deletion of multi-byte chars

* fixed line history with some exceptions

* improved insert and add

* fixed issues with moving across lines

* end of line extra space tracking'

* saved changes

* fixed end of line issues with empty spaces

* worked some more

* worked on end of line

* fixed failed test

* fixed minor inserting bug

* fixed movement hotkeys

* adjusted hotkeys

* removed comments

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* deleted comments and duplicate code

* removed duplicate code

* added comments, refactored add function to use addChar

* added helper to retrieve lineSpacing, renamed lineFlags for clarity

* fixed remove()

---------

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2024-05-28 12:04:03 -07:00
Jeffrey Morgan
b7d316d98d fix nvidia detection in install script (#4683) 2024-05-28 09:59:36 -07:00
Daniel Hiltgen
d7339fad52 Merge pull request #4682 from dhiltgen/more_time
Give the final model loading more time
2024-05-28 09:36:02 -07:00
Daniel Hiltgen
92c81e8117 Give the final model loading more time
On some systems, 1 minute isn't sufficient to finish the load after it
hits 100% This creates 2 distinct timers, although they're both set to
the same value for now so we can refine the timeouts further.
2024-05-28 09:08:10 -07:00
Tai
9db0996ed4 Add OllamaSpring Project to Readme (#4672)
* Add OllamaSpring Project to Readme

* Update README.md

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-05-27 19:58:26 -07:00
Orfeo Ciano
6f43898b17 Adds olpaka flutter client (#4647)
* Adds olpaka flutter client

* Update README.md

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-05-27 17:22:01 -07:00
Lei Jitang
7487229c34 llm/server.go: Fix 2 minor typos (#4661)
Signed-off-by: Lei Jitang <leijitang@outlook.com>
2024-05-27 17:21:10 -07:00
Rayan Mostovoi
8a8e7afa96 small fix on examples/python-simplechat/client.py to actually get a streamed response and get tokens printed as we receive it (#4671) 2024-05-27 17:19:20 -07:00
Jeffrey Morgan
c79f8c9c39 Ensure nvidia and nvidia_uvm kernel modules are loaded in install.sh script and at startup (#4652)
* ensure kernel modules are loaded in `install.sh` script and at startup

* indentation

* use `SUDO` variable

* restart if nouveau is detected

* consistent success message for AMD
2024-05-26 14:57:17 -07:00
Jeffrey Morgan
485016bfbb Update install.sh 2024-05-26 11:46:00 -07:00
likelovewant
2a80d6f743 Merge branch 'ollama:main' into main 2024-05-26 11:57:21 +08:00
Daniel Hiltgen
0165ba1651 Merge pull request #4638 from dhiltgen/better_error
Report better warning on client closed abort of load
2024-05-25 14:32:28 -07:00
Daniel Hiltgen
c4209d6d21 Report better warning on client closed abort of load
If the client closes the connection before we finish loading the model
we abort, so lets make the log message clearer why to help users
understand this failure mode
2024-05-25 09:23:28 -07:00
Michael Yang
6adca97f37 Merge pull request #4619 from noxer/patch-1
Fix download retry issue
2024-05-24 17:21:57 -07:00
Michael Yang
9a3c8003c8 Merge pull request #4624 from ollama/mxyng/fix-5
fix q5_0, q5_1
2024-05-24 16:11:21 -07:00
Michael Yang
d51f15257c Update llm/ggml.go
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2024-05-24 16:10:43 -07:00
Michael Yang
8f440d579a fix q5_0, q5_1 2024-05-24 16:01:46 -07:00
Patrick Devine
4cc3be3035 Move envconfig and consolidate env vars (#4608) 2024-05-24 14:57:15 -07:00
Tim Scheuermann
db2ffa79f1 Fix download retry issue 2024-05-24 20:30:42 +02:00
likelovewant
73c49d57e8 Update amd_windows.go
remove this will broken the installer build
2024-05-24 20:06:28 +08:00
likelovewant
6b50b2f3bf Update gen_windows.ps1 2024-05-24 15:42:29 +08:00
Wang,Zhe
fd5971be0b support ollama run on Intel GPUs 2024-05-24 11:18:27 +08:00
Daniel Hiltgen
f77713bf1f Add isolated gpu test to troubleshooting 2024-05-23 09:33:25 -07:00
Michael Yang
85a57006d1 check if name exists before create/pull/copy 2024-05-14 14:58:58 -07:00
Michael Yang
c5e892cb3e update tests 2024-05-14 14:56:31 -07:00
Michael Yang
81fb06f530 more resilient Manifests 2024-05-14 14:08:24 -07:00
Michael Yang
a385382ff5 filepath.Join 2024-05-14 14:08:24 -07:00
Michael Yang
b8772a353f remove DeleteModel 2024-05-14 14:08:24 -07:00
Michael Yang
c2714fcbfd routes: use Manifests for ListHandler 2024-05-14 14:08:24 -07:00
Michael Yang
a2fc933fed update delete handler to use model.Name 2024-05-14 14:08:24 -07:00
114 changed files with 3493 additions and 1190 deletions

View File

@@ -34,13 +34,13 @@ jobs:
git diff-tree -r --no-commit-id --name-only \
$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
${{ github.event.pull_request.head.sha }} \
| xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
}
{
echo GENERATE=$(changed llm/)
echo GENERATE_CUDA=$(changed llm/)
echo GENERATE_ROCM=$(changed llm/)
echo GENERATE=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
echo GENERATE_CUDA=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
echo GENERATE_ROCM=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
} >>$GITHUB_OUTPUT
generate:
@@ -269,9 +269,9 @@ jobs:
mkdir -p llm/build/darwin/$ARCH/stub/bin
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'macos-') }}
- uses: golangci/golangci-lint-action@v4
- uses: golangci/golangci-lint-action@v6
with:
args: --timeout 8m0s -v
args: --timeout 8m0s -v ${{ startsWith(matrix.os, 'windows-') && '' || '--disable gofmt --disable goimports' }}
test:
strategy:
matrix:
@@ -287,6 +287,8 @@ jobs:
GOARCH: ${{ matrix.arch }}
CGO_ENABLED: '1'
OLLAMA_CPU_TARGET: 'static'
OLLAMA_SKIP_CPU_GENERATE: '1'
OLLAMA_SKIP_METAL_GENERATE: '1'
steps:
- uses: actions/checkout@v4
with:

View File

@@ -9,9 +9,26 @@ linters:
- contextcheck
- exportloopref
- gocheckcompilerdirectives
# FIXME: for some reason this errors on windows
# conditionally enable this on linux/macos
# - gofmt
# - goimports
- intrange
- misspell
- nilerr
- nolintlint
- nosprintfhostport
- testifylint
- unconvert
- unused
- wastedassign
- whitespace
- usestdlibvars
severity:
default-severity: error
rules:
- linters:
- gofmt
- goimports
- intrange
- usestdlibvars
severity: info

View File

@@ -6,7 +6,7 @@
[![Discord](https://dcbadge.vercel.app/api/server/ollama?style=flat&compact=true)](https://discord.gg/ollama)
Get up and running with large language models locally.
Get up and running with large language models.
### macOS
@@ -30,7 +30,7 @@ Example extra list add on this repo.
```
Please follow the [wiki](https://github.com/likelovewant/ollama-for-amd/wiki) guide to build or use the pre-release version.
Note: `gfx803, gfx1010` reported not working by the wiki method ,expected a future support
Note: `gfx803` reported partialy working by the wiki method ,expected a future support
@@ -301,6 +301,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG)
- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation)
- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends)
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
### Terminal
@@ -323,6 +326,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [ShellOracle](https://github.com/djcopley/ShellOracle)
- [tlm](https://github.com/yusufcanb/tlm)
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
- [gollama](https://github.com/sammcj/gollama)
### Database
@@ -340,11 +344,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/modules/model_io/models/llms/integrations/ollama) with [example](https://js.langchain.com/docs/use_cases/question_answering/local_retrieval_qa)
- [LangChainGo](https://github.com/tmc/langchaingo/) with [example](https://github.com/tmc/langchaingo/tree/main/examples/ollama-completion-example)
- [LangChain4j](https://github.com/langchain4j/langchain4j) with [example](https://github.com/langchain4j/langchain4j-examples/tree/main/ollama-examples/src/main/java)
- [LangChainRust](https://github.com/Abraxas-365/langchain-rust) with [example](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs)
- [LlamaIndex](https://gpt-index.readthedocs.io/en/stable/examples/llm/ollama.html)
- [LiteLLM](https://github.com/BerriAI/litellm)
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
- [Ollama for Ruby](https://github.com/gbaptista/ollama-ai)
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp)
- [Ollama4j for Java](https://github.com/amithkoujalgi/ollama4j)
- [ModelFusion Typescript Library](https://modelfusion.dev/integration/model-provider/ollama)
- [OllamaKit for Swift](https://github.com/kevinhermawan/OllamaKit)
@@ -362,6 +368,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama)
- [PromptingTools.jl](https://github.com/svilupp/PromptingTools.jl) with an [example](https://svilupp.github.io/PromptingTools.jl/dev/examples/working_with_ollama)
- [LlamaScript](https://github.com/Project-Llama/llamascript)
### Mobile
- [Enchanted](https://github.com/AugustDev/enchanted)
@@ -394,7 +401,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depends on ollama server)
### Supported backends
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.

View File

@@ -23,11 +23,9 @@ import (
"net"
"net/http"
"net/url"
"os"
"runtime"
"strconv"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/version"
)
@@ -65,10 +63,7 @@ func checkError(resp *http.Response, body []byte) error {
// If the variable is not specified, a default ollama host and port will be
// used.
func ClientFromEnvironment() (*Client, error) {
ollamaHost, err := GetOllamaHost()
if err != nil {
return nil, err
}
ollamaHost := envconfig.Host
return &Client{
base: &url.URL{
@@ -79,52 +74,6 @@ func ClientFromEnvironment() (*Client, error) {
}, nil
}
type OllamaHost struct {
Scheme string
Host string
Port string
}
func GetOllamaHost() (OllamaHost, error) {
defaultPort := "11434"
hostVar := os.Getenv("OLLAMA_HOST")
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
scheme, hostport, ok := strings.Cut(hostVar, "://")
switch {
case !ok:
scheme, hostport = "http", hostVar
case scheme == "http":
defaultPort = "80"
case scheme == "https":
defaultPort = "443"
}
// trim trailing slashes
hostport = strings.TrimRight(hostport, "/")
host, port, err := net.SplitHostPort(hostport)
if err != nil {
host, port = "127.0.0.1", defaultPort
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
host = ip.String()
} else if hostport != "" {
host = hostport
}
}
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
return OllamaHost{}, ErrInvalidHostPort
}
return OllamaHost{
Scheme: scheme,
Host: host,
Port: port,
}, nil
}
func NewClient(base *url.URL, http *http.Client) *Client {
return &Client{
base: base,
@@ -355,8 +304,8 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
}
// List running models.
func (c *Client) ListRunning(ctx context.Context) (*ListResponse, error) {
var lr ListResponse
func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
var lr ProcessResponse
if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
return nil, err
}

View File

@@ -1,11 +1,9 @@
package api
import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/ollama/ollama/envconfig"
)
func TestClientFromEnvironment(t *testing.T) {
@@ -35,6 +33,7 @@ func TestClientFromEnvironment(t *testing.T) {
for k, v := range testCases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value)
envconfig.LoadConfig()
client, err := ClientFromEnvironment()
if err != v.err {
@@ -46,40 +45,4 @@ func TestClientFromEnvironment(t *testing.T) {
}
})
}
hostTestCases := map[string]*testCase{
"empty": {value: "", expect: "127.0.0.1:11434"},
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
"only port": {value: ":1234", expect: ":1234"},
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
"hostname": {value: "example.com", expect: "example.com:11434"},
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
"zero port": {value: ":0", expect: ":0"},
"too large port": {value: ":66000", err: ErrInvalidHostPort},
"too small port": {value: ":-1", err: ErrInvalidHostPort},
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
}
for k, v := range hostTestCases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value)
oh, err := GetOllamaHost()
if err != v.err {
t.Fatalf("expected %s, got %s", v.err, err)
}
if err == nil {
host := net.JoinHostPort(oh.Host, oh.Port)
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
}
})
}
}

View File

@@ -2,7 +2,6 @@ package api
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"math"
@@ -282,19 +281,33 @@ type PushRequest struct {
// ListResponse is the response from [Client.List].
type ListResponse struct {
Models []ModelResponse `json:"models"`
Models []ListModelResponse `json:"models"`
}
// ModelResponse is a single model description in [ListResponse].
type ModelResponse struct {
// ProcessResponse is the response from [Client.Process].
type ProcessResponse struct {
Models []ProcessModelResponse `json:"models"`
}
// ListModelResponse is a single model description in [ListResponse].
type ListModelResponse struct {
Name string `json:"name"`
Model string `json:"model"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
ModifiedAt time.Time `json:"modified_at"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
SizeVRAM int64 `json:"size_vram,omitempty"`
}
// ProcessModelResponse is a single model description in [ProcessResponse].
type ProcessModelResponse struct {
Name string `json:"name"`
Model string `json:"model"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
SizeVRAM int64 `json:"size_vram"`
}
type TokenResponse struct {
@@ -306,7 +319,7 @@ type GenerateResponse struct {
// Model is the model name that generated the response.
Model string `json:"model"`
//CreatedAt is the timestamp of the response.
// CreatedAt is the timestamp of the response.
CreatedAt time.Time `json:"created_at"`
// Response is the textual response itself.
@@ -363,8 +376,6 @@ func (m *Metrics) Summary() {
}
}
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
func (opts *Options) FromMap(m map[string]interface{}) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct

View File

@@ -72,13 +72,13 @@ func TestDurationMarshalUnmarshal(t *testing.T) {
},
{
"positive duration",
time.Duration(42 * time.Second),
time.Duration(42 * time.Second),
42 * time.Second,
42 * time.Second,
},
{
"another positive duration",
time.Duration(42 * time.Minute),
time.Duration(42 * time.Minute),
42 * time.Minute,
42 * time.Minute,
},
{
"zero duration",

View File

@@ -6,7 +6,7 @@ import (
"os"
"path/filepath"
"github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/envconfig"
)
func InitLogging() {

View File

@@ -69,7 +69,6 @@ func init() {
slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err))
}
}
} else if runtime.GOOS == "darwin" {
// TODO
AppName += ".app"

View File

@@ -15,7 +15,7 @@ import (
)
func getCLIFullPath(command string) string {
cmdPath := ""
var cmdPath string
appExe, err := os.Executable()
if err == nil {
cmdPath = filepath.Join(filepath.Dir(appExe), command)
@@ -65,7 +65,6 @@ func start(ctx context.Context, command string) (*exec.Cmd, error) {
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
}
if err := os.MkdirAll(logDir, 0o755); err != nil {

View File

@@ -24,7 +24,8 @@ func terminate(cmd *exec.Cmd) error {
if err != nil {
return err
}
defer dll.Release() // nolint: errcheck
//nolint:errcheck
defer dll.Release()
pid := cmd.Process.Pid
@@ -73,7 +74,8 @@ func isProcessExited(pid int) (bool, error) {
if err != nil {
return false, fmt.Errorf("failed to open process: %v", err)
}
defer windows.CloseHandle(hProcess) // nolint: errcheck
//nolint:errcheck
defer windows.CloseHandle(hProcess)
var exitCode uint32
err = windows.GetExitCodeProcess(hProcess, &exitCode)

View File

@@ -78,7 +78,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
}
defer resp.Body.Close()
if resp.StatusCode == 204 {
if resp.StatusCode == http.StatusNoContent {
slog.Debug("check update response 204 (current version is up to date)")
return false, updateResp
}
@@ -87,7 +87,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
slog.Warn(fmt.Sprintf("failed to read body response: %s", err))
}
if resp.StatusCode != 200 {
if resp.StatusCode != http.StatusOK {
slog.Info(fmt.Sprintf("check update error %d - %.96s", resp.StatusCode, string(body)))
return false, updateResp
}
@@ -114,7 +114,7 @@ func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
if err != nil {
return fmt.Errorf("error checking update: %w", err)
}
if resp.StatusCode != 200 {
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode)
}
resp.Body.Close()

View File

@@ -4,5 +4,5 @@ write-host "Welcome to Ollama!"
write-host ""
write-host "Run your first model:"
write-host ""
write-host "`tollama run llama2"
write-host "`tollama run llama3"
write-host ""

View File

@@ -29,7 +29,6 @@ func GetID() string {
initStore()
}
return store.ID
}
func GetFirstTimeRun() bool {

View File

@@ -47,7 +47,6 @@ func nativeLoop() {
default:
pTranslateMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
pDispatchMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
}
}
}
@@ -160,8 +159,8 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
lResult, _, _ = pDefWindowProc.Call(
uintptr(hWnd),
uintptr(message),
uintptr(wParam),
uintptr(lParam),
wParam,
lParam,
)
}
return

View File

@@ -186,7 +186,7 @@ func (t *winTray) initInstance() error {
t.muNID.Lock()
defer t.muNID.Unlock()
t.nid = &notifyIconData{
Wnd: windows.Handle(t.window),
Wnd: t.window,
ID: 100,
Flags: NIF_MESSAGE,
CallbackMessage: t.wmSystrayMessage,
@@ -197,7 +197,6 @@ func (t *winTray) initInstance() error {
}
func (t *winTray) createMenu() error {
menuHandle, _, err := pCreatePopupMenu.Call()
if menuHandle == 0 {
return err
@@ -246,7 +245,7 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title
mi := menuItemInfo{
Mask: MIIM_FTYPE | MIIM_STRING | MIIM_ID | MIIM_STATE,
Type: MFT_STRING,
ID: uint32(menuItemId),
ID: menuItemId,
TypeData: titlePtr,
Cch: uint32(len(title)),
}
@@ -302,11 +301,10 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title
}
func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
mi := menuItemInfo{
Mask: MIIM_FTYPE | MIIM_ID | MIIM_STATE,
Type: MFT_SEPARATOR,
ID: uint32(menuItemId),
ID: menuItemId,
}
mi.Size = uint32(unsafe.Sizeof(mi))
@@ -426,7 +424,6 @@ func iconBytesToFilePath(iconBytes []byte) (string, error) {
// Loads an image from file and shows it in tray.
// Shell_NotifyIcon: https://msdn.microsoft.com/en-us/library/windows/desktop/bb762159(v=vs.85).aspx
func (t *winTray) setIcon(src string) error {
h, err := t.loadIconFrom(src)
if err != nil {
return err
@@ -444,7 +441,6 @@ func (t *winTray) setIcon(src string) error {
// Loads an image from file to be shown in tray or menu item.
// LoadImage: https://msdn.microsoft.com/en-us/library/windows/desktop/ms648045(v=vs.85).aspx
func (t *winTray) loadIconFrom(src string) (windows.Handle, error) {
// Save and reuse handles of loaded images
t.muLoadedImages.RLock()
h, ok := t.loadedImages[src]

View File

@@ -20,6 +20,7 @@ import (
"path/filepath"
"regexp"
"runtime"
"slices"
"strings"
"syscall"
"time"
@@ -29,11 +30,11 @@ import (
"github.com/olekukonko/tablewriter"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/exp/slices"
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
@@ -745,7 +746,6 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
if wordWrap && termWidth >= 10 {
for _, ch := range content {
if state.lineLength+1 > termWidth-5 {
if runewidth.StringWidth(state.wordBuffer) > termWidth-10 {
fmt.Printf("%s%c", state.wordBuffer, ch)
state.wordBuffer = ""
@@ -754,7 +754,11 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
}
// backtrack the length of the last word and clear to the end of the line
fmt.Printf("\x1b[%dD\x1b[K\n", runewidth.StringWidth(state.wordBuffer))
a := runewidth.StringWidth(state.wordBuffer)
if a > 0 {
fmt.Printf("\x1b[%dD", a)
}
fmt.Printf("\x1b[K\n")
fmt.Printf("%s%c", state.wordBuffer, ch)
chWidth := runewidth.RuneWidth(ch)
@@ -956,17 +960,11 @@ func generate(cmd *cobra.Command, opts runOptions) error {
}
func RunServer(cmd *cobra.Command, _ []string) error {
// retrieve the OLLAMA_HOST environment variable
ollamaHost, err := api.GetOllamaHost()
if err != nil {
return err
}
if err := initializeKeypair(); err != nil {
return err
}
ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
ln, err := net.Listen("tcp", net.JoinHostPort(envconfig.Host.Host, envconfig.Host.Port))
if err != nil {
return err
}
@@ -1025,24 +1023,6 @@ func initializeKeypair() error {
return nil
}
//nolint:unused
func waitForServer(ctx context.Context, client *api.Client) error {
// wait for the server to start
timeout := time.After(5 * time.Second)
tick := time.Tick(500 * time.Millisecond)
for {
select {
case <-timeout:
return errors.New("timed out waiting for server to start")
case <-tick:
if err := client.Heartbeat(ctx); err == nil {
return nil // server has started
}
}
}
}
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -1079,12 +1059,7 @@ func versionHandler(cmd *cobra.Command, _ []string) {
}
}
type EnvironmentVar struct {
Name string
Description string
}
func appendEnvDocs(cmd *cobra.Command, envs []EnvironmentVar) {
func appendEnvDocs(cmd *cobra.Command, envs []envconfig.EnvVar) {
if len(envs) == 0 {
return
}
@@ -1093,7 +1068,7 @@ func appendEnvDocs(cmd *cobra.Command, envs []EnvironmentVar) {
Environment Variables:
`
for _, e := range envs {
envUsage += fmt.Sprintf(" %-16s %s\n", e.Name, e.Description)
envUsage += fmt.Sprintf(" %-24s %s\n", e.Name, e.Description)
}
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
@@ -1172,15 +1147,6 @@ func NewCLI() *cobra.Command {
Args: cobra.ExactArgs(0),
RunE: RunServer,
}
serveCmd.SetUsageTemplate(serveCmd.UsageTemplate() + `
Environment Variables:
OLLAMA_HOST The host:port to bind to (default "127.0.0.1:11434")
OLLAMA_ORIGINS A comma separated list of allowed origins
OLLAMA_MODELS The path to the models directory (default "~/.ollama/models")
OLLAMA_KEEP_ALIVE The duration that models stay loaded in memory (default "5m")
OLLAMA_DEBUG Set to 1 to enable additional debug logging
`)
pullCmd := &cobra.Command{
Use: "pull MODEL",
@@ -1233,9 +1199,9 @@ Environment Variables:
RunE: DeleteHandler,
}
ollamaHostEnv := EnvironmentVar{"OLLAMA_HOST", "The host:port or base URL of the Ollama server (e.g. http://localhost:11434)"}
ollamaNoHistoryEnv := EnvironmentVar{"OLLAMA_NOHISTORY", "Disable readline history"}
envs := []EnvironmentVar{ollamaHostEnv}
envVars := envconfig.AsMap()
envs := []envconfig.EnvVar{envVars["OLLAMA_HOST"]}
for _, cmd := range []*cobra.Command{
createCmd,
@@ -1247,10 +1213,27 @@ Environment Variables:
psCmd,
copyCmd,
deleteCmd,
serveCmd,
} {
switch cmd {
case runCmd:
appendEnvDocs(cmd, []EnvironmentVar{ollamaHostEnv, ollamaNoHistoryEnv})
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
case serveCmd:
appendEnvDocs(cmd, []envconfig.EnvVar{
envVars["OLLAMA_DEBUG"],
envVars["OLLAMA_HOST"],
envVars["OLLAMA_KEEP_ALIVE"],
envVars["OLLAMA_MAX_LOADED_MODELS"],
envVars["OLLAMA_MAX_QUEUE"],
envVars["OLLAMA_MODELS"],
envVars["OLLAMA_NUM_PARALLEL"],
envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_MAX_VRAM"],
})
default:
appendEnvDocs(cmd, envs)
}

View File

@@ -8,13 +8,14 @@ import (
"os"
"path/filepath"
"regexp"
"slices"
"sort"
"strings"
"github.com/spf13/cobra"
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes"
@@ -183,7 +184,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err
}
if os.Getenv("OLLAMA_NOHISTORY") != "" {
if envconfig.NoHistory {
scanner.HistoryDisable()
}

View File

@@ -6,6 +6,7 @@ import (
"text/template"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
)
@@ -85,11 +86,11 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
`
tmpl, err := template.New("").Parse(expectedModelfile)
assert.Nil(t, err)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, opts)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, buf.String(), mf)
opts.ParentModel = "horseshark"
@@ -107,10 +108,10 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
`
tmpl, err = template.New("").Parse(expectedModelfile)
assert.Nil(t, err)
require.NoError(t, err)
var parentBuf bytes.Buffer
err = tmpl.Execute(&parentBuf, opts)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, parentBuf.String(), mf)
}

27
cmd/start.go Normal file
View File

@@ -0,0 +1,27 @@
//go:build darwin || windows
package cmd
import (
"context"
"errors"
"time"
"github.com/ollama/ollama/api"
)
func waitForServer(ctx context.Context, client *api.Client) error {
// wait for the server to start
timeout := time.After(5 * time.Second)
tick := time.Tick(500 * time.Millisecond)
for {
select {
case <-timeout:
return errors.New("timed out waiting for server to start")
case <-tick:
if err := client.Heartbeat(ctx); err == nil {
return nil // server has started
}
}
}
}

View File

@@ -189,7 +189,7 @@ func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
if params.VocabSize > len(v.Tokens) {
missingTokens := params.VocabSize - len(v.Tokens)
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
for cnt := 0; cnt < missingTokens; cnt++ {
for cnt := range missingTokens {
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
v.Scores = append(v.Scores, -1)
v.Types = append(v.Types, tokenTypeUserDefined)

View File

@@ -35,7 +35,6 @@ func addOnes(data []float32, vectorSize int) ([]float32, error) {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@@ -119,11 +119,12 @@ func llamaRepack(name string, params *Params, data []float32, shape []uint64) ([
}
var heads int
if strings.HasSuffix(name, "attn_q.weight") {
switch {
case strings.HasSuffix(name, "attn_q.weight"):
heads = params.AttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight") {
case strings.HasSuffix(name, "attn_k.weight"):
heads = cmp.Or(params.KeyValHeads, params.AttentionHeads)
} else {
default:
return nil, fmt.Errorf("unknown tensor name: %s", name)
}

View File

@@ -120,7 +120,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
Name: name,
Kind: kind,
Offset: offset,
Shape: shape[:],
Shape: shape,
}
t.WriterTo = safetensorWriterTo{

View File

@@ -85,13 +85,10 @@ func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, e
sha256sum := sha256.New()
for _, pt := range t.PreTokenizer.PreTokenizers {
switch pt.Type {
case "Split":
if pt.Pattern.Regex != "" {
if pt.Type == "Split" && pt.Pattern.Regex != "" {
sha256sum.Write([]byte(pt.Pattern.Regex))
}
}
}
switch digest := fmt.Sprintf("%x", sha256sum.Sum(nil)); digest {
case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f":

View File

@@ -88,7 +88,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
Name: ggufName,
Kind: kind,
Offset: offset, // calculate the offset
Shape: shape[:],
Shape: shape,
}
tensor.WriterTo = torchWriterTo{
@@ -104,7 +104,6 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
}
return tensors, nil
}
func getAltParams(dirpath string) (*Params, error) {

View File

@@ -12,6 +12,7 @@
- [Pull a Model](#pull-a-model)
- [Push a Model](#push-a-model)
- [Generate Embeddings](#generate-embeddings)
- [List Running Models](#list-running-models)
## Conventions
@@ -249,7 +250,7 @@ curl http://localhost:11434/api/generate -d '{
#### Request (Reproducible outputs)
For reproducible outputs, set `temperature` to 0 and `seed` to a number:
For reproducible outputs, set `seed` to a number:
##### Request
@@ -258,8 +259,7 @@ curl http://localhost:11434/api/generate -d '{
"model": "mistral",
"prompt": "Why is the sky blue?",
"options": {
"seed": 123,
"temperature": 0
"seed": 123
}
}'
```
@@ -1035,3 +1035,47 @@ curl http://localhost:11434/api/embeddings -d '{
]
}
```
## List Running Models
```shell
GET /api/ps
```
List models that are currently loaded into memory.
#### Examples
### Request
```shell
curl http://localhost:11434/api/ps
```
#### Response
A single JSON object will be returned.
```json
{
"models": [
{
"name": "mistral:latest",
"model": "mistral:latest",
"size": 5137025024,
"digest": "2ae6f6dd7a3dd734790bbbf58b8909a606e0e7e97e94b7604e0aa7ae4490e6d8",
"details": {
"parent_model": "",
"format": "gguf",
"family": "llama",
"families": [
"llama"
],
"parameter_size": "7.2B",
"quantization_level": "Q4_0"
},
"expires_at": "2024-06-04T14:38:31.83753-07:00",
"size_vram": 5137025024
}
]
}
```

View File

@@ -1,170 +1,99 @@
# Import a model
# Import
This guide walks through importing a GGUF, PyTorch or Safetensors model.
GGUF models and select Safetensors models can be imported directly into Ollama.
## Importing (GGUF)
## Import GGUF
### Step 1: Write a `Modelfile`
A binary GGUF file can be imported directly into Ollama through a Modelfile.
Start by creating a `Modelfile`. This file is the blueprint for your model, specifying weights, parameters, prompt templates and more.
```
FROM ./mistral-7b-v0.1.Q4_0.gguf
```dockerfile
FROM /path/to/file.gguf
```
(Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`:
## Import Safetensors
```
FROM ./mistral-7b-v0.1.Q4_0.gguf
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
If the model being imported is one of these architectures, it can be imported directly into Ollama through a Modelfile:
- LlamaForCausalLM
- MistralForCausalLM
- GemmaForCausalLM
```dockerfile
FROM /path/to/safetensors/directory
```
### Step 2: Create the Ollama model
For architectures not directly convertable by Ollama, see llama.cpp's [guide](https://github.com/ggerganov/llama.cpp/blob/master/README.md#prepare-and-quantize) on conversion. After conversion, see [Import GGUF](#import-gguf).
Finally, create a model from your `Modelfile`:
## Automatic Quantization
> [!NOTE]
> Automatic quantization requires v0.1.35 or higher.
Ollama is capable of quantizing FP16 or FP32 models to any of the supported quantizations with the `-q/--quantize` flag in `ollama create`.
```dockerfile
FROM /path/to/my/gemma/f16/model
```
ollama create example -f Modelfile
```
### Step 3: Run your model
Next, test the model with `ollama run`:
```
ollama run example "What is your favourite condiment?"
```
## Importing (PyTorch & Safetensors)
> Importing from PyTorch and Safetensors is a longer process than importing from GGUF. Improvements that make it easier are a work in progress.
### Setup
First, clone the `ollama/ollama` repo:
```
git clone git@github.com:ollama/ollama.git ollama
cd ollama
```
and then fetch its `llama.cpp` submodule:
```shell
git submodule init
git submodule update llm/llama.cpp
$ ollama create -q Q4_K_M mymodel
transferring model data
quantizing F16 model to Q4_K_M
creating new layer sha256:735e246cc1abfd06e9cdcf95504d6789a6cd1ad7577108a70d9902fef503c1bd
creating new layer sha256:0853f0ad24e5865173bbf9ffcc7b0f5d56b66fd690ab1009867e45e7d2c4db0f
writing manifest
success
```
Next, install the Python dependencies:
### Supported Quantizations
```
python3 -m venv llm/llama.cpp/.venv
source llm/llama.cpp/.venv/bin/activate
pip install -r llm/llama.cpp/requirements.txt
<details>
<summary>Legacy Quantization</summary>
- `Q4_0`
- `Q4_1`
- `Q5_0`
- `Q5_1`
- `Q8_0`
</details>
<details>
<summary>K-means Quantization</summary>`
- `Q3_K_S`
- `Q3_K_M`
- `Q3_K_L`
- `Q4_K_S`
- `Q4_K_M`
- `Q5_K_S`
- `Q5_K_M`
- `Q6_K`
</details>
> [!NOTE]
> Activation-aware Weight Quantization (i.e. IQ) are not currently supported for automatic quantization however you can still import the quantized model into Ollama, see [Import GGUF](#import-gguf).
## Template Detection
> [!NOTE]
> Template detection requires v0.1.42 or higher.
Ollama uses model metadata, specifically `tokenizer.chat_template`, to automatically create a template appropriate for the model you're importing.
```dockerfile
FROM /path/to/my/gemma/model
```
Then build the `quantize` tool:
```
make -C llm/llama.cpp quantize
```shell
$ ollama create mymodel
transferring model data
using autodetected template gemma-instruct
creating new layer sha256:baa2a0edc27d19cc6b7537578a9a7ba1a4e3214dc185ed5ae43692b319af7b84
creating new layer sha256:ba66c3309914dbef07e5149a648fd1877f030d337a4f240d444ea335008943cb
writing manifest
success
```
### Clone the HuggingFace repository (optional)
If the model is currently hosted in a HuggingFace repository, first clone that repository to download the raw model.
Install [Git LFS](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage), verify it's installed, and then clone the model's repository:
```
git lfs install
git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 model
```
### Convert the model
> Note: some model architectures require using specific convert scripts. For example, Qwen models require running `convert-hf-to-gguf.py` instead of `convert.py`
```
python llm/llama.cpp/convert.py ./model --outtype f16 --outfile converted.bin
```
### Quantize the model
```
llm/llama.cpp/quantize converted.bin quantized.bin q4_0
```
### Step 3: Write a `Modelfile`
Next, create a `Modelfile` for your model:
```
FROM quantized.bin
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
```
### Step 4: Create the Ollama model
Finally, create a model from your `Modelfile`:
```
ollama create example -f Modelfile
```
### Step 5: Run your model
Next, test the model with `ollama run`:
```
ollama run example "What is your favourite condiment?"
```
## Publishing your model (optional early alpha)
Publishing models is in early alpha. If you'd like to publish your model to share with others, follow these steps:
1. Create [an account](https://ollama.com/signup)
2. Copy your Ollama public key:
- macOS: `cat ~/.ollama/id_ed25519.pub | pbcopy`
- Windows: `type %USERPROFILE%\.ollama\id_ed25519.pub`
- Linux: `cat /usr/share/ollama/.ollama/id_ed25519.pub`
3. Add your public key to your [Ollama account](https://ollama.com/settings/keys)
Next, copy your model to your username's namespace:
```
ollama cp example <your username>/example
```
> Note: model names may only contain lowercase letters, digits, and the characters `.`, `-`, and `_`.
Then push the model:
```
ollama push <your username>/example
```
After publishing, your model will be available at `https://ollama.com/<your username>/example`.
## Quantization reference
The quantization options are as follow (from highest highest to lowest levels of quantization). Note: some architectures such as Falcon do not support K quants.
- `q2_K`
- `q3_K`
- `q3_K_S`
- `q3_K_M`
- `q3_K_L`
- `q4_0` (recommended)
- `q4_1`
- `q4_K`
- `q4_K_S`
- `q4_K_M`
- `q5_0`
- `q5_1`
- `q5_K`
- `q5_K_S`
- `q5_K_M`
- `q6_K`
- `q8_0`
- `f16`
Defining a template in the Modelfile will disable this feature which may be useful if you want to use a different template than the autodetected one.

View File

@@ -100,6 +100,16 @@ sudo curl -L https://ollama.com/download/ollama-linux-amd64 -o /usr/bin/ollama
sudo chmod +x /usr/bin/ollama
```
## Installing specific versions
Use `OLLAMA_VERSION` environment variable with the install script to install a specific version of Ollama, including pre-releases. You can find the version numbers in the [releases page](https://github.com/ollama/ollama/releases).
For example:
```
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.1.32 sh
```
## Viewing logs
To view logs of Ollama running as a startup service, run:

View File

@@ -76,6 +76,7 @@ Make sure you've set up the container runtime first as described in [docker.md](
Sometimes the container runtime can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
- Is the container runtime working? Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU.
- Is the uvm driver not loaded? `sudo nvidia-modprobe -u`
- Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm`
- Try rebooting

View File

@@ -45,7 +45,7 @@ all_splits = text_splitter.split_documents(data)
```
It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install chromadb`
We also need to pull embedding model: `ollama pull nomic-embed-text`
```python
from langchain.embeddings import OllamaEmbeddings
from langchain.vectorstores import Chroma
@@ -68,7 +68,8 @@ The next thing is to send the question and the relevant parts of the docs to the
```python
from langchain.chains import RetrievalQA
qachain=RetrievalQA.from_chain_type(ollama, retriever=vectorstore.as_retriever())
qachain.invoke({"query": question})
res = qachain.invoke({"query": question})
print(res['result'])
```
The answer received from this chain was:

View File

@@ -1,8 +1,10 @@
package envconfig
import (
"errors"
"fmt"
"log/slog"
"net"
"os"
"path/filepath"
"runtime"
@@ -10,11 +12,27 @@ import (
"strings"
)
type OllamaHost struct {
Scheme string
Host string
Port string
}
func (o OllamaHost) String() string {
return fmt.Sprintf("%s://%s:%s", o.Scheme, o.Host, o.Port)
}
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
var (
// Set via OLLAMA_ORIGINS in the environment
AllowOrigins []string
// Set via OLLAMA_DEBUG in the environment
Debug bool
// Experimental flash attention
FlashAttention bool
// Set via OLLAMA_KEEP_ALIVE in the environment
KeepAlive string
// Set via OLLAMA_LLM_LIBRARY in the environment
LLMLibrary string
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
@@ -23,34 +41,54 @@ var (
MaxQueuedRequests int
// Set via OLLAMA_MAX_VRAM in the environment
MaxVRAM uint64
// Set via OLLAMA_NOHISTORY in the environment
NoHistory bool
// Set via OLLAMA_NOPRUNE in the environment
NoPrune bool
// Set via OLLAMA_NUM_PARALLEL in the environment
NumParallel int
// Set via OLLAMA_HOST in the environment
Host *OllamaHost
// Set via OLLAMA_RUNNERS_DIR in the environment
RunnersDir string
// Set via OLLAMA_TMPDIR in the environment
TmpDir string
// Experimental flash attention
FlashAttention bool
)
func AsMap() map[string]string {
return map[string]string{
"OLLAMA_ORIGINS": fmt.Sprintf("%v", AllowOrigins),
"OLLAMA_DEBUG": fmt.Sprintf("%v", Debug),
"OLLAMA_LLM_LIBRARY": fmt.Sprintf("%v", LLMLibrary),
"OLLAMA_MAX_LOADED_MODELS": fmt.Sprintf("%v", MaxRunners),
"OLLAMA_MAX_QUEUE": fmt.Sprintf("%v", MaxQueuedRequests),
"OLLAMA_MAX_VRAM": fmt.Sprintf("%v", MaxVRAM),
"OLLAMA_NOPRUNE": fmt.Sprintf("%v", NoPrune),
"OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel),
"OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir),
"OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir),
"OLLAMA_FLASH_ATTENTION": fmt.Sprintf("%v", FlashAttention),
type EnvVar struct {
Name string
Value any
Description string
}
func AsMap() map[string]EnvVar {
return map[string]EnvVar{
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug, "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"},
"OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", "", "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
}
}
func Values() map[string]string {
vals := make(map[string]string)
for k, v := range AsMap() {
vals[k] = fmt.Sprintf("%v", v.Value)
}
return vals
}
var defaultAllowOrigins = []string{
"localhost",
"127.0.0.1",
@@ -104,7 +142,7 @@ func LoadConfig() {
var paths []string
for _, root := range []string{filepath.Dir(appExe), cwd} {
paths = append(paths,
filepath.Join(root),
root,
filepath.Join(root, "windows-"+runtime.GOARCH),
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
)
@@ -147,6 +185,10 @@ func LoadConfig() {
}
}
if nohistory := clean("OLLAMA_NOHISTORY"); nohistory != "" {
NoHistory = true
}
if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" {
NoPrune = true
}
@@ -158,11 +200,17 @@ func LoadConfig() {
AllowOrigins = append(AllowOrigins,
fmt.Sprintf("http://%s", allowOrigin),
fmt.Sprintf("https://%s", allowOrigin),
fmt.Sprintf("http://%s:*", allowOrigin),
fmt.Sprintf("https://%s:*", allowOrigin),
fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")),
fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")),
)
}
AllowOrigins = append(AllowOrigins,
"app://*",
"file://*",
"tauri://*",
)
maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
if maxRunners != "" {
m, err := strconv.Atoi(maxRunners)
@@ -181,4 +229,56 @@ func LoadConfig() {
MaxQueuedRequests = p
}
}
KeepAlive = clean("OLLAMA_KEEP_ALIVE")
var err error
Host, err = getOllamaHost()
if err != nil {
slog.Error("invalid setting", "OLLAMA_HOST", Host, "error", err, "using default port", Host.Port)
}
}
func getOllamaHost() (*OllamaHost, error) {
defaultPort := "11434"
hostVar := os.Getenv("OLLAMA_HOST")
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
scheme, hostport, ok := strings.Cut(hostVar, "://")
switch {
case !ok:
scheme, hostport = "http", hostVar
case scheme == "http":
defaultPort = "80"
case scheme == "https":
defaultPort = "443"
}
// trim trailing slashes
hostport = strings.TrimRight(hostport, "/")
host, port, err := net.SplitHostPort(hostport)
if err != nil {
host, port = "127.0.0.1", defaultPort
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
host = ip.String()
} else if hostport != "" {
host = hostport
}
}
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
return &OllamaHost{
Scheme: scheme,
Host: host,
Port: defaultPort,
}, ErrInvalidHostPort
}
return &OllamaHost{
Scheme: scheme,
Host: host,
Port: port,
}, nil
}

71
envconfig/config_test.go Normal file
View File

@@ -0,0 +1,71 @@
package envconfig
import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConfig(t *testing.T) {
Debug = false // Reset whatever was loaded in init()
t.Setenv("OLLAMA_DEBUG", "")
LoadConfig()
require.False(t, Debug)
t.Setenv("OLLAMA_DEBUG", "false")
LoadConfig()
require.False(t, Debug)
t.Setenv("OLLAMA_DEBUG", "1")
LoadConfig()
require.True(t, Debug)
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
LoadConfig()
require.True(t, FlashAttention)
}
func TestClientFromEnvironment(t *testing.T) {
type testCase struct {
value string
expect string
err error
}
hostTestCases := map[string]*testCase{
"empty": {value: "", expect: "127.0.0.1:11434"},
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
"only port": {value: ":1234", expect: ":1234"},
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
"hostname": {value: "example.com", expect: "example.com:11434"},
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
"zero port": {value: ":0", expect: ":0"},
"too large port": {value: ":66000", err: ErrInvalidHostPort},
"too small port": {value: ":-1", err: ErrInvalidHostPort},
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
}
for k, v := range hostTestCases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value)
LoadConfig()
oh, err := getOllamaHost()
if err != v.err {
t.Fatalf("expected %s, got %s", v.err, err)
}
if err == nil {
host := net.JoinHostPort(oh.Host, oh.Port)
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
}
})
}
}

View File

@@ -77,13 +77,21 @@ LOADER_MAPPING = {
def load_single_document(file_path: str) -> List[Document]:
ext = "." + file_path.rsplit(".", 1)[-1]
if os.path.getsize(file_path) != 0:
filename, ext = os.path.splitext(file_path)
if ext in LOADER_MAPPING:
loader_class, loader_args = LOADER_MAPPING[ext]
try:
loader = loader_class(file_path, **loader_args)
if loader:
return loader.load()
except:
print(f"Corrupted file {file_path}. Ignoring it.")
else:
print(f"Unsupported file {file_path}. Ignoring it.")
else:
print(f"Empty file {file_path}. Ignoring it.")
raise ValueError(f"Unsupported file extension '{ext}'")
def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Document]:
"""
@@ -100,6 +108,7 @@ def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Docum
results = []
with tqdm(total=len(filtered_files), desc='Loading new documents', ncols=80) as pbar:
for i, docs in enumerate(pool.imap_unordered(load_single_document, filtered_files)):
if docs:
results.extend(docs)
pbar.update()

View File

@@ -12,3 +12,4 @@ pandoc==2.3
pypandoc==1.11
tqdm==4.66.1
sentence_transformers==2.2.2
numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability

View File

@@ -9,6 +9,7 @@ def chat(messages):
r = requests.post(
"http://0.0.0.0:11434/api/chat",
json={"model": model, "messages": messages, "stream": True},
stream=True
)
r.raise_for_status()
output = ""

View File

@@ -5,7 +5,6 @@ import (
)
func TestHumanNumber(t *testing.T) {
type testCase struct {
input uint64
expected string

1
go.mod
View File

@@ -16,6 +16,7 @@ require (
)
require (
github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0

6
go.sum
View File

@@ -4,10 +4,14 @@ dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7
gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8=
github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo=
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6ICHXqG5hm0ZW5IHyeEJXoIJSOZeBLmWPNeIQ=
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
@@ -36,6 +40,8 @@ github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLc
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g=
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=

View File

@@ -7,7 +7,7 @@ import (
"os"
"path/filepath"
"slices"
"strings"
// "strings"
"github.com/ollama/ollama/format"
)
@@ -65,7 +65,7 @@ func AMDGetGPUInfo() []GpuInfo {
slog.Debug("detected hip devices", "count", count)
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
for i := 0; i < count; i++ {
for i := range count {
err = hl.HipSetDevice(i)
if err != nil {
slog.Warn("set device", "id", i, "error", err)
@@ -108,10 +108,10 @@ func AMDGetGPUInfo() []GpuInfo {
}
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
//if totalMemory < IGPUMemLimit {
// slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
// continue
//}
if totalMemory < IGPUMemLimit {
slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
continue
}
// TODO revisit this once ROCm v6 is available on windows.
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable

View File

@@ -13,7 +13,7 @@ import (
"syscall"
"time"
"github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/envconfig"
)
var (
@@ -80,7 +80,7 @@ func cleanupTmpDirs() {
if err == nil {
pid, err := strconv.Atoi(string(raw))
if err == nil {
if proc, err := os.FindProcess(int(pid)); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
if proc, err := os.FindProcess(pid); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
// Another running ollama, ignore this tmpdir
continue
}

View File

@@ -18,5 +18,4 @@ func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids = append(ids, info.ID)
}
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
}

View File

@@ -20,14 +20,15 @@ import (
"sync"
"unsafe"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/server/envconfig"
)
type handles struct {
deviceCount int
cudart *C.cudart_handle_t
nvcuda *C.nvcuda_handle_t
oneapi *C.oneapi_handle_t
}
const (
@@ -80,6 +81,15 @@ var NvcudaWindowsGlobs = []string{
"c:\\windows\\system*\\nvcuda.dll",
}
var OneapiWindowsGlobs = []string{
"c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
}
var OneapiLinuxGlobs = []string{
"/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*",
"/usr/lib*/libze_intel_gpu.so*",
}
// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
var CudaTegra string = os.Getenv("JETSON_JETPACK")
@@ -141,6 +151,7 @@ func initGPUHandles() *handles {
return gpuHandles
}
}
return gpuHandles
}
@@ -176,11 +187,12 @@ func GetGPUInfo() GpuInfoList {
resp := []GpuInfo{}
// NVIDIA first
for i := 0; i < gpuHandles.deviceCount; i++ {
for i := range gpuHandles.deviceCount {
// TODO once we support CPU compilation variants of GPU libraries refine this...
if cpuVariant == "" && runtime.GOARCH == "amd64" {
continue
}
if gpuHandles.cudart != nil || gpuHandles.nvcuda != nil {
gpuInfo := GpuInfo{
Library: "cuda",
}
@@ -209,12 +221,13 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.MinimumMemory = cudaMinimumMemory
gpuInfo.DependencyPath = depPath
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
gpuInfo.DriverMajor = int(driverMajor)
gpuInfo.DriverMinor = int(driverMinor)
gpuInfo.DriverMajor = driverMajor
gpuInfo.DriverMinor = driverMinor
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
resp = append(resp, gpuInfo)
}
}
// Then AMD
resp = append(resp, AMDGetGPUInfo()...)
@@ -348,6 +361,23 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
return 0, nil, ""
}
func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
var resp C.oneapi_init_resp_t
resp.oh.verbose = getVerboseState()
for _, libPath := range oneapiLibPaths {
lib := C.CString(libPath)
defer C.free(unsafe.Pointer(lib))
C.oneapi_init(lib, &resp)
if resp.err != nil {
slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err))
} else {
return int(resp.num_devices), &resp.oh, libPath
}
}
return 0, nil, ""
}
func getVerboseState() C.uint16_t {
if envconfig.Debug {
return C.uint16_t(1)
@@ -368,6 +398,8 @@ func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
return cudaGetVisibleDevicesEnv(l)
case "rocm":
return rocmGetVisibleDevicesEnv(l)
case "oneapi":
return oneapiGetVisibleDevicesEnv(l)
default:
slog.Debug("no filter required for library " + l[0].Library)
return "", ""

View File

@@ -62,6 +62,7 @@ void cpu_check_ram(mem_info_t *resp);
#include "gpu_info_cudart.h"
#include "gpu_info_nvcuda.h"
#include "gpu_info_oneapi.h"
#endif // __GPU_INFO_H__
#endif // __APPLE__

214
gpu/gpu_info_oneapi.c Normal file
View File

@@ -0,0 +1,214 @@
#ifndef __APPLE__
#include "gpu_info_oneapi.h"
#include <string.h>
void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
{
ze_result_t ret;
resp->err = NULL;
const int buflen = 256;
char buf[buflen + 1];
int i;
struct lookup
{
char *s;
void **p;
} l[] = {
{"zesInit", (void *)&resp->oh.zesInit},
{"zesDriverGet", (void *)&resp->oh.zesDriverGet},
{"zesDeviceGet", (void *)&resp->oh.zesDeviceGet},
{"zesDeviceGetProperties", (void *)&resp->oh.zesDeviceGetProperties},
{"zesDeviceEnumMemoryModules",
(void *)&resp->oh.zesDeviceEnumMemoryModules},
{"zesMemoryGetProperties", (void *)&resp->oh.zesMemoryGetProperties},
{"zesMemoryGetState", (void *)&resp->oh.zesMemoryGetState},
{NULL, NULL},
};
resp->oh.handle = LOAD_LIBRARY(oneapi_lib_path, RTLD_LAZY);
if (!resp->oh.handle)
{
char *msg = LOAD_ERR();
snprintf(buf, buflen,
"Unable to load %s library to query for Intel GPUs: %s\n",
oneapi_lib_path, msg);
free(msg);
resp->err = strdup(buf);
return;
}
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->oh.verbose,
"wiring Level-Zero management library functions in %s\n",
oneapi_lib_path);
for (i = 0; l[i].s != NULL; i++)
{
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s);
if (!l[i].p)
{
resp->oh.handle = NULL;
char *msg = LOAD_ERR();
LOG(resp->oh.verbose, "dlerr: %s\n", msg);
UNLOAD_LIBRARY(resp->oh.handle);
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, msg);
free(msg);
resp->err = strdup(buf);
return;
}
}
ret = (*resp->oh.zesInit)(0);
if (ret != ZE_RESULT_SUCCESS)
{
LOG(resp->oh.verbose, "zesInit err: %d\n", ret);
UNLOAD_LIBRARY(resp->oh.handle);
resp->oh.handle = NULL;
snprintf(buf, buflen, "oneapi vram init failure: %d", ret);
resp->err = strdup(buf);
}
(*resp->oh.zesDriverGet)(&resp->num_devices, NULL);
return;
}
void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
{
ze_result_t ret;
resp->err = NULL;
uint64_t totalMem = 0;
uint64_t usedMem = 0;
const int buflen = 256;
char buf[buflen + 1];
int i, d, m;
if (h.handle == NULL)
{
resp->err = strdup("Level-Zero handle not initialized");
return;
}
uint32_t driversCount = 0;
ret = (*h.zesDriverGet)(&driversCount, NULL);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen, "unable to get driver count: %d", ret);
resp->err = strdup(buf);
return;
}
LOG(h.verbose, "discovered %d Level-Zero drivers\n", driversCount);
zes_driver_handle_t *allDrivers =
malloc(driversCount * sizeof(zes_driver_handle_t));
(*h.zesDriverGet)(&driversCount, allDrivers);
resp->total = 0;
resp->free = 0;
for (d = 0; d < driversCount; d++)
{
uint32_t deviceCount = 0;
ret = (*h.zesDeviceGet)(allDrivers[d], &deviceCount, NULL);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
free(allDrivers);
return;
}
LOG(h.verbose, "discovered %d Level-Zero devices\n", deviceCount);
zes_device_handle_t *devices =
malloc(deviceCount * sizeof(zes_device_handle_t));
(*h.zesDeviceGet)(allDrivers[d], &deviceCount, devices);
for (i = 0; i < deviceCount; i++)
{
zes_device_ext_properties_t ext_props;
ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
ext_props.pNext = NULL;
zes_device_properties_t props;
props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
props.pNext = &ext_props;
ret = (*h.zesDeviceGetProperties)(devices[i], &props);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen, "unable to get device properties: %d", ret);
resp->err = strdup(buf);
free(allDrivers);
free(devices);
return;
}
if (h.verbose)
{
// When in verbose mode, report more information about
// the card we discover.
LOG(h.verbose, "[%d] oneAPI device name: %s\n", i,
props.modelName);
LOG(h.verbose, "[%d] oneAPI brand: %s\n", i,
props.brandName);
LOG(h.verbose, "[%d] oneAPI vendor: %s\n", i,
props.vendorName);
LOG(h.verbose, "[%d] oneAPI S/N: %s\n", i,
props.serialNumber);
LOG(h.verbose, "[%d] oneAPI board number: %s\n", i,
props.boardNumber);
}
uint32_t memCount = 0;
ret = (*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, NULL);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen,
"unable to enumerate Level-Zero memory modules: %d", ret);
resp->err = strdup(buf);
free(allDrivers);
free(devices);
return;
}
LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
(*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, mems);
for (m = 0; m < memCount; m++)
{
zes_mem_state_t state;
state.stype = ZES_STRUCTURE_TYPE_MEM_STATE;
state.pNext = NULL;
ret = (*h.zesMemoryGetState)(mems[m], &state);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen, "unable to get memory state: %d", ret);
resp->err = strdup(buf);
free(allDrivers);
free(devices);
free(mems);
return;
}
resp->total += state.size;
resp->free += state.free;
}
free(mems);
}
free(devices);
}
free(allDrivers);
}
#endif // __APPLE__

211
gpu/gpu_info_oneapi.h Normal file
View File

@@ -0,0 +1,211 @@
#ifndef __APPLE__
#ifndef __GPU_INFO_ONEAPI_H__
#define __GPU_INFO_ONEAPI_H__
#include "gpu_info.h"
#define ZE_MAX_DEVICE_NAME 256
#define ZE_MAX_DEVICE_UUID_SIZE 16
#define ZES_STRING_PROPERTY_SIZE 64
#define ZE_BIT(_i) (1 << _i)
// Just enough typedef's to dlopen/dlsym for memory information
typedef enum ze_result_t
{
ZE_RESULT_SUCCESS = 0,
// Other values omitted for now...
} ze_result_t;
typedef uint8_t ze_bool_t;
typedef struct _zes_driver_handle_t *zes_driver_handle_t;
typedef struct _zes_device_handle_t *zes_device_handle_t;
typedef struct _zes_mem_handle_t *zes_mem_handle_t;
typedef enum _ze_structure_type_t
{
ZE_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
} ze_structure_type_t;
typedef enum _zes_structure_type_t
{
ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES = 0x1,
ZES_STRUCTURE_TYPE_MEM_PROPERTIES = 0xb,
ZES_STRUCTURE_TYPE_MEM_STATE = 0x1e,
ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES = 0x2d,
ZES_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
} zes_structure_type_t;
typedef enum _zes_mem_type_t
{
ZES_MEM_TYPE_FORCE_UINT32 = 0x7fffffff
} zes_mem_type_t;
typedef enum _zes_mem_loc_t
{
ZES_MEM_LOC_SYSTEM = 0,
ZES_MEM_LOC_DEVICE = 1,
ZES_MEM_LOC_FORCE_UINT32 = 0x7fffffff
} zes_mem_loc_t;
typedef enum _zes_mem_health_t
{
ZES_MEM_HEALTH_FORCE_UINT32 = 0x7fffffff
} zes_mem_health_t;
typedef struct _ze_device_uuid_t
{
uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
} ze_device_uuid_t;
typedef struct _zes_uuid_t
{
uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
} zes_uuid_t;
typedef enum _ze_device_type_t
{
ZE_DEVICE_TYPE_GPU = 1,
ZE_DEVICE_TYPE_CPU = 2,
ZE_DEVICE_TYPE_FPGA = 3,
ZE_DEVICE_TYPE_MCA = 4,
ZE_DEVICE_TYPE_VPU = 5,
ZE_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff
} ze_device_type_t;
typedef enum _zes_device_type_t
{
ZES_DEVICE_TYPE_GPU = 1,
ZES_DEVICE_TYPE_CPU = 2,
ZES_DEVICE_TYPE_FPGA = 3,
ZES_DEVICE_TYPE_MCA = 4,
ZES_DEVICE_TYPE_VPU = 5,
ZES_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff
} zes_device_type_t;
typedef uint32_t ze_device_property_flags_t;
typedef enum _ze_device_property_flag_t
{
ZE_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
ZE_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
ZE_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
ZE_DEVICE_PROPERTY_FLAG_ONDEMANDPAGING = ZE_BIT(3),
ZE_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff
} ze_device_property_flag_t;
typedef uint32_t zes_device_property_flags_t;
typedef enum _zes_device_property_flag_t
{
ZES_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
ZES_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
ZES_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
ZES_DEVICE_PROPERTY_FLAG_ONDEMANDPAGING = ZE_BIT(3),
ZES_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff
} zes_device_property_flag_t;
typedef struct _ze_device_properties_t
{
ze_structure_type_t stype;
void *pNext;
ze_device_type_t type;
uint32_t vendorId;
uint32_t deviceId;
ze_device_property_flags_t flags;
uint32_t subdeviceId;
uint32_t coreClockRate;
uint64_t maxMemAllocSize;
uint32_t maxHardwareContexts;
uint32_t maxCommandQueuePriority;
uint32_t numThreadsPerEU;
uint32_t physicalEUSimdWidth;
uint32_t numEUsPerSubslice;
uint32_t numSubslicesPerSlice;
uint32_t numSlices;
uint64_t timerResolution;
uint32_t timestampValidBits;
uint32_t kernelTimestampValidBits;
ze_device_uuid_t uuid;
char name[ZE_MAX_DEVICE_NAME];
} ze_device_properties_t;
typedef struct _zes_device_properties_t
{
zes_structure_type_t stype;
void *pNext;
ze_device_properties_t core;
uint32_t numSubdevices;
char serialNumber[ZES_STRING_PROPERTY_SIZE];
char boardNumber[ZES_STRING_PROPERTY_SIZE];
char brandName[ZES_STRING_PROPERTY_SIZE];
char modelName[ZES_STRING_PROPERTY_SIZE];
char vendorName[ZES_STRING_PROPERTY_SIZE];
char driverVersion[ZES_STRING_PROPERTY_SIZE];
} zes_device_properties_t;
typedef struct _zes_device_ext_properties_t
{
zes_structure_type_t stype;
void *pNext;
zes_uuid_t uuid;
zes_device_type_t type;
zes_device_property_flags_t flags;
} zes_device_ext_properties_t;
typedef struct _zes_mem_properties_t
{
zes_structure_type_t stype;
void *pNext;
zes_mem_type_t type;
ze_bool_t onSubdevice;
uint32_t subdeviceId;
zes_mem_loc_t location;
uint64_t physicalSize;
int32_t busWidth;
int32_t numChannels;
} zes_mem_properties_t;
typedef struct _zes_mem_state_t
{
zes_structure_type_t stype;
const void *pNext;
zes_mem_health_t health;
uint64_t free;
uint64_t size;
} zes_mem_state_t;
typedef struct oneapi_handle
{
void *handle;
uint16_t verbose;
ze_result_t (*zesInit)(int);
ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers);
ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount,
zes_device_handle_t *phDevices);
ze_result_t (*zesDeviceGetProperties)(zes_device_handle_t hDevice,
zes_device_properties_t *pProperties);
ze_result_t (*zesDeviceEnumMemoryModules)(zes_device_handle_t hDevice,
uint32_t *pCount,
zes_mem_handle_t *phMemory);
ze_result_t (*zesMemoryGetProperties)(zes_mem_handle_t hMemory,
zes_mem_properties_t *pProperties);
ze_result_t (*zesMemoryGetState)(zes_mem_handle_t hMemory,
zes_mem_state_t *pState);
} oneapi_handle_t;
typedef struct oneapi_init_resp
{
char *err; // If err is non-null handle is invalid
int num_devices;
oneapi_handle_t oh;
} oneapi_init_resp_t;
typedef struct oneapi_version_resp
{
ze_result_t status;
char *str; // Contains version or error string if status != 0
} oneapi_version_resp_t;
void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp);
void oneapi_check_vram(oneapi_handle_t rh, mem_info_t *resp);
#endif // __GPU_INFO_INTEL_H__
#endif // __APPLE__

21
gpu/gpu_oneapi.go Normal file
View File

@@ -0,0 +1,21 @@
//go:build linux || windows
package gpu
import (
"log/slog"
"strings"
)
func oneapiGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "oneapi" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("oneapiGetVisibleDevicesEnv skipping over non-sycl device", "library", info.Library)
continue
}
ids = append(ids, info.ID)
}
return "ONEAPI_DEVICE_SELECTOR", "level_zero:" + strings.Join(ids, ",")
}

View File

@@ -5,11 +5,12 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBasicGetGPUInfo(t *testing.T) {
info := GetGPUInfo()
assert.Greater(t, len(info), 0)
assert.NotEmpty(t, len(info))
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
if info[0].Library != "cpu" {
assert.Greater(t, info[0].TotalMemory, uint64(0))
@@ -19,7 +20,7 @@ func TestBasicGetGPUInfo(t *testing.T) {
func TestCPUMemInfo(t *testing.T) {
info, err := GetCPUMem()
assert.NoError(t, err)
require.NoError(t, err)
switch runtime.GOOS {
case "darwin":
t.Skip("CPU memory not populated on darwin")

View File

@@ -140,7 +140,6 @@ struct server_slot {
std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs;
bool infill = false;
bool embedding = false;
bool has_next_token = true;
bool truncated = false;
@@ -187,7 +186,6 @@ struct server_slot {
n_past = 0;
n_sent_text = 0;
n_sent_token_probs = 0;
infill = false;
ga_i = 0;
n_past_se = 0;
@@ -361,7 +359,6 @@ struct llama_server_context
// slots / clients
std::vector<server_slot> slots;
json default_generation_settings_for_props;
llama_server_queue queue_tasks;
llama_server_response queue_results;
@@ -485,9 +482,6 @@ struct llama_server_context
slots.push_back(slot);
}
default_generation_settings_for_props = get_formated_generation(slots.front());
default_generation_settings_for_props["seed"] = -1;
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
}
@@ -586,7 +580,7 @@ struct llama_server_context
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
slot->params.seed = json_value(data, "seed", default_params.seed);
slot->sparams.seed = json_value(data, "seed", default_params.seed);
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
@@ -600,16 +594,6 @@ struct llama_server_context
slot->params.n_predict = slot->n_predict;
}
// infill
if (data.count("input_prefix") != 0)
{
slot->params.input_prefix = data["input_prefix"];
}
else
{
slot->params.input_prefix = "";
}
if (data.count("input_suffix") != 0)
{
slot->params.input_suffix = data["input_suffix"];
@@ -823,7 +807,6 @@ struct llama_server_context
llama_sampling_free(slot->ctx_sampling);
}
slot->ctx_sampling = llama_sampling_init(slot->sparams);
llama_set_rng_seed(ctx, slot->params.seed);
slot->command = LOAD_PROMPT;
all_slots_are_idle = false;
@@ -847,7 +830,7 @@ struct llama_server_context
system_tokens.clear();
if (!system_prompt.empty()) {
system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
system_tokens = ::llama_tokenize(ctx, system_prompt, true);
llama_batch_clear(batch);
@@ -897,15 +880,6 @@ struct llama_server_context
system_need_update = true;
}
void system_prompt_process(const json &sys_props) {
system_prompt = sys_props.value("prompt", "");
name_user = sys_props.value("anti_prompt", "");
name_assistant = sys_props.value("assistant_name", "");
system_prompt_notify();
}
static size_t find_stopping_strings(const std::string &text, const size_t last_token_size,
const stop_type type, server_slot &slot)
{
@@ -1263,13 +1237,12 @@ struct llama_server_context
queue_results.send(res);
}
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
void request_completion(int task_id, json data, bool embedding, int multitask_id)
{
task_server task;
task.id = task_id;
task.target_id = 0;
task.data = std::move(data);
task.infill_mode = infill;
task.embedding_mode = embedding;
task.type = TASK_TYPE_COMPLETION;
task.multitask_id = multitask_id;
@@ -1415,8 +1388,8 @@ struct llama_server_context
json subtask_data = multiprompt_task.data;
subtask_data["prompt"] = subtask_data["prompt"][i];
// subtasks inherit everything else (infill mode, embedding mode, etc.)
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
// subtasks inherit everything else (embedding mode, etc.)
request_completion(subtask_ids[i], subtask_data, multiprompt_task.embedding_mode, multitask_id);
}
}
@@ -1434,26 +1407,8 @@ struct llama_server_context
break;
}
if (task.data.contains("system_prompt"))
{
if (!all_slots_are_idle) {
send_error(task, "system prompt can only be updated when all slots are idle");
break;
}
system_prompt_process(task.data["system_prompt"]);
// reset cache_tokens for all slots
for (server_slot &slot : slots)
{
slot.cache_tokens.clear();
slot.n_past = 0;
slot.n_past_se = 0;
}
}
slot->reset();
slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode;
slot->task_id = task.id;
slot->multitask_id = task.multitask_id;
@@ -1679,8 +1634,7 @@ struct llama_server_context
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty();
// empty prompt passed -> release the slot and send empty response
// note: infill mode allows empty prompt
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt && !slot.infill)
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt)
{
slot.release();
slot.print_timings();
@@ -1697,33 +1651,7 @@ struct llama_server_context
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_genereration = 0;
if (slot.infill)
{
bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1)
{
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
const int space_token = 29871; // TODO: this should not be hardcoded
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin());
}
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(model));
prompt_tokens = prefix_tokens;
}
else
{
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
}
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
slot.n_prompt_tokens = prompt_tokens.size();
@@ -2130,8 +2058,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf("\n");
}
static void server_params_parse(int argc, char **argv, server_params &sparams,
gpt_params &params, llama_server_context& llama)
static void server_params_parse(int argc, char **argv, server_params &sparams, gpt_params &params)
{
gpt_params default_params;
server_params default_sparams;
@@ -2546,27 +2473,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.n_predict = std::stoi(argv[i]);
}
else if (arg == "-spf" || arg == "--system-prompt-file")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
break;
}
std::string systm_content;
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(systm_content)
);
llama.system_prompt_process(json::parse(systm_content));
}
else if (arg == "-ctk" || arg == "--cache-type-k") {
params.cache_type_k = argv[++i];
}
@@ -2818,7 +2724,7 @@ int main(int argc, char **argv) {
// struct that contains llama context and inference
llama_server_context llama;
server_params_parse(argc, argv, sparams, params, llama);
server_params_parse(argc, argv, sparams, params);
if (params.model_alias == "unknown")
{
@@ -3150,7 +3056,7 @@ int main(int argc, char **argv) {
json data = json::parse(req.body);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);
llama.request_completion(task_id, data, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);
@@ -3272,7 +3178,7 @@ int main(int argc, char **argv) {
// create and queue the task
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);
// get the result
task_result result = llama.queue_results.recv(task_id);

View File

@@ -32,7 +32,7 @@ case "${GOARCH}" in
echo "Building static library"
build
if [ -z "$OLLAMA_SKIP_CPU_GENERATE" ]; then
#
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
#
@@ -68,6 +68,7 @@ case "${GOARCH}" in
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
fi
;;
"arm64")
@@ -79,6 +80,7 @@ case "${GOARCH}" in
echo "Building static library"
build
if [ -z "$OLLAMA_SKIP_METAL_GENERATE" ]; then
init_vars
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/metal"
@@ -86,6 +88,7 @@ case "${GOARCH}" in
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
fi
;;
*)
echo "GOARCH must be set"

View File

@@ -215,6 +215,36 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then
fi
if [ -z "${ONEAPI_ROOT}" ]; then
# Try the default location in case it exists
ONEAPI_ROOT=/opt/intel/oneapi
fi
if [ -z "${OLLAMA_SKIP_ONEAPI_GENERATE}" -a -d "${ONEAPI_ROOT}" ]; then
echo "OneAPI libraries detected - building dynamic OneAPI library"
init_vars
source ${ONEAPI_ROOT}/setvars.sh --force # set up environment variables for oneAPI
CC=icx
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_SYCL=ON -DLLAMA_SYCL_F16=OFF"
BUILD_DIR="../build/linux/${ARCH}/oneapi"
EXTRA_LIBS="-fsycl -Wl,-rpath,${ONEAPI_ROOT}/compiler/latest/lib,-rpath,${ONEAPI_ROOT}/mkl/latest/lib,-rpath,${ONEAPI_ROOT}/tbb/latest/lib,-rpath,${ONEAPI_ROOT}/compiler/latest/opt/oclfpga/linux64/lib -lOpenCL -lmkl_core -lmkl_sycl_blas -lmkl_intel_ilp64 -lmkl_tbb_thread -ltbb"
DEBUG_FLAGS="" # icx compiles with -O0 if we pass -g, so we must remove it
build
# copy oneAPI dependencies
for dep in $(ldd "${BUILD_DIR}/bin/ollama_llama_server" | grep "=>" | cut -f2 -d= | cut -f2 -d' ' | grep -e sycl -e mkl -e tbb); do
cp "${dep}" "${BUILD_DIR}/bin/"
done
cp "${ONEAPI_ROOT}/compiler/latest/lib/libOpenCL.so" "${BUILD_DIR}/bin/"
cp "${ONEAPI_ROOT}/compiler/latest/lib/libimf.so" "${BUILD_DIR}/bin/"
cp "${ONEAPI_ROOT}/compiler/latest/lib/libintlc.so.5" "${BUILD_DIR}/bin/"
cp "${ONEAPI_ROOT}/compiler/latest/lib/libirng.so" "${BUILD_DIR}/bin/"
cp "${ONEAPI_ROOT}/compiler/latest/lib/libpi_level_zero.so" "${BUILD_DIR}/bin/"
cp "${ONEAPI_ROOT}/compiler/latest/lib/libsvml.so" "${BUILD_DIR}/bin/"
cp "${ONEAPI_ROOT}/compiler/latest/lib/libur_loader.so.0" "${BUILD_DIR}/bin/"
compress
fi
if [ -z "${ROCM_PATH}" ]; then
# Try the default location in case it exists
ROCM_PATH=/opt/rocm

View File

@@ -12,6 +12,7 @@ function amdGPUs {
"gfx900"
"gfx902"
"gfx904"
"gfx90c"
"gfx906:xnack-"
"gfx908:xnack-"
"gfx90a:xnack+"
@@ -25,6 +26,7 @@ function amdGPUs {
"gfx1030"
"gfx1031"
"gfx1032"
"gfx1033"
"gfx1034"
"gfx1035"
"gfx1036"
@@ -299,6 +301,49 @@ function build_cuda() {
}
}
function build_oneapi() {
if ((-not "${env:OLLAMA_SKIP_ONEAPI_GENERATE}") -and ("${env:ONEAPI_ROOT}")) {
# Get oneAPI version
$script:ONEAPI_VERSION = icpx --version
$script:ONEAPI_VERSION = [regex]::Match($script:ONEAPI_VERSION, '(?<=oneAPI DPC\+\+/C\+\+ Compiler )(?<version>\d+\.\d+\.\d+)').Value
if ($null -ne $script:ONEAPI_VERSION) {
$script:ONEAPI_VARIANT = "_v" + $script:ONEAPI_VERSION
}
init_vars
$script:buildDir = "../build/windows/${script:ARCH}/oneapi$script:ONEAPI_VARIANT"
$script:distDir ="$script:DIST_BASE\oneapi$script:ONEAPI_VARIANT"
$script:cmakeDefs += @(
"-G", "MinGW Makefiles",
"-DLLAMA_SYCL=ON",
"-DCMAKE_C_COMPILER=icx",
"-DCMAKE_CXX_COMPILER=icx",
"-DCMAKE_BUILD_TYPE=Release"
)
Write-Host "Building oneAPI"
build
# Ninja doesn't prefix with config name
if ($null -ne $script:DUMPBIN) {
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | Select-String ".dll"
}
sign
install
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libirngmd.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libmmd.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_level_zero.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_unified_runtime.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_win_proxy_loader.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\svml_dispmd.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\sycl7.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_core.2.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_sycl_blas.4.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_tbb_thread.2.dll" "${script:distDir}"
} else {
Write-Host "Skipping oneAPI generation step"
}
}
function build_rocm() {
if ((-not "${env:OLLAMA_SKIP_ROCM_GENERATE}") -and ("${env:HIP_PATH}")) {
$script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
@@ -366,6 +411,7 @@ if ($($args.count) -eq 0) {
build_cpu_avx
build_cpu_avx2
build_cuda
build_oneapi
build_rocm
}

View File

@@ -81,6 +81,11 @@ func (kv KV) ContextLength() uint64 {
return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
}
func (kv KV) ChatTemplate() string {
s, _ := kv["tokenizer.chat_template"].(string)
return s
}
type Tensors []*Tensor
func (ts Tensors) Layers() map[string]Layer {
@@ -125,9 +130,9 @@ type Tensor struct {
func (t Tensor) blockSize() uint64 {
switch t.Kind {
case 0, 1, 24, 25, 26, 27, 28, 31: // F32, F16, I8, I16, I32, I64, F64, BF16
case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
return 1
case 2, 3, 8, 9, 20: // Q4_0, Q4_1, Q8_0, Q8_1, IQ4_NL
case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL
return 32
default: // All others
return 256

View File

@@ -592,8 +592,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err
}
dims := 0
for cnt := 0; cnt < len(tensor.Shape); cnt++ {
var dims int
for cnt := range len(tensor.Shape) {
if tensor.Shape[cnt] > 0 {
dims++
}
@@ -603,8 +603,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err
}
for i := 0; i < dims; i++ {
if err := binary.Write(ws, llm.ByteOrder, uint64(tensor.Shape[dims-1-i])); err != nil {
for i := range dims {
if err := binary.Write(ws, llm.ByteOrder, tensor.Shape[dims-1-i]); err != nil {
return err
}
}
@@ -618,22 +618,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
}
}
offset, err := ws.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
var alignment int64 = 32
padding := llm.padding(offset, alignment)
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
return err
}
for _, tensor := range tensors {
if _, err := tensor.WriteTo(ws); err != nil {
return err
}
offset, err := ws.Seek(0, io.SeekCurrent)
if err != nil {
return err
@@ -643,6 +629,10 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
return err
}
if _, err := tensor.WriteTo(ws); err != nil {
return err
}
}
return nil

View File

@@ -5,9 +5,9 @@ import (
"log/slog"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/server/envconfig"
)
// This algorithm looks for a complete fit to determine if we need to unload other models
@@ -103,7 +103,7 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
}
var layerCount int
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
for i := range int(ggml.KV().BlockCount()) {
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
memoryLayer := blk.size()

View File

@@ -1,35 +1,32 @@
From d02a06f3f45a09255ace8684a66590e06ce44605 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Thu, 23 May 2024 11:33:20 -0700
Subject: [PATCH] default pretokenizer on unrecognized type
---
llama.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/llama.cpp b/llama.cpp
index 15c66077..af1aede3 100644
index 40d2ec2c..74f3ee9c 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -4504,9 +4504,6 @@ static void llm_load_vocab(
LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
LLAMA_LOG_WARN("%s: \n", __func__);
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
- } else if (
- tokenizer_pre == "default") {
@@ -4642,16 +4642,7 @@ static void llm_load_vocab(
// for now, only BPE models have pre-tokenizers
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
- if (tokenizer_pre.empty()) {
- LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
- LLAMA_LOG_WARN("%s: \n", __func__);
- LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
- LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__);
- LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__);
- LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
- LLAMA_LOG_WARN("%s: \n", __func__);
- vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
- } else if (
+ if (
tokenizer_pre == "default") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if (
tokenizer_pre == "llama3" ||
tokenizer_pre == "llama-v3" ||
@@ -4553,7 +4550,7 @@ static void llm_load_vocab(
tokenizer_pre == "dbrx") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX;
@@ -4703,7 +4694,8 @@ static void llm_load_vocab(
tokenizer_pre == "smaug-bpe") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG;
} else {
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
}
} else {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
--
2.45.1

13
llm/patches/06-qwen2.diff Normal file
View File

@@ -0,0 +1,13 @@
diff --git a/llama.cpp b/llama.cpp
index 40d2ec2c..f34eb79a 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

View File

@@ -10,9 +10,9 @@ import (
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/gpu"

View File

@@ -24,9 +24,9 @@ import (
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/server/envconfig"
)
type LlamaServer interface {
@@ -85,7 +85,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
var systemMemory uint64
gpuCount := len(gpus)
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
cpuRunner = serverForCpu()
@@ -104,21 +103,22 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
var layers int
layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
switch {
case gpus[0].Library == "metal" && estimatedVRAM > systemMemory:
// disable partial offloading when model is greater than total system memory as this
// can lead to locking up the system
opts.NumGPU = 0
} else if gpus[0].Library != "metal" && layers == 0 {
case gpus[0].Library != "metal" && layers == 0:
// Don't bother loading into the GPU if no layers can fit
cpuRunner = serverForCpu()
gpuCount = 0
} else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
case opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu":
opts.NumGPU = layers
}
}
// Loop through potential servers
finalErr := fmt.Errorf("no suitable llama servers found")
finalErr := errors.New("no suitable llama servers found")
if len(adapters) > 1 {
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
@@ -189,35 +189,38 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--memory-f32")
}
if opts.UseMLock {
params = append(params, "--mlock")
flashAttnEnabled := envconfig.FlashAttention
for _, g := range gpus {
// only cuda (compute capability 7+) and metal support flash attention
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnEnabled = false
}
// mmap has issues with partial offloading on metal
if g.Library == "metal" &&
uint64(opts.NumGPU) > 0 &&
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
opts.UseMMap = false
}
}
if flashAttnEnabled {
params = append(params, "--flash-attn")
}
if !opts.UseMMap {
params = append(params, "--no-mmap")
}
if opts.UseMLock {
params = append(params, "--mlock")
}
if opts.UseNUMA {
params = append(params, "--numa")
}
flashAttnEnabled := envconfig.FlashAttention
// partial offloading does not support flash attention
if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
flashAttnEnabled = false
}
// only cuda (compute capability 7+) and metal support flash attention
for _, g := range gpus {
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnEnabled = false
}
}
if flashAttnEnabled {
params = append(params, "--flash-attn")
}
numParallel := envconfig.NumParallel
// TODO (jmorganca): multimodal models don't support parallel yet
@@ -229,7 +232,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
for i := 0; i < len(servers); i++ {
for i := range len(servers) {
dir := availableServers[servers[i]]
if dir == "" {
// Shouldn't happen
@@ -243,7 +246,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
gpuCount = 0
}
// Find an availableServers port, retry on each iterration in case the failure was a port conflict race
// Find an availableServers port, retry on each iteration in case the failure was a port conflict race
port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
var l *net.TCPListener
@@ -281,7 +284,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
server := filepath.Join(dir, "ollama_llama_server")
if runtime.GOOS == "windows" {
server = server + ".exe"
server += ".exe"
}
// Detect tmp cleaners wiping out the file
@@ -312,7 +315,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
s.cmd.Stdout = os.Stdout
s.cmd.Stderr = s.status
visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
visibleDevicesEnv, visibleDevicesEnvVal := gpus.GetVisibleDevicesEnv()
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add the path and visible devices variable with our adjusted version
@@ -456,7 +459,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return ServerStatusNotResponding, fmt.Errorf("server not responding")
return ServerStatusNotResponding, errors.New("server not responding")
}
return ServerStatusError, fmt.Errorf("health resp: %w", err)
}
@@ -519,16 +522,18 @@ func (s *llmServer) Ping(ctx context.Context) error {
func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
start := time.Now()
stallDuration := 60 * time.Second
stallTimer := time.Now().Add(stallDuration) // give up if we stall for
stallDuration := 5 * time.Minute // If no progress happens
finalLoadDuration := 5 * time.Minute // After we hit 100%, give the runner more time to come online
stallTimer := time.Now().Add(stallDuration) // give up if we stall
slog.Info("waiting for llama runner to start responding")
var lastStatus ServerStatus = -1
fullyLoaded := false
for {
select {
case <-ctx.Done():
slog.Info("context expired before server started")
slog.Warn("client connection closed before server finished loading, aborting load")
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
case err := <-s.done:
msg := ""
@@ -572,6 +577,10 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
if priorProgress != s.loadProgress {
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
stallTimer = time.Now().Add(stallDuration)
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
stallTimer = time.Now().Add(finalLoadDuration)
fullyLoaded = true
}
time.Sleep(time.Millisecond * 250)
continue
@@ -597,7 +606,7 @@ array ::=
string ::=
"\"" (
[^"\\] |
[^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
@@ -756,7 +765,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
var c completion
if err := json.Unmarshal(evt, &c); err != nil {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
}
switch {

View File

@@ -245,7 +245,6 @@ func (w *writer) writeResponse(data []byte) (int, error) {
d, err := json.Marshal(toChunk(w.id, chatResponse))
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")

View File

@@ -3,12 +3,15 @@ package parser
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"log/slog"
"strconv"
"strings"
"unicode"
"unicode/utf16"
"unicode/utf8"
)
type File struct {
@@ -69,33 +72,31 @@ func ParseFile(r io.Reader) (*File, error) {
var b bytes.Buffer
var role string
var lineCount int
var linePos int
var utf16 bool
var f File
br := bufio.NewReader(r)
for {
r, _, err := br.ReadRune()
if errors.Is(err, io.EOF) {
break
} else if err != nil {
var sc scannerDecoder = utf8ScannerDecoder{}
if bom, err := br.Peek(2); err != nil {
slog.Warn("error reading byte-order mark", "error", err)
} else if bytes.Equal(bom, []byte{0xFE, 0xFF}) {
sc = utf16ScannerDecoder{binary.LittleEndian}
//nolint:errcheck
br.Discard(2)
} else if bytes.Equal(bom, []byte{0xFF, 0xFE}) {
sc = utf16ScannerDecoder{binary.BigEndian}
//nolint:errcheck
br.Discard(2)
}
scanner := bufio.NewScanner(br)
scanner.Split(sc.ScanBytes)
for scanner.Scan() {
r, err := sc.DecodeRune(scanner.Bytes())
if err != nil {
return nil, err
}
// the utf16 byte order mark will be read as "unreadable" by ReadRune()
if isUnreadable(r) && lineCount == 0 && linePos == 0 {
utf16 = true
continue
}
// skip the second byte if we're reading utf16
if utf16 && r == 0 {
continue
}
next, r, err := parseRuneForState(r, curr)
if errors.Is(err, io.ErrUnexpectedEOF) {
return nil, fmt.Errorf("%w: %s", err, b.String())
@@ -103,13 +104,6 @@ func ParseFile(r io.Reader) (*File, error) {
return nil, err
}
if isNewline(r) {
lineCount++
linePos = 0
} else {
linePos++
}
// process the state transition, some transitions need to be intercepted and redirected
if next != curr {
switch curr {
@@ -309,10 +303,6 @@ func isNewline(r rune) bool {
return r == '\r' || r == '\n'
}
func isUnreadable(r rune) bool {
return r == unicode.ReplacementChar
}
func isValidMessageRole(role string) bool {
return role == "system" || role == "user" || role == "assistant"
}
@@ -325,3 +315,39 @@ func isValidCommand(cmd string) bool {
return false
}
}
type scannerDecoder interface {
ScanBytes(data []byte, atEOF bool) (advance int, token []byte, err error)
DecodeRune([]byte) (rune, error)
}
type utf8ScannerDecoder struct{}
func (utf8ScannerDecoder) ScanBytes(data []byte, atEOF bool) (advance int, token []byte, err error) {
return scanBytesN(data, 1, atEOF)
}
func (utf8ScannerDecoder) DecodeRune(data []byte) (rune, error) {
r, _ := utf8.DecodeRune(data)
return r, nil
}
type utf16ScannerDecoder struct {
binary.ByteOrder
}
func (utf16ScannerDecoder) ScanBytes(data []byte, atEOF bool) (advance int, token []byte, err error) {
return scanBytesN(data, 2, atEOF)
}
func (e utf16ScannerDecoder) DecodeRune(data []byte) (rune, error) {
return utf16.Decode([]uint16{e.ByteOrder.Uint16(data)})[0], nil
}
func scanBytesN(data []byte, n int, atEOF bool) (int, []byte, error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
return n, data[:n], nil
}

View File

@@ -10,6 +10,7 @@ import (
"unicode/utf16"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseFileFile(t *testing.T) {
@@ -25,7 +26,7 @@ TEMPLATE template1
reader := strings.NewReader(input)
modelfile, err := ParseFile(reader)
assert.NoError(t, err)
require.NoError(t, err)
expectedCommands := []Command{
{Name: "model", Args: "model1"},
@@ -88,7 +89,7 @@ func TestParseFileFrom(t *testing.T) {
for _, c := range cases {
t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c.input))
assert.ErrorIs(t, err, c.err)
require.ErrorIs(t, err, c.err)
if modelfile != nil {
assert.Equal(t, c.expected, modelfile.Commands)
}
@@ -105,7 +106,7 @@ PARAMETER param1
reader := strings.NewReader(input)
_, err := ParseFile(reader)
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
}
func TestParseFileBadCommand(t *testing.T) {
@@ -114,8 +115,7 @@ FROM foo
BADCOMMAND param1 value1
`
_, err := ParseFile(strings.NewReader(input))
assert.ErrorIs(t, err, errInvalidCommand)
require.ErrorIs(t, err, errInvalidCommand)
}
func TestParseFileMessages(t *testing.T) {
@@ -201,7 +201,7 @@ MESSAGE system`,
for _, c := range cases {
t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c.input))
assert.ErrorIs(t, err, c.err)
require.ErrorIs(t, err, c.err)
if modelfile != nil {
assert.Equal(t, c.expected, modelfile.Commands)
}
@@ -355,7 +355,7 @@ TEMPLATE """
for _, c := range cases {
t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c.multiline))
assert.ErrorIs(t, err, c.err)
require.ErrorIs(t, err, c.err)
if modelfile != nil {
assert.Equal(t, c.expected, modelfile.Commands)
}
@@ -413,7 +413,7 @@ func TestParseFileParameters(t *testing.T) {
fmt.Fprintln(&b, "FROM foo")
fmt.Fprintln(&b, "PARAMETER", k)
modelfile, err := ParseFile(&b)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, []Command{
{Name: "model", Args: "foo"},
@@ -442,7 +442,7 @@ FROM foo
for _, c := range cases {
t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c.input))
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, c.expected, modelfile.Commands)
})
}
@@ -501,15 +501,14 @@ SYSTEM ""
for _, c := range cases {
t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c))
assert.NoError(t, err)
require.NoError(t, err)
modelfile2, err := ParseFile(strings.NewReader(modelfile.String()))
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, modelfile, modelfile2)
})
}
}
func TestParseFileUTF16ParseFile(t *testing.T) {
@@ -522,10 +521,10 @@ SYSTEM You are a utf16 file.
utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...))
buf := new(bytes.Buffer)
err := binary.Write(buf, binary.LittleEndian, utf16File)
assert.NoError(t, err)
require.NoError(t, err)
actual, err := ParseFile(buf)
assert.NoError(t, err)
require.NoError(t, err)
expected := []Command{
{Name: "model", Args: "bob"},
@@ -539,9 +538,9 @@ SYSTEM You are a utf16 file.
// simulate a utf16 be file
buf = new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, utf16File)
assert.NoError(t, err)
require.NoError(t, err)
actual, err = ParseFile(buf)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, expected, actual.Commands)
}

View File

@@ -59,7 +59,7 @@ func (p *Progress) StopAndClear() bool {
stopped := p.stop()
if stopped {
// clear all progress lines
for i := 0; i < p.pos; i++ {
for i := range p.pos {
if i > 0 {
fmt.Fprint(p.w, "\033[A")
}
@@ -85,7 +85,7 @@ func (p *Progress) render() {
defer fmt.Fprint(p.w, "\033[?25h")
// clear already rendered progress lines
for i := 0; i < p.pos; i++ {
for i := range p.pos {
if i > 0 {
fmt.Fprint(p.w, "\033[A")
}

View File

@@ -5,12 +5,16 @@ import (
"os"
"github.com/emirpasic/gods/lists/arraylist"
"github.com/mattn/go-runewidth"
"golang.org/x/term"
)
type Buffer struct {
DisplayPos int
Pos int
Buf *arraylist.List
//LineHasSpace is an arraylist of bools to keep track of whether a line has a space at the end
LineHasSpace *arraylist.List
Prompt *Prompt
LineWidth int
Width int
@@ -27,8 +31,10 @@ func NewBuffer(prompt *Prompt) (*Buffer, error) {
lwidth := width - len(prompt.prompt())
b := &Buffer{
DisplayPos: 0,
Pos: 0,
Buf: arraylist.New(),
LineHasSpace: arraylist.New(),
Prompt: prompt,
Width: width,
Height: height,
@@ -38,14 +44,43 @@ func NewBuffer(prompt *Prompt) (*Buffer, error) {
return b, nil
}
func (b *Buffer) GetLineSpacing(line int) bool {
hasSpace, _ := b.LineHasSpace.Get(line)
if hasSpace == nil {
return false
}
return hasSpace.(bool)
}
func (b *Buffer) MoveLeft() {
if b.Pos > 0 {
if b.Pos%b.LineWidth == 0 {
//asserts that we retrieve a rune
if e, ok := b.Buf.Get(b.Pos - 1); ok {
if r, ok := e.(rune); ok {
rLength := runewidth.RuneWidth(r)
if b.DisplayPos%b.LineWidth == 0 {
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
} else {
if rLength == 2 {
fmt.Print(CursorLeft)
}
line := b.DisplayPos/b.LineWidth - 1
hasSpace := b.GetLineSpacing(line)
if hasSpace {
b.DisplayPos -= 1
fmt.Print(CursorLeft)
}
} else {
fmt.Print(cursorLeftN(rLength))
}
b.Pos -= 1
b.DisplayPos -= rLength
}
}
}
}
@@ -71,18 +106,32 @@ func (b *Buffer) MoveLeftWord() {
}
func (b *Buffer) MoveRight() {
if b.Pos < b.Size() {
if b.Pos < b.Buf.Size() {
if e, ok := b.Buf.Get(b.Pos); ok {
if r, ok := e.(rune); ok {
rLength := runewidth.RuneWidth(r)
b.Pos += 1
if b.Pos%b.LineWidth == 0 {
hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth)
b.DisplayPos += rLength
if b.DisplayPos%b.LineWidth == 0 {
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
} else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace {
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())+rLength))
b.DisplayPos += 1
} else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace {
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
b.DisplayPos += 1
} else {
fmt.Print(CursorRight)
fmt.Print(cursorRightN(rLength))
}
}
}
}
}
func (b *Buffer) MoveRightWord() {
if b.Pos < b.Size() {
if b.Pos < b.Buf.Size() {
for {
b.MoveRight()
v, _ := b.Buf.Get(b.Pos)
@@ -90,7 +139,7 @@ func (b *Buffer) MoveRightWord() {
break
}
if b.Pos == b.Size() {
if b.Pos == b.Buf.Size() {
break
}
}
@@ -99,89 +148,200 @@ func (b *Buffer) MoveRightWord() {
func (b *Buffer) MoveToStart() {
if b.Pos > 0 {
currLine := b.Pos / b.LineWidth
currLine := b.DisplayPos / b.LineWidth
if currLine > 0 {
for cnt := 0; cnt < currLine; cnt++ {
for range currLine {
fmt.Print(CursorUp)
}
}
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())))
b.Pos = 0
b.DisplayPos = 0
}
}
func (b *Buffer) MoveToEnd() {
if b.Pos < b.Size() {
currLine := b.Pos / b.LineWidth
totalLines := b.Size() / b.LineWidth
if b.Pos < b.Buf.Size() {
currLine := b.DisplayPos / b.LineWidth
totalLines := b.DisplaySize() / b.LineWidth
if currLine < totalLines {
for cnt := 0; cnt < totalLines-currLine; cnt++ {
for range totalLines - currLine {
fmt.Print(CursorDown)
}
remainder := b.Size() % b.LineWidth
remainder := b.DisplaySize() % b.LineWidth
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())+remainder))
} else {
fmt.Print(cursorRightN(b.Size() - b.Pos))
fmt.Print(cursorRightN(b.DisplaySize() - b.DisplayPos))
}
b.Pos = b.Size()
b.Pos = b.Buf.Size()
b.DisplayPos = b.DisplaySize()
}
}
func (b *Buffer) Size() int {
return b.Buf.Size()
func (b *Buffer) DisplaySize() int {
sum := 0
for i := range b.Buf.Size() {
if e, ok := b.Buf.Get(i); ok {
if r, ok := e.(rune); ok {
sum += runewidth.RuneWidth(r)
}
}
}
return sum
}
func (b *Buffer) Add(r rune) {
if b.Pos == b.Buf.Size() {
b.AddChar(r, false)
} else {
b.AddChar(r, true)
}
}
func (b *Buffer) AddChar(r rune, insert bool) {
rLength := runewidth.RuneWidth(r)
b.DisplayPos += rLength
if b.Pos > 0 {
if b.DisplayPos%b.LineWidth == 0 {
fmt.Printf("%c", r)
b.Buf.Add(r)
b.Pos += 1
if b.Pos > 0 && b.Pos%b.LineWidth == 0 {
fmt.Printf("\n%s", b.Prompt.AltPrompt)
if insert {
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth-1, false)
} else {
b.LineHasSpace.Add(false)
}
// this case occurs when a double-width rune crosses the line boundary
} else if b.DisplayPos%b.LineWidth < (b.DisplayPos-rLength)%b.LineWidth {
if insert {
fmt.Print(ClearToEOL)
}
fmt.Printf("\n%s", b.Prompt.AltPrompt)
b.DisplayPos += 1
fmt.Printf("%c", r)
if insert {
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth-1, true)
} else {
b.LineHasSpace.Add(true)
}
} else {
fmt.Printf("%c", r)
b.Buf.Insert(b.Pos, r)
b.Pos += 1
if b.Pos > 0 && b.Pos%b.LineWidth == 0 {
fmt.Printf("\n%s", b.Prompt.AltPrompt)
}
} else {
fmt.Printf("%c", r)
}
if insert {
b.Buf.Insert(b.Pos, r)
} else {
b.Buf.Add(r)
}
b.Pos += 1
if insert {
b.drawRemaining()
}
}
func (b *Buffer) countRemainingLineWidth(place int) int {
var sum int
counter := -1
var prevLen int
for place <= b.LineWidth {
counter += 1
sum += prevLen
if e, ok := b.Buf.Get(b.Pos + counter); ok {
if r, ok := e.(rune); ok {
place += runewidth.RuneWidth(r)
prevLen = len(string(r))
}
} else {
break
}
}
return sum
}
func (b *Buffer) drawRemaining() {
var place int
remainingText := b.StringN(b.Pos)
if b.Pos > 0 {
place = b.Pos % b.LineWidth
place = b.DisplayPos % b.LineWidth
}
fmt.Print(CursorHide)
// render the rest of the current line
currLine := remainingText[:min(b.LineWidth-place, len(remainingText))]
currLineLength := b.countRemainingLineWidth(place)
currLine := remainingText[:min(currLineLength, len(remainingText))]
currLineSpace := runewidth.StringWidth(currLine)
remLength := runewidth.StringWidth(remainingText)
if len(currLine) > 0 {
fmt.Printf(ClearToEOL + currLine)
fmt.Print(cursorLeftN(len(currLine)))
fmt.Print(cursorLeftN(currLineSpace))
} else {
fmt.Print(ClearToEOL)
}
if currLineSpace != b.LineWidth-place && currLineSpace != remLength {
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth, true)
} else if currLineSpace != b.LineWidth-place {
b.LineHasSpace.Remove(b.DisplayPos / b.LineWidth)
} else {
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth, false)
}
if (b.DisplayPos+currLineSpace)%b.LineWidth == 0 && currLine == remainingText {
fmt.Print(cursorRightN(currLineSpace))
fmt.Printf("\n%s", b.Prompt.AltPrompt)
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width-currLineSpace))
}
// render the other lines
if len(remainingText) > len(currLine) {
remaining := []rune(remainingText[len(currLine):])
if remLength > currLineSpace {
remaining := (remainingText[len(currLine):])
var totalLines int
for i, c := range remaining {
if i%b.LineWidth == 0 {
var displayLength int
var lineLength int = currLineSpace
for _, c := range remaining {
if displayLength == 0 || (displayLength+runewidth.RuneWidth(c))%b.LineWidth < displayLength%b.LineWidth {
fmt.Printf("\n%s", b.Prompt.AltPrompt)
totalLines += 1
if displayLength != 0 {
if lineLength == b.LineWidth {
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth+totalLines-1, false)
} else {
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth+totalLines-1, true)
}
}
lineLength = 0
}
displayLength += runewidth.RuneWidth(c)
lineLength += runewidth.RuneWidth(c)
fmt.Printf("%c", c)
}
fmt.Print(ClearToEOL)
fmt.Print(cursorUpN(totalLines))
fmt.Printf(CursorBOL + cursorRightN(b.Width-len(currLine)))
fmt.Printf(CursorBOL + cursorRightN(b.Width-currLineSpace))
hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth)
if hasSpace && b.DisplayPos%b.LineWidth != b.LineWidth-1 {
fmt.Print(CursorLeft)
}
}
fmt.Print(CursorShow)
@@ -189,46 +349,81 @@ func (b *Buffer) drawRemaining() {
func (b *Buffer) Remove() {
if b.Buf.Size() > 0 && b.Pos > 0 {
if b.Pos%b.LineWidth == 0 {
if e, ok := b.Buf.Get(b.Pos - 1); ok {
if r, ok := e.(rune); ok {
rLength := runewidth.RuneWidth(r)
hasSpace := b.GetLineSpacing(b.DisplayPos/b.LineWidth - 1)
if b.DisplayPos%b.LineWidth == 0 {
// if the user backspaces over the word boundary, do this magic to clear the line
// and move to the end of the previous line
fmt.Printf(CursorBOL + ClearToEOL)
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width) + " " + CursorLeft)
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
if b.DisplaySize()%b.LineWidth < (b.DisplaySize()-rLength)%b.LineWidth {
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
}
if hasSpace {
b.DisplayPos -= 1
fmt.Print(CursorLeft)
}
if rLength == 2 {
fmt.Print(CursorLeft + " " + cursorLeftN(2))
} else {
fmt.Printf(CursorLeft + " " + CursorLeft)
fmt.Print(" " + CursorLeft)
}
} else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace {
fmt.Printf(CursorBOL + ClearToEOL)
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
if b.Pos == b.Buf.Size() {
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
}
b.DisplayPos -= 1
} else {
fmt.Print(cursorLeftN(rLength))
for range rLength {
fmt.Print(" ")
}
fmt.Print(cursorLeftN(rLength))
}
var eraseExtraLine bool
if (b.Size()-1)%b.LineWidth == 0 {
if (b.DisplaySize()-1)%b.LineWidth == 0 || (rLength == 2 && ((b.DisplaySize()-2)%b.LineWidth == 0)) || b.DisplaySize()%b.LineWidth == 0 {
eraseExtraLine = true
}
b.Pos -= 1
b.DisplayPos -= rLength
b.Buf.Remove(b.Pos)
if b.Pos < b.Size() {
if b.Pos < b.Buf.Size() {
b.drawRemaining()
// this erases a line which is left over when backspacing in the middle of a line and there
// are trailing characters which go over the line width boundary
if eraseExtraLine {
remainingLines := (b.Size() - b.Pos) / b.LineWidth
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
fmt.Printf(cursorDownN(remainingLines+1) + CursorBOL + ClearToEOL)
place := b.Pos % b.LineWidth
place := b.DisplayPos % b.LineWidth
fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.prompt())))
}
}
}
}
}
}
func (b *Buffer) Delete() {
if b.Size() > 0 && b.Pos < b.Size() {
if b.Buf.Size() > 0 && b.Pos < b.Buf.Size() {
b.Buf.Remove(b.Pos)
b.drawRemaining()
if b.Size()%b.LineWidth == 0 {
if b.Pos != b.Size() {
remainingLines := (b.Size() - b.Pos) / b.LineWidth
if b.DisplaySize()%b.LineWidth == 0 {
if b.DisplayPos != b.DisplaySize() {
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
fmt.Printf(cursorDownN(remainingLines) + CursorBOL + ClearToEOL)
place := b.Pos % b.LineWidth
place := b.DisplayPos % b.LineWidth
fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.prompt())))
}
}
@@ -244,9 +439,9 @@ func (b *Buffer) DeleteBefore() {
}
func (b *Buffer) DeleteRemaining() {
if b.Size() > 0 && b.Pos < b.Size() {
charsToDel := b.Size() - b.Pos
for cnt := 0; cnt < charsToDel; cnt++ {
if b.DisplaySize() > 0 && b.Pos < b.DisplaySize() {
charsToDel := b.Buf.Size() - b.Pos
for range charsToDel {
b.Delete()
}
}
@@ -281,14 +476,16 @@ func (b *Buffer) ClearScreen() {
ph := b.Prompt.placeholder()
fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault)
} else {
currPos := b.Pos
currPos := b.DisplayPos
currIndex := b.Pos
b.Pos = 0
b.DisplayPos = 0
b.drawRemaining()
fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.prompt())))
if currPos > 0 {
targetLine := currPos / b.LineWidth
if targetLine > 0 {
for cnt := 0; cnt < targetLine; cnt++ {
for range targetLine {
fmt.Print(CursorDown)
}
}
@@ -300,7 +497,8 @@ func (b *Buffer) ClearScreen() {
fmt.Printf(CursorBOL + b.Prompt.AltPrompt)
}
}
b.Pos = currPos
b.Pos = currIndex
b.DisplayPos = currPos
}
}
@@ -309,9 +507,20 @@ func (b *Buffer) IsEmpty() bool {
}
func (b *Buffer) Replace(r []rune) {
b.DisplayPos = 0
b.Pos = 0
lineNums := b.DisplaySize() / b.LineWidth
b.Buf.Clear()
fmt.Printf(ClearLine + CursorBOL + b.Prompt.prompt())
fmt.Printf(CursorBOL + ClearToEOL)
for range lineNums {
fmt.Print(CursorUp + CursorBOL + ClearToEOL)
}
fmt.Printf(CursorBOL + b.Prompt.prompt())
for _, c := range r {
b.Add(c)
}
@@ -328,7 +537,7 @@ func (b *Buffer) StringN(n int) string {
func (b *Buffer) StringNM(n, m int) string {
var s string
if m == 0 {
m = b.Size()
m = b.Buf.Size()
}
for cnt := n; cnt < m; cnt++ {
c, _ := b.Buf.Get(cnt)

View File

@@ -91,7 +91,7 @@ func (h *History) Add(l []rune) {
func (h *History) Compact() {
s := h.Buf.Size()
if s > h.Limit {
for cnt := 0; cnt < s-h.Limit; cnt++ {
for range s - h.Limit {
h.Buf.Remove(0)
}
}
@@ -139,7 +139,7 @@ func (h *History) Save() error {
defer f.Close()
buf := bufio.NewWriter(f)
for cnt := 0; cnt < h.Size(); cnt++ {
for cnt := range h.Size() {
v, _ := h.Buf.Get(cnt)
line, _ := v.([]rune)
if _, err := buf.WriteString(string(line) + "\n"); err != nil {

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"io"
"os"
"syscall"
)
type Prompt struct {
@@ -63,7 +62,7 @@ func New(prompt Prompt) (*Instance, error) {
func (i *Instance) Readline() (string, error) {
if !i.Terminal.rawmode {
fd := int(syscall.Stdin)
fd := os.Stdin.Fd()
termios, err := SetRawMode(fd)
if err != nil {
return "", err
@@ -80,8 +79,8 @@ func (i *Instance) Readline() (string, error) {
fmt.Print(prompt)
defer func() {
fd := int(syscall.Stdin)
// nolint: errcheck
fd := os.Stdin.Fd()
//nolint:errcheck
UnsetRawMode(fd, i.Terminal.termios)
i.Terminal.rawmode = false
}()
@@ -136,7 +135,7 @@ func (i *Instance) Readline() (string, error) {
buf.MoveRight()
case CharBracketedPaste:
var code string
for cnt := 0; cnt < 3; cnt++ {
for range 3 {
r, err = i.Terminal.Read()
if err != nil {
return "", io.EOF
@@ -150,7 +149,7 @@ func (i *Instance) Readline() (string, error) {
i.Pasting = false
}
case KeyDel:
if buf.Size() > 0 {
if buf.DisplaySize() > 0 {
buf.Delete()
}
metaDel = true
@@ -198,11 +197,11 @@ func (i *Instance) Readline() (string, error) {
buf.Remove()
case CharTab:
// todo: convert back to real tabs
for cnt := 0; cnt < 8; cnt++ {
for range 8 {
buf.Add(' ')
}
case CharDelete:
if buf.Size() > 0 {
if buf.DisplaySize() > 0 {
buf.Delete()
} else {
return "", io.EOF
@@ -216,7 +215,7 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlW:
buf.DeleteWord()
case CharCtrlZ:
fd := int(syscall.Stdin)
fd := os.Stdin.Fd()
return handleCharCtrlZ(fd, i.Terminal.termios)
case CharEnter, CharCtrlJ:
output := buf.String()
@@ -248,7 +247,7 @@ func (i *Instance) HistoryDisable() {
}
func NewTerminal() (*Terminal, error) {
fd := int(syscall.Stdin)
fd := os.Stdin.Fd()
termios, err := SetRawMode(fd)
if err != nil {
return nil, err

View File

@@ -6,7 +6,7 @@ import (
"syscall"
)
func handleCharCtrlZ(fd int, termios any) (string, error) {
func handleCharCtrlZ(fd uintptr, termios any) (string, error) {
t := termios.(*Termios)
if err := UnsetRawMode(fd, t); err != nil {
return "", err

View File

@@ -1,6 +1,6 @@
package readline
func handleCharCtrlZ(fd int, state any) (string, error) {
func handleCharCtrlZ(fd uintptr, state any) (string, error) {
// not supported
return "", nil
}

View File

@@ -8,7 +8,7 @@ import (
type Termios syscall.Termios
func SetRawMode(fd int) (*Termios, error) {
func SetRawMode(fd uintptr) (*Termios, error) {
termios, err := getTermios(fd)
if err != nil {
return nil, err
@@ -25,13 +25,13 @@ func SetRawMode(fd int) (*Termios, error) {
return termios, setTermios(fd, &newTermios)
}
func UnsetRawMode(fd int, termios any) error {
func UnsetRawMode(fd uintptr, termios any) error {
t := termios.(*Termios)
return setTermios(fd, t)
}
// IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
func IsTerminal(fd uintptr) bool {
_, err := getTermios(fd)
return err == nil
}

View File

@@ -7,17 +7,17 @@ import (
"unsafe"
)
func getTermios(fd int) (*Termios, error) {
func getTermios(fd uintptr) (*Termios, error) {
termios := new(Termios)
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
if err != 0 {
return nil, err
}
return termios, nil
}
func setTermios(fd int, termios *Termios) error {
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
func setTermios(fd uintptr, termios *Termios) error {
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
if err != 0 {
return err
}

View File

@@ -10,17 +10,17 @@ import (
const tcgets = 0x5401
const tcsets = 0x5402
func getTermios(fd int) (*Termios, error) {
func getTermios(fd uintptr) (*Termios, error) {
termios := new(Termios)
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
if err != 0 {
return nil, err
}
return termios, nil
}
func setTermios(fd int, termios *Termios) error {
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
func setTermios(fd uintptr, termios *Termios) error {
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
if err != 0 {
return err
}

View File

@@ -9,13 +9,13 @@ type State struct {
}
// IsTerminal checks if the given file descriptor is associated with a terminal
func IsTerminal(fd int) bool {
func IsTerminal(fd uintptr) bool {
var st uint32
err := windows.GetConsoleMode(windows.Handle(fd), &st)
return err == nil
}
func SetRawMode(fd int) (*State, error) {
func SetRawMode(fd uintptr) (*State, error) {
var st uint32
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err
@@ -32,7 +32,7 @@ func SetRawMode(fd int) (*State, error) {
return &State{st}, nil
}
func UnsetRawMode(fd int, state any) error {
func UnsetRawMode(fd uintptr, state any) error {
s := state.(*State)
return windows.SetConsoleMode(windows.Handle(fd), s.mode)
}

View File

@@ -33,9 +33,11 @@ case "$ARCH" in
*) error "Unsupported architecture: $ARCH" ;;
esac
IS_WSL2=false
KERN=$(uname -r)
case "$KERN" in
*icrosoft*WSL2 | *icrosoft*wsl2) ;;
*icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
*icrosoft) error "Microsoft WSL1 is not currently supported. Please upgrade to WSL2 with 'wsl --set-version <distro> 2'" ;;
*) ;;
esac
@@ -131,6 +133,17 @@ if available systemctl; then
configure_systemd
fi
# WSL2 only supports GPUs via nvidia passthrough
# so check for nvidia-smi to determine if GPU is available
if [ "$IS_WSL2" = true ]; then
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
status "Nvidia GPU detected."
fi
install_success
exit 0
fi
# Install GPU dependencies on Linux
if ! available lspci && ! available lshw; then
warning "Unable to detect NVIDIA/AMD GPU. Install lspci or lshw to automatically detect and install GPU dependencies."
exit 0
@@ -181,7 +194,7 @@ if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
curl --fail --show-error --location --progress-bar "https://ollama.com/download/ollama-linux-amd64-rocm.tgz${VER_PARAM}" \
| $SUDO tar zx --owner ollama --group ollama -C /usr/share/ollama/lib/rocm .
install_success
status "AMD GPU dependencies installed."
status "AMD GPU ready."
exit 0
fi
@@ -274,7 +287,7 @@ if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\
esac
fi
if ! lsmod | grep -q nvidia; then
if ! lsmod | grep -q nvidia || ! lsmod | grep -q nvidia_uvm; then
KERNEL_RELEASE="$(uname -r)"
case $OS_NAME in
rocky) $SUDO $PACKAGE_MANAGER -y install kernel-devel kernel-headers ;;
@@ -295,7 +308,19 @@ if ! lsmod | grep -q nvidia; then
fi
$SUDO modprobe nvidia
$SUDO modprobe nvidia_uvm
fi
# make sure the NVIDIA modules are loaded on boot with nvidia-persistenced
if command -v nvidia-persistenced > /dev/null 2>&1; then
$SUDO touch /etc/modules-load.d/nvidia.conf
MODULES="nvidia nvidia-uvm"
for MODULE in $MODULES; do
if ! grep -qxF "$MODULE" /etc/modules-load.d/nvidia.conf; then
echo "$MODULE" | sudo tee -a /etc/modules-load.d/nvidia.conf > /dev/null
fi
done
fi
status "NVIDIA CUDA drivers installed."
status "NVIDIA GPU ready."
install_success

View File

@@ -221,7 +221,7 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
}
defer resp.Body.Close()
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size)
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed)
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress
b.Completed.Add(-n)
@@ -340,17 +340,17 @@ type downloadOpts struct {
}
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error {
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
fp, err := GetBlobsPath(opts.digest)
if err != nil {
return err
return false, err
}
fi, err := os.Stat(fp)
switch {
case errors.Is(err, os.ErrNotExist):
case err != nil:
return err
return false, err
default:
opts.fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", opts.digest[7:19]),
@@ -359,7 +359,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
Completed: fi.Size(),
})
return nil
return true, nil
}
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
@@ -369,12 +369,12 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobDownloadManager.Delete(opts.digest)
return err
return false, err
}
// nolint: contextcheck
//nolint:contextcheck
go download.Run(context.Background(), requestURL, opts.regOpts)
}
return download.Wait(ctx, opts.fn)
return false, download.Wait(ctx, opts.fn)
}

View File

@@ -1,23 +0,0 @@
package envconfig
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestConfig(t *testing.T) {
Debug = false // Reset whatever was loaded in init()
t.Setenv("OLLAMA_DEBUG", "")
LoadConfig()
require.False(t, Debug)
t.Setenv("OLLAMA_DEBUG", "false")
LoadConfig()
require.False(t, Debug)
t.Setenv("OLLAMA_DEBUG", "1")
LoadConfig()
require.True(t, Debug)
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
LoadConfig()
require.True(t, FlashAttention)
}

View File

@@ -18,17 +18,16 @@ import (
"os"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
@@ -315,7 +314,7 @@ func realpath(rel, from string) string {
return abspath
}
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) {
func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) {
config := ConfigV2{
OS: "linux",
Architecture: "amd64",
@@ -333,7 +332,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
switch c.Name {
case "model", "adapter":
var baseLayers []*layerWithGGML
var baseLayers []*layerGGML
if name := model.ParseName(c.Args); name.IsValid() {
baseLayers, err = parseFromModel(ctx, name, fn)
if err != nil {
@@ -440,19 +439,27 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
layers = append(layers, baseLayer.Layer)
}
case "license", "template", "system":
if c.Name != "license" {
// replace
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
if layer.MediaType != mediatype {
return false
}
if err := layer.Remove(); err != nil {
return false
}
return true
})
}
blob := strings.NewReader(c.Args)
layer, err := NewLayer(blob, mediatype)
if err != nil {
return err
}
if c.Name != "license" {
// replace
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
return layer.MediaType == mediatype
})
}
layers = append(layers, layer)
case "message":
role, content, ok := strings.Cut(c.Args, ": ")
@@ -571,26 +578,15 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
}
}
unref := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range manifest.Layers {
if !slices.Contains(digests, layer.Digest) {
unref[layer.Digest] = struct{}{}
}
}
if manifest.Config.Digest != layer.Digest {
unref[manifest.Config.Digest] = struct{}{}
}
}
old, _ := ParseNamedManifest(name)
fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, layer, layers); err != nil {
return err
}
if !envconfig.NoPrune {
if err := deleteUnusedLayers(nil, unref); err != nil {
if !envconfig.NoPrune && old != nil {
if err := old.RemoveLayers(); err != nil {
return err
}
}
@@ -662,7 +658,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{})
// save (i.e. delete from the deleteMap) any files used in other manifests
manifest, _, err := GetManifest(fmp)
if err != nil {
// nolint: nilerr
//nolint:nilerr
return nil
}
@@ -771,37 +767,6 @@ func PruneDirectory(path string) error {
return nil
}
func DeleteModel(name string) error {
mp := ParseModelPath(name)
manifest, _, err := GetManifest(mp)
if err != nil {
return err
}
deleteMap := make(map[string]struct{})
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
}
deleteMap[manifest.Config.Digest] = struct{}{}
err = deleteUnusedLayers(&mp, deleteMap)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
err = os.Remove(fp)
if err != nil {
slog.Info(fmt.Sprintf("couldn't remove manifest file '%s': %v", fp, err))
return err
}
return nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
@@ -888,23 +853,27 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config)
skipVerify := make(map[string]bool)
for _, layer := range layers {
if err := downloadBlob(
ctx,
downloadOpts{
cacheHit, err := downloadBlob(ctx, downloadOpts{
mp: mp,
digest: layer.Digest,
regOpts: regOpts,
fn: fn,
}); err != nil {
})
if err != nil {
return err
}
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
if skipVerify[layer.Digest] {
continue
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
@@ -1019,7 +988,7 @@ func getTokenSubject(token string) string {
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
anonymous := true // access will default to anonymous if no user is found associated with the public key
for i := 0; i < 2; i++ {
for range 2 {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil {
if !errors.Is(err, context.Canceled) {

View File

@@ -88,3 +88,26 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
return os.Open(blob)
}
func (l *Layer) Remove() error {
ms, err := Manifests()
if err != nil {
return err
}
for _, m := range ms {
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest == l.Digest {
// something is using this layer
return nil
}
}
}
blob, err := GetBlobsPath(l.Digest)
if err != nil {
return err
}
return os.Remove(blob)
}

View File

@@ -1,11 +1,12 @@
package server
import (
"bytes"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
@@ -14,7 +15,10 @@ import (
type Manifest struct {
ManifestV2
Digest string `json:"-"`
filepath string
fi os.FileInfo
digest string
}
func (m *Manifest) Size() (size int64) {
@@ -25,9 +29,34 @@ func (m *Manifest) Size() (size int64) {
return
}
func ParseNamedManifest(name model.Name) (*Manifest, error) {
if !name.IsFullyQualified() {
return nil, model.Unqualified(name)
func (m *Manifest) Remove() error {
if err := os.Remove(m.filepath); err != nil {
return err
}
manifests, err := GetManifestPath()
if err != nil {
return err
}
return PruneDirectory(manifests)
}
func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
}
}
return nil
}
func ParseNamedManifest(n model.Name) (*Manifest, error) {
if !n.IsFullyQualified() {
return nil, model.Unqualified(n)
}
manifests, err := GetManifestPath()
@@ -35,45 +64,101 @@ func ParseNamedManifest(name model.Name) (*Manifest, error) {
return nil, err
}
var manifest ManifestV2
manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath()))
p := filepath.Join(manifests, n.Filepath())
var m ManifestV2
f, err := os.Open(p)
if err != nil {
return nil, err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return nil, err
}
sha256sum := sha256.New()
if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil {
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
return nil, err
}
return &Manifest{
ManifestV2: manifest,
Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
ManifestV2: m,
filepath: p,
fi: fi,
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
}, nil
}
func WriteManifest(name string, config *Layer, layers []*Layer) error {
manifest := ManifestV2{
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
manifests, err := GetManifestPath()
if err != nil {
return err
}
p := filepath.Join(manifests, name.Filepath())
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
return err
}
f, err := os.Create(p)
if err != nil {
return err
}
defer f.Close()
m := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config,
Layers: layers,
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(manifest); err != nil {
return err
}
modelpath := ParseModelPath(name)
manifestPath, err := modelpath.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
return err
}
return os.WriteFile(manifestPath, b.Bytes(), 0o644)
return json.NewEncoder(f).Encode(m)
}
func Manifests() (map[model.Name]*Manifest, error) {
manifests, err := GetManifestPath()
if err != nil {
return nil, err
}
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
if err != nil {
return nil, err
}
ms := make(map[model.Name]*Manifest)
for _, match := range matches {
fi, err := os.Stat(match)
if err != nil {
return nil, err
}
if !fi.IsDir() {
rel, err := filepath.Rel(manifests, match)
if err != nil {
slog.Warn("bad filepath", "path", match, "error", err)
continue
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest name", "path", rel, "error", err)
continue
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
continue
}
ms[n] = m
}
}
return ms, nil
}

150
server/manifest_test.go Normal file
View File

@@ -0,0 +1,150 @@
package server
import (
"encoding/json"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/types/model"
)
func createManifest(t *testing.T, path, name string) {
t.Helper()
p := filepath.Join(path, "manifests", name)
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
t.Fatal(err)
}
f, err := os.Create(p)
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
t.Fatal(err)
}
}
func TestManifests(t *testing.T) {
cases := map[string]struct {
ps []string
wantValidCount int
wantInvalidCount int
}{
"empty": {},
"single": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag"),
},
wantValidCount: 1,
},
"multiple": {
ps: []string{
filepath.Join("registry.ollama.ai", "library", "llama3", "latest"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"),
},
wantValidCount: 15,
},
"hidden": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag"),
filepath.Join("host", "namespace", "model", ".hidden"),
},
wantValidCount: 1,
wantInvalidCount: 1,
},
"subdir": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag", "one"),
filepath.Join("host", "namespace", "model", "tag", "another", "one"),
},
wantInvalidCount: 2,
},
"upper tag": {
ps: []string{
filepath.Join("host", "namespace", "model", "TAG"),
},
wantValidCount: 1,
},
"upper model": {
ps: []string{
filepath.Join("host", "namespace", "MODEL", "tag"),
},
wantValidCount: 1,
},
"upper namespace": {
ps: []string{
filepath.Join("host", "NAMESPACE", "model", "tag"),
},
wantValidCount: 1,
},
"upper host": {
ps: []string{
filepath.Join("HOST", "namespace", "model", "tag"),
},
wantValidCount: 1,
},
}
for n, wants := range cases {
t.Run(n, func(t *testing.T) {
d := t.TempDir()
t.Setenv("OLLAMA_MODELS", d)
for _, p := range wants.ps {
createManifest(t, d, p)
}
ms, err := Manifests()
if err != nil {
t.Fatal(err)
}
var ns []model.Name
for k := range ms {
ns = append(ns, k)
}
var gotValidCount, gotInvalidCount int
for _, p := range wants.ps {
n := model.ParseNameFromFilepath(p)
if n.IsValid() {
gotValidCount++
} else {
gotInvalidCount++
}
if !n.IsValid() && slices.Contains(ns, n) {
t.Errorf("unexpected invalid name: %s", p)
} else if n.IsValid() && !slices.Contains(ns, n) {
t.Errorf("missing valid name: %s", p)
}
}
if gotValidCount != wants.wantValidCount {
t.Errorf("got valid count %d, want %d", gotValidCount, wants.wantValidCount)
}
if gotInvalidCount != wants.wantInvalidCount {
t.Errorf("got invalid count %d, want %d", gotInvalidCount, wants.wantInvalidCount)
}
})
}
}

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
@@ -14,27 +15,26 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/templates"
"github.com/ollama/ollama/types/model"
)
var intermediateBlobs map[string]string = make(map[string]string)
type layerWithGGML struct {
type layerGGML struct {
*Layer
*llm.GGML
}
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
modelpath := ParseModelPath(name.String())
manifest, _, err := GetManifest(modelpath)
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
m, err := ParseNamedManifest(name)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err
}
modelpath = ParseModelPath(name.String())
manifest, _, err = GetManifest(modelpath)
m, err = ParseNamedManifest(name)
if err != nil {
return nil, err
}
@@ -42,8 +42,8 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
return nil, err
}
for _, layer := range manifest.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
for _, layer := range m.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
if err != nil {
return nil, err
}
@@ -68,17 +68,16 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
return nil, err
}
layers = append(layers, &layerWithGGML{layer, ggml})
layers = append(layers, &layerGGML{layer, ggml})
default:
layers = append(layers, &layerWithGGML{layer, nil})
layers = append(layers, &layerGGML{layer, nil})
}
}
return layers, nil
}
func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
stat, err := file.Stat()
if err != nil {
return nil, err
@@ -182,13 +181,13 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a
return nil, err
}
layers = append(layers, &layerWithGGML{layer, ggml})
layers = append(layers, &layerGGML{layer, ggml})
intermediateBlobs[digest] = layer.Digest
return layers, nil
return detectChatTemplate(layers)
}
func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
sr := io.NewSectionReader(file, 0, 512)
contentType, err := detectContentType(sr)
if err != nil {
@@ -230,10 +229,30 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
return nil, err
}
layers = append(layers, &layerWithGGML{layer, ggml})
layers = append(layers, &layerGGML{layer, ggml})
offset = n
}
return detectChatTemplate(layers)
}
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
for _, layer := range layers {
if s := layer.GGML.KV().ChatTemplate(); s != "" {
if t, err := templates.NamedTemplate(s); err != nil {
slog.Debug("template detection", "error", err)
} else {
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
if err != nil {
return nil, err
}
tmpl.status = fmt.Sprintf("using autodetected template %s", t.Name)
layers = append(layers, &layerGGML{tmpl, nil})
}
}
}
return layers, nil
}

View File

@@ -6,12 +6,13 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetBlobsPath(t *testing.T) {
// GetBlobsPath expects an actual directory to exist
dir, err := os.MkdirTemp("", "ollama-test")
assert.Nil(t, err)
require.NoError(t, err)
defer os.RemoveAll(dir)
tests := []struct {
@@ -63,7 +64,7 @@ func TestGetBlobsPath(t *testing.T) {
got, err := GetBlobsPath(tc.digest)
assert.ErrorIs(t, tc.err, err, tc.name)
require.ErrorIs(t, tc.err, err, tc.name)
assert.Equal(t, tc.expected, got, tc.name)
})
}

View File

@@ -16,6 +16,7 @@ import (
"os"
"os/signal"
"path/filepath"
"slices"
"strconv"
"strings"
"syscall"
@@ -23,14 +24,13 @@ import (
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
@@ -77,7 +77,6 @@ func isSupportedImageType(image []byte) bool {
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
err := c.ShouldBindJSON(&req)
@@ -315,10 +314,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
func getDefaultSessionDuration() time.Duration {
if t, exists := os.LookupEnv("OLLAMA_KEEP_ALIVE"); exists {
v, err := strconv.Atoi(t)
if envconfig.KeepAlive != "" {
v, err := strconv.Atoi(envconfig.KeepAlive)
if err != nil {
d, err := time.ParseDuration(t)
d, err := time.ParseDuration(envconfig.KeepAlive)
if err != nil {
return defaultSessionDuration
}
@@ -421,13 +420,14 @@ func (s *Server) PullModelHandler(c *gin.Context) {
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
name := model.ParseName(cmp.Or(req.Model, req.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
return
}
if err := checkNameExists(name); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -445,7 +445,7 @@ func (s *Server) PullModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PullModel(ctx, model, regOpts, fn); err != nil {
if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -507,9 +507,24 @@ func (s *Server) PushModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
func checkNameExists(name model.Name) error {
names, err := Manifests()
if err != nil {
return err
}
for n := range names {
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
return fmt.Errorf("a model with that name already exists")
}
}
return nil
}
func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
var r api.CreateRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
@@ -517,30 +532,35 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
return
}
name := model.ParseName(cmp.Or(req.Model, req.Name))
name := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
return
}
if req.Path == "" && req.Modelfile == "" {
if err := checkNameExists(name); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if r.Path == "" && r.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return
}
var r io.Reader = strings.NewReader(req.Modelfile)
if req.Path != "" && req.Modelfile == "" {
f, err := os.Open(req.Path)
var sr io.Reader = strings.NewReader(r.Modelfile)
if r.Path != "" && r.Modelfile == "" {
f, err := os.Open(r.Path)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return
}
defer f.Close()
r = f
sr = f
}
modelfile, err := parser.ParseFile(r)
f, err := parser.ParseFile(sr)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -556,17 +576,13 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
quantization := req.Quantization
if req.Quantize != "" {
quantization = req.Quantize
}
if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(quantization), modelfile, fn); err != nil {
quantization := cmp.Or(r.Quantize, r.Quantization)
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
if r.Stream != nil && !*r.Stream {
waitForStream(c, ch)
return
}
@@ -575,48 +591,36 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
}
func (s *Server) DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
var r api.DeleteRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
return
}
if err := DeleteModel(model); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
manifestsPath, err := GetManifestPath()
m, err := ParseNamedManifest(n)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := PruneDirectory(manifestsPath); err != nil {
if err := m.Remove(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, nil)
if err := m.RemoveLayers(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
func (s *Server) ShowModelHandler(c *gin.Context) {
@@ -720,75 +724,45 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
func (s *Server) ListModelsHandler(c *gin.Context) {
manifests, err := GetManifestPath()
ms, err := Manifests()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
models := []api.ModelResponse{}
if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error {
if !info.IsDir() {
rel, err := filepath.Rel(manifests, path)
if err != nil {
return err
}
if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil {
return err
} else if hidden {
return nil
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest filepath", "path", rel)
return nil
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
return nil
}
models := []api.ListModelResponse{}
for n, m := range ms {
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest config filepath", "name", n, "error", err)
return nil
slog.Warn("bad manifest filepath", "name", n, "error", err)
continue
}
defer f.Close()
var c ConfigV2
if err := json.NewDecoder(f).Decode(&c); err != nil {
var cf ConfigV2
if err := json.NewDecoder(f).Decode(&cf); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
return nil
continue
}
// tag should never be masked
models = append(models, api.ModelResponse{
models = append(models, api.ListModelResponse{
Model: n.DisplayShortest(),
Name: n.DisplayShortest(),
Size: m.Size(),
Digest: m.Digest,
ModifiedAt: info.ModTime(),
Digest: m.digest,
ModifiedAt: m.fi.ModTime(),
Details: api.ModelDetails{
Format: c.ModelFormat,
Family: c.ModelFamily,
Families: c.ModelFamilies,
ParameterSize: c.ModelType,
QuantizationLevel: c.FileType,
Format: cf.ModelFormat,
Family: cf.ModelFamily,
Families: cf.ModelFamilies,
ParameterSize: cf.ModelType,
QuantizationLevel: cf.FileType,
},
})
}
return nil
}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
// most recently modified first
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
})
@@ -818,6 +792,11 @@ func (s *Server) CopyModelHandler(c *gin.Context) {
return
}
if err := checkNameExists(dst); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
} else if err != nil {
@@ -963,7 +942,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
}
if allowedHost(host) {
if c.Request.Method == "OPTIONS" {
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusNoContent)
return
}
@@ -981,6 +960,10 @@ func (s *Server) GenerateRoutes() http.Handler {
config.AllowWildcard = true
config.AllowBrowserExtensions = true
config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"}
for _, prop := range openAIProperties {
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
}
config.AllowOrigins = envconfig.AllowOrigins
r := gin.Default()
@@ -1025,7 +1008,7 @@ func Serve(ln net.Listener) error {
level = slog.LevelDebug
}
slog.Info("server config", "env", envconfig.AsMap())
slog.Info("server config", "env", envconfig.Values())
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
AddSource: true,
@@ -1160,7 +1143,7 @@ func streamResponse(c *gin.Context, ch chan any) {
}
func (s *Server) ProcessHandler(c *gin.Context) {
models := []api.ModelResponse{}
models := []api.ProcessModelResponse{}
for _, v := range s.sched.loaded {
model := v.model
@@ -1172,7 +1155,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
QuantizationLevel: model.Config.FileType,
}
mr := api.ModelResponse{
mr := api.ProcessModelResponse{
Model: model.ShortName,
Name: model.ShortName,
Size: int64(v.estimatedTotal),
@@ -1192,7 +1175,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
models = append(models, mr)
}
c.JSON(http.StatusOK, api.ListResponse{Models: models})
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
@@ -1327,7 +1310,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch)
fn := func(r llm.CompletionResponse) {
resp := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),

View File

@@ -0,0 +1,560 @@
package server
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"slices"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
var stream bool = false
func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "")
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := llm.NewGGUFV3(binary.LittleEndian).Encode(f, kv, ti); err != nil {
t.Fatal(err)
}
return f.Name()
}
type responseRecorder struct {
*httptest.ResponseRecorder
http.CloseNotifier
}
func NewRecorder() *responseRecorder {
return &responseRecorder{
ResponseRecorder: httptest.NewRecorder(),
}
}
func (t *responseRecorder) CloseNotify() <-chan bool {
return make(chan bool)
}
func createRequest(t *testing.T, fn func(*gin.Context), body any) *httptest.ResponseRecorder {
t.Helper()
w := NewRecorder()
c, _ := gin.CreateTestContext(w)
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(body); err != nil {
t.Fatal(err)
}
c.Request = &http.Request{
Body: io.NopCloser(&b),
}
fn(c)
return w.ResponseRecorder
}
func checkFileExists(t *testing.T, p string, expect []string) {
t.Helper()
actual, err := filepath.Glob(p)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(actual, expect) {
t.Fatalf("expected slices to be equal %v", actual)
}
}
func TestCreateFromBin(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
})
}
func TestCreateFromModel(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: "FROM test",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
})
}
func TestCreateRemovesLayers(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-b507b9c2f6ca642bffcd06665ea7c91f235fd32daeefdf875a0f938db05fb315"),
filepath.Join(p, "blobs", "sha256-bc80b03733773e0728011b2f4adf34c458b400e1aad48cb28d61170f3a2ad2d6"),
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
}
func TestCreateUnsetsSystem(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nSYSTEM Say hi!", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-8585df945d1069bc78b79bd10bb73ba07fbc29b0f5479a31a601c0d12731416e"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-f29e82a8284dbdf5910b1555580ff60b04238b8da9d5e51159ada67a4d0d5851"),
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nSYSTEM \"\"", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-67d4b8d106af2a5b100a46e9bdc038c71eef2a35c9abac784092654212f97cf5"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"),
})
bts, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"))
if err != nil {
t.Fatal(err)
}
if string(bts) != "" {
t.Fatalf("expected empty string, actual %s", string(bts))
}
}
func TestCreateMergeParameters(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nPARAMETER temperature 1\nPARAMETER top_k 10\nPARAMETER stop USER:\nPARAMETER stop ASSISTANT:", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
})
// in order to merge parameters, the second model must be created FROM the first
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
filepath.Join(p, "blobs", "sha256-4cd9d4ba6b734d9b4cbd1e5caa60374c00722e993fce5e1e2d15a33698f71187"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-e29a7b3c47287a2489c895d21fe413c20f859a85d20e749492f52a838e36e1ba"),
})
actual, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e29a7b3c47287a2489c895d21fe413c20f859a85d20e749492f52a838e36e1ba"))
if err != nil {
t.Fatal(err)
}
expect, err := json.Marshal(map[string]any{"temperature": 0.6, "top_k": 10, "top_p": 0.7, "stop": []string{"USER:", "ASSISTANT:"}})
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(bytes.TrimSpace(expect), bytes.TrimSpace(actual)) {
t.Errorf("expected %s, actual %s", string(expect), string(actual))
}
// slices are replaced
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7\nPARAMETER stop <|endoftext|>",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"),
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
filepath.Join(p, "blobs", "sha256-257aa726584f24970a4f240765e75a7169bfbe7f4966c1f04513d6b6c860583a"),
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
})
actual, err = os.ReadFile(filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"))
if err != nil {
t.Fatal(err)
}
expect, err = json.Marshal(map[string]any{"temperature": 0.6, "top_k": 10, "top_p": 0.7, "stop": []string{"<|endoftext|>"}})
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(bytes.TrimSpace(expect), bytes.TrimSpace(actual)) {
t.Errorf("expected %s, actual %s", string(expect), string(actual))
}
}
func TestCreateReplacesMessages(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nMESSAGE assistant \"What is my purpose?\"\nMESSAGE user \"You run tests.\"\nMESSAGE assistant \"Oh, my god.\"", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"),
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: "FROM test\nMESSAGE assistant \"You're a test, Harry.\"\nMESSAGE user \"I-I'm a what?\"\nMESSAGE assistant \"A test. And a thumping good one at that, I'd wager.\"",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
filepath.Join(p, "blobs", "sha256-4f48b25fe9969564c82f58eb1cedbdff6484cc0baf474bc6c2a9b37c8da3362a"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-a60ecc9da299ec7ede453f99236e5577fd125e143689b646d9f0ddc9971bf4db"),
filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"),
})
type message struct {
Role string `json:"role"`
Content string `json:"content"`
}
f, err := os.Open(filepath.Join(p, "blobs", "sha256-a60ecc9da299ec7ede453f99236e5577fd125e143689b646d9f0ddc9971bf4db"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
var actual []message
if err := json.NewDecoder(f).Decode(&actual); err != nil {
t.Fatal(err)
}
expect := []message{
{Role: "assistant", Content: "You're a test, Harry."},
{Role: "user", Content: "I-I'm a what?"},
{Role: "assistant", Content: "A test. And a thumping good one at that, I'd wager."},
}
if !slices.Equal(actual, expect) {
t.Errorf("expected %s, actual %s", expect, actual)
}
}
func TestCreateTemplateSystem(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}\nSYSTEM Say hello!\nTEMPLATE {{ .System }} {{ .Prompt }}\nSYSTEM Say bye!", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-2b5e330885117c82f3fd75169ea323e141070a2947c11ddb9f79ee0b01c589c1"),
filepath.Join(p, "blobs", "sha256-4c5f51faac758fecaff8db42f0b7382891a4d0c0bb885f7b86be88c814a7cc86"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
template, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"))
if err != nil {
t.Fatal(err)
}
if string(template) != "{{ .System }} {{ .Prompt }}" {
t.Errorf("expected \"{{ .System }} {{ .Prompt }}\", actual %s", template)
}
system, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-4c5f51faac758fecaff8db42f0b7382891a4d0c0bb885f7b86be88c814a7cc86"))
if err != nil {
t.Fatal(err)
}
if string(system) != "Say bye!" {
t.Errorf("expected \"Say bye!\", actual %s", system)
}
}
func TestCreateLicenses(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"),
filepath.Join(p, "blobs", "sha256-79a39c37536ddee29cbadd5d5e2dcba8ed7f03e431f626ff38432c1c866bb7e2"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-e5dcffe836b6ec8a58e492419b550e65fb8cbdc308503979e5dacb33ac7ea3b7"),
})
mit, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e5dcffe836b6ec8a58e492419b550e65fb8cbdc308503979e5dacb33ac7ea3b7"))
if err != nil {
t.Fatal(err)
}
if string(mit) != "MIT" {
t.Errorf("expected MIT, actual %s", mit)
}
apache, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"))
if err != nil {
t.Fatal(err)
}
if string(apache) != "Apache-2.0" {
t.Errorf("expected Apache-2.0, actual %s", apache)
}
}
func TestCreateDetectTemplate(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
t.Run("matched", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"tokenizer.chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
}, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-2f8e594e6f34b1b4d36a246628eeb3365ce442303d656f1fcc69e821722acea0"),
filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
})
})
t.Run("unmatched", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
})
})
}

View File

@@ -0,0 +1,104 @@
package server
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"path/filepath"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
func TestDelete(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)),
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test2"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
}
func TestDeleteDuplicateLayers(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
n := model.ParseName("test")
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(&ConfigV2{}); err != nil {
t.Fatal(err)
}
config, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatal(err)
}
// create a manifest with duplicate layers
if err := WriteManifest(n, config, []*Layer{config}); err != nil {
t.Fatal(err)
}
w := createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test"})
if w.Code != http.StatusOK {
t.Errorf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
}

View File

@@ -0,0 +1,61 @@
package server
import (
"encoding/json"
"fmt"
"net/http"
"slices"
"testing"
"github.com/ollama/ollama/api"
)
func TestList(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
expectNames := []string{
"mistral:7b-instruct-q4_0",
"zephyr:7b-beta-q5_K_M",
"apple/OpenELM:latest",
"boreas:2b-code-v1.5-q6_K",
"notus:7b-v1-IQ2_S",
// TODO: host:port currently fails on windows (#4107)
// "localhost:5000/library/eurus:700b-v0.5-iq3_XXS",
"mynamespace/apeliotes:latest",
"myhost/mynamespace/lips:code",
}
var s Server
for _, n := range expectNames {
createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: n,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
})
}
w := createRequest(t, s.ListModelsHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
var resp api.ListResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if len(resp.Models) != len(expectNames) {
t.Fatalf("expected %d models, actual %d", len(expectNames), len(resp.Models))
}
actualNames := make([]string, len(resp.Models))
for i, m := range resp.Models {
actualNames[i] = m.Name
}
slices.Sort(actualNames)
slices.Sort(expectNames)
if !slices.Equal(actualNames, expectNames) {
t.Fatalf("expected slices to be equal %v", actualNames)
}
}

View File

@@ -15,12 +15,36 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
func createTestFile(t *testing.T, name string) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), name)
require.NoError(t, err)
defer f.Close()
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
require.NoError(t, err)
err = binary.Write(f, binary.LittleEndian, uint32(3))
require.NoError(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
require.NoError(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
require.NoError(t, err)
return f.Name()
}
func Test_Routes(t *testing.T) {
type testCase struct {
Name string
@@ -30,39 +54,19 @@ func Test_Routes(t *testing.T) {
Expected func(t *testing.T, resp *http.Response)
}
createTestFile := func(t *testing.T, name string) string {
createTestModel := func(t *testing.T, name string) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), name)
assert.Nil(t, err)
defer f.Close()
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint32(3))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
return f.Name()
}
createTestModel := func(t *testing.T, name string) {
fname := createTestFile(t, "ollama-model")
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
modelfile, err := parser.ParseFile(r)
assert.Nil(t, err)
require.NoError(t, err)
fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status)
}
err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
assert.Nil(t, err)
err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn)
require.NoError(t, err)
}
testCases := []testCase{
@@ -74,9 +78,9 @@ func Test_Routes(t *testing.T) {
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, contentType, "application/json; charset=utf-8")
assert.Equal(t, "application/json; charset=utf-8", contentType)
body, err := io.ReadAll(resp.Body)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body))
},
},
@@ -86,17 +90,17 @@ func Test_Routes(t *testing.T) {
Path: "/api/tags",
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, contentType, "application/json; charset=utf-8")
assert.Equal(t, "application/json; charset=utf-8", contentType)
body, err := io.ReadAll(resp.Body)
assert.Nil(t, err)
require.NoError(t, err)
var modelList api.ListResponse
err = json.Unmarshal(body, &modelList)
assert.Nil(t, err)
require.NoError(t, err)
assert.NotNil(t, modelList.Models)
assert.Equal(t, 0, len(modelList.Models))
assert.Empty(t, len(modelList.Models))
},
},
{
@@ -108,16 +112,18 @@ func Test_Routes(t *testing.T) {
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, contentType, "application/json; charset=utf-8")
assert.Equal(t, "application/json; charset=utf-8", contentType)
body, err := io.ReadAll(resp.Body)
assert.Nil(t, err)
require.NoError(t, err)
assert.NotContains(t, string(body), "expires_at")
var modelList api.ListResponse
err = json.Unmarshal(body, &modelList)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, 1, len(modelList.Models))
assert.Equal(t, modelList.Models[0].Name, "test-model:latest")
assert.Len(t, modelList.Models, 1)
assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
},
},
{
@@ -134,7 +140,7 @@ func Test_Routes(t *testing.T) {
Stream: &stream,
}
jsonData, err := json.Marshal(createReq)
assert.Nil(t, err)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
@@ -142,11 +148,11 @@ func Test_Routes(t *testing.T) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json", contentType)
_, err := io.ReadAll(resp.Body)
assert.Nil(t, err)
assert.Equal(t, resp.StatusCode, 200)
require.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
model, err := GetModel("t-bone")
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, "t-bone:latest", model.ShortName)
},
},
@@ -161,13 +167,13 @@ func Test_Routes(t *testing.T) {
Destination: "beefsteak",
}
jsonData, err := json.Marshal(copyReq)
assert.Nil(t, err)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
model, err := GetModel("beefsteak")
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, "beefsteak:latest", model.ShortName)
},
},
@@ -179,18 +185,18 @@ func Test_Routes(t *testing.T) {
createTestModel(t, "show-model")
showReq := api.ShowRequest{Model: "show-model"}
jsonData, err := json.Marshal(showReq)
assert.Nil(t, err)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, contentType, "application/json; charset=utf-8")
assert.Equal(t, "application/json; charset=utf-8", contentType)
body, err := io.ReadAll(resp.Body)
assert.Nil(t, err)
require.NoError(t, err)
var showResp api.ShowResponse
err = json.Unmarshal(body, &showResp)
assert.Nil(t, err)
require.NoError(t, err)
var params []string
paramsSplit := strings.Split(showResp.Parameters, "\n")
@@ -209,26 +215,26 @@ func Test_Routes(t *testing.T) {
},
}
t.Setenv("OLLAMA_MODELS", t.TempDir())
s := &Server{}
router := s.GenerateRoutes()
httpSrv := httptest.NewServer(router)
t.Cleanup(httpSrv.Close)
t.Setenv("OLLAMA_MODELS", t.TempDir())
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
u := httpSrv.URL + tc.Path
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
assert.Nil(t, err)
require.NoError(t, err)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp, err := httpSrv.Client().Do(req)
assert.Nil(t, err)
require.NoError(t, err)
defer resp.Body.Close()
if tc.Expected != nil {
@@ -237,3 +243,82 @@ func Test_Routes(t *testing.T) {
})
}
}
func TestCase(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
cases := []string{
"mistral",
"llama3:latest",
"library/phi3:q4_0",
"registry.ollama.ai/library/gemma:q5_K_M",
// TODO: host:port currently fails on windows (#4107)
// "localhost:5000/alice/bob:latest",
}
var s Server
for _, tt := range cases {
t.Run(tt, func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: tt,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200 got %d", w.Code)
}
expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
if err != nil {
t.Fatal(err)
}
t.Run("create", func(t *testing.T) {
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: strings.ToUpper(tt),
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
t.Run("pull", func(t *testing.T) {
w := createRequest(t, s.PullModelHandler, api.PullRequest{
Name: strings.ToUpper(tt),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
t.Run("copy", func(t *testing.T) {
w := createRequest(t, s.CopyModelHandler, api.CopyRequest{
Source: tt,
Destination: strings.ToUpper(tt),
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
})
}
}

View File

@@ -7,17 +7,17 @@ import (
"log/slog"
"reflect"
"runtime"
"slices"
"sort"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/envconfig"
"golang.org/x/exp/slices"
)
type LlmRequest struct {
@@ -66,7 +66,7 @@ func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options,
opts.NumCtx = 4
}
opts.NumCtx = opts.NumCtx * envconfig.NumParallel
opts.NumCtx *= envconfig.NumParallel
req := &LlmRequest{
ctx: c,
@@ -370,7 +370,6 @@ func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
r.refMu.Lock()
gpuIDs := make([]string, 0, len(r.gpus))
if r.llama != nil {
// TODO this should be broken down by GPU instead of assuming uniform spread
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
for _, gpu := range r.gpus {
@@ -529,7 +528,6 @@ func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
}
}()
return finished
}
type ByDuration []*runnerRef

View File

@@ -12,11 +12,10 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/envconfig"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -53,10 +52,10 @@ func TestLoad(t *testing.T) {
}
gpus := gpu.GpuInfoList{}
s.load(req, ggml, gpus)
require.Len(t, req.successCh, 0)
require.Empty(t, req.successCh)
require.Len(t, req.errCh, 1)
s.loadedMu.Lock()
require.Len(t, s.loaded, 0)
require.Empty(t, s.loaded)
s.loadedMu.Unlock()
err := <-req.errCh
require.Contains(t, err.Error(), "this model may be incompatible")
@@ -113,7 +112,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName)
assert.Nil(t, err)
require.NoError(t, err)
defer f.Close()
gguf := llm.NewGGUFV3(binary.LittleEndian)
@@ -131,7 +130,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
}, []llm.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
})
assert.Nil(t, err)
require.NoError(t, err)
fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname}
@@ -190,8 +189,8 @@ func TestRequests(t *testing.T) {
select {
case resp := <-scenario1a.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario1a.req.errCh, 0)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1a.req.errCh)
case <-ctx.Done():
t.Errorf("timeout")
}
@@ -203,8 +202,8 @@ func TestRequests(t *testing.T) {
select {
case resp := <-scenario1b.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario1b.req.errCh, 0)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1b.req.errCh)
case <-ctx.Done():
t.Errorf("timeout")
}
@@ -221,8 +220,8 @@ func TestRequests(t *testing.T) {
select {
case resp := <-scenario2a.req.successCh:
require.Equal(t, resp.llama, scenario2a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario2a.req.errCh, 0)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario2a.req.errCh)
case <-ctx.Done():
t.Errorf("timeout")
}
@@ -237,8 +236,8 @@ func TestRequests(t *testing.T) {
select {
case resp := <-scenario3a.req.successCh:
require.Equal(t, resp.llama, scenario3a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3a.req.errCh, 0)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3a.req.errCh)
case <-ctx.Done():
t.Errorf("timeout")
}
@@ -253,8 +252,8 @@ func TestRequests(t *testing.T) {
select {
case resp := <-scenario3b.req.successCh:
require.Equal(t, resp.llama, scenario3b.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3b.req.errCh, 0)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3b.req.errCh)
case <-ctx.Done():
t.Errorf("timeout")
}
@@ -269,8 +268,8 @@ func TestRequests(t *testing.T) {
select {
case resp := <-scenario3c.req.successCh:
require.Equal(t, resp.llama, scenario3c.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3c.req.errCh, 0)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3c.req.errCh)
case <-ctx.Done():
t.Errorf("timeout")
}
@@ -296,8 +295,8 @@ func TestRequests(t *testing.T) {
select {
case resp := <-scenario3d.req.successCh:
require.Equal(t, resp.llama, scenario3d.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3d.req.errCh, 0)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3d.req.errCh)
case <-ctx.Done():
t.Errorf("timeout")
}
@@ -332,7 +331,7 @@ func TestGetRunner(t *testing.T) {
slog.Info("scenario1b")
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
require.Len(t, successCh1b, 0)
require.Empty(t, successCh1b)
require.Len(t, errCh1b, 1)
err := <-errCh1b
require.Contains(t, err.Error(), "server busy")
@@ -340,8 +339,8 @@ func TestGetRunner(t *testing.T) {
select {
case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, errCh1a, 0)
require.Empty(t, s.pendingReqCh)
require.Empty(t, errCh1a)
case <-ctx.Done():
t.Errorf("timeout")
}
@@ -355,9 +354,9 @@ func TestGetRunner(t *testing.T) {
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
// Starts in pending channel, then should be quickly processsed to return an error
time.Sleep(5 * time.Millisecond)
require.Len(t, successCh1c, 0)
require.Empty(t, successCh1c)
s.loadedMu.Lock()
require.Len(t, s.loaded, 0)
require.Empty(t, s.loaded)
s.loadedMu.Unlock()
require.Len(t, errCh1c, 1)
err = <-errCh1c
@@ -386,8 +385,8 @@ func TestPrematureExpired(t *testing.T) {
select {
case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, errCh1a, 0)
require.Empty(t, s.pendingReqCh)
require.Empty(t, errCh1a)
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
@@ -401,9 +400,9 @@ func TestPrematureExpired(t *testing.T) {
time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1)
time.Sleep(10 * time.Millisecond)
require.Len(t, s.finishedReqCh, 0)
require.Empty(t, s.finishedReqCh)
s.loadedMu.Lock()
require.Len(t, s.loaded, 0)
require.Empty(t, s.loaded)
s.loadedMu.Unlock()
// also shouldn't happen in real life
@@ -487,7 +486,6 @@ func TestFindRunnerToUnload(t *testing.T) {
r2.refCount = 1
resp = s.findRunnerToUnload()
require.Equal(t, r1, resp)
}
func TestNeedsReload(t *testing.T) {

View File

@@ -146,7 +146,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
case requestURL := <-b.nextURL:
g.Go(func() error {
var err error
for try := 0; try < maxRetries; try++ {
for try := range maxRetries {
err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
switch {
case errors.Is(err, context.Canceled):
@@ -190,7 +190,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", "0")
for try := 0; try < maxRetries; try++ {
for try := range maxRetries {
var resp *http.Response
resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
if errors.Is(err, context.Canceled) {
@@ -253,7 +253,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
}
// retry uploading to the redirect URL
for try := 0; try < maxRetries; try++ {
for try := range maxRetries {
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
switch {
case errors.Is(err, context.Canceled):
@@ -391,7 +391,7 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryO
return err
}
// nolint: contextcheck
//nolint:contextcheck
go upload.Run(context.Background(), opts)
}

1
templates/alfred.gotmpl Normal file
View File

@@ -0,0 +1 @@
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>

7
templates/alpaca.gotmpl Normal file
View File

@@ -0,0 +1,7 @@
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction:
{{ .Prompt }}
{{ end }}### Response:
{{ .Response }}

6
templates/chatml.gotmpl Normal file
View File

@@ -0,0 +1,6 @@
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>

5
templates/chatqa.gotmpl Normal file
View File

@@ -0,0 +1,5 @@
{{ if .System }}System: {{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}

View File

@@ -0,0 +1,8 @@
{{ if .System }} Source: system
{{ .System }} <step>{{ end }} Source: user
{{ .Prompt }} <step> Source: assistant
Destination: user
{{ .Response }}<step>

View File

@@ -0,0 +1,3 @@
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: {{ .Response }}

View File

@@ -0,0 +1,4 @@
<start_of_turn>user
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
<start_of_turn>model
{{ .Response }}<end_of_turn>

View File

@@ -0,0 +1,9 @@
{{ if .System }}
System:
{{ .System }}
{{ end }}{{ if .Prompt }}Question:
{{ .Prompt }}
{{ end }}Answer:
{{ .Response }}

138
templates/index.json Normal file
View File

@@ -0,0 +1,138 @@
[
{
"template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}",
"name": "chatml"
},
{
"template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"name": "chatml"
},
{
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
"name": "zephyr"
},
{
"template": "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}",
"name": "chatml"
},
{
"template": "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
"name": "openchat"
},
{
"template": "{{bos_token}}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"name": "chatml"
},
{
"template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"name": "chatml"
},
{
"template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"name": "chatml"
},
{
"template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"name": "chatml"
},
{
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
"name": "zephyr"
},
{
"template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"name": "mistral-instruct"
},
{
"template": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'### Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response\n'}}",
"name": "starcoder2-instruct"
},
{
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}",
"name": "llama2-chat"
},
{
"template": "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '<s>' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\n\n ' + message['content'] | trim %}{{ content + ' <step> ' }}{% endfor %}{{'Source: assistant\nDestination: user\n\n '}}",
"name": "codellama-70b-instruct"
},
{
"template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"name": "mistral-instruct"
},
{
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'system' %}\n{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|im_start|>assistant\n' + message['content'] + '<|im_end|>' }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|im_start|>assistant' }}\n{% endif %}\n{% endfor %}",
"name": "chatml"
},
{
"template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"name": "chatml"
},
{
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"name": "chatml"
},
{
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif 'system' not in messages[0]['role'] %}{% set loop_messages = messages %}{% set system_message = 'You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks \u2014 remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER\\'S QUERY.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% if system_message != false %}{{ '<|im_start|>system\n' + system_message | trim + '<|im_end|>\n'}}{% endif %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% else %}{{ '\n' + '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% endif %}{% if (add_generation_prompt == true and loop.last) %}{{ '\n' + '<|im_start|>' + 'assistant' + '\n' }}{% endif %}{% endfor %}",
"name": "chatml"
},
{
"template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}",
"name": "alpaca"
},
{
"template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
"name": "chatqa"
},
{
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"name": "gemma-instruct"
},
{
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"name": "llama3-instruct"
},
{
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'Question:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'system' %}\n{{ 'System:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Answer:\n' + message['content'] + '\n\n' }}{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Answer:\n' }}{% endif %}{% endfor %}",
"name": "granite-instruct"
},
{
"template": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'@@ Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'@@ Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'@@ Response\n'}}",
"name": "magicoder"
},
{
"template": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<start_user>' + message['content'].strip() + '<end_message>' }}{% elif message['role'] == 'system' %}{{ '<start_system>' + message['content'].strip() + '<end_message>' }}{% elif message['role'] == 'assistant' %}{{ '<start_assistant>' + message['content'] + '<end_message>' }}{% else %}{{ raise_exception('Only system, user and assistant roles are supported.') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<start_assistant>' }}{% endif %}{% endfor %}",
"name": "alfred"
},
{
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
"name": "llama2-chat"
},
{
"template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
"name": "phi-3"
},
{
"template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
"name": "phi-3"
},
{
"template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
"name": "phi-3"
},
{
"template": "{{ bos_token }}{%- if messages[0]['role'] == 'system' -%}{% set loop_messages = messages[1:] %}{%- else -%}{% set loop_messages = messages %}{% endif %}System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context.\n\n{% for message in loop_messages %}{%- if message['role'] == 'user' -%}User: {{ message['content'].strip() + '\n\n' }}{%- else -%}Assistant: {{ message['content'].strip() + '\n\n' }}{%- endif %}{% if loop.last and message['role'] == 'user' %}Assistant:{% endif %}{% endfor %}",
"name": "chatqa"
},
{
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'User: \n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ 'System: ' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ 'Falcon:\n' + message['content']}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Falcon:' }}\n{% endif %}\n{% endfor %}",
"name": "falcon-instruct"
},
{
"template": "{% for message in messages %}{% if not loop.first %}{{ '\n' }}{% endif %}{% if message['role'] == 'system' %}{{ 'System: ' }}{% elif message['role'] == 'user' %}{{ 'User: ' }}{% elif message['role'] == 'assistant' %}{{ 'Falcon: ' }}{% endif %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ '\n' + 'Falcon:' }}{% endif %}",
"name": "falcon-instruct"
},
{
"template": "{% for message in messages %}{% if message['role'] == 'system' %}{% if message['content']%}{{'### System:\n' + message['content']+'\n\n'}}{% endif %}{% elif message['role'] == 'user' %}{{'### User:\n' + message['content']+'\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Assistant:\n' + message['content']}}{% endif %}{% if loop.last and add_generation_prompt %}{{ '### Assistant:\n' }}{% endif %}{% endfor %}",
"name": "solar-instruct"
}
]

View File

@@ -0,0 +1,3 @@
[INST] <<SYS>>{{ .System }}<</SYS>>
{{ .Prompt }} [/INST] {{ .Response }}

Some files were not shown because too many files have changed in this diff Show More