Compare commits

...

63 Commits

Author SHA1 Message Date
likelovewant
2490a69f7b Merge branch 'ollama:main' into main 2024-06-03 14:15:05 +08:00
Jeffrey Morgan
d4a86102fd update welcome prompt in windows to llama3 (#4779) 2024-06-01 21:05:51 -07:00
Jeffrey Morgan
476fb8e892 Limit GPU lib search for now (#4777)
* fix oneapi errors on windows 10
2024-06-01 19:24:33 -07:00
Michael Yang
829ff87bd1 revert tokenize ffi (#4761)
* Revert "use `int32_t` for call to tokenize (#4738)"

This reverts commit 763bb65dbb.

* Revert "vocab only"

This reverts commit bf54c845e9.

* Revert "use ffi for tokenizing/detokenizing"

This reverts commit 26a00a0410.
2024-05-31 18:54:21 -07:00
Josh
f6b622c4b3 Merge pull request #4733 from ollama/jyan/isvalidname
added IsValidNamespace function
2024-05-31 14:08:45 -07:00
Josh Yan
2e4da8eec2 added tests for IsValidNamespace 2024-05-31 11:48:07 -07:00
likelovewant
16ce79eb3b Merge branch 'ollama:main' into main 2024-05-31 18:43:24 +08:00
Jeffrey Morgan
763bb65dbb use int32_t for call to tokenize (#4738)
* use `int32_t` for call to tokenize

* variable naming

* cleanup

* fix crash
2024-05-30 21:43:30 -07:00
Jeffrey Morgan
7ca9605f54 speed up tests by only building static lib (#4740) 2024-05-30 21:43:15 -07:00
Michael Yang
eb2c443a79 Merge pull request #4736 from ollama/mxyng/vocab-only
vocab only for tokenize
2024-05-30 17:21:00 -07:00
Michael Yang
278e25ea44 Merge pull request #4737 from ollama/mxyng/less-generate
only generate on relevant changes
2024-05-30 17:17:50 -07:00
Jeffrey Morgan
a50a87a7b8 partial offloading: allow flash attention and disable mmap (#4734)
* partial offloading: allow flash attention and disable mmap

* allow mmap with num_gpu=0
2024-05-30 16:58:01 -07:00
Michael Yang
98085015d5 only generate on relevant changes 2024-05-30 16:54:11 -07:00
Michael Yang
bf54c845e9 vocab only 2024-05-30 16:49:28 -07:00
Josh Yan
c365f195a8 directly use isvalidpart 2024-05-30 16:40:04 -07:00
Josh
e91d0ef737 Merge pull request #4728 from ollama/jyan/japanese
fixed japanese characters deleted at end of line
2024-05-30 16:25:12 -07:00
Jeffrey Morgan
22f5c12ced Update llama.cpp submodule to 5921b8f0 (#4731)
* update llama.cpp submodule to `5921b8f089d3b7bda86aac5a66825df6a6c10603`

* add patch
2024-05-30 16:20:22 -07:00
Josh Yan
298c996e54 added IsValidNamespace function 2024-05-30 16:02:07 -07:00
Daniel Hiltgen
0fc0cfc6d2 Merge pull request #4594 from dhiltgen/doc_container_workarounds
Add isolated gpu test to troubleshooting
2024-05-30 13:10:54 -07:00
Josh Yan
914f68f021 replaced duplicate call with variable 2024-05-30 10:38:07 -07:00
Josh Yan
bd1d119ba9 fixed japanese characters deleted at end of line 2024-05-30 10:24:21 -07:00
Lei Jitang
a03be18189 Fix OLLAMA_LLM_LIBRARY with wrong map name and add more env vars to help message (#4663)
* envconfig/config.go: Fix wrong description of OLLAMA_LLM_LIBRARY

Signed-off-by: Lei Jitang <leijitang@outlook.com>

* serve: Add more env to help message of ollama serve

Add more enviroment variables to `ollama serve --help`
to let users know what can be configurated.

Signed-off-by: Lei Jitang <leijitang@outlook.com>

---------

Signed-off-by: Lei Jitang <leijitang@outlook.com>
2024-05-30 09:36:51 -07:00
Michael Yang
96bc232b43 Merge pull request #4413 from ollama/mxyng/name-check
check if name exists before create/pull/copy
2024-05-29 12:06:58 -07:00
Michael Yang
bca7b12284 Merge pull request #3718 from ollama/mxyng/modelname-3
update delete handler to use model.Name
2024-05-29 12:02:07 -07:00
Michael Yang
32cb1960c1 Merge pull request #4380 from ollama/mxyng/tokenize
use tokenize/detokenize
2024-05-29 12:00:59 -07:00
Michael Yang
de781b37c8 rm unused infill 2024-05-29 11:26:47 -07:00
Michael Yang
3e21799377 rm unused system prompt 2024-05-29 11:26:47 -07:00
Michael Yang
26a00a0410 use ffi for tokenizing/detokenizing 2024-05-29 11:26:47 -07:00
likelovewant
cafde1f8ce Merge branch 'ollama:main' into main 2024-05-29 19:33:39 +08:00
Daniel Hiltgen
646371f56d Merge pull request #3278 from zhewang1-intc/rebase_ollama_main
Enabling ollama to run on Intel GPUs with SYCL backend
2024-05-28 16:30:50 -07:00
Jeffrey Morgan
1f5008544b Update install.sh 2024-05-28 15:01:22 -07:00
Jeffrey Morgan
45cbfc5aee fix wsl2 status check for nvidia cards (#4689) 2024-05-28 14:49:46 -07:00
Jeffrey Morgan
6d423b383b Improve install experience on WSL2 and Linux (#4653) 2024-05-28 14:41:50 -07:00
Josh
ad897080a2 working on integration of multi-byte and multi-width runes (#4549)
* integrated runewidth for display management - fixed cursor movement for mutli-width char

* updated input and deletion of multi-byte chars

* fixed line history with some exceptions

* improved insert and add

* fixed issues with moving across lines

* end of line extra space tracking'

* saved changes

* fixed end of line issues with empty spaces

* worked some more

* worked on end of line

* fixed failed test

* fixed minor inserting bug

* fixed movement hotkeys

* adjusted hotkeys

* removed comments

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* Update readline/buffer.go

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>

* deleted comments and duplicate code

* removed duplicate code

* added comments, refactored add function to use addChar

* added helper to retrieve lineSpacing, renamed lineFlags for clarity

* fixed remove()

---------

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2024-05-28 12:04:03 -07:00
Jeffrey Morgan
b7d316d98d fix nvidia detection in install script (#4683) 2024-05-28 09:59:36 -07:00
Daniel Hiltgen
d7339fad52 Merge pull request #4682 from dhiltgen/more_time
Give the final model loading more time
2024-05-28 09:36:02 -07:00
Daniel Hiltgen
92c81e8117 Give the final model loading more time
On some systems, 1 minute isn't sufficient to finish the load after it
hits 100% This creates 2 distinct timers, although they're both set to
the same value for now so we can refine the timeouts further.
2024-05-28 09:08:10 -07:00
Tai
9db0996ed4 Add OllamaSpring Project to Readme (#4672)
* Add OllamaSpring Project to Readme

* Update README.md

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-05-27 19:58:26 -07:00
Orfeo Ciano
6f43898b17 Adds olpaka flutter client (#4647)
* Adds olpaka flutter client

* Update README.md

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-05-27 17:22:01 -07:00
Lei Jitang
7487229c34 llm/server.go: Fix 2 minor typos (#4661)
Signed-off-by: Lei Jitang <leijitang@outlook.com>
2024-05-27 17:21:10 -07:00
Rayan Mostovoi
8a8e7afa96 small fix on examples/python-simplechat/client.py to actually get a streamed response and get tokens printed as we receive it (#4671) 2024-05-27 17:19:20 -07:00
Jeffrey Morgan
c79f8c9c39 Ensure nvidia and nvidia_uvm kernel modules are loaded in install.sh script and at startup (#4652)
* ensure kernel modules are loaded in `install.sh` script and at startup

* indentation

* use `SUDO` variable

* restart if nouveau is detected

* consistent success message for AMD
2024-05-26 14:57:17 -07:00
Jeffrey Morgan
485016bfbb Update install.sh 2024-05-26 11:46:00 -07:00
likelovewant
2a80d6f743 Merge branch 'ollama:main' into main 2024-05-26 11:57:21 +08:00
Daniel Hiltgen
0165ba1651 Merge pull request #4638 from dhiltgen/better_error
Report better warning on client closed abort of load
2024-05-25 14:32:28 -07:00
Daniel Hiltgen
c4209d6d21 Report better warning on client closed abort of load
If the client closes the connection before we finish loading the model
we abort, so lets make the log message clearer why to help users
understand this failure mode
2024-05-25 09:23:28 -07:00
Michael Yang
6adca97f37 Merge pull request #4619 from noxer/patch-1
Fix download retry issue
2024-05-24 17:21:57 -07:00
Michael Yang
9a3c8003c8 Merge pull request #4624 from ollama/mxyng/fix-5
fix q5_0, q5_1
2024-05-24 16:11:21 -07:00
Michael Yang
d51f15257c Update llm/ggml.go
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2024-05-24 16:10:43 -07:00
Michael Yang
8f440d579a fix q5_0, q5_1 2024-05-24 16:01:46 -07:00
Patrick Devine
4cc3be3035 Move envconfig and consolidate env vars (#4608) 2024-05-24 14:57:15 -07:00
Tim Scheuermann
db2ffa79f1 Fix download retry issue 2024-05-24 20:30:42 +02:00
likelovewant
73c49d57e8 Update amd_windows.go
remove this will broken the installer build
2024-05-24 20:06:28 +08:00
likelovewant
6b50b2f3bf Update gen_windows.ps1 2024-05-24 15:42:29 +08:00
Wang,Zhe
fd5971be0b support ollama run on Intel GPUs 2024-05-24 11:18:27 +08:00
Daniel Hiltgen
f77713bf1f Add isolated gpu test to troubleshooting 2024-05-23 09:33:25 -07:00
Michael Yang
85a57006d1 check if name exists before create/pull/copy 2024-05-14 14:58:58 -07:00
Michael Yang
c5e892cb3e update tests 2024-05-14 14:56:31 -07:00
Michael Yang
81fb06f530 more resilient Manifests 2024-05-14 14:08:24 -07:00
Michael Yang
a385382ff5 filepath.Join 2024-05-14 14:08:24 -07:00
Michael Yang
b8772a353f remove DeleteModel 2024-05-14 14:08:24 -07:00
Michael Yang
c2714fcbfd routes: use Manifests for ListHandler 2024-05-14 14:08:24 -07:00
Michael Yang
a2fc933fed update delete handler to use model.Name 2024-05-14 14:08:24 -07:00
43 changed files with 1891 additions and 529 deletions

View File

@@ -34,13 +34,13 @@ jobs:
git diff-tree -r --no-commit-id --name-only \ git diff-tree -r --no-commit-id --name-only \
$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \ $(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
${{ github.event.pull_request.head.sha }} \ ${{ github.event.pull_request.head.sha }} \
| xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))" | xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
} }
{ {
echo GENERATE=$(changed llm/) echo GENERATE=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
echo GENERATE_CUDA=$(changed llm/) echo GENERATE_CUDA=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
echo GENERATE_ROCM=$(changed llm/) echo GENERATE_ROCM=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
} >>$GITHUB_OUTPUT } >>$GITHUB_OUTPUT
generate: generate:
@@ -287,6 +287,8 @@ jobs:
GOARCH: ${{ matrix.arch }} GOARCH: ${{ matrix.arch }}
CGO_ENABLED: '1' CGO_ENABLED: '1'
OLLAMA_CPU_TARGET: 'static' OLLAMA_CPU_TARGET: 'static'
OLLAMA_SKIP_CPU_GENERATE: '1'
OLLAMA_SKIP_METAL_GENERATE: '1'
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:

View File

@@ -301,6 +301,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG) - [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG)
- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation) - [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation)
- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends) - [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends)
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
### Terminal ### Terminal

View File

@@ -6,7 +6,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/envconfig"
) )
func InitLogging() { func InitLogging() {

View File

@@ -4,5 +4,5 @@ write-host "Welcome to Ollama!"
write-host "" write-host ""
write-host "Run your first model:" write-host "Run your first model:"
write-host "" write-host ""
write-host "`tollama run llama2" write-host "`tollama run llama3"
write-host "" write-host ""

View File

@@ -34,6 +34,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth" "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
@@ -754,7 +755,11 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
} }
// backtrack the length of the last word and clear to the end of the line // backtrack the length of the last word and clear to the end of the line
fmt.Printf("\x1b[%dD\x1b[K\n", runewidth.StringWidth(state.wordBuffer)) a := runewidth.StringWidth(state.wordBuffer)
if a > 0 {
fmt.Printf("\x1b[%dD", a)
}
fmt.Printf("\x1b[K\n")
fmt.Printf("%s%c", state.wordBuffer, ch) fmt.Printf("%s%c", state.wordBuffer, ch)
chWidth := runewidth.RuneWidth(ch) chWidth := runewidth.RuneWidth(ch)
@@ -1079,12 +1084,7 @@ func versionHandler(cmd *cobra.Command, _ []string) {
} }
} }
type EnvironmentVar struct { func appendEnvDocs(cmd *cobra.Command, envs []envconfig.EnvVar) {
Name string
Description string
}
func appendEnvDocs(cmd *cobra.Command, envs []EnvironmentVar) {
if len(envs) == 0 { if len(envs) == 0 {
return return
} }
@@ -1093,7 +1093,7 @@ func appendEnvDocs(cmd *cobra.Command, envs []EnvironmentVar) {
Environment Variables: Environment Variables:
` `
for _, e := range envs { for _, e := range envs {
envUsage += fmt.Sprintf(" %-16s %s\n", e.Name, e.Description) envUsage += fmt.Sprintf(" %-24s %s\n", e.Name, e.Description)
} }
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage) cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
@@ -1172,15 +1172,6 @@ func NewCLI() *cobra.Command {
Args: cobra.ExactArgs(0), Args: cobra.ExactArgs(0),
RunE: RunServer, RunE: RunServer,
} }
serveCmd.SetUsageTemplate(serveCmd.UsageTemplate() + `
Environment Variables:
OLLAMA_HOST The host:port to bind to (default "127.0.0.1:11434")
OLLAMA_ORIGINS A comma separated list of allowed origins
OLLAMA_MODELS The path to the models directory (default "~/.ollama/models")
OLLAMA_KEEP_ALIVE The duration that models stay loaded in memory (default "5m")
OLLAMA_DEBUG Set to 1 to enable additional debug logging
`)
pullCmd := &cobra.Command{ pullCmd := &cobra.Command{
Use: "pull MODEL", Use: "pull MODEL",
@@ -1233,9 +1224,9 @@ Environment Variables:
RunE: DeleteHandler, RunE: DeleteHandler,
} }
ollamaHostEnv := EnvironmentVar{"OLLAMA_HOST", "The host:port or base URL of the Ollama server (e.g. http://localhost:11434)"} envVars := envconfig.AsMap()
ollamaNoHistoryEnv := EnvironmentVar{"OLLAMA_NOHISTORY", "Disable readline history"}
envs := []EnvironmentVar{ollamaHostEnv} envs := []envconfig.EnvVar{envVars["OLLAMA_HOST"]}
for _, cmd := range []*cobra.Command{ for _, cmd := range []*cobra.Command{
createCmd, createCmd,
@@ -1247,10 +1238,27 @@ Environment Variables:
psCmd, psCmd,
copyCmd, copyCmd,
deleteCmd, deleteCmd,
serveCmd,
} { } {
switch cmd { switch cmd {
case runCmd: case runCmd:
appendEnvDocs(cmd, []EnvironmentVar{ollamaHostEnv, ollamaNoHistoryEnv}) appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
case serveCmd:
appendEnvDocs(cmd, []envconfig.EnvVar{
envVars["OLLAMA_DEBUG"],
envVars["OLLAMA_HOST"],
envVars["OLLAMA_KEEP_ALIVE"],
envVars["OLLAMA_MAX_LOADED_MODELS"],
envVars["OLLAMA_MAX_QUEUE"],
envVars["OLLAMA_MODELS"],
envVars["OLLAMA_NUM_PARALLEL"],
envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_MAX_VRAM"],
})
default: default:
appendEnvDocs(cmd, envs) appendEnvDocs(cmd, envs)
} }

View File

@@ -15,6 +15,7 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline" "github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
@@ -183,7 +184,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err return err
} }
if os.Getenv("OLLAMA_NOHISTORY") != "" { if envconfig.NoHistory {
scanner.HistoryDisable() scanner.HistoryDisable()
} }

View File

@@ -76,6 +76,7 @@ Make sure you've set up the container runtime first as described in [docker.md](
Sometimes the container runtime can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem Sometimes the container runtime can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
- Is the container runtime working? Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU.
- Is the uvm driver not loaded? `sudo nvidia-modprobe -u` - Is the uvm driver not loaded? `sudo nvidia-modprobe -u`
- Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm` - Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm`
- Try rebooting - Try rebooting

View File

@@ -15,6 +15,10 @@ var (
AllowOrigins []string AllowOrigins []string
// Set via OLLAMA_DEBUG in the environment // Set via OLLAMA_DEBUG in the environment
Debug bool Debug bool
// Experimental flash attention
FlashAttention bool
// Set via OLLAMA_KEEP_ALIVE in the environment
KeepAlive string
// Set via OLLAMA_LLM_LIBRARY in the environment // Set via OLLAMA_LLM_LIBRARY in the environment
LLMLibrary string LLMLibrary string
// Set via OLLAMA_MAX_LOADED_MODELS in the environment // Set via OLLAMA_MAX_LOADED_MODELS in the environment
@@ -23,6 +27,8 @@ var (
MaxQueuedRequests int MaxQueuedRequests int
// Set via OLLAMA_MAX_VRAM in the environment // Set via OLLAMA_MAX_VRAM in the environment
MaxVRAM uint64 MaxVRAM uint64
// Set via OLLAMA_NOHISTORY in the environment
NoHistory bool
// Set via OLLAMA_NOPRUNE in the environment // Set via OLLAMA_NOPRUNE in the environment
NoPrune bool NoPrune bool
// Set via OLLAMA_NUM_PARALLEL in the environment // Set via OLLAMA_NUM_PARALLEL in the environment
@@ -31,26 +37,42 @@ var (
RunnersDir string RunnersDir string
// Set via OLLAMA_TMPDIR in the environment // Set via OLLAMA_TMPDIR in the environment
TmpDir string TmpDir string
// Experimental flash attention
FlashAttention bool
) )
func AsMap() map[string]string { type EnvVar struct {
return map[string]string{ Name string
"OLLAMA_ORIGINS": fmt.Sprintf("%v", AllowOrigins), Value any
"OLLAMA_DEBUG": fmt.Sprintf("%v", Debug), Description string
"OLLAMA_LLM_LIBRARY": fmt.Sprintf("%v", LLMLibrary), }
"OLLAMA_MAX_LOADED_MODELS": fmt.Sprintf("%v", MaxRunners),
"OLLAMA_MAX_QUEUE": fmt.Sprintf("%v", MaxQueuedRequests), func AsMap() map[string]EnvVar {
"OLLAMA_MAX_VRAM": fmt.Sprintf("%v", MaxVRAM), return map[string]EnvVar{
"OLLAMA_NOPRUNE": fmt.Sprintf("%v", NoPrune), "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug, "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel), "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"},
"OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir), "OLLAMA_HOST": {"OLLAMA_HOST", "", "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir), "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_FLASH_ATTENTION": fmt.Sprintf("%v", FlashAttention), "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", "", "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
} }
} }
func Values() map[string]string {
vals := make(map[string]string)
for k, v := range AsMap() {
vals[k] = fmt.Sprintf("%v", v.Value)
}
return vals
}
var defaultAllowOrigins = []string{ var defaultAllowOrigins = []string{
"localhost", "localhost",
"127.0.0.1", "127.0.0.1",
@@ -147,6 +169,10 @@ func LoadConfig() {
} }
} }
if nohistory := clean("OLLAMA_NOHISTORY"); nohistory != "" {
NoHistory = true
}
if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" { if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" {
NoPrune = true NoPrune = true
} }
@@ -181,4 +207,6 @@ func LoadConfig() {
MaxQueuedRequests = p MaxQueuedRequests = p
} }
} }
KeepAlive = clean("OLLAMA_KEEP_ALIVE")
} }

View File

@@ -9,6 +9,7 @@ def chat(messages):
r = requests.post( r = requests.post(
"http://0.0.0.0:11434/api/chat", "http://0.0.0.0:11434/api/chat",
json={"model": model, "messages": messages, "stream": True}, json={"model": model, "messages": messages, "stream": True},
stream=True
) )
r.raise_for_status() r.raise_for_status()
output = "" output = ""

View File

@@ -108,10 +108,10 @@ func AMDGetGPUInfo() []GpuInfo {
} }
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library // iGPU detection, remove this check once we can support an iGPU variant of the rocm library
//if totalMemory < IGPUMemLimit { if totalMemory < IGPUMemLimit {
// slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory)) slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
// continue continue
//} }
// TODO revisit this once ROCm v6 is available on windows. // TODO revisit this once ROCm v6 is available on windows.
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable // v5.7 only reports VRAM used by this process, so it's completely wrong and unusable

View File

@@ -13,7 +13,7 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/envconfig"
) )
var ( var (

View File

@@ -20,14 +20,15 @@ import (
"sync" "sync"
"unsafe" "unsafe"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/server/envconfig"
) )
type handles struct { type handles struct {
deviceCount int deviceCount int
cudart *C.cudart_handle_t cudart *C.cudart_handle_t
nvcuda *C.nvcuda_handle_t nvcuda *C.nvcuda_handle_t
oneapi *C.oneapi_handle_t
} }
const ( const (
@@ -80,6 +81,15 @@ var NvcudaWindowsGlobs = []string{
"c:\\windows\\system*\\nvcuda.dll", "c:\\windows\\system*\\nvcuda.dll",
} }
var OneapiWindowsGlobs = []string{
"c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
}
var OneapiLinuxGlobs = []string{
"/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*",
"/usr/lib*/libze_intel_gpu.so*",
}
// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed. // Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
var CudaTegra string = os.Getenv("JETSON_JETPACK") var CudaTegra string = os.Getenv("JETSON_JETPACK")
@@ -141,6 +151,7 @@ func initGPUHandles() *handles {
return gpuHandles return gpuHandles
} }
} }
return gpuHandles return gpuHandles
} }
@@ -181,39 +192,41 @@ func GetGPUInfo() GpuInfoList {
if cpuVariant == "" && runtime.GOARCH == "amd64" { if cpuVariant == "" && runtime.GOARCH == "amd64" {
continue continue
} }
gpuInfo := GpuInfo{ if gpuHandles.cudart != nil || gpuHandles.nvcuda != nil {
Library: "cuda", gpuInfo := GpuInfo{
} Library: "cuda",
var driverMajor int }
var driverMinor int var driverMajor int
if gpuHandles.cudart != nil { var driverMinor int
C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo) if gpuHandles.cudart != nil {
} else { C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo) } else {
driverMajor = int(gpuHandles.nvcuda.driver_major) C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo)
driverMinor = int(gpuHandles.nvcuda.driver_minor) driverMajor = int(gpuHandles.nvcuda.driver_major)
} driverMinor = int(gpuHandles.nvcuda.driver_minor)
if memInfo.err != nil { }
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) if memInfo.err != nil {
C.free(unsafe.Pointer(memInfo.err)) slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
continue C.free(unsafe.Pointer(memInfo.err))
} continue
if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) { }
slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor)) if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
continue slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
} continue
gpuInfo.TotalMemory = uint64(memInfo.total) }
gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
gpuInfo.MinimumMemory = cudaMinimumMemory gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor)
gpuInfo.DependencyPath = depPath gpuInfo.MinimumMemory = cudaMinimumMemory
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) gpuInfo.DependencyPath = depPath
gpuInfo.DriverMajor = int(driverMajor) gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
gpuInfo.DriverMinor = int(driverMinor) gpuInfo.DriverMajor = int(driverMajor)
gpuInfo.DriverMinor = int(driverMinor)
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does... // TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
resp = append(resp, gpuInfo) resp = append(resp, gpuInfo)
}
} }
// Then AMD // Then AMD
@@ -348,6 +361,23 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
return 0, nil, "" return 0, nil, ""
} }
func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
var resp C.oneapi_init_resp_t
resp.oh.verbose = getVerboseState()
for _, libPath := range oneapiLibPaths {
lib := C.CString(libPath)
defer C.free(unsafe.Pointer(lib))
C.oneapi_init(lib, &resp)
if resp.err != nil {
slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err))
} else {
return int(resp.num_devices), &resp.oh, libPath
}
}
return 0, nil, ""
}
func getVerboseState() C.uint16_t { func getVerboseState() C.uint16_t {
if envconfig.Debug { if envconfig.Debug {
return C.uint16_t(1) return C.uint16_t(1)
@@ -368,6 +398,8 @@ func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
return cudaGetVisibleDevicesEnv(l) return cudaGetVisibleDevicesEnv(l)
case "rocm": case "rocm":
return rocmGetVisibleDevicesEnv(l) return rocmGetVisibleDevicesEnv(l)
case "oneapi":
return oneapiGetVisibleDevicesEnv(l)
default: default:
slog.Debug("no filter required for library " + l[0].Library) slog.Debug("no filter required for library " + l[0].Library)
return "", "" return "", ""

View File

@@ -62,6 +62,7 @@ void cpu_check_ram(mem_info_t *resp);
#include "gpu_info_cudart.h" #include "gpu_info_cudart.h"
#include "gpu_info_nvcuda.h" #include "gpu_info_nvcuda.h"
#include "gpu_info_oneapi.h"
#endif // __GPU_INFO_H__ #endif // __GPU_INFO_H__
#endif // __APPLE__ #endif // __APPLE__

214
gpu/gpu_info_oneapi.c Normal file
View File

@@ -0,0 +1,214 @@
#ifndef __APPLE__
#include "gpu_info_oneapi.h"
#include <string.h>
void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
{
ze_result_t ret;
resp->err = NULL;
const int buflen = 256;
char buf[buflen + 1];
int i;
struct lookup
{
char *s;
void **p;
} l[] = {
{"zesInit", (void *)&resp->oh.zesInit},
{"zesDriverGet", (void *)&resp->oh.zesDriverGet},
{"zesDeviceGet", (void *)&resp->oh.zesDeviceGet},
{"zesDeviceGetProperties", (void *)&resp->oh.zesDeviceGetProperties},
{"zesDeviceEnumMemoryModules",
(void *)&resp->oh.zesDeviceEnumMemoryModules},
{"zesMemoryGetProperties", (void *)&resp->oh.zesMemoryGetProperties},
{"zesMemoryGetState", (void *)&resp->oh.zesMemoryGetState},
{NULL, NULL},
};
resp->oh.handle = LOAD_LIBRARY(oneapi_lib_path, RTLD_LAZY);
if (!resp->oh.handle)
{
char *msg = LOAD_ERR();
snprintf(buf, buflen,
"Unable to load %s library to query for Intel GPUs: %s\n",
oneapi_lib_path, msg);
free(msg);
resp->err = strdup(buf);
return;
}
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->oh.verbose,
"wiring Level-Zero management library functions in %s\n",
oneapi_lib_path);
for (i = 0; l[i].s != NULL; i++)
{
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s);
if (!l[i].p)
{
resp->oh.handle = NULL;
char *msg = LOAD_ERR();
LOG(resp->oh.verbose, "dlerr: %s\n", msg);
UNLOAD_LIBRARY(resp->oh.handle);
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, msg);
free(msg);
resp->err = strdup(buf);
return;
}
}
ret = (*resp->oh.zesInit)(0);
if (ret != ZE_RESULT_SUCCESS)
{
LOG(resp->oh.verbose, "zesInit err: %d\n", ret);
UNLOAD_LIBRARY(resp->oh.handle);
resp->oh.handle = NULL;
snprintf(buf, buflen, "oneapi vram init failure: %d", ret);
resp->err = strdup(buf);
}
(*resp->oh.zesDriverGet)(&resp->num_devices, NULL);
return;
}
void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
{
ze_result_t ret;
resp->err = NULL;
uint64_t totalMem = 0;
uint64_t usedMem = 0;
const int buflen = 256;
char buf[buflen + 1];
int i, d, m;
if (h.handle == NULL)
{
resp->err = strdup("Level-Zero handle not initialized");
return;
}
uint32_t driversCount = 0;
ret = (*h.zesDriverGet)(&driversCount, NULL);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen, "unable to get driver count: %d", ret);
resp->err = strdup(buf);
return;
}
LOG(h.verbose, "discovered %d Level-Zero drivers\n", driversCount);
zes_driver_handle_t *allDrivers =
malloc(driversCount * sizeof(zes_driver_handle_t));
(*h.zesDriverGet)(&driversCount, allDrivers);
resp->total = 0;
resp->free = 0;
for (d = 0; d < driversCount; d++)
{
uint32_t deviceCount = 0;
ret = (*h.zesDeviceGet)(allDrivers[d], &deviceCount, NULL);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
free(allDrivers);
return;
}
LOG(h.verbose, "discovered %d Level-Zero devices\n", deviceCount);
zes_device_handle_t *devices =
malloc(deviceCount * sizeof(zes_device_handle_t));
(*h.zesDeviceGet)(allDrivers[d], &deviceCount, devices);
for (i = 0; i < deviceCount; i++)
{
zes_device_ext_properties_t ext_props;
ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
ext_props.pNext = NULL;
zes_device_properties_t props;
props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
props.pNext = &ext_props;
ret = (*h.zesDeviceGetProperties)(devices[i], &props);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen, "unable to get device properties: %d", ret);
resp->err = strdup(buf);
free(allDrivers);
free(devices);
return;
}
if (h.verbose)
{
// When in verbose mode, report more information about
// the card we discover.
LOG(h.verbose, "[%d] oneAPI device name: %s\n", i,
props.modelName);
LOG(h.verbose, "[%d] oneAPI brand: %s\n", i,
props.brandName);
LOG(h.verbose, "[%d] oneAPI vendor: %s\n", i,
props.vendorName);
LOG(h.verbose, "[%d] oneAPI S/N: %s\n", i,
props.serialNumber);
LOG(h.verbose, "[%d] oneAPI board number: %s\n", i,
props.boardNumber);
}
uint32_t memCount = 0;
ret = (*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, NULL);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen,
"unable to enumerate Level-Zero memory modules: %d", ret);
resp->err = strdup(buf);
free(allDrivers);
free(devices);
return;
}
LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
(*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, mems);
for (m = 0; m < memCount; m++)
{
zes_mem_state_t state;
state.stype = ZES_STRUCTURE_TYPE_MEM_STATE;
state.pNext = NULL;
ret = (*h.zesMemoryGetState)(mems[m], &state);
if (ret != ZE_RESULT_SUCCESS)
{
snprintf(buf, buflen, "unable to get memory state: %d", ret);
resp->err = strdup(buf);
free(allDrivers);
free(devices);
free(mems);
return;
}
resp->total += state.size;
resp->free += state.free;
}
free(mems);
}
free(devices);
}
free(allDrivers);
}
#endif // __APPLE__

211
gpu/gpu_info_oneapi.h Normal file
View File

@@ -0,0 +1,211 @@
#ifndef __APPLE__
#ifndef __GPU_INFO_ONEAPI_H__
#define __GPU_INFO_ONEAPI_H__
#include "gpu_info.h"
#define ZE_MAX_DEVICE_NAME 256
#define ZE_MAX_DEVICE_UUID_SIZE 16
#define ZES_STRING_PROPERTY_SIZE 64
#define ZE_BIT(_i) (1 << _i)
// Just enough typedef's to dlopen/dlsym for memory information
typedef enum ze_result_t
{
ZE_RESULT_SUCCESS = 0,
// Other values omitted for now...
} ze_result_t;
typedef uint8_t ze_bool_t;
typedef struct _zes_driver_handle_t *zes_driver_handle_t;
typedef struct _zes_device_handle_t *zes_device_handle_t;
typedef struct _zes_mem_handle_t *zes_mem_handle_t;
typedef enum _ze_structure_type_t
{
ZE_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
} ze_structure_type_t;
typedef enum _zes_structure_type_t
{
ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES = 0x1,
ZES_STRUCTURE_TYPE_MEM_PROPERTIES = 0xb,
ZES_STRUCTURE_TYPE_MEM_STATE = 0x1e,
ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES = 0x2d,
ZES_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
} zes_structure_type_t;
typedef enum _zes_mem_type_t
{
ZES_MEM_TYPE_FORCE_UINT32 = 0x7fffffff
} zes_mem_type_t;
typedef enum _zes_mem_loc_t
{
ZES_MEM_LOC_SYSTEM = 0,
ZES_MEM_LOC_DEVICE = 1,
ZES_MEM_LOC_FORCE_UINT32 = 0x7fffffff
} zes_mem_loc_t;
typedef enum _zes_mem_health_t
{
ZES_MEM_HEALTH_FORCE_UINT32 = 0x7fffffff
} zes_mem_health_t;
typedef struct _ze_device_uuid_t
{
uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
} ze_device_uuid_t;
typedef struct _zes_uuid_t
{
uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
} zes_uuid_t;
typedef enum _ze_device_type_t
{
ZE_DEVICE_TYPE_GPU = 1,
ZE_DEVICE_TYPE_CPU = 2,
ZE_DEVICE_TYPE_FPGA = 3,
ZE_DEVICE_TYPE_MCA = 4,
ZE_DEVICE_TYPE_VPU = 5,
ZE_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff
} ze_device_type_t;
typedef enum _zes_device_type_t
{
ZES_DEVICE_TYPE_GPU = 1,
ZES_DEVICE_TYPE_CPU = 2,
ZES_DEVICE_TYPE_FPGA = 3,
ZES_DEVICE_TYPE_MCA = 4,
ZES_DEVICE_TYPE_VPU = 5,
ZES_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff
} zes_device_type_t;
typedef uint32_t ze_device_property_flags_t;
typedef enum _ze_device_property_flag_t
{
ZE_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
ZE_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
ZE_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
ZE_DEVICE_PROPERTY_FLAG_ONDEMANDPAGING = ZE_BIT(3),
ZE_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff
} ze_device_property_flag_t;
typedef uint32_t zes_device_property_flags_t;
typedef enum _zes_device_property_flag_t
{
ZES_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
ZES_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
ZES_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
ZES_DEVICE_PROPERTY_FLAG_ONDEMANDPAGING = ZE_BIT(3),
ZES_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff
} zes_device_property_flag_t;
typedef struct _ze_device_properties_t
{
ze_structure_type_t stype;
void *pNext;
ze_device_type_t type;
uint32_t vendorId;
uint32_t deviceId;
ze_device_property_flags_t flags;
uint32_t subdeviceId;
uint32_t coreClockRate;
uint64_t maxMemAllocSize;
uint32_t maxHardwareContexts;
uint32_t maxCommandQueuePriority;
uint32_t numThreadsPerEU;
uint32_t physicalEUSimdWidth;
uint32_t numEUsPerSubslice;
uint32_t numSubslicesPerSlice;
uint32_t numSlices;
uint64_t timerResolution;
uint32_t timestampValidBits;
uint32_t kernelTimestampValidBits;
ze_device_uuid_t uuid;
char name[ZE_MAX_DEVICE_NAME];
} ze_device_properties_t;
typedef struct _zes_device_properties_t
{
zes_structure_type_t stype;
void *pNext;
ze_device_properties_t core;
uint32_t numSubdevices;
char serialNumber[ZES_STRING_PROPERTY_SIZE];
char boardNumber[ZES_STRING_PROPERTY_SIZE];
char brandName[ZES_STRING_PROPERTY_SIZE];
char modelName[ZES_STRING_PROPERTY_SIZE];
char vendorName[ZES_STRING_PROPERTY_SIZE];
char driverVersion[ZES_STRING_PROPERTY_SIZE];
} zes_device_properties_t;
typedef struct _zes_device_ext_properties_t
{
zes_structure_type_t stype;
void *pNext;
zes_uuid_t uuid;
zes_device_type_t type;
zes_device_property_flags_t flags;
} zes_device_ext_properties_t;
typedef struct _zes_mem_properties_t
{
zes_structure_type_t stype;
void *pNext;
zes_mem_type_t type;
ze_bool_t onSubdevice;
uint32_t subdeviceId;
zes_mem_loc_t location;
uint64_t physicalSize;
int32_t busWidth;
int32_t numChannels;
} zes_mem_properties_t;
typedef struct _zes_mem_state_t
{
zes_structure_type_t stype;
const void *pNext;
zes_mem_health_t health;
uint64_t free;
uint64_t size;
} zes_mem_state_t;
typedef struct oneapi_handle
{
void *handle;
uint16_t verbose;
ze_result_t (*zesInit)(int);
ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers);
ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount,
zes_device_handle_t *phDevices);
ze_result_t (*zesDeviceGetProperties)(zes_device_handle_t hDevice,
zes_device_properties_t *pProperties);
ze_result_t (*zesDeviceEnumMemoryModules)(zes_device_handle_t hDevice,
uint32_t *pCount,
zes_mem_handle_t *phMemory);
ze_result_t (*zesMemoryGetProperties)(zes_mem_handle_t hMemory,
zes_mem_properties_t *pProperties);
ze_result_t (*zesMemoryGetState)(zes_mem_handle_t hMemory,
zes_mem_state_t *pState);
} oneapi_handle_t;
typedef struct oneapi_init_resp
{
char *err; // If err is non-null handle is invalid
int num_devices;
oneapi_handle_t oh;
} oneapi_init_resp_t;
typedef struct oneapi_version_resp
{
ze_result_t status;
char *str; // Contains version or error string if status != 0
} oneapi_version_resp_t;
void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp);
void oneapi_check_vram(oneapi_handle_t rh, mem_info_t *resp);
#endif // __GPU_INFO_INTEL_H__
#endif // __APPLE__

21
gpu/gpu_oneapi.go Normal file
View File

@@ -0,0 +1,21 @@
//go:build linux || windows
package gpu
import (
"log/slog"
"strings"
)
func oneapiGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "oneapi" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("oneapiGetVisibleDevicesEnv skipping over non-sycl device", "library", info.Library)
continue
}
ids = append(ids, info.ID)
}
return "ONEAPI_DEVICE_SELECTOR", "level_zero:" + strings.Join(ids, ",")
}

View File

@@ -140,7 +140,6 @@ struct server_slot {
std::vector<llama_token> cache_tokens; std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs; std::vector<completion_token_output> generated_token_probs;
bool infill = false;
bool embedding = false; bool embedding = false;
bool has_next_token = true; bool has_next_token = true;
bool truncated = false; bool truncated = false;
@@ -187,7 +186,6 @@ struct server_slot {
n_past = 0; n_past = 0;
n_sent_text = 0; n_sent_text = 0;
n_sent_token_probs = 0; n_sent_token_probs = 0;
infill = false;
ga_i = 0; ga_i = 0;
n_past_se = 0; n_past_se = 0;
@@ -600,16 +598,6 @@ struct llama_server_context
slot->params.n_predict = slot->n_predict; slot->params.n_predict = slot->n_predict;
} }
// infill
if (data.count("input_prefix") != 0)
{
slot->params.input_prefix = data["input_prefix"];
}
else
{
slot->params.input_prefix = "";
}
if (data.count("input_suffix") != 0) if (data.count("input_suffix") != 0)
{ {
slot->params.input_suffix = data["input_suffix"]; slot->params.input_suffix = data["input_suffix"];
@@ -897,15 +885,6 @@ struct llama_server_context
system_need_update = true; system_need_update = true;
} }
void system_prompt_process(const json &sys_props) {
system_prompt = sys_props.value("prompt", "");
name_user = sys_props.value("anti_prompt", "");
name_assistant = sys_props.value("assistant_name", "");
system_prompt_notify();
}
static size_t find_stopping_strings(const std::string &text, const size_t last_token_size, static size_t find_stopping_strings(const std::string &text, const size_t last_token_size,
const stop_type type, server_slot &slot) const stop_type type, server_slot &slot)
{ {
@@ -1263,13 +1242,12 @@ struct llama_server_context
queue_results.send(res); queue_results.send(res);
} }
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id) void request_completion(int task_id, json data, bool embedding, int multitask_id)
{ {
task_server task; task_server task;
task.id = task_id; task.id = task_id;
task.target_id = 0; task.target_id = 0;
task.data = std::move(data); task.data = std::move(data);
task.infill_mode = infill;
task.embedding_mode = embedding; task.embedding_mode = embedding;
task.type = TASK_TYPE_COMPLETION; task.type = TASK_TYPE_COMPLETION;
task.multitask_id = multitask_id; task.multitask_id = multitask_id;
@@ -1415,8 +1393,8 @@ struct llama_server_context
json subtask_data = multiprompt_task.data; json subtask_data = multiprompt_task.data;
subtask_data["prompt"] = subtask_data["prompt"][i]; subtask_data["prompt"] = subtask_data["prompt"][i];
// subtasks inherit everything else (infill mode, embedding mode, etc.) // subtasks inherit everything else (embedding mode, etc.)
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id); request_completion(subtask_ids[i], subtask_data, multiprompt_task.embedding_mode, multitask_id);
} }
} }
@@ -1434,26 +1412,8 @@ struct llama_server_context
break; break;
} }
if (task.data.contains("system_prompt"))
{
if (!all_slots_are_idle) {
send_error(task, "system prompt can only be updated when all slots are idle");
break;
}
system_prompt_process(task.data["system_prompt"]);
// reset cache_tokens for all slots
for (server_slot &slot : slots)
{
slot.cache_tokens.clear();
slot.n_past = 0;
slot.n_past_se = 0;
}
}
slot->reset(); slot->reset();
slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode; slot->embedding = task.embedding_mode;
slot->task_id = task.id; slot->task_id = task.id;
slot->multitask_id = task.multitask_id; slot->multitask_id = task.multitask_id;
@@ -1679,8 +1639,7 @@ struct llama_server_context
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty(); const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty();
// empty prompt passed -> release the slot and send empty response // empty prompt passed -> release the slot and send empty response
// note: infill mode allows empty prompt if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt)
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt && !slot.infill)
{ {
slot.release(); slot.release();
slot.print_timings(); slot.print_timings();
@@ -1697,33 +1656,7 @@ struct llama_server_context
slot.t_start_process_prompt = ggml_time_us(); slot.t_start_process_prompt = ggml_time_us();
slot.t_start_genereration = 0; slot.t_start_genereration = 0;
if (slot.infill) prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
{
bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1)
{
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
const int space_token = 29871; // TODO: this should not be hardcoded
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin());
}
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(model));
prompt_tokens = prefix_tokens;
}
else
{
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
}
slot.n_prompt_tokens = prompt_tokens.size(); slot.n_prompt_tokens = prompt_tokens.size();
@@ -2130,8 +2063,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf("\n"); printf("\n");
} }
static void server_params_parse(int argc, char **argv, server_params &sparams, static void server_params_parse(int argc, char **argv, server_params &sparams, gpt_params &params)
gpt_params &params, llama_server_context& llama)
{ {
gpt_params default_params; gpt_params default_params;
server_params default_sparams; server_params default_sparams;
@@ -2546,27 +2478,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
params.n_predict = std::stoi(argv[i]); params.n_predict = std::stoi(argv[i]);
} }
else if (arg == "-spf" || arg == "--system-prompt-file")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
break;
}
std::string systm_content;
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(systm_content)
);
llama.system_prompt_process(json::parse(systm_content));
}
else if (arg == "-ctk" || arg == "--cache-type-k") { else if (arg == "-ctk" || arg == "--cache-type-k") {
params.cache_type_k = argv[++i]; params.cache_type_k = argv[++i];
} }
@@ -2818,7 +2729,7 @@ int main(int argc, char **argv) {
// struct that contains llama context and inference // struct that contains llama context and inference
llama_server_context llama; llama_server_context llama;
server_params_parse(argc, argv, sparams, params, llama); server_params_parse(argc, argv, sparams, params);
if (params.model_alias == "unknown") if (params.model_alias == "unknown")
{ {
@@ -3150,7 +3061,7 @@ int main(int argc, char **argv) {
json data = json::parse(req.body); json data = json::parse(req.body);
const int task_id = llama.queue_tasks.get_new_id(); const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id); llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1); llama.request_completion(task_id, data, false, -1);
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(task_id);
@@ -3272,7 +3183,7 @@ int main(int argc, char **argv) {
// create and queue the task // create and queue the task
const int task_id = llama.queue_tasks.get_new_id(); const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id); llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1); llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);
// get the result // get the result
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(task_id);

View File

@@ -32,42 +32,43 @@ case "${GOARCH}" in
echo "Building static library" echo "Building static library"
build build
if [ -z "$OLLAMA_SKIP_CPU_GENERATE" ]; then
#
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu"
echo "Building LCD CPU"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
# #
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta) # ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
# # Approximately 400% faster than LCD on same CPU
init_vars #
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" init_vars
BUILD_DIR="../build/darwin/${ARCH}/cpu" CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
echo "Building LCD CPU" BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
build echo "Building AVX CPU"
sign ${BUILD_DIR}/bin/ollama_llama_server build
compress sign ${BUILD_DIR}/bin/ollama_llama_server
compress
# #
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance # ~2013 CPU Dynamic library
# Approximately 400% faster than LCD on same CPU # Approximately 10% faster than AVX on same CPU
# #
init_vars init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx" BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
echo "Building AVX CPU" echo "Building AVX2 CPU"
build EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
sign ${BUILD_DIR}/bin/ollama_llama_server build
compress sign ${BUILD_DIR}/bin/ollama_llama_server
compress
# fi
# ~2013 CPU Dynamic library
# Approximately 10% faster than AVX on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
echo "Building AVX2 CPU"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
;; ;;
"arm64") "arm64")
@@ -79,13 +80,15 @@ case "${GOARCH}" in
echo "Building static library" echo "Building static library"
build build
init_vars if [ -z "$OLLAMA_SKIP_METAL_GENERATE" ]; then
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}" init_vars
BUILD_DIR="../build/darwin/${ARCH}/metal" CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders" BUILD_DIR="../build/darwin/${ARCH}/metal"
build EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
sign ${BUILD_DIR}/bin/ollama_llama_server build
compress sign ${BUILD_DIR}/bin/ollama_llama_server
compress
fi
;; ;;
*) *)
echo "GOARCH must be set" echo "GOARCH must be set"

View File

@@ -215,6 +215,36 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then
fi fi
if [ -z "${ONEAPI_ROOT}" ]; then
# Try the default location in case it exists
ONEAPI_ROOT=/opt/intel/oneapi
fi
if [ -d "${ONEAPI_ROOT}" ]; then
echo "OneAPI libraries detected - building dynamic OneAPI library"
init_vars
source ${ONEAPI_ROOT}/setvars.sh --force # set up environment variables for oneAPI
CC=icx
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_SYCL=ON -DLLAMA_SYCL_F16=OFF"
BUILD_DIR="../build/linux/${ARCH}/oneapi"
EXTRA_LIBS="-fsycl -Wl,-rpath,${ONEAPI_ROOT}/compiler/latest/lib,-rpath,${ONEAPI_ROOT}/mkl/latest/lib,-rpath,${ONEAPI_ROOT}/tbb/latest/lib,-rpath,${ONEAPI_ROOT}/compiler/latest/opt/oclfpga/linux64/lib -lOpenCL -lmkl_core -lmkl_sycl_blas -lmkl_intel_ilp64 -lmkl_tbb_thread -ltbb"
DEBUG_FLAGS="" # icx compiles with -O0 if we pass -g, so we must remove it
build
# copy oneAPI dependencies
for dep in $(ldd "${BUILD_DIR}/bin/ollama_llama_server" | grep "=>" | cut -f2 -d= | cut -f2 -d' ' | grep -e sycl -e mkl -e tbb); do
cp "${dep}" "${BUILD_DIR}/bin/"
done
cp "${ONEAPI_ROOT}/compiler/latest/lib/libOpenCL.so" "${BUILD_DIR}/bin/"
cp "${ONEAPI_ROOT}/compiler/latest/lib/libimf.so" "${BUILD_DIR}/bin/"
cp "${ONEAPI_ROOT}/compiler/latest/lib/libintlc.so.5" "${BUILD_DIR}/bin/"
cp "${ONEAPI_ROOT}/compiler/latest/lib/libirng.so" "${BUILD_DIR}/bin/"
cp "${ONEAPI_ROOT}/compiler/latest/lib/libpi_level_zero.so" "${BUILD_DIR}/bin/"
cp "${ONEAPI_ROOT}/compiler/latest/lib/libsvml.so" "${BUILD_DIR}/bin/"
cp "${ONEAPI_ROOT}/compiler/latest/lib/libur_loader.so.0" "${BUILD_DIR}/bin/"
compress
fi
if [ -z "${ROCM_PATH}" ]; then if [ -z "${ROCM_PATH}" ]; then
# Try the default location in case it exists # Try the default location in case it exists
ROCM_PATH=/opt/rocm ROCM_PATH=/opt/rocm

View File

@@ -25,6 +25,7 @@ function amdGPUs {
"gfx1030" "gfx1030"
"gfx1031" "gfx1031"
"gfx1032" "gfx1032"
"gfx1033"
"gfx1034" "gfx1034"
"gfx1035" "gfx1035"
"gfx1036" "gfx1036"
@@ -299,6 +300,49 @@ function build_cuda() {
} }
} }
function build_oneapi() {
if ((-not "${env:OLLAMA_SKIP_CUDA_GENERATE}") -and ("${env:ONEAPI_ROOT}")) {
# Get oneAPI version
$script:ONEAPI_VERSION = icpx --version
$script:ONEAPI_VERSION = [regex]::Match($script:ONEAPI_VERSION, '(?<=oneAPI DPC\+\+/C\+\+ Compiler )(?<version>\d+\.\d+\.\d+)').Value
if ($null -ne $script:ONEAPI_VERSION) {
$script:ONEAPI_VARIANT = "_v" + $script:ONEAPI_VERSION
}
init_vars
$script:buildDir = "../build/windows/${script:ARCH}/oneapi$script:ONEAPI_VARIANT"
$script:distDir ="$script:DIST_BASE\oneapi$script:ONEAPI_VARIANT"
$script:cmakeDefs += @(
"-G", "MinGW Makefiles",
"-DLLAMA_SYCL=ON",
"-DCMAKE_C_COMPILER=icx",
"-DCMAKE_CXX_COMPILER=icx",
"-DCMAKE_BUILD_TYPE=Release"
)
Write-Host "Building oneAPI"
build
# Ninja doesn't prefix with config name
if ($null -ne $script:DUMPBIN) {
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | Select-String ".dll"
}
sign
install
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libirngmd.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libmmd.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_level_zero.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_unified_runtime.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_win_proxy_loader.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\svml_dispmd.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\sycl7.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_core.2.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_sycl_blas.4.dll" "${script:distDir}"
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_tbb_thread.2.dll" "${script:distDir}"
} else {
Write-Host "Skipping oneAPI generation step"
}
}
function build_rocm() { function build_rocm() {
if ((-not "${env:OLLAMA_SKIP_ROCM_GENERATE}") -and ("${env:HIP_PATH}")) { if ((-not "${env:OLLAMA_SKIP_ROCM_GENERATE}") -and ("${env:HIP_PATH}")) {
$script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename $script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
@@ -366,6 +410,7 @@ if ($($args.count) -eq 0) {
build_cpu_avx build_cpu_avx
build_cpu_avx2 build_cpu_avx2
build_cuda build_cuda
build_oneapi
build_rocm build_rocm
} }

View File

@@ -125,9 +125,9 @@ type Tensor struct {
func (t Tensor) blockSize() uint64 { func (t Tensor) blockSize() uint64 {
switch t.Kind { switch t.Kind {
case 0, 1, 24, 25, 26, 27, 28, 31: // F32, F16, I8, I16, I32, I64, F64, BF16 case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
return 1 return 1
case 2, 3, 8, 9, 20: // Q4_0, Q4_1, Q8_0, Q8_1, IQ4_NL case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL
return 32 return 32
default: // All others default: // All others
return 256 return 256

View File

@@ -7,7 +7,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/envconfig"
) )
// This algorithm looks for a complete fit to determine if we need to unload other models // This algorithm looks for a complete fit to determine if we need to unload other models

View File

@@ -1,35 +1,32 @@
From d02a06f3f45a09255ace8684a66590e06ce44605 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Thu, 23 May 2024 11:33:20 -0700
Subject: [PATCH] default pretokenizer on unrecognized type
---
llama.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/llama.cpp b/llama.cpp diff --git a/llama.cpp b/llama.cpp
index 15c66077..af1aede3 100644 index 40d2ec2c..74f3ee9c 100644
--- a/llama.cpp --- a/llama.cpp
+++ b/llama.cpp +++ b/llama.cpp
@@ -4504,9 +4504,6 @@ static void llm_load_vocab( @@ -4642,16 +4642,7 @@ static void llm_load_vocab(
LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
LLAMA_LOG_WARN("%s: \n", __func__); // for now, only BPE models have pre-tokenizers
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
- } else if ( - if (tokenizer_pre.empty()) {
- tokenizer_pre == "default") { - LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
- LLAMA_LOG_WARN("%s: \n", __func__);
- LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
- LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__);
- LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__);
- LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
- LLAMA_LOG_WARN("%s: \n", __func__);
- vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
- } else if (
+ if (
tokenizer_pre == "default") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if ( } else if (
tokenizer_pre == "llama3" || @@ -4703,7 +4694,8 @@ static void llm_load_vocab(
tokenizer_pre == "llama-v3" || tokenizer_pre == "smaug-bpe") {
@@ -4553,7 +4550,7 @@ static void llm_load_vocab( vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG;
tokenizer_pre == "dbrx") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX;
} else { } else {
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); - throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} }
} else { } else {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
--
2.45.1

View File

@@ -24,9 +24,9 @@ import (
"golang.org/x/sync/semaphore" "golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/server/envconfig"
) )
type LlamaServer interface { type LlamaServer interface {
@@ -189,35 +189,38 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--memory-f32") params = append(params, "--memory-f32")
} }
if opts.UseMLock { flashAttnEnabled := envconfig.FlashAttention
params = append(params, "--mlock")
for _, g := range gpus {
// only cuda (compute capability 7+) and metal support flash attention
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnEnabled = false
}
// mmap has issues with partial offloading on metal
if g.Library == "metal" &&
uint64(opts.NumGPU) > 0 &&
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
opts.UseMMap = false
}
}
if flashAttnEnabled {
params = append(params, "--flash-attn")
} }
if !opts.UseMMap { if !opts.UseMMap {
params = append(params, "--no-mmap") params = append(params, "--no-mmap")
} }
if opts.UseMLock {
params = append(params, "--mlock")
}
if opts.UseNUMA { if opts.UseNUMA {
params = append(params, "--numa") params = append(params, "--numa")
} }
flashAttnEnabled := envconfig.FlashAttention
// partial offloading does not support flash attention
if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
flashAttnEnabled = false
}
// only cuda (compute capability 7+) and metal support flash attention
for _, g := range gpus {
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnEnabled = false
}
}
if flashAttnEnabled {
params = append(params, "--flash-attn")
}
numParallel := envconfig.NumParallel numParallel := envconfig.NumParallel
// TODO (jmorganca): multimodal models don't support parallel yet // TODO (jmorganca): multimodal models don't support parallel yet
@@ -243,7 +246,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
gpuCount = 0 gpuCount = 0
} }
// Find an availableServers port, retry on each iterration in case the failure was a port conflict race // Find an availableServers port, retry on each iteration in case the failure was a port conflict race
port := 0 port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
var l *net.TCPListener var l *net.TCPListener
@@ -519,16 +522,18 @@ func (s *llmServer) Ping(ctx context.Context) error {
func (s *llmServer) WaitUntilRunning(ctx context.Context) error { func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
start := time.Now() start := time.Now()
stallDuration := 60 * time.Second stallDuration := 5 * time.Minute // If no progress happens
stallTimer := time.Now().Add(stallDuration) // give up if we stall for finalLoadDuration := 5 * time.Minute // After we hit 100%, give the runner more time to come online
stallTimer := time.Now().Add(stallDuration) // give up if we stall
slog.Info("waiting for llama runner to start responding") slog.Info("waiting for llama runner to start responding")
var lastStatus ServerStatus = -1 var lastStatus ServerStatus = -1
fullyLoaded := false
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
slog.Info("context expired before server started") slog.Warn("client connection closed before server finished loading, aborting load")
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err()) return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
case err := <-s.done: case err := <-s.done:
msg := "" msg := ""
@@ -572,6 +577,10 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
if priorProgress != s.loadProgress { if priorProgress != s.loadProgress {
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress)) slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
stallTimer = time.Now().Add(stallDuration) stallTimer = time.Now().Add(stallDuration)
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
stallTimer = time.Now().Add(finalLoadDuration)
fullyLoaded = true
} }
time.Sleep(time.Millisecond * 250) time.Sleep(time.Millisecond * 250)
continue continue
@@ -756,7 +765,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
var c completion var c completion
if err := json.Unmarshal(evt, &c); err != nil { if err := json.Unmarshal(evt, &c); err != nil {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err) return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
} }
switch { switch {

View File

@@ -5,16 +5,20 @@ import (
"os" "os"
"github.com/emirpasic/gods/lists/arraylist" "github.com/emirpasic/gods/lists/arraylist"
"github.com/mattn/go-runewidth"
"golang.org/x/term" "golang.org/x/term"
) )
type Buffer struct { type Buffer struct {
Pos int DisplayPos int
Buf *arraylist.List Pos int
Prompt *Prompt Buf *arraylist.List
LineWidth int //LineHasSpace is an arraylist of bools to keep track of whether a line has a space at the end
Width int LineHasSpace *arraylist.List
Height int Prompt *Prompt
LineWidth int
Width int
Height int
} }
func NewBuffer(prompt *Prompt) (*Buffer, error) { func NewBuffer(prompt *Prompt) (*Buffer, error) {
@@ -27,25 +31,57 @@ func NewBuffer(prompt *Prompt) (*Buffer, error) {
lwidth := width - len(prompt.prompt()) lwidth := width - len(prompt.prompt())
b := &Buffer{ b := &Buffer{
Pos: 0, DisplayPos: 0,
Buf: arraylist.New(), Pos: 0,
Prompt: prompt, Buf: arraylist.New(),
Width: width, LineHasSpace: arraylist.New(),
Height: height, Prompt: prompt,
LineWidth: lwidth, Width: width,
Height: height,
LineWidth: lwidth,
} }
return b, nil return b, nil
} }
func (b *Buffer) GetLineSpacing(line int) bool {
hasSpace, _ := b.LineHasSpace.Get(line)
if hasSpace == nil {
return false
}
return hasSpace.(bool)
}
func (b *Buffer) MoveLeft() { func (b *Buffer) MoveLeft() {
if b.Pos > 0 { if b.Pos > 0 {
if b.Pos%b.LineWidth == 0 { //asserts that we retrieve a rune
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width)) if e, ok := b.Buf.Get(b.Pos - 1); ok {
} else { if r, ok := e.(rune); ok {
fmt.Print(CursorLeft) rLength := runewidth.RuneWidth(r)
if b.DisplayPos%b.LineWidth == 0 {
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
if rLength == 2 {
fmt.Print(CursorLeft)
}
line := b.DisplayPos/b.LineWidth - 1
hasSpace := b.GetLineSpacing(line)
if hasSpace {
b.DisplayPos -= 1
fmt.Print(CursorLeft)
}
} else {
fmt.Print(cursorLeftN(rLength))
}
b.Pos -= 1
b.DisplayPos -= rLength
}
} }
b.Pos -= 1
} }
} }
@@ -71,18 +107,35 @@ func (b *Buffer) MoveLeftWord() {
} }
func (b *Buffer) MoveRight() { func (b *Buffer) MoveRight() {
if b.Pos < b.Size() { if b.Pos < b.Buf.Size() {
b.Pos += 1 if e, ok := b.Buf.Get(b.Pos); ok {
if b.Pos%b.LineWidth == 0 { if r, ok := e.(rune); ok {
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt()))) rLength := runewidth.RuneWidth(r)
} else { b.Pos += 1
fmt.Print(CursorRight) hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth)
b.DisplayPos += rLength
if b.DisplayPos%b.LineWidth == 0 {
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
} else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace {
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())+rLength))
b.DisplayPos += 1
} else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace {
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
b.DisplayPos += 1
} else {
fmt.Print(cursorRightN(rLength))
}
}
} }
} }
} }
func (b *Buffer) MoveRightWord() { func (b *Buffer) MoveRightWord() {
if b.Pos < b.Size() { if b.Pos < b.Buf.Size() {
for { for {
b.MoveRight() b.MoveRight()
v, _ := b.Buf.Get(b.Pos) v, _ := b.Buf.Get(b.Pos)
@@ -90,7 +143,7 @@ func (b *Buffer) MoveRightWord() {
break break
} }
if b.Pos == b.Size() { if b.Pos == b.Buf.Size() {
break break
} }
} }
@@ -99,7 +152,7 @@ func (b *Buffer) MoveRightWord() {
func (b *Buffer) MoveToStart() { func (b *Buffer) MoveToStart() {
if b.Pos > 0 { if b.Pos > 0 {
currLine := b.Pos / b.LineWidth currLine := b.DisplayPos / b.LineWidth
if currLine > 0 { if currLine > 0 {
for cnt := 0; cnt < currLine; cnt++ { for cnt := 0; cnt < currLine; cnt++ {
fmt.Print(CursorUp) fmt.Print(CursorUp)
@@ -107,81 +160,195 @@ func (b *Buffer) MoveToStart() {
} }
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt()))) fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())))
b.Pos = 0 b.Pos = 0
b.DisplayPos = 0
} }
} }
func (b *Buffer) MoveToEnd() { func (b *Buffer) MoveToEnd() {
if b.Pos < b.Size() { if b.Pos < b.Buf.Size() {
currLine := b.Pos / b.LineWidth currLine := b.DisplayPos / b.LineWidth
totalLines := b.Size() / b.LineWidth totalLines := b.DisplaySize() / b.LineWidth
if currLine < totalLines { if currLine < totalLines {
for cnt := 0; cnt < totalLines-currLine; cnt++ { for cnt := 0; cnt < totalLines-currLine; cnt++ {
fmt.Print(CursorDown) fmt.Print(CursorDown)
} }
remainder := b.Size() % b.LineWidth remainder := b.DisplaySize() % b.LineWidth
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())+remainder)) fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())+remainder))
} else { } else {
fmt.Print(cursorRightN(b.Size() - b.Pos)) fmt.Print(cursorRightN(b.DisplaySize() - b.DisplayPos))
} }
b.Pos = b.Size() b.Pos = b.Buf.Size()
b.DisplayPos = b.DisplaySize()
} }
} }
func (b *Buffer) Size() int { func (b *Buffer) DisplaySize() int {
return b.Buf.Size() sum := 0
for i := 0; i < b.Buf.Size(); i++ {
if e, ok := b.Buf.Get(i); ok {
if r, ok := e.(rune); ok {
sum += runewidth.RuneWidth(r)
}
}
}
return sum
} }
func (b *Buffer) Add(r rune) { func (b *Buffer) Add(r rune) {
if b.Pos == b.Buf.Size() { if b.Pos == b.Buf.Size() {
fmt.Printf("%c", r) b.AddChar(r, false)
b.Buf.Add(r) } else {
b.Pos += 1 b.AddChar(r, true)
if b.Pos > 0 && b.Pos%b.LineWidth == 0 { }
}
func (b *Buffer) AddChar(r rune, insert bool) {
rLength := runewidth.RuneWidth(r)
b.DisplayPos += rLength
if b.Pos > 0 {
if b.DisplayPos%b.LineWidth == 0 {
fmt.Printf("%c", r)
fmt.Printf("\n%s", b.Prompt.AltPrompt) fmt.Printf("\n%s", b.Prompt.AltPrompt)
if insert {
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth-1, false)
} else {
b.LineHasSpace.Add(false)
}
// this case occurs when a double-width rune crosses the line boundary
} else if b.DisplayPos%b.LineWidth < (b.DisplayPos-rLength)%b.LineWidth {
if insert {
fmt.Print(ClearToEOL)
}
fmt.Printf("\n%s", b.Prompt.AltPrompt)
b.DisplayPos += 1
fmt.Printf("%c", r)
if insert {
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth-1, true)
} else {
b.LineHasSpace.Add(true)
}
} else {
fmt.Printf("%c", r)
} }
} else { } else {
fmt.Printf("%c", r) fmt.Printf("%c", r)
}
if insert {
b.Buf.Insert(b.Pos, r) b.Buf.Insert(b.Pos, r)
b.Pos += 1 } else {
if b.Pos > 0 && b.Pos%b.LineWidth == 0 { b.Buf.Add(r)
fmt.Printf("\n%s", b.Prompt.AltPrompt) }
}
b.Pos += 1
if insert {
b.drawRemaining() b.drawRemaining()
} }
} }
func (b *Buffer) countRemainingLineWidth(place int) int {
var sum int
counter := -1
var prevLen int
for place <= b.LineWidth {
counter += 1
sum += prevLen
if e, ok := b.Buf.Get(b.Pos + counter); ok {
if r, ok := e.(rune); ok {
place += runewidth.RuneWidth(r)
prevLen = len(string(r))
}
} else {
break
}
}
return sum
}
func (b *Buffer) drawRemaining() { func (b *Buffer) drawRemaining() {
var place int var place int
remainingText := b.StringN(b.Pos) remainingText := b.StringN(b.Pos)
if b.Pos > 0 { if b.Pos > 0 {
place = b.Pos % b.LineWidth place = b.DisplayPos % b.LineWidth
} }
fmt.Print(CursorHide) fmt.Print(CursorHide)
// render the rest of the current line // render the rest of the current line
currLine := remainingText[:min(b.LineWidth-place, len(remainingText))] currLineLength := b.countRemainingLineWidth(place)
currLine := remainingText[:min(currLineLength, len(remainingText))]
currLineSpace := runewidth.StringWidth(currLine)
remLength := runewidth.StringWidth(remainingText)
if len(currLine) > 0 { if len(currLine) > 0 {
fmt.Printf(ClearToEOL + currLine) fmt.Printf(ClearToEOL + currLine)
fmt.Print(cursorLeftN(len(currLine))) fmt.Print(cursorLeftN(currLineSpace))
} else { } else {
fmt.Print(ClearToEOL) fmt.Print(ClearToEOL)
} }
if currLineSpace != b.LineWidth-place && currLineSpace != remLength {
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth, true)
} else if currLineSpace != b.LineWidth-place {
b.LineHasSpace.Remove(b.DisplayPos / b.LineWidth)
} else {
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth, false)
}
if (b.DisplayPos+currLineSpace)%b.LineWidth == 0 && currLine == remainingText {
fmt.Print(cursorRightN(currLineSpace))
fmt.Printf("\n%s", b.Prompt.AltPrompt)
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width-currLineSpace))
}
// render the other lines // render the other lines
if len(remainingText) > len(currLine) { if remLength > currLineSpace {
remaining := []rune(remainingText[len(currLine):]) remaining := (remainingText[len(currLine):])
var totalLines int var totalLines int
for i, c := range remaining { var displayLength int
if i%b.LineWidth == 0 { var lineLength int = currLineSpace
for _, c := range remaining {
if displayLength == 0 || (displayLength+runewidth.RuneWidth(c))%b.LineWidth < displayLength%b.LineWidth {
fmt.Printf("\n%s", b.Prompt.AltPrompt) fmt.Printf("\n%s", b.Prompt.AltPrompt)
totalLines += 1 totalLines += 1
if displayLength != 0 {
if lineLength == b.LineWidth {
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth+totalLines-1, false)
} else {
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth+totalLines-1, true)
}
}
lineLength = 0
} }
displayLength += runewidth.RuneWidth(c)
lineLength += runewidth.RuneWidth(c)
fmt.Printf("%c", c) fmt.Printf("%c", c)
} }
fmt.Print(ClearToEOL) fmt.Print(ClearToEOL)
fmt.Print(cursorUpN(totalLines)) fmt.Print(cursorUpN(totalLines))
fmt.Printf(CursorBOL + cursorRightN(b.Width-len(currLine))) fmt.Printf(CursorBOL + cursorRightN(b.Width-currLineSpace))
hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth)
if hasSpace && b.DisplayPos%b.LineWidth != b.LineWidth-1 {
fmt.Print(CursorLeft)
}
} }
fmt.Print(CursorShow) fmt.Print(CursorShow)
@@ -189,46 +356,84 @@ func (b *Buffer) drawRemaining() {
func (b *Buffer) Remove() { func (b *Buffer) Remove() {
if b.Buf.Size() > 0 && b.Pos > 0 { if b.Buf.Size() > 0 && b.Pos > 0 {
if b.Pos%b.LineWidth == 0 {
// if the user backspaces over the word boundary, do this magic to clear the line
// and move to the end of the previous line
fmt.Printf(CursorBOL + ClearToEOL)
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width) + " " + CursorLeft)
} else {
fmt.Printf(CursorLeft + " " + CursorLeft)
}
var eraseExtraLine bool if e, ok := b.Buf.Get(b.Pos - 1); ok {
if (b.Size()-1)%b.LineWidth == 0 { if r, ok := e.(rune); ok {
eraseExtraLine = true rLength := runewidth.RuneWidth(r)
} hasSpace := b.GetLineSpacing(b.DisplayPos/b.LineWidth - 1)
b.Pos -= 1 if b.DisplayPos%b.LineWidth == 0 {
b.Buf.Remove(b.Pos) // if the user backspaces over the word boundary, do this magic to clear the line
// and move to the end of the previous line
fmt.Printf(CursorBOL + ClearToEOL)
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
if b.Pos < b.Size() { if b.DisplaySize()%b.LineWidth < (b.DisplaySize()-rLength)%b.LineWidth {
b.drawRemaining() b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
// this erases a line which is left over when backspacing in the middle of a line and there }
// are trailing characters which go over the line width boundary
if eraseExtraLine { if hasSpace {
remainingLines := (b.Size() - b.Pos) / b.LineWidth b.DisplayPos -= 1
fmt.Printf(cursorDownN(remainingLines+1) + CursorBOL + ClearToEOL) fmt.Print(CursorLeft)
place := b.Pos % b.LineWidth }
fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.prompt())))
if rLength == 2 {
fmt.Print(CursorLeft + " " + cursorLeftN(2))
} else {
fmt.Print(" " + CursorLeft)
}
} else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace {
fmt.Printf(CursorBOL + ClearToEOL)
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
if b.Pos == b.Buf.Size() {
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
}
b.DisplayPos -= 1
} else {
fmt.Print(cursorLeftN(rLength))
for i := 0; i < rLength; i++ {
fmt.Print(" ")
}
fmt.Print(cursorLeftN(rLength))
}
var eraseExtraLine bool
if (b.DisplaySize()-1)%b.LineWidth == 0 || (rLength == 2 && ((b.DisplaySize()-2)%b.LineWidth == 0)) || b.DisplaySize()%b.LineWidth == 0 {
eraseExtraLine = true
}
b.Pos -= 1
b.DisplayPos -= rLength
b.Buf.Remove(b.Pos)
if b.Pos < b.Buf.Size() {
b.drawRemaining()
// this erases a line which is left over when backspacing in the middle of a line and there
// are trailing characters which go over the line width boundary
if eraseExtraLine {
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
fmt.Printf(cursorDownN(remainingLines+1) + CursorBOL + ClearToEOL)
place := b.DisplayPos % b.LineWidth
fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.prompt())))
}
}
} }
} }
} }
} }
func (b *Buffer) Delete() { func (b *Buffer) Delete() {
if b.Size() > 0 && b.Pos < b.Size() { if b.Buf.Size() > 0 && b.Pos < b.Buf.Size() {
b.Buf.Remove(b.Pos) b.Buf.Remove(b.Pos)
b.drawRemaining() b.drawRemaining()
if b.Size()%b.LineWidth == 0 { if b.DisplaySize()%b.LineWidth == 0 {
if b.Pos != b.Size() { if b.DisplayPos != b.DisplaySize() {
remainingLines := (b.Size() - b.Pos) / b.LineWidth remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
fmt.Printf(cursorDownN(remainingLines) + CursorBOL + ClearToEOL) fmt.Printf(cursorDownN(remainingLines) + CursorBOL + ClearToEOL)
place := b.Pos % b.LineWidth place := b.DisplayPos % b.LineWidth
fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.prompt()))) fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.prompt())))
} }
} }
@@ -244,8 +449,8 @@ func (b *Buffer) DeleteBefore() {
} }
func (b *Buffer) DeleteRemaining() { func (b *Buffer) DeleteRemaining() {
if b.Size() > 0 && b.Pos < b.Size() { if b.DisplaySize() > 0 && b.Pos < b.DisplaySize() {
charsToDel := b.Size() - b.Pos charsToDel := b.Buf.Size() - b.Pos
for cnt := 0; cnt < charsToDel; cnt++ { for cnt := 0; cnt < charsToDel; cnt++ {
b.Delete() b.Delete()
} }
@@ -281,8 +486,10 @@ func (b *Buffer) ClearScreen() {
ph := b.Prompt.placeholder() ph := b.Prompt.placeholder()
fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault) fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault)
} else { } else {
currPos := b.Pos currPos := b.DisplayPos
currIndex := b.Pos
b.Pos = 0 b.Pos = 0
b.DisplayPos = 0
b.drawRemaining() b.drawRemaining()
fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.prompt()))) fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.prompt())))
if currPos > 0 { if currPos > 0 {
@@ -300,7 +507,8 @@ func (b *Buffer) ClearScreen() {
fmt.Printf(CursorBOL + b.Prompt.AltPrompt) fmt.Printf(CursorBOL + b.Prompt.AltPrompt)
} }
} }
b.Pos = currPos b.Pos = currIndex
b.DisplayPos = currPos
} }
} }
@@ -309,9 +517,20 @@ func (b *Buffer) IsEmpty() bool {
} }
func (b *Buffer) Replace(r []rune) { func (b *Buffer) Replace(r []rune) {
b.DisplayPos = 0
b.Pos = 0 b.Pos = 0
lineNums := b.DisplaySize() / b.LineWidth
b.Buf.Clear() b.Buf.Clear()
fmt.Printf(ClearLine + CursorBOL + b.Prompt.prompt())
fmt.Printf(CursorBOL + ClearToEOL)
for i := 0; i < lineNums; i++ {
fmt.Print(CursorUp + CursorBOL + ClearToEOL)
}
fmt.Printf(CursorBOL + b.Prompt.prompt())
for _, c := range r { for _, c := range r {
b.Add(c) b.Add(c)
} }
@@ -328,7 +547,7 @@ func (b *Buffer) StringN(n int) string {
func (b *Buffer) StringNM(n, m int) string { func (b *Buffer) StringNM(n, m int) string {
var s string var s string
if m == 0 { if m == 0 {
m = b.Size() m = b.Buf.Size()
} }
for cnt := n; cnt < m; cnt++ { for cnt := n; cnt < m; cnt++ {
c, _ := b.Buf.Get(cnt) c, _ := b.Buf.Get(cnt)

View File

@@ -150,7 +150,7 @@ func (i *Instance) Readline() (string, error) {
i.Pasting = false i.Pasting = false
} }
case KeyDel: case KeyDel:
if buf.Size() > 0 { if buf.DisplaySize() > 0 {
buf.Delete() buf.Delete()
} }
metaDel = true metaDel = true
@@ -202,7 +202,7 @@ func (i *Instance) Readline() (string, error) {
buf.Add(' ') buf.Add(' ')
} }
case CharDelete: case CharDelete:
if buf.Size() > 0 { if buf.DisplaySize() > 0 {
buf.Delete() buf.Delete()
} else { } else {
return "", io.EOF return "", io.EOF

View File

@@ -33,9 +33,11 @@ case "$ARCH" in
*) error "Unsupported architecture: $ARCH" ;; *) error "Unsupported architecture: $ARCH" ;;
esac esac
IS_WSL2=false
KERN=$(uname -r) KERN=$(uname -r)
case "$KERN" in case "$KERN" in
*icrosoft*WSL2 | *icrosoft*wsl2) ;; *icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
*icrosoft) error "Microsoft WSL1 is not currently supported. Please upgrade to WSL2 with 'wsl --set-version <distro> 2'" ;; *icrosoft) error "Microsoft WSL1 is not currently supported. Please upgrade to WSL2 with 'wsl --set-version <distro> 2'" ;;
*) ;; *) ;;
esac esac
@@ -72,7 +74,7 @@ status "Installing ollama to $BINDIR..."
$SUDO install -o0 -g0 -m755 -d $BINDIR $SUDO install -o0 -g0 -m755 -d $BINDIR
$SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama $SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama
install_success() { install_success() {
status 'The Ollama API is now available at 127.0.0.1:11434.' status 'The Ollama API is now available at 127.0.0.1:11434.'
status 'Install complete. Run "ollama" from the command line.' status 'Install complete. Run "ollama" from the command line.'
} }
@@ -131,6 +133,17 @@ if available systemctl; then
configure_systemd configure_systemd
fi fi
# WSL2 only supports GPUs via nvidia passthrough
# so check for nvidia-smi to determine if GPU is available
if [ "$IS_WSL2" = true ]; then
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
status "Nvidia GPU detected."
fi
install_success
exit 0
fi
# Install GPU dependencies on Linux
if ! available lspci && ! available lshw; then if ! available lspci && ! available lshw; then
warning "Unable to detect NVIDIA/AMD GPU. Install lspci or lshw to automatically detect and install GPU dependencies." warning "Unable to detect NVIDIA/AMD GPU. Install lspci or lshw to automatically detect and install GPU dependencies."
exit 0 exit 0
@@ -139,12 +152,12 @@ fi
check_gpu() { check_gpu() {
# Look for devices based on vendor ID for NVIDIA and AMD # Look for devices based on vendor ID for NVIDIA and AMD
case $1 in case $1 in
lspci) lspci)
case $2 in case $2 in
nvidia) available lspci && lspci -d '10de:' | grep -q 'NVIDIA' || return 1 ;; nvidia) available lspci && lspci -d '10de:' | grep -q 'NVIDIA' || return 1 ;;
amdgpu) available lspci && lspci -d '1002:' | grep -q 'AMD' || return 1 ;; amdgpu) available lspci && lspci -d '1002:' | grep -q 'AMD' || return 1 ;;
esac ;; esac ;;
lshw) lshw)
case $2 in case $2 in
nvidia) available lshw && $SUDO lshw -c display -numeric | grep -q 'vendor: .* \[10DE\]' || return 1 ;; nvidia) available lshw && $SUDO lshw -c display -numeric | grep -q 'vendor: .* \[10DE\]' || return 1 ;;
amdgpu) available lshw && $SUDO lshw -c display -numeric | grep -q 'vendor: .* \[1002\]' || return 1 ;; amdgpu) available lshw && $SUDO lshw -c display -numeric | grep -q 'vendor: .* \[1002\]' || return 1 ;;
@@ -181,7 +194,7 @@ if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
curl --fail --show-error --location --progress-bar "https://ollama.com/download/ollama-linux-amd64-rocm.tgz${VER_PARAM}" \ curl --fail --show-error --location --progress-bar "https://ollama.com/download/ollama-linux-amd64-rocm.tgz${VER_PARAM}" \
| $SUDO tar zx --owner ollama --group ollama -C /usr/share/ollama/lib/rocm . | $SUDO tar zx --owner ollama --group ollama -C /usr/share/ollama/lib/rocm .
install_success install_success
status "AMD GPU dependencies installed." status "AMD GPU ready."
exit 0 exit 0
fi fi
@@ -274,7 +287,7 @@ if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\
esac esac
fi fi
if ! lsmod | grep -q nvidia; then if ! lsmod | grep -q nvidia || ! lsmod | grep -q nvidia_uvm; then
KERNEL_RELEASE="$(uname -r)" KERNEL_RELEASE="$(uname -r)"
case $OS_NAME in case $OS_NAME in
rocky) $SUDO $PACKAGE_MANAGER -y install kernel-devel kernel-headers ;; rocky) $SUDO $PACKAGE_MANAGER -y install kernel-devel kernel-headers ;;
@@ -295,7 +308,19 @@ if ! lsmod | grep -q nvidia; then
fi fi
$SUDO modprobe nvidia $SUDO modprobe nvidia
$SUDO modprobe nvidia_uvm
fi fi
# make sure the NVIDIA modules are loaded on boot with nvidia-persistenced
if command -v nvidia-persistenced > /dev/null 2>&1; then
$SUDO touch /etc/modules-load.d/nvidia.conf
MODULES="nvidia nvidia-uvm"
for MODULE in $MODULES; do
if ! grep -qxF "$MODULE" /etc/modules-load.d/nvidia.conf; then
echo "$MODULE" | sudo tee -a /etc/modules-load.d/nvidia.conf > /dev/null
fi
done
fi
status "NVIDIA CUDA drivers installed." status "NVIDIA GPU ready."
install_success

View File

@@ -221,7 +221,7 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
} }
defer resp.Body.Close() defer resp.Body.Close()
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size) n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed)
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress // rollback progress
b.Completed.Add(-n) b.Completed.Add(-n)

View File

@@ -28,7 +28,7 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@@ -771,37 +771,6 @@ func PruneDirectory(path string) error {
return nil return nil
} }
func DeleteModel(name string) error {
mp := ParseModelPath(name)
manifest, _, err := GetManifest(mp)
if err != nil {
return err
}
deleteMap := make(map[string]struct{})
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
}
deleteMap[manifest.Config.Digest] = struct{}{}
err = deleteUnusedLayers(&mp, deleteMap)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
err = os.Remove(fp)
if err != nil {
slog.Info(fmt.Sprintf("couldn't remove manifest file '%s': %v", fp, err))
return err
}
return nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"}) fn(api.ProgressResponse{Status: "retrieving manifest"})

View File

@@ -88,3 +88,26 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
return os.Open(blob) return os.Open(blob)
} }
func (l *Layer) Remove() error {
ms, err := Manifests()
if err != nil {
return err
}
for _, m := range ms {
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest == l.Digest {
// something is using this layer
return nil
}
}
}
blob, err := GetBlobsPath(l.Digest)
if err != nil {
return err
}
return os.Remove(blob)
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"os" "os"
"path/filepath" "path/filepath"
@@ -14,7 +15,10 @@ import (
type Manifest struct { type Manifest struct {
ManifestV2 ManifestV2
Digest string `json:"-"`
filepath string
fi os.FileInfo
digest string
} }
func (m *Manifest) Size() (size int64) { func (m *Manifest) Size() (size int64) {
@@ -25,9 +29,28 @@ func (m *Manifest) Size() (size int64) {
return return
} }
func ParseNamedManifest(name model.Name) (*Manifest, error) { func (m *Manifest) Remove() error {
if !name.IsFullyQualified() { if err := os.Remove(m.filepath); err != nil {
return nil, model.Unqualified(name) return err
}
for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); err != nil {
return err
}
}
manifests, err := GetManifestPath()
if err != nil {
return err
}
return PruneDirectory(manifests)
}
func ParseNamedManifest(n model.Name) (*Manifest, error) {
if !n.IsFullyQualified() {
return nil, model.Unqualified(n)
} }
manifests, err := GetManifestPath() manifests, err := GetManifestPath()
@@ -35,20 +58,30 @@ func ParseNamedManifest(name model.Name) (*Manifest, error) {
return nil, err return nil, err
} }
var manifest ManifestV2 p := filepath.Join(manifests, n.Filepath())
manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath()))
var m ManifestV2
f, err := os.Open(p)
if err != nil {
return nil, err
}
defer f.Close()
fi, err := f.Stat()
if err != nil { if err != nil {
return nil, err return nil, err
} }
sha256sum := sha256.New() sha256sum := sha256.New()
if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil { if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
return nil, err return nil, err
} }
return &Manifest{ return &Manifest{
ManifestV2: manifest, ManifestV2: m,
Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), filepath: p,
fi: fi,
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
}, nil }, nil
} }
@@ -77,3 +110,48 @@ func WriteManifest(name string, config *Layer, layers []*Layer) error {
return os.WriteFile(manifestPath, b.Bytes(), 0o644) return os.WriteFile(manifestPath, b.Bytes(), 0o644)
} }
func Manifests() (map[model.Name]*Manifest, error) {
manifests, err := GetManifestPath()
if err != nil {
return nil, err
}
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
if err != nil {
return nil, err
}
ms := make(map[model.Name]*Manifest)
for _, match := range matches {
fi, err := os.Stat(match)
if err != nil {
return nil, err
}
if !fi.IsDir() {
rel, err := filepath.Rel(manifests, match)
if err != nil {
slog.Warn("bad filepath", "path", match, "error", err)
continue
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest name", "path", rel, "error", err)
continue
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
continue
}
ms[n] = m
}
}
return ms, nil
}

150
server/manifest_test.go Normal file
View File

@@ -0,0 +1,150 @@
package server
import (
"encoding/json"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/types/model"
)
func createManifest(t *testing.T, path, name string) {
t.Helper()
p := filepath.Join(path, "manifests", name)
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
t.Fatal(err)
}
f, err := os.Create(p)
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
t.Fatal(err)
}
}
func TestManifests(t *testing.T) {
cases := map[string]struct {
ps []string
wantValidCount int
wantInvalidCount int
}{
"empty": {},
"single": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag"),
},
wantValidCount: 1,
},
"multiple": {
ps: []string{
filepath.Join("registry.ollama.ai", "library", "llama3", "latest"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"),
},
wantValidCount: 15,
},
"hidden": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag"),
filepath.Join("host", "namespace", "model", ".hidden"),
},
wantValidCount: 1,
wantInvalidCount: 1,
},
"subdir": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag", "one"),
filepath.Join("host", "namespace", "model", "tag", "another", "one"),
},
wantInvalidCount: 2,
},
"upper tag": {
ps: []string{
filepath.Join("host", "namespace", "model", "TAG"),
},
wantValidCount: 1,
},
"upper model": {
ps: []string{
filepath.Join("host", "namespace", "MODEL", "tag"),
},
wantValidCount: 1,
},
"upper namespace": {
ps: []string{
filepath.Join("host", "NAMESPACE", "model", "tag"),
},
wantValidCount: 1,
},
"upper host": {
ps: []string{
filepath.Join("HOST", "namespace", "model", "tag"),
},
wantValidCount: 1,
},
}
for n, wants := range cases {
t.Run(n, func(t *testing.T) {
d := t.TempDir()
t.Setenv("OLLAMA_MODELS", d)
for _, p := range wants.ps {
createManifest(t, d, p)
}
ms, err := Manifests()
if err != nil {
t.Fatal(err)
}
var ns []model.Name
for k := range ms {
ns = append(ns, k)
}
var gotValidCount, gotInvalidCount int
for _, p := range wants.ps {
n := model.ParseNameFromFilepath(p)
if n.IsValid() {
gotValidCount++
} else {
gotInvalidCount++
}
if !n.IsValid() && slices.Contains(ns, n) {
t.Errorf("unexpected invalid name: %s", p)
} else if n.IsValid() && !slices.Contains(ns, n) {
t.Errorf("missing valid name: %s", p)
}
}
if gotValidCount != wants.wantValidCount {
t.Errorf("got valid count %d, want %d", gotValidCount, wants.wantValidCount)
}
if gotInvalidCount != wants.wantInvalidCount {
t.Errorf("got invalid count %d, want %d", gotInvalidCount, wants.wantInvalidCount)
}
})
}
}

View File

@@ -26,11 +26,11 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@@ -315,10 +315,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
func getDefaultSessionDuration() time.Duration { func getDefaultSessionDuration() time.Duration {
if t, exists := os.LookupEnv("OLLAMA_KEEP_ALIVE"); exists { if envconfig.KeepAlive != "" {
v, err := strconv.Atoi(t) v, err := strconv.Atoi(envconfig.KeepAlive)
if err != nil { if err != nil {
d, err := time.ParseDuration(t) d, err := time.ParseDuration(envconfig.KeepAlive)
if err != nil { if err != nil {
return defaultSessionDuration return defaultSessionDuration
} }
@@ -421,13 +421,14 @@ func (s *Server) PullModelHandler(c *gin.Context) {
return return
} }
var model string name := model.ParseName(cmp.Or(req.Model, req.Name))
if req.Model != "" { if !name.IsValid() {
model = req.Model c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
} else if req.Name != "" { return
model = req.Name }
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) if err := checkNameExists(name); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
@@ -445,7 +446,7 @@ func (s *Server) PullModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := PullModel(ctx, model, regOpts, fn); err != nil { if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@@ -507,6 +508,21 @@ func (s *Server) PushModelHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func checkNameExists(name model.Name) error {
names, err := Manifests()
if err != nil {
return err
}
for n := range names {
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
return fmt.Errorf("a model with that name already exists")
}
}
return nil
}
func (s *Server) CreateModelHandler(c *gin.Context) { func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest var req api.CreateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
@@ -523,6 +539,11 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
return return
} }
if err := checkNameExists(name); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Path == "" && req.Modelfile == "" { if req.Path == "" && req.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return return
@@ -575,48 +596,31 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
} }
func (s *Server) DeleteModelHandler(c *gin.Context) { func (s *Server) DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest var r api.DeleteRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
var model string n := model.ParseName(cmp.Or(r.Model, r.Name))
if req.Model != "" { if !n.IsValid() {
model = req.Model c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
} }
if err := DeleteModel(model); err != nil { m, err := ParseNamedManifest(n)
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
manifestsPath, err := GetManifestPath()
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if err := PruneDirectory(manifestsPath); err != nil { if err := m.Remove(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
c.JSON(http.StatusOK, nil)
} }
func (s *Server) ShowModelHandler(c *gin.Context) { func (s *Server) ShowModelHandler(c *gin.Context) {
@@ -720,72 +724,42 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
func (s *Server) ListModelsHandler(c *gin.Context) { func (s *Server) ListModelsHandler(c *gin.Context) {
manifests, err := GetManifestPath() ms, err := Manifests()
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
models := []api.ModelResponse{} models := []api.ModelResponse{}
if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error { for n, m := range ms {
if !info.IsDir() { f, err := m.Config.Open()
rel, err := filepath.Rel(manifests, path) if err != nil {
if err != nil { slog.Warn("bad manifest filepath", "name", n, "error", err)
return err continue
} }
defer f.Close()
if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil { var cf ConfigV2
return err if err := json.NewDecoder(f).Decode(&cf); err != nil {
} else if hidden { slog.Warn("bad manifest config", "name", n, "error", err)
return nil continue
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest filepath", "path", rel)
return nil
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
return nil
}
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest config filepath", "name", n, "error", err)
return nil
}
defer f.Close()
var c ConfigV2
if err := json.NewDecoder(f).Decode(&c); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
return nil
}
// tag should never be masked
models = append(models, api.ModelResponse{
Model: n.DisplayShortest(),
Name: n.DisplayShortest(),
Size: m.Size(),
Digest: m.Digest,
ModifiedAt: info.ModTime(),
Details: api.ModelDetails{
Format: c.ModelFormat,
Family: c.ModelFamily,
Families: c.ModelFamilies,
ParameterSize: c.ModelType,
QuantizationLevel: c.FileType,
},
})
} }
return nil // tag should never be masked
}); err != nil { models = append(models, api.ModelResponse{
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) Model: n.DisplayShortest(),
return Name: n.DisplayShortest(),
Size: m.Size(),
Digest: m.digest,
ModifiedAt: m.fi.ModTime(),
Details: api.ModelDetails{
Format: cf.ModelFormat,
Family: cf.ModelFamily,
Families: cf.ModelFamilies,
ParameterSize: cf.ModelType,
QuantizationLevel: cf.FileType,
},
})
} }
slices.SortStableFunc(models, func(i, j api.ModelResponse) int { slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
@@ -818,6 +792,11 @@ func (s *Server) CopyModelHandler(c *gin.Context) {
return return
} }
if err := checkNameExists(dst); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) { if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
} else if err != nil { } else if err != nil {
@@ -1025,7 +1004,7 @@ func Serve(ln net.Listener) error {
level = slog.LevelDebug level = slog.LevelDebug
} }
slog.Info("server config", "env", envconfig.AsMap()) slog.Info("server config", "env", envconfig.Values())
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level, Level: level,
AddSource: true, AddSource: true,

View File

@@ -0,0 +1,160 @@
package server
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"slices"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
)
var stream bool = false
func createBinFile(t *testing.T) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "")
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
t.Fatal(err)
}
if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil {
t.Fatal(err)
}
if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil {
t.Fatal(err)
}
if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil {
t.Fatal(err)
}
return f.Name()
}
type responseRecorder struct {
*httptest.ResponseRecorder
http.CloseNotifier
}
func NewRecorder() *responseRecorder {
return &responseRecorder{
ResponseRecorder: httptest.NewRecorder(),
}
}
func (t *responseRecorder) CloseNotify() <-chan bool {
return make(chan bool)
}
func createRequest(t *testing.T, fn func(*gin.Context), body any) *httptest.ResponseRecorder {
t.Helper()
w := NewRecorder()
c, _ := gin.CreateTestContext(w)
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(body); err != nil {
t.Fatal(err)
}
c.Request = &http.Request{
Body: io.NopCloser(&b),
}
fn(c)
return w.ResponseRecorder
}
func checkFileExists(t *testing.T, p string, expect []string) {
t.Helper()
actual, err := filepath.Glob(p)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(actual, expect) {
t.Fatalf("expected slices to be equal %v", actual)
}
}
func TestCreateFromBin(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
})
}
func TestCreateFromModel(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: "FROM test",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
})
}

View File

@@ -0,0 +1,71 @@
package server
import (
"fmt"
"net/http"
"path/filepath"
"testing"
"github.com/ollama/ollama/api"
)
func TestDelete(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t)),
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test2"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
}

View File

@@ -0,0 +1,61 @@
package server
import (
"encoding/json"
"fmt"
"net/http"
"slices"
"testing"
"github.com/ollama/ollama/api"
)
func TestList(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
expectNames := []string{
"mistral:7b-instruct-q4_0",
"zephyr:7b-beta-q5_K_M",
"apple/OpenELM:latest",
"boreas:2b-code-v1.5-q6_K",
"notus:7b-v1-IQ2_S",
// TODO: host:port currently fails on windows (#4107)
// "localhost:5000/library/eurus:700b-v0.5-iq3_XXS",
"mynamespace/apeliotes:latest",
"myhost/mynamespace/lips:code",
}
var s Server
for _, n := range expectNames {
createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: n,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
})
}
w := createRequest(t, s.ListModelsHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
var resp api.ListResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if len(resp.Models) != len(expectNames) {
t.Fatalf("expected %d models, actual %d", len(expectNames), len(resp.Models))
}
actualNames := make([]string, len(resp.Models))
for i, m := range resp.Models {
actualNames[i] = m.Name
}
slices.Sort(actualNames)
slices.Sort(expectNames)
if !slices.Equal(actualNames, expectNames) {
t.Fatalf("expected slices to be equal %v", actualNames)
}
}

View File

@@ -21,6 +21,28 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
func createTestFile(t *testing.T, name string) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), name)
assert.Nil(t, err)
defer f.Close()
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint32(3))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
return f.Name()
}
func Test_Routes(t *testing.T) { func Test_Routes(t *testing.T) {
type testCase struct { type testCase struct {
Name string Name string
@@ -30,28 +52,6 @@ func Test_Routes(t *testing.T) {
Expected func(t *testing.T, resp *http.Response) Expected func(t *testing.T, resp *http.Response)
} }
createTestFile := func(t *testing.T, name string) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), name)
assert.Nil(t, err)
defer f.Close()
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint32(3))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
return f.Name()
}
createTestModel := func(t *testing.T, name string) { createTestModel := func(t *testing.T, name string) {
fname := createTestFile(t, "ollama-model") fname := createTestFile(t, "ollama-model")
@@ -209,14 +209,14 @@ func Test_Routes(t *testing.T) {
}, },
} }
t.Setenv("OLLAMA_MODELS", t.TempDir())
s := &Server{} s := &Server{}
router := s.GenerateRoutes() router := s.GenerateRoutes()
httpSrv := httptest.NewServer(router) httpSrv := httptest.NewServer(router)
t.Cleanup(httpSrv.Close) t.Cleanup(httpSrv.Close)
t.Setenv("OLLAMA_MODELS", t.TempDir())
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
u := httpSrv.URL + tc.Path u := httpSrv.URL + tc.Path
@@ -237,3 +237,82 @@ func Test_Routes(t *testing.T) {
}) })
} }
} }
func TestCase(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
cases := []string{
"mistral",
"llama3:latest",
"library/phi3:q4_0",
"registry.ollama.ai/library/gemma:q5_K_M",
// TODO: host:port currently fails on windows (#4107)
// "localhost:5000/alice/bob:latest",
}
var s Server
for _, tt := range cases {
t.Run(tt, func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: tt,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200 got %d", w.Code)
}
expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
if err != nil {
t.Fatal(err)
}
t.Run("create", func(t *testing.T) {
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: strings.ToUpper(tt),
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
t.Run("pull", func(t *testing.T) {
w := createRequest(t, s.PullModelHandler, api.PullRequest{
Name: strings.ToUpper(tt),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
t.Run("copy", func(t *testing.T) {
w := createRequest(t, s.CopyModelHandler, api.CopyRequest{
Source: tt,
Destination: strings.ToUpper(tt),
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
})
}
}

View File

@@ -16,7 +16,7 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/envconfig"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )

View File

@@ -15,7 +15,7 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )

View File

@@ -251,6 +251,10 @@ func (n Name) DisplayShortest() string {
return sb.String() return sb.String()
} }
func IsValidNamespace(namespace string) bool {
return isValidPart(kindNamespace, namespace)
}
// IsValid reports whether all parts of the name are present and valid. The // IsValid reports whether all parts of the name are present and valid. The
// digest is a special case, and is checked for validity only if present. // digest is a special case, and is checked for validity only if present.
func (n Name) IsValid() bool { func (n Name) IsValid() bool {

View File

@@ -385,3 +385,30 @@ func FuzzName(f *testing.F) {
}) })
} }
func TestIsValidNamespace(t *testing.T) {
cases := []struct {
username string
expected bool
}{
{"", false},
{"a", true},
{"a:b", false},
{"a/b", false},
{"a:b/c", false},
{"a/b:c", false},
{"a/b:c", false},
{"a/b:c/d", false},
{"a/b:c/d@e", false},
{"a/b:c/d@sha256-100", false},
{"himynameisjoe", true},
{"himynameisreallyreallyreallyreallylongbutitshouldstillbevalid", true},
}
for _, tt := range cases {
t.Run(tt.username, func(t *testing.T) {
if got := IsValidNamespace(tt.username); got != tt.expected {
t.Errorf("IsValidName(%q) = %v; want %v", tt.username, got, tt.expected)
}
})
}
}