mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-22 14:53:56 +00:00
Compare commits
128 Commits
v0.1.38-al
...
v0.1.43-al
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff50cfb582 | ||
|
|
c69bc19e46 | ||
|
|
bba5d177aa | ||
|
|
c16f8af911 | ||
|
|
edaec3183a | ||
|
|
217f60c3d9 | ||
|
|
7bdcd1da94 | ||
|
|
ead259d877 | ||
|
|
2ff45d571d | ||
|
|
0f3cf1d42e | ||
|
|
5bc029c529 | ||
|
|
e9a9c6a8e8 | ||
|
|
515f497e6d | ||
|
|
b27268aaef | ||
|
|
f5f245cc15 | ||
|
|
94d37fdcae | ||
|
|
b84aea1685 | ||
|
|
896495de7b | ||
|
|
5528dd9d11 | ||
|
|
943172cbf4 | ||
|
|
1b5848cbf2 | ||
|
|
76026b4a35 | ||
|
|
85169e8d6f | ||
|
|
34f142797a | ||
|
|
46a7f1e74a | ||
|
|
620d5c569e | ||
|
|
b9ce7bf75e | ||
|
|
cddc63381c | ||
|
|
385a32ecb5 | ||
|
|
030e765e76 | ||
|
|
ab8c929e20 | ||
|
|
27e7397b11 | ||
|
|
a6390a8992 | ||
|
|
ce0dc33cb8 | ||
|
|
78f81fc0e5 | ||
|
|
9b6c2e6eb6 | ||
|
|
1a29e9a879 | ||
|
|
4bf1da4944 | ||
|
|
de5beb06b3 | ||
|
|
98e65929dc | ||
|
|
66ab48772f | ||
|
|
22fcf8f7de | ||
|
|
28c7813ac4 | ||
|
|
1d8616d30f | ||
|
|
d61ef8b954 | ||
|
|
89d9900152 | ||
|
|
4a048715b6 | ||
|
|
6297f85606 | ||
|
|
ed56428dd7 | ||
|
|
ad40b92b6a | ||
|
|
8ce4032e72 | ||
|
|
42660466f8 | ||
|
|
e919f6811f | ||
|
|
bf7edb0d5d | ||
|
|
f38353d6b9 | ||
|
|
201d853fdf | ||
|
|
e40145a39d | ||
|
|
c895a7d13f | ||
|
|
dad7a987ae | ||
|
|
8ffb51749f | ||
|
|
55f6eba049 | ||
|
|
04f3c12bb7 | ||
|
|
60323e0805 | ||
|
|
71ae05239e | ||
|
|
a4a435bf8f | ||
|
|
2490a69f7b | ||
|
|
d4a86102fd | ||
|
|
476fb8e892 | ||
|
|
829ff87bd1 | ||
|
|
f6b622c4b3 | ||
|
|
2e4da8eec2 | ||
|
|
16ce79eb3b | ||
|
|
763bb65dbb | ||
|
|
7ca9605f54 | ||
|
|
eb2c443a79 | ||
|
|
278e25ea44 | ||
|
|
a50a87a7b8 | ||
|
|
98085015d5 | ||
|
|
bf54c845e9 | ||
|
|
c365f195a8 | ||
|
|
e91d0ef737 | ||
|
|
22f5c12ced | ||
|
|
298c996e54 | ||
|
|
0fc0cfc6d2 | ||
|
|
914f68f021 | ||
|
|
bd1d119ba9 | ||
|
|
a03be18189 | ||
|
|
96bc232b43 | ||
|
|
bca7b12284 | ||
|
|
32cb1960c1 | ||
|
|
de781b37c8 | ||
|
|
3e21799377 | ||
|
|
26a00a0410 | ||
|
|
cafde1f8ce | ||
|
|
646371f56d | ||
|
|
1f5008544b | ||
|
|
45cbfc5aee | ||
|
|
6d423b383b | ||
|
|
ad897080a2 | ||
|
|
b7d316d98d | ||
|
|
d7339fad52 | ||
|
|
92c81e8117 | ||
|
|
9db0996ed4 | ||
|
|
6f43898b17 | ||
|
|
7487229c34 | ||
|
|
8a8e7afa96 | ||
|
|
c79f8c9c39 | ||
|
|
485016bfbb | ||
|
|
2a80d6f743 | ||
|
|
0165ba1651 | ||
|
|
c4209d6d21 | ||
|
|
6adca97f37 | ||
|
|
9a3c8003c8 | ||
|
|
d51f15257c | ||
|
|
8f440d579a | ||
|
|
4cc3be3035 | ||
|
|
db2ffa79f1 | ||
|
|
73c49d57e8 | ||
|
|
6b50b2f3bf | ||
|
|
fd5971be0b | ||
|
|
f77713bf1f | ||
|
|
85a57006d1 | ||
|
|
c5e892cb3e | ||
|
|
81fb06f530 | ||
|
|
a385382ff5 | ||
|
|
b8772a353f | ||
|
|
c2714fcbfd | ||
|
|
a2fc933fed |
14
.github/workflows/test.yaml
vendored
14
.github/workflows/test.yaml
vendored
@@ -34,13 +34,13 @@ jobs:
|
|||||||
git diff-tree -r --no-commit-id --name-only \
|
git diff-tree -r --no-commit-id --name-only \
|
||||||
$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
|
$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
|
||||||
${{ github.event.pull_request.head.sha }} \
|
${{ github.event.pull_request.head.sha }} \
|
||||||
| xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
|
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
echo GENERATE=$(changed llm/)
|
echo GENERATE=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
|
||||||
echo GENERATE_CUDA=$(changed llm/)
|
echo GENERATE_CUDA=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
|
||||||
echo GENERATE_ROCM=$(changed llm/)
|
echo GENERATE_ROCM=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
|
||||||
} >>$GITHUB_OUTPUT
|
} >>$GITHUB_OUTPUT
|
||||||
|
|
||||||
generate:
|
generate:
|
||||||
@@ -269,9 +269,9 @@ jobs:
|
|||||||
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
||||||
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
||||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||||
- uses: golangci/golangci-lint-action@v4
|
- uses: golangci/golangci-lint-action@v6
|
||||||
with:
|
with:
|
||||||
args: --timeout 8m0s -v
|
args: --timeout 8m0s -v ${{ startsWith(matrix.os, 'windows-') && '' || '--disable gofmt --disable goimports' }}
|
||||||
test:
|
test:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -287,6 +287,8 @@ jobs:
|
|||||||
GOARCH: ${{ matrix.arch }}
|
GOARCH: ${{ matrix.arch }}
|
||||||
CGO_ENABLED: '1'
|
CGO_ENABLED: '1'
|
||||||
OLLAMA_CPU_TARGET: 'static'
|
OLLAMA_CPU_TARGET: 'static'
|
||||||
|
OLLAMA_SKIP_CPU_GENERATE: '1'
|
||||||
|
OLLAMA_SKIP_METAL_GENERATE: '1'
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
@@ -9,9 +9,26 @@ linters:
|
|||||||
- contextcheck
|
- contextcheck
|
||||||
- exportloopref
|
- exportloopref
|
||||||
- gocheckcompilerdirectives
|
- gocheckcompilerdirectives
|
||||||
# FIXME: for some reason this errors on windows
|
# conditionally enable this on linux/macos
|
||||||
# - gofmt
|
# - gofmt
|
||||||
# - goimports
|
# - goimports
|
||||||
|
- intrange
|
||||||
- misspell
|
- misspell
|
||||||
- nilerr
|
- nilerr
|
||||||
|
- nolintlint
|
||||||
|
- nosprintfhostport
|
||||||
|
- testifylint
|
||||||
|
- unconvert
|
||||||
- unused
|
- unused
|
||||||
|
- wastedassign
|
||||||
|
- whitespace
|
||||||
|
- usestdlibvars
|
||||||
|
severity:
|
||||||
|
default-severity: error
|
||||||
|
rules:
|
||||||
|
- linters:
|
||||||
|
- gofmt
|
||||||
|
- goimports
|
||||||
|
- intrange
|
||||||
|
- usestdlibvars
|
||||||
|
severity: info
|
||||||
|
|||||||
15
README.md
15
README.md
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
[](https://discord.gg/ollama)
|
[](https://discord.gg/ollama)
|
||||||
|
|
||||||
Get up and running with large language models locally.
|
Get up and running with large language models.
|
||||||
|
|
||||||
### macOS
|
### macOS
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ Example extra list add on this repo.
|
|||||||
```
|
```
|
||||||
Please follow the [wiki](https://github.com/likelovewant/ollama-for-amd/wiki) guide to build or use the pre-release version.
|
Please follow the [wiki](https://github.com/likelovewant/ollama-for-amd/wiki) guide to build or use the pre-release version.
|
||||||
|
|
||||||
Note: `gfx803, gfx1010` reported not working by the wiki method ,expected a future support
|
Note: `gfx803` reported partialy working by the wiki method ,expected a future support
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -301,6 +301,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG)
|
- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG)
|
||||||
- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation)
|
- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation)
|
||||||
- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends)
|
- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends)
|
||||||
|
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
||||||
|
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||||
|
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||||
|
|
||||||
### Terminal
|
### Terminal
|
||||||
|
|
||||||
@@ -323,6 +326,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [ShellOracle](https://github.com/djcopley/ShellOracle)
|
- [ShellOracle](https://github.com/djcopley/ShellOracle)
|
||||||
- [tlm](https://github.com/yusufcanb/tlm)
|
- [tlm](https://github.com/yusufcanb/tlm)
|
||||||
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
|
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
|
||||||
|
- [gollama](https://github.com/sammcj/gollama)
|
||||||
|
|
||||||
### Database
|
### Database
|
||||||
|
|
||||||
@@ -340,11 +344,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/modules/model_io/models/llms/integrations/ollama) with [example](https://js.langchain.com/docs/use_cases/question_answering/local_retrieval_qa)
|
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/modules/model_io/models/llms/integrations/ollama) with [example](https://js.langchain.com/docs/use_cases/question_answering/local_retrieval_qa)
|
||||||
- [LangChainGo](https://github.com/tmc/langchaingo/) with [example](https://github.com/tmc/langchaingo/tree/main/examples/ollama-completion-example)
|
- [LangChainGo](https://github.com/tmc/langchaingo/) with [example](https://github.com/tmc/langchaingo/tree/main/examples/ollama-completion-example)
|
||||||
- [LangChain4j](https://github.com/langchain4j/langchain4j) with [example](https://github.com/langchain4j/langchain4j-examples/tree/main/ollama-examples/src/main/java)
|
- [LangChain4j](https://github.com/langchain4j/langchain4j) with [example](https://github.com/langchain4j/langchain4j-examples/tree/main/ollama-examples/src/main/java)
|
||||||
|
- [LangChainRust](https://github.com/Abraxas-365/langchain-rust) with [example](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs)
|
||||||
- [LlamaIndex](https://gpt-index.readthedocs.io/en/stable/examples/llm/ollama.html)
|
- [LlamaIndex](https://gpt-index.readthedocs.io/en/stable/examples/llm/ollama.html)
|
||||||
- [LiteLLM](https://github.com/BerriAI/litellm)
|
- [LiteLLM](https://github.com/BerriAI/litellm)
|
||||||
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
|
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
|
||||||
- [Ollama for Ruby](https://github.com/gbaptista/ollama-ai)
|
- [Ollama for Ruby](https://github.com/gbaptista/ollama-ai)
|
||||||
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
|
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
|
||||||
|
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp)
|
||||||
- [Ollama4j for Java](https://github.com/amithkoujalgi/ollama4j)
|
- [Ollama4j for Java](https://github.com/amithkoujalgi/ollama4j)
|
||||||
- [ModelFusion Typescript Library](https://modelfusion.dev/integration/model-provider/ollama)
|
- [ModelFusion Typescript Library](https://modelfusion.dev/integration/model-provider/ollama)
|
||||||
- [OllamaKit for Swift](https://github.com/kevinhermawan/OllamaKit)
|
- [OllamaKit for Swift](https://github.com/kevinhermawan/OllamaKit)
|
||||||
@@ -362,6 +368,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama)
|
- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama)
|
||||||
- [PromptingTools.jl](https://github.com/svilupp/PromptingTools.jl) with an [example](https://svilupp.github.io/PromptingTools.jl/dev/examples/working_with_ollama)
|
- [PromptingTools.jl](https://github.com/svilupp/PromptingTools.jl) with an [example](https://svilupp.github.io/PromptingTools.jl/dev/examples/working_with_ollama)
|
||||||
- [LlamaScript](https://github.com/Project-Llama/llamascript)
|
- [LlamaScript](https://github.com/Project-Llama/llamascript)
|
||||||
|
|
||||||
### Mobile
|
### Mobile
|
||||||
|
|
||||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||||
@@ -394,7 +401,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
|
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
|
||||||
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
|
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
|
||||||
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
|
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
|
||||||
|
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depends on ollama server)
|
||||||
|
|
||||||
|
### Supported backends
|
||||||
|
|
||||||
### Supported backends
|
|
||||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||||
|
|
||||||
|
|||||||
@@ -23,11 +23,9 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
@@ -65,10 +63,7 @@ func checkError(resp *http.Response, body []byte) error {
|
|||||||
// If the variable is not specified, a default ollama host and port will be
|
// If the variable is not specified, a default ollama host and port will be
|
||||||
// used.
|
// used.
|
||||||
func ClientFromEnvironment() (*Client, error) {
|
func ClientFromEnvironment() (*Client, error) {
|
||||||
ollamaHost, err := GetOllamaHost()
|
ollamaHost := envconfig.Host
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
base: &url.URL{
|
base: &url.URL{
|
||||||
@@ -79,52 +74,6 @@ func ClientFromEnvironment() (*Client, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type OllamaHost struct {
|
|
||||||
Scheme string
|
|
||||||
Host string
|
|
||||||
Port string
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetOllamaHost() (OllamaHost, error) {
|
|
||||||
defaultPort := "11434"
|
|
||||||
|
|
||||||
hostVar := os.Getenv("OLLAMA_HOST")
|
|
||||||
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
|
|
||||||
|
|
||||||
scheme, hostport, ok := strings.Cut(hostVar, "://")
|
|
||||||
switch {
|
|
||||||
case !ok:
|
|
||||||
scheme, hostport = "http", hostVar
|
|
||||||
case scheme == "http":
|
|
||||||
defaultPort = "80"
|
|
||||||
case scheme == "https":
|
|
||||||
defaultPort = "443"
|
|
||||||
}
|
|
||||||
|
|
||||||
// trim trailing slashes
|
|
||||||
hostport = strings.TrimRight(hostport, "/")
|
|
||||||
|
|
||||||
host, port, err := net.SplitHostPort(hostport)
|
|
||||||
if err != nil {
|
|
||||||
host, port = "127.0.0.1", defaultPort
|
|
||||||
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
|
|
||||||
host = ip.String()
|
|
||||||
} else if hostport != "" {
|
|
||||||
host = hostport
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
|
|
||||||
return OllamaHost{}, ErrInvalidHostPort
|
|
||||||
}
|
|
||||||
|
|
||||||
return OllamaHost{
|
|
||||||
Scheme: scheme,
|
|
||||||
Host: host,
|
|
||||||
Port: port,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClient(base *url.URL, http *http.Client) *Client {
|
func NewClient(base *url.URL, http *http.Client) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
base: base,
|
base: base,
|
||||||
@@ -355,8 +304,8 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// List running models.
|
// List running models.
|
||||||
func (c *Client) ListRunning(ctx context.Context) (*ListResponse, error) {
|
func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
|
||||||
var lr ListResponse
|
var lr ProcessResponse
|
||||||
if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
|
if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestClientFromEnvironment(t *testing.T) {
|
func TestClientFromEnvironment(t *testing.T) {
|
||||||
@@ -35,6 +33,7 @@ func TestClientFromEnvironment(t *testing.T) {
|
|||||||
for k, v := range testCases {
|
for k, v := range testCases {
|
||||||
t.Run(k, func(t *testing.T) {
|
t.Run(k, func(t *testing.T) {
|
||||||
t.Setenv("OLLAMA_HOST", v.value)
|
t.Setenv("OLLAMA_HOST", v.value)
|
||||||
|
envconfig.LoadConfig()
|
||||||
|
|
||||||
client, err := ClientFromEnvironment()
|
client, err := ClientFromEnvironment()
|
||||||
if err != v.err {
|
if err != v.err {
|
||||||
@@ -46,40 +45,4 @@ func TestClientFromEnvironment(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
hostTestCases := map[string]*testCase{
|
|
||||||
"empty": {value: "", expect: "127.0.0.1:11434"},
|
|
||||||
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
|
|
||||||
"only port": {value: ":1234", expect: ":1234"},
|
|
||||||
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
|
|
||||||
"hostname": {value: "example.com", expect: "example.com:11434"},
|
|
||||||
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
|
|
||||||
"zero port": {value: ":0", expect: ":0"},
|
|
||||||
"too large port": {value: ":66000", err: ErrInvalidHostPort},
|
|
||||||
"too small port": {value: ":-1", err: ErrInvalidHostPort},
|
|
||||||
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
|
|
||||||
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
|
|
||||||
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
|
|
||||||
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
|
|
||||||
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
|
|
||||||
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
|
|
||||||
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
|
|
||||||
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range hostTestCases {
|
|
||||||
t.Run(k, func(t *testing.T) {
|
|
||||||
t.Setenv("OLLAMA_HOST", v.value)
|
|
||||||
|
|
||||||
oh, err := GetOllamaHost()
|
|
||||||
if err != v.err {
|
|
||||||
t.Fatalf("expected %s, got %s", v.err, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
host := net.JoinHostPort(oh.Host, oh.Port)
|
|
||||||
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
31
api/types.go
31
api/types.go
@@ -2,7 +2,6 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
@@ -282,19 +281,33 @@ type PushRequest struct {
|
|||||||
|
|
||||||
// ListResponse is the response from [Client.List].
|
// ListResponse is the response from [Client.List].
|
||||||
type ListResponse struct {
|
type ListResponse struct {
|
||||||
Models []ModelResponse `json:"models"`
|
Models []ListModelResponse `json:"models"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelResponse is a single model description in [ListResponse].
|
// ProcessResponse is the response from [Client.Process].
|
||||||
type ModelResponse struct {
|
type ProcessResponse struct {
|
||||||
|
Models []ProcessModelResponse `json:"models"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListModelResponse is a single model description in [ListResponse].
|
||||||
|
type ListModelResponse struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
ModifiedAt time.Time `json:"modified_at"`
|
||||||
Size int64 `json:"size"`
|
Size int64 `json:"size"`
|
||||||
Digest string `json:"digest"`
|
Digest string `json:"digest"`
|
||||||
Details ModelDetails `json:"details,omitempty"`
|
Details ModelDetails `json:"details,omitempty"`
|
||||||
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
}
|
||||||
SizeVRAM int64 `json:"size_vram,omitempty"`
|
|
||||||
|
// ProcessModelResponse is a single model description in [ProcessResponse].
|
||||||
|
type ProcessModelResponse struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
Digest string `json:"digest"`
|
||||||
|
Details ModelDetails `json:"details,omitempty"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
SizeVRAM int64 `json:"size_vram"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
@@ -306,7 +319,7 @@ type GenerateResponse struct {
|
|||||||
// Model is the model name that generated the response.
|
// Model is the model name that generated the response.
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
|
||||||
//CreatedAt is the timestamp of the response.
|
// CreatedAt is the timestamp of the response.
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
|
||||||
// Response is the textual response itself.
|
// Response is the textual response itself.
|
||||||
@@ -363,8 +376,6 @@ func (m *Metrics) Summary() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
|
|
||||||
|
|
||||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||||
|
|||||||
@@ -72,13 +72,13 @@ func TestDurationMarshalUnmarshal(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"positive duration",
|
"positive duration",
|
||||||
time.Duration(42 * time.Second),
|
42 * time.Second,
|
||||||
time.Duration(42 * time.Second),
|
42 * time.Second,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"another positive duration",
|
"another positive duration",
|
||||||
time.Duration(42 * time.Minute),
|
42 * time.Minute,
|
||||||
time.Duration(42 * time.Minute),
|
42 * time.Minute,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"zero duration",
|
"zero duration",
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
func InitLogging() {
|
func InitLogging() {
|
||||||
|
|||||||
@@ -69,7 +69,6 @@ func init() {
|
|||||||
slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err))
|
slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if runtime.GOOS == "darwin" {
|
} else if runtime.GOOS == "darwin" {
|
||||||
// TODO
|
// TODO
|
||||||
AppName += ".app"
|
AppName += ".app"
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func getCLIFullPath(command string) string {
|
func getCLIFullPath(command string) string {
|
||||||
cmdPath := ""
|
var cmdPath string
|
||||||
appExe, err := os.Executable()
|
appExe, err := os.Executable()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
cmdPath = filepath.Join(filepath.Dir(appExe), command)
|
cmdPath = filepath.Join(filepath.Dir(appExe), command)
|
||||||
@@ -65,7 +65,6 @@ func start(ctx context.Context, command string) (*exec.Cmd, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, os.ErrNotExist) {
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
|
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ func terminate(cmd *exec.Cmd) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer dll.Release() // nolint: errcheck
|
//nolint:errcheck
|
||||||
|
defer dll.Release()
|
||||||
|
|
||||||
pid := cmd.Process.Pid
|
pid := cmd.Process.Pid
|
||||||
|
|
||||||
@@ -73,7 +74,8 @@ func isProcessExited(pid int) (bool, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to open process: %v", err)
|
return false, fmt.Errorf("failed to open process: %v", err)
|
||||||
}
|
}
|
||||||
defer windows.CloseHandle(hProcess) // nolint: errcheck
|
//nolint:errcheck
|
||||||
|
defer windows.CloseHandle(hProcess)
|
||||||
|
|
||||||
var exitCode uint32
|
var exitCode uint32
|
||||||
err = windows.GetExitCodeProcess(hProcess, &exitCode)
|
err = windows.GetExitCodeProcess(hProcess, &exitCode)
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode == 204 {
|
if resp.StatusCode == http.StatusNoContent {
|
||||||
slog.Debug("check update response 204 (current version is up to date)")
|
slog.Debug("check update response 204 (current version is up to date)")
|
||||||
return false, updateResp
|
return false, updateResp
|
||||||
}
|
}
|
||||||
@@ -87,7 +87,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
|
|||||||
slog.Warn(fmt.Sprintf("failed to read body response: %s", err))
|
slog.Warn(fmt.Sprintf("failed to read body response: %s", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != http.StatusOK {
|
||||||
slog.Info(fmt.Sprintf("check update error %d - %.96s", resp.StatusCode, string(body)))
|
slog.Info(fmt.Sprintf("check update error %d - %.96s", resp.StatusCode, string(body)))
|
||||||
return false, updateResp
|
return false, updateResp
|
||||||
}
|
}
|
||||||
@@ -114,7 +114,7 @@ func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error checking update: %w", err)
|
return fmt.Errorf("error checking update: %w", err)
|
||||||
}
|
}
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode)
|
return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
|
|||||||
@@ -4,5 +4,5 @@ write-host "Welcome to Ollama!"
|
|||||||
write-host ""
|
write-host ""
|
||||||
write-host "Run your first model:"
|
write-host "Run your first model:"
|
||||||
write-host ""
|
write-host ""
|
||||||
write-host "`tollama run llama2"
|
write-host "`tollama run llama3"
|
||||||
write-host ""
|
write-host ""
|
||||||
@@ -29,7 +29,6 @@ func GetID() string {
|
|||||||
initStore()
|
initStore()
|
||||||
}
|
}
|
||||||
return store.ID
|
return store.ID
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetFirstTimeRun() bool {
|
func GetFirstTimeRun() bool {
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ func nativeLoop() {
|
|||||||
default:
|
default:
|
||||||
pTranslateMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
|
pTranslateMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
|
||||||
pDispatchMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
|
pDispatchMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -160,8 +159,8 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
|
|||||||
lResult, _, _ = pDefWindowProc.Call(
|
lResult, _, _ = pDefWindowProc.Call(
|
||||||
uintptr(hWnd),
|
uintptr(hWnd),
|
||||||
uintptr(message),
|
uintptr(message),
|
||||||
uintptr(wParam),
|
wParam,
|
||||||
uintptr(lParam),
|
lParam,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ func (t *winTray) initInstance() error {
|
|||||||
t.muNID.Lock()
|
t.muNID.Lock()
|
||||||
defer t.muNID.Unlock()
|
defer t.muNID.Unlock()
|
||||||
t.nid = ¬ifyIconData{
|
t.nid = ¬ifyIconData{
|
||||||
Wnd: windows.Handle(t.window),
|
Wnd: t.window,
|
||||||
ID: 100,
|
ID: 100,
|
||||||
Flags: NIF_MESSAGE,
|
Flags: NIF_MESSAGE,
|
||||||
CallbackMessage: t.wmSystrayMessage,
|
CallbackMessage: t.wmSystrayMessage,
|
||||||
@@ -197,7 +197,6 @@ func (t *winTray) initInstance() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTray) createMenu() error {
|
func (t *winTray) createMenu() error {
|
||||||
|
|
||||||
menuHandle, _, err := pCreatePopupMenu.Call()
|
menuHandle, _, err := pCreatePopupMenu.Call()
|
||||||
if menuHandle == 0 {
|
if menuHandle == 0 {
|
||||||
return err
|
return err
|
||||||
@@ -246,7 +245,7 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title
|
|||||||
mi := menuItemInfo{
|
mi := menuItemInfo{
|
||||||
Mask: MIIM_FTYPE | MIIM_STRING | MIIM_ID | MIIM_STATE,
|
Mask: MIIM_FTYPE | MIIM_STRING | MIIM_ID | MIIM_STATE,
|
||||||
Type: MFT_STRING,
|
Type: MFT_STRING,
|
||||||
ID: uint32(menuItemId),
|
ID: menuItemId,
|
||||||
TypeData: titlePtr,
|
TypeData: titlePtr,
|
||||||
Cch: uint32(len(title)),
|
Cch: uint32(len(title)),
|
||||||
}
|
}
|
||||||
@@ -302,11 +301,10 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
|
func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
|
||||||
|
|
||||||
mi := menuItemInfo{
|
mi := menuItemInfo{
|
||||||
Mask: MIIM_FTYPE | MIIM_ID | MIIM_STATE,
|
Mask: MIIM_FTYPE | MIIM_ID | MIIM_STATE,
|
||||||
Type: MFT_SEPARATOR,
|
Type: MFT_SEPARATOR,
|
||||||
ID: uint32(menuItemId),
|
ID: menuItemId,
|
||||||
}
|
}
|
||||||
|
|
||||||
mi.Size = uint32(unsafe.Sizeof(mi))
|
mi.Size = uint32(unsafe.Sizeof(mi))
|
||||||
@@ -426,7 +424,6 @@ func iconBytesToFilePath(iconBytes []byte) (string, error) {
|
|||||||
// Loads an image from file and shows it in tray.
|
// Loads an image from file and shows it in tray.
|
||||||
// Shell_NotifyIcon: https://msdn.microsoft.com/en-us/library/windows/desktop/bb762159(v=vs.85).aspx
|
// Shell_NotifyIcon: https://msdn.microsoft.com/en-us/library/windows/desktop/bb762159(v=vs.85).aspx
|
||||||
func (t *winTray) setIcon(src string) error {
|
func (t *winTray) setIcon(src string) error {
|
||||||
|
|
||||||
h, err := t.loadIconFrom(src)
|
h, err := t.loadIconFrom(src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -444,7 +441,6 @@ func (t *winTray) setIcon(src string) error {
|
|||||||
// Loads an image from file to be shown in tray or menu item.
|
// Loads an image from file to be shown in tray or menu item.
|
||||||
// LoadImage: https://msdn.microsoft.com/en-us/library/windows/desktop/ms648045(v=vs.85).aspx
|
// LoadImage: https://msdn.microsoft.com/en-us/library/windows/desktop/ms648045(v=vs.85).aspx
|
||||||
func (t *winTray) loadIconFrom(src string) (windows.Handle, error) {
|
func (t *winTray) loadIconFrom(src string) (windows.Handle, error) {
|
||||||
|
|
||||||
// Save and reuse handles of loaded images
|
// Save and reuse handles of loaded images
|
||||||
t.muLoadedImages.RLock()
|
t.muLoadedImages.RLock()
|
||||||
h, ok := t.loadedImages[src]
|
h, ok := t.loadedImages[src]
|
||||||
|
|||||||
79
cmd/cmd.go
79
cmd/cmd.go
@@ -20,6 +20,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@@ -29,11 +30,11 @@ import (
|
|||||||
"github.com/olekukonko/tablewriter"
|
"github.com/olekukonko/tablewriter"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/auth"
|
"github.com/ollama/ollama/auth"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
@@ -745,7 +746,6 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
|
|||||||
if wordWrap && termWidth >= 10 {
|
if wordWrap && termWidth >= 10 {
|
||||||
for _, ch := range content {
|
for _, ch := range content {
|
||||||
if state.lineLength+1 > termWidth-5 {
|
if state.lineLength+1 > termWidth-5 {
|
||||||
|
|
||||||
if runewidth.StringWidth(state.wordBuffer) > termWidth-10 {
|
if runewidth.StringWidth(state.wordBuffer) > termWidth-10 {
|
||||||
fmt.Printf("%s%c", state.wordBuffer, ch)
|
fmt.Printf("%s%c", state.wordBuffer, ch)
|
||||||
state.wordBuffer = ""
|
state.wordBuffer = ""
|
||||||
@@ -754,7 +754,11 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// backtrack the length of the last word and clear to the end of the line
|
// backtrack the length of the last word and clear to the end of the line
|
||||||
fmt.Printf("\x1b[%dD\x1b[K\n", runewidth.StringWidth(state.wordBuffer))
|
a := runewidth.StringWidth(state.wordBuffer)
|
||||||
|
if a > 0 {
|
||||||
|
fmt.Printf("\x1b[%dD", a)
|
||||||
|
}
|
||||||
|
fmt.Printf("\x1b[K\n")
|
||||||
fmt.Printf("%s%c", state.wordBuffer, ch)
|
fmt.Printf("%s%c", state.wordBuffer, ch)
|
||||||
chWidth := runewidth.RuneWidth(ch)
|
chWidth := runewidth.RuneWidth(ch)
|
||||||
|
|
||||||
@@ -956,17 +960,11 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||||
// retrieve the OLLAMA_HOST environment variable
|
|
||||||
ollamaHost, err := api.GetOllamaHost()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := initializeKeypair(); err != nil {
|
if err := initializeKeypair(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
|
ln, err := net.Listen("tcp", net.JoinHostPort(envconfig.Host.Host, envconfig.Host.Port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1025,24 +1023,6 @@ func initializeKeypair() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:unused
|
|
||||||
func waitForServer(ctx context.Context, client *api.Client) error {
|
|
||||||
// wait for the server to start
|
|
||||||
timeout := time.After(5 * time.Second)
|
|
||||||
tick := time.Tick(500 * time.Millisecond)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-timeout:
|
|
||||||
return errors.New("timed out waiting for server to start")
|
|
||||||
case <-tick:
|
|
||||||
if err := client.Heartbeat(ctx); err == nil {
|
|
||||||
return nil // server has started
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1079,12 +1059,7 @@ func versionHandler(cmd *cobra.Command, _ []string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type EnvironmentVar struct {
|
func appendEnvDocs(cmd *cobra.Command, envs []envconfig.EnvVar) {
|
||||||
Name string
|
|
||||||
Description string
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendEnvDocs(cmd *cobra.Command, envs []EnvironmentVar) {
|
|
||||||
if len(envs) == 0 {
|
if len(envs) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1093,7 +1068,7 @@ func appendEnvDocs(cmd *cobra.Command, envs []EnvironmentVar) {
|
|||||||
Environment Variables:
|
Environment Variables:
|
||||||
`
|
`
|
||||||
for _, e := range envs {
|
for _, e := range envs {
|
||||||
envUsage += fmt.Sprintf(" %-16s %s\n", e.Name, e.Description)
|
envUsage += fmt.Sprintf(" %-24s %s\n", e.Name, e.Description)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
|
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
|
||||||
@@ -1172,15 +1147,6 @@ func NewCLI() *cobra.Command {
|
|||||||
Args: cobra.ExactArgs(0),
|
Args: cobra.ExactArgs(0),
|
||||||
RunE: RunServer,
|
RunE: RunServer,
|
||||||
}
|
}
|
||||||
serveCmd.SetUsageTemplate(serveCmd.UsageTemplate() + `
|
|
||||||
Environment Variables:
|
|
||||||
|
|
||||||
OLLAMA_HOST The host:port to bind to (default "127.0.0.1:11434")
|
|
||||||
OLLAMA_ORIGINS A comma separated list of allowed origins
|
|
||||||
OLLAMA_MODELS The path to the models directory (default "~/.ollama/models")
|
|
||||||
OLLAMA_KEEP_ALIVE The duration that models stay loaded in memory (default "5m")
|
|
||||||
OLLAMA_DEBUG Set to 1 to enable additional debug logging
|
|
||||||
`)
|
|
||||||
|
|
||||||
pullCmd := &cobra.Command{
|
pullCmd := &cobra.Command{
|
||||||
Use: "pull MODEL",
|
Use: "pull MODEL",
|
||||||
@@ -1233,9 +1199,9 @@ Environment Variables:
|
|||||||
RunE: DeleteHandler,
|
RunE: DeleteHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
ollamaHostEnv := EnvironmentVar{"OLLAMA_HOST", "The host:port or base URL of the Ollama server (e.g. http://localhost:11434)"}
|
envVars := envconfig.AsMap()
|
||||||
ollamaNoHistoryEnv := EnvironmentVar{"OLLAMA_NOHISTORY", "Disable readline history"}
|
|
||||||
envs := []EnvironmentVar{ollamaHostEnv}
|
envs := []envconfig.EnvVar{envVars["OLLAMA_HOST"]}
|
||||||
|
|
||||||
for _, cmd := range []*cobra.Command{
|
for _, cmd := range []*cobra.Command{
|
||||||
createCmd,
|
createCmd,
|
||||||
@@ -1247,10 +1213,27 @@ Environment Variables:
|
|||||||
psCmd,
|
psCmd,
|
||||||
copyCmd,
|
copyCmd,
|
||||||
deleteCmd,
|
deleteCmd,
|
||||||
|
serveCmd,
|
||||||
} {
|
} {
|
||||||
switch cmd {
|
switch cmd {
|
||||||
case runCmd:
|
case runCmd:
|
||||||
appendEnvDocs(cmd, []EnvironmentVar{ollamaHostEnv, ollamaNoHistoryEnv})
|
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
|
||||||
|
case serveCmd:
|
||||||
|
appendEnvDocs(cmd, []envconfig.EnvVar{
|
||||||
|
envVars["OLLAMA_DEBUG"],
|
||||||
|
envVars["OLLAMA_HOST"],
|
||||||
|
envVars["OLLAMA_KEEP_ALIVE"],
|
||||||
|
envVars["OLLAMA_MAX_LOADED_MODELS"],
|
||||||
|
envVars["OLLAMA_MAX_QUEUE"],
|
||||||
|
envVars["OLLAMA_MODELS"],
|
||||||
|
envVars["OLLAMA_NUM_PARALLEL"],
|
||||||
|
envVars["OLLAMA_NOPRUNE"],
|
||||||
|
envVars["OLLAMA_ORIGINS"],
|
||||||
|
envVars["OLLAMA_TMPDIR"],
|
||||||
|
envVars["OLLAMA_FLASH_ATTENTION"],
|
||||||
|
envVars["OLLAMA_LLM_LIBRARY"],
|
||||||
|
envVars["OLLAMA_MAX_VRAM"],
|
||||||
|
})
|
||||||
default:
|
default:
|
||||||
appendEnvDocs(cmd, envs)
|
appendEnvDocs(cmd, envs)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,13 +8,14 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
@@ -183,7 +184,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if os.Getenv("OLLAMA_NOHISTORY") != "" {
|
if envconfig.NoHistory {
|
||||||
scanner.HistoryDisable()
|
scanner.HistoryDisable()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
@@ -85,11 +86,11 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
|
|||||||
`
|
`
|
||||||
|
|
||||||
tmpl, err := template.New("").Parse(expectedModelfile)
|
tmpl, err := template.New("").Parse(expectedModelfile)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
err = tmpl.Execute(&buf, opts)
|
err = tmpl.Execute(&buf, opts)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, buf.String(), mf)
|
assert.Equal(t, buf.String(), mf)
|
||||||
|
|
||||||
opts.ParentModel = "horseshark"
|
opts.ParentModel = "horseshark"
|
||||||
@@ -107,10 +108,10 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
|
|||||||
`
|
`
|
||||||
|
|
||||||
tmpl, err = template.New("").Parse(expectedModelfile)
|
tmpl, err = template.New("").Parse(expectedModelfile)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var parentBuf bytes.Buffer
|
var parentBuf bytes.Buffer
|
||||||
err = tmpl.Execute(&parentBuf, opts)
|
err = tmpl.Execute(&parentBuf, opts)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, parentBuf.String(), mf)
|
assert.Equal(t, parentBuf.String(), mf)
|
||||||
}
|
}
|
||||||
|
|||||||
27
cmd/start.go
Normal file
27
cmd/start.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
//go:build darwin || windows
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func waitForServer(ctx context.Context, client *api.Client) error {
|
||||||
|
// wait for the server to start
|
||||||
|
timeout := time.After(5 * time.Second)
|
||||||
|
tick := time.Tick(500 * time.Millisecond)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
return errors.New("timed out waiting for server to start")
|
||||||
|
case <-tick:
|
||||||
|
if err := client.Heartbeat(ctx); err == nil {
|
||||||
|
return nil // server has started
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -189,7 +189,7 @@ func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
|
|||||||
if params.VocabSize > len(v.Tokens) {
|
if params.VocabSize > len(v.Tokens) {
|
||||||
missingTokens := params.VocabSize - len(v.Tokens)
|
missingTokens := params.VocabSize - len(v.Tokens)
|
||||||
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
|
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
|
||||||
for cnt := 0; cnt < missingTokens; cnt++ {
|
for cnt := range missingTokens {
|
||||||
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
|
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
|
||||||
v.Scores = append(v.Scores, -1)
|
v.Scores = append(v.Scores, -1)
|
||||||
v.Types = append(v.Types, tokenTypeUserDefined)
|
v.Types = append(v.Types, tokenTypeUserDefined)
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ func addOnes(data []float32, vectorSize int) ([]float32, error) {
|
|||||||
f32s = append(f32s, t...)
|
f32s = append(f32s, t...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
return f32s, nil
|
return f32s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -119,11 +119,12 @@ func llamaRepack(name string, params *Params, data []float32, shape []uint64) ([
|
|||||||
}
|
}
|
||||||
|
|
||||||
var heads int
|
var heads int
|
||||||
if strings.HasSuffix(name, "attn_q.weight") {
|
switch {
|
||||||
|
case strings.HasSuffix(name, "attn_q.weight"):
|
||||||
heads = params.AttentionHeads
|
heads = params.AttentionHeads
|
||||||
} else if strings.HasSuffix(name, "attn_k.weight") {
|
case strings.HasSuffix(name, "attn_k.weight"):
|
||||||
heads = cmp.Or(params.KeyValHeads, params.AttentionHeads)
|
heads = cmp.Or(params.KeyValHeads, params.AttentionHeads)
|
||||||
} else {
|
default:
|
||||||
return nil, fmt.Errorf("unknown tensor name: %s", name)
|
return nil, fmt.Errorf("unknown tensor name: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
|
|||||||
Name: name,
|
Name: name,
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
Offset: offset,
|
Offset: offset,
|
||||||
Shape: shape[:],
|
Shape: shape,
|
||||||
}
|
}
|
||||||
|
|
||||||
t.WriterTo = safetensorWriterTo{
|
t.WriterTo = safetensorWriterTo{
|
||||||
|
|||||||
@@ -85,11 +85,8 @@ func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, e
|
|||||||
|
|
||||||
sha256sum := sha256.New()
|
sha256sum := sha256.New()
|
||||||
for _, pt := range t.PreTokenizer.PreTokenizers {
|
for _, pt := range t.PreTokenizer.PreTokenizers {
|
||||||
switch pt.Type {
|
if pt.Type == "Split" && pt.Pattern.Regex != "" {
|
||||||
case "Split":
|
sha256sum.Write([]byte(pt.Pattern.Regex))
|
||||||
if pt.Pattern.Regex != "" {
|
|
||||||
sha256sum.Write([]byte(pt.Pattern.Regex))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
|
|||||||
Name: ggufName,
|
Name: ggufName,
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
Offset: offset, // calculate the offset
|
Offset: offset, // calculate the offset
|
||||||
Shape: shape[:],
|
Shape: shape,
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor.WriterTo = torchWriterTo{
|
tensor.WriterTo = torchWriterTo{
|
||||||
@@ -104,7 +104,6 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
|
|||||||
}
|
}
|
||||||
|
|
||||||
return tensors, nil
|
return tensors, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAltParams(dirpath string) (*Params, error) {
|
func getAltParams(dirpath string) (*Params, error) {
|
||||||
|
|||||||
50
docs/api.md
50
docs/api.md
@@ -12,6 +12,7 @@
|
|||||||
- [Pull a Model](#pull-a-model)
|
- [Pull a Model](#pull-a-model)
|
||||||
- [Push a Model](#push-a-model)
|
- [Push a Model](#push-a-model)
|
||||||
- [Generate Embeddings](#generate-embeddings)
|
- [Generate Embeddings](#generate-embeddings)
|
||||||
|
- [List Running Models](#list-running-models)
|
||||||
|
|
||||||
## Conventions
|
## Conventions
|
||||||
|
|
||||||
@@ -249,7 +250,7 @@ curl http://localhost:11434/api/generate -d '{
|
|||||||
|
|
||||||
#### Request (Reproducible outputs)
|
#### Request (Reproducible outputs)
|
||||||
|
|
||||||
For reproducible outputs, set `temperature` to 0 and `seed` to a number:
|
For reproducible outputs, set `seed` to a number:
|
||||||
|
|
||||||
##### Request
|
##### Request
|
||||||
|
|
||||||
@@ -258,8 +259,7 @@ curl http://localhost:11434/api/generate -d '{
|
|||||||
"model": "mistral",
|
"model": "mistral",
|
||||||
"prompt": "Why is the sky blue?",
|
"prompt": "Why is the sky blue?",
|
||||||
"options": {
|
"options": {
|
||||||
"seed": 123,
|
"seed": 123
|
||||||
"temperature": 0
|
|
||||||
}
|
}
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
@@ -1035,3 +1035,47 @@ curl http://localhost:11434/api/embeddings -d '{
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## List Running Models
|
||||||
|
```shell
|
||||||
|
GET /api/ps
|
||||||
|
```
|
||||||
|
|
||||||
|
List models that are currently loaded into memory.
|
||||||
|
|
||||||
|
#### Examples
|
||||||
|
|
||||||
|
### Request
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/api/ps
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Response
|
||||||
|
|
||||||
|
A single JSON object will be returned.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"name": "mistral:latest",
|
||||||
|
"model": "mistral:latest",
|
||||||
|
"size": 5137025024,
|
||||||
|
"digest": "2ae6f6dd7a3dd734790bbbf58b8909a606e0e7e97e94b7604e0aa7ae4490e6d8",
|
||||||
|
"details": {
|
||||||
|
"parent_model": "",
|
||||||
|
"format": "gguf",
|
||||||
|
"family": "llama",
|
||||||
|
"families": [
|
||||||
|
"llama"
|
||||||
|
],
|
||||||
|
"parameter_size": "7.2B",
|
||||||
|
"quantization_level": "Q4_0"
|
||||||
|
},
|
||||||
|
"expires_at": "2024-06-04T14:38:31.83753-07:00",
|
||||||
|
"size_vram": 5137025024
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|||||||
227
docs/import.md
227
docs/import.md
@@ -1,170 +1,99 @@
|
|||||||
# Import a model
|
# Import
|
||||||
|
|
||||||
This guide walks through importing a GGUF, PyTorch or Safetensors model.
|
GGUF models and select Safetensors models can be imported directly into Ollama.
|
||||||
|
|
||||||
## Importing (GGUF)
|
## Import GGUF
|
||||||
|
|
||||||
### Step 1: Write a `Modelfile`
|
A binary GGUF file can be imported directly into Ollama through a Modelfile.
|
||||||
|
|
||||||
Start by creating a `Modelfile`. This file is the blueprint for your model, specifying weights, parameters, prompt templates and more.
|
```dockerfile
|
||||||
|
FROM /path/to/file.gguf
|
||||||
```
|
|
||||||
FROM ./mistral-7b-v0.1.Q4_0.gguf
|
|
||||||
```
|
```
|
||||||
|
|
||||||
(Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`:
|
## Import Safetensors
|
||||||
|
|
||||||
```
|
If the model being imported is one of these architectures, it can be imported directly into Ollama through a Modelfile:
|
||||||
FROM ./mistral-7b-v0.1.Q4_0.gguf
|
|
||||||
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
|
- LlamaForCausalLM
|
||||||
|
- MistralForCausalLM
|
||||||
|
- GemmaForCausalLM
|
||||||
|
|
||||||
|
```dockerfile
|
||||||
|
FROM /path/to/safetensors/directory
|
||||||
```
|
```
|
||||||
|
|
||||||
### Step 2: Create the Ollama model
|
For architectures not directly convertable by Ollama, see llama.cpp's [guide](https://github.com/ggerganov/llama.cpp/blob/master/README.md#prepare-and-quantize) on conversion. After conversion, see [Import GGUF](#import-gguf).
|
||||||
|
|
||||||
Finally, create a model from your `Modelfile`:
|
## Automatic Quantization
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Automatic quantization requires v0.1.35 or higher.
|
||||||
|
|
||||||
|
Ollama is capable of quantizing FP16 or FP32 models to any of the supported quantizations with the `-q/--quantize` flag in `ollama create`.
|
||||||
|
|
||||||
|
```dockerfile
|
||||||
|
FROM /path/to/my/gemma/f16/model
|
||||||
```
|
```
|
||||||
ollama create example -f Modelfile
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 3: Run your model
|
|
||||||
|
|
||||||
Next, test the model with `ollama run`:
|
|
||||||
|
|
||||||
```
|
|
||||||
ollama run example "What is your favourite condiment?"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Importing (PyTorch & Safetensors)
|
|
||||||
|
|
||||||
> Importing from PyTorch and Safetensors is a longer process than importing from GGUF. Improvements that make it easier are a work in progress.
|
|
||||||
|
|
||||||
### Setup
|
|
||||||
|
|
||||||
First, clone the `ollama/ollama` repo:
|
|
||||||
|
|
||||||
```
|
|
||||||
git clone git@github.com:ollama/ollama.git ollama
|
|
||||||
cd ollama
|
|
||||||
```
|
|
||||||
|
|
||||||
and then fetch its `llama.cpp` submodule:
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
git submodule init
|
$ ollama create -q Q4_K_M mymodel
|
||||||
git submodule update llm/llama.cpp
|
transferring model data
|
||||||
|
quantizing F16 model to Q4_K_M
|
||||||
|
creating new layer sha256:735e246cc1abfd06e9cdcf95504d6789a6cd1ad7577108a70d9902fef503c1bd
|
||||||
|
creating new layer sha256:0853f0ad24e5865173bbf9ffcc7b0f5d56b66fd690ab1009867e45e7d2c4db0f
|
||||||
|
writing manifest
|
||||||
|
success
|
||||||
```
|
```
|
||||||
|
|
||||||
Next, install the Python dependencies:
|
### Supported Quantizations
|
||||||
|
|
||||||
```
|
<details>
|
||||||
python3 -m venv llm/llama.cpp/.venv
|
<summary>Legacy Quantization</summary>
|
||||||
source llm/llama.cpp/.venv/bin/activate
|
|
||||||
pip install -r llm/llama.cpp/requirements.txt
|
- `Q4_0`
|
||||||
|
- `Q4_1`
|
||||||
|
- `Q5_0`
|
||||||
|
- `Q5_1`
|
||||||
|
- `Q8_0`
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>K-means Quantization</summary>`
|
||||||
|
|
||||||
|
- `Q3_K_S`
|
||||||
|
- `Q3_K_M`
|
||||||
|
- `Q3_K_L`
|
||||||
|
- `Q4_K_S`
|
||||||
|
- `Q4_K_M`
|
||||||
|
- `Q5_K_S`
|
||||||
|
- `Q5_K_M`
|
||||||
|
- `Q6_K`
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Activation-aware Weight Quantization (i.e. IQ) are not currently supported for automatic quantization however you can still import the quantized model into Ollama, see [Import GGUF](#import-gguf).
|
||||||
|
|
||||||
|
## Template Detection
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Template detection requires v0.1.42 or higher.
|
||||||
|
|
||||||
|
Ollama uses model metadata, specifically `tokenizer.chat_template`, to automatically create a template appropriate for the model you're importing.
|
||||||
|
|
||||||
|
```dockerfile
|
||||||
|
FROM /path/to/my/gemma/model
|
||||||
```
|
```
|
||||||
|
|
||||||
Then build the `quantize` tool:
|
```shell
|
||||||
|
$ ollama create mymodel
|
||||||
```
|
transferring model data
|
||||||
make -C llm/llama.cpp quantize
|
using autodetected template gemma-instruct
|
||||||
|
creating new layer sha256:baa2a0edc27d19cc6b7537578a9a7ba1a4e3214dc185ed5ae43692b319af7b84
|
||||||
|
creating new layer sha256:ba66c3309914dbef07e5149a648fd1877f030d337a4f240d444ea335008943cb
|
||||||
|
writing manifest
|
||||||
|
success
|
||||||
```
|
```
|
||||||
|
|
||||||
### Clone the HuggingFace repository (optional)
|
Defining a template in the Modelfile will disable this feature which may be useful if you want to use a different template than the autodetected one.
|
||||||
|
|
||||||
If the model is currently hosted in a HuggingFace repository, first clone that repository to download the raw model.
|
|
||||||
|
|
||||||
Install [Git LFS](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage), verify it's installed, and then clone the model's repository:
|
|
||||||
|
|
||||||
```
|
|
||||||
git lfs install
|
|
||||||
git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 model
|
|
||||||
```
|
|
||||||
|
|
||||||
### Convert the model
|
|
||||||
|
|
||||||
> Note: some model architectures require using specific convert scripts. For example, Qwen models require running `convert-hf-to-gguf.py` instead of `convert.py`
|
|
||||||
|
|
||||||
```
|
|
||||||
python llm/llama.cpp/convert.py ./model --outtype f16 --outfile converted.bin
|
|
||||||
```
|
|
||||||
|
|
||||||
### Quantize the model
|
|
||||||
|
|
||||||
```
|
|
||||||
llm/llama.cpp/quantize converted.bin quantized.bin q4_0
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 3: Write a `Modelfile`
|
|
||||||
|
|
||||||
Next, create a `Modelfile` for your model:
|
|
||||||
|
|
||||||
```
|
|
||||||
FROM quantized.bin
|
|
||||||
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 4: Create the Ollama model
|
|
||||||
|
|
||||||
Finally, create a model from your `Modelfile`:
|
|
||||||
|
|
||||||
```
|
|
||||||
ollama create example -f Modelfile
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 5: Run your model
|
|
||||||
|
|
||||||
Next, test the model with `ollama run`:
|
|
||||||
|
|
||||||
```
|
|
||||||
ollama run example "What is your favourite condiment?"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Publishing your model (optional – early alpha)
|
|
||||||
|
|
||||||
Publishing models is in early alpha. If you'd like to publish your model to share with others, follow these steps:
|
|
||||||
|
|
||||||
1. Create [an account](https://ollama.com/signup)
|
|
||||||
2. Copy your Ollama public key:
|
|
||||||
- macOS: `cat ~/.ollama/id_ed25519.pub | pbcopy`
|
|
||||||
- Windows: `type %USERPROFILE%\.ollama\id_ed25519.pub`
|
|
||||||
- Linux: `cat /usr/share/ollama/.ollama/id_ed25519.pub`
|
|
||||||
3. Add your public key to your [Ollama account](https://ollama.com/settings/keys)
|
|
||||||
|
|
||||||
Next, copy your model to your username's namespace:
|
|
||||||
|
|
||||||
```
|
|
||||||
ollama cp example <your username>/example
|
|
||||||
```
|
|
||||||
|
|
||||||
> Note: model names may only contain lowercase letters, digits, and the characters `.`, `-`, and `_`.
|
|
||||||
|
|
||||||
Then push the model:
|
|
||||||
|
|
||||||
```
|
|
||||||
ollama push <your username>/example
|
|
||||||
```
|
|
||||||
|
|
||||||
After publishing, your model will be available at `https://ollama.com/<your username>/example`.
|
|
||||||
|
|
||||||
## Quantization reference
|
|
||||||
|
|
||||||
The quantization options are as follow (from highest highest to lowest levels of quantization). Note: some architectures such as Falcon do not support K quants.
|
|
||||||
|
|
||||||
- `q2_K`
|
|
||||||
- `q3_K`
|
|
||||||
- `q3_K_S`
|
|
||||||
- `q3_K_M`
|
|
||||||
- `q3_K_L`
|
|
||||||
- `q4_0` (recommended)
|
|
||||||
- `q4_1`
|
|
||||||
- `q4_K`
|
|
||||||
- `q4_K_S`
|
|
||||||
- `q4_K_M`
|
|
||||||
- `q5_0`
|
|
||||||
- `q5_1`
|
|
||||||
- `q5_K`
|
|
||||||
- `q5_K_S`
|
|
||||||
- `q5_K_M`
|
|
||||||
- `q6_K`
|
|
||||||
- `q8_0`
|
|
||||||
- `f16`
|
|
||||||
|
|||||||
@@ -100,6 +100,16 @@ sudo curl -L https://ollama.com/download/ollama-linux-amd64 -o /usr/bin/ollama
|
|||||||
sudo chmod +x /usr/bin/ollama
|
sudo chmod +x /usr/bin/ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Installing specific versions
|
||||||
|
|
||||||
|
Use `OLLAMA_VERSION` environment variable with the install script to install a specific version of Ollama, including pre-releases. You can find the version numbers in the [releases page](https://github.com/ollama/ollama/releases).
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.1.32 sh
|
||||||
|
```
|
||||||
|
|
||||||
## Viewing logs
|
## Viewing logs
|
||||||
|
|
||||||
To view logs of Ollama running as a startup service, run:
|
To view logs of Ollama running as a startup service, run:
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ Make sure you've set up the container runtime first as described in [docker.md](
|
|||||||
|
|
||||||
Sometimes the container runtime can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
|
Sometimes the container runtime can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
|
||||||
|
|
||||||
|
- Is the container runtime working? Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU.
|
||||||
- Is the uvm driver not loaded? `sudo nvidia-modprobe -u`
|
- Is the uvm driver not loaded? `sudo nvidia-modprobe -u`
|
||||||
- Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm`
|
- Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm`
|
||||||
- Try rebooting
|
- Try rebooting
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ all_splits = text_splitter.split_documents(data)
|
|||||||
```
|
```
|
||||||
|
|
||||||
It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install chromadb`
|
It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install chromadb`
|
||||||
|
We also need to pull embedding model: `ollama pull nomic-embed-text`
|
||||||
```python
|
```python
|
||||||
from langchain.embeddings import OllamaEmbeddings
|
from langchain.embeddings import OllamaEmbeddings
|
||||||
from langchain.vectorstores import Chroma
|
from langchain.vectorstores import Chroma
|
||||||
@@ -68,7 +68,8 @@ The next thing is to send the question and the relevant parts of the docs to the
|
|||||||
```python
|
```python
|
||||||
from langchain.chains import RetrievalQA
|
from langchain.chains import RetrievalQA
|
||||||
qachain=RetrievalQA.from_chain_type(ollama, retriever=vectorstore.as_retriever())
|
qachain=RetrievalQA.from_chain_type(ollama, retriever=vectorstore.as_retriever())
|
||||||
qachain.invoke({"query": question})
|
res = qachain.invoke({"query": question})
|
||||||
|
print(res['result'])
|
||||||
```
|
```
|
||||||
|
|
||||||
The answer received from this chain was:
|
The answer received from this chain was:
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package envconfig
|
package envconfig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -10,11 +12,27 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type OllamaHost struct {
|
||||||
|
Scheme string
|
||||||
|
Host string
|
||||||
|
Port string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o OllamaHost) String() string {
|
||||||
|
return fmt.Sprintf("%s://%s:%s", o.Scheme, o.Host, o.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// Set via OLLAMA_ORIGINS in the environment
|
// Set via OLLAMA_ORIGINS in the environment
|
||||||
AllowOrigins []string
|
AllowOrigins []string
|
||||||
// Set via OLLAMA_DEBUG in the environment
|
// Set via OLLAMA_DEBUG in the environment
|
||||||
Debug bool
|
Debug bool
|
||||||
|
// Experimental flash attention
|
||||||
|
FlashAttention bool
|
||||||
|
// Set via OLLAMA_KEEP_ALIVE in the environment
|
||||||
|
KeepAlive string
|
||||||
// Set via OLLAMA_LLM_LIBRARY in the environment
|
// Set via OLLAMA_LLM_LIBRARY in the environment
|
||||||
LLMLibrary string
|
LLMLibrary string
|
||||||
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
|
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
|
||||||
@@ -23,34 +41,54 @@ var (
|
|||||||
MaxQueuedRequests int
|
MaxQueuedRequests int
|
||||||
// Set via OLLAMA_MAX_VRAM in the environment
|
// Set via OLLAMA_MAX_VRAM in the environment
|
||||||
MaxVRAM uint64
|
MaxVRAM uint64
|
||||||
|
// Set via OLLAMA_NOHISTORY in the environment
|
||||||
|
NoHistory bool
|
||||||
// Set via OLLAMA_NOPRUNE in the environment
|
// Set via OLLAMA_NOPRUNE in the environment
|
||||||
NoPrune bool
|
NoPrune bool
|
||||||
// Set via OLLAMA_NUM_PARALLEL in the environment
|
// Set via OLLAMA_NUM_PARALLEL in the environment
|
||||||
NumParallel int
|
NumParallel int
|
||||||
|
// Set via OLLAMA_HOST in the environment
|
||||||
|
Host *OllamaHost
|
||||||
// Set via OLLAMA_RUNNERS_DIR in the environment
|
// Set via OLLAMA_RUNNERS_DIR in the environment
|
||||||
RunnersDir string
|
RunnersDir string
|
||||||
// Set via OLLAMA_TMPDIR in the environment
|
// Set via OLLAMA_TMPDIR in the environment
|
||||||
TmpDir string
|
TmpDir string
|
||||||
// Experimental flash attention
|
|
||||||
FlashAttention bool
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func AsMap() map[string]string {
|
type EnvVar struct {
|
||||||
return map[string]string{
|
Name string
|
||||||
"OLLAMA_ORIGINS": fmt.Sprintf("%v", AllowOrigins),
|
Value any
|
||||||
"OLLAMA_DEBUG": fmt.Sprintf("%v", Debug),
|
Description string
|
||||||
"OLLAMA_LLM_LIBRARY": fmt.Sprintf("%v", LLMLibrary),
|
}
|
||||||
"OLLAMA_MAX_LOADED_MODELS": fmt.Sprintf("%v", MaxRunners),
|
|
||||||
"OLLAMA_MAX_QUEUE": fmt.Sprintf("%v", MaxQueuedRequests),
|
func AsMap() map[string]EnvVar {
|
||||||
"OLLAMA_MAX_VRAM": fmt.Sprintf("%v", MaxVRAM),
|
return map[string]EnvVar{
|
||||||
"OLLAMA_NOPRUNE": fmt.Sprintf("%v", NoPrune),
|
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug, "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||||
"OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel),
|
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"},
|
||||||
"OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir),
|
"OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||||
"OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir),
|
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
|
||||||
"OLLAMA_FLASH_ATTENTION": fmt.Sprintf("%v", FlashAttention),
|
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
|
||||||
|
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"},
|
||||||
|
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
|
||||||
|
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
|
||||||
|
"OLLAMA_MODELS": {"OLLAMA_MODELS", "", "The path to the models directory"},
|
||||||
|
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
|
||||||
|
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
|
||||||
|
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"},
|
||||||
|
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
|
||||||
|
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
|
||||||
|
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Values() map[string]string {
|
||||||
|
vals := make(map[string]string)
|
||||||
|
for k, v := range AsMap() {
|
||||||
|
vals[k] = fmt.Sprintf("%v", v.Value)
|
||||||
|
}
|
||||||
|
return vals
|
||||||
|
}
|
||||||
|
|
||||||
var defaultAllowOrigins = []string{
|
var defaultAllowOrigins = []string{
|
||||||
"localhost",
|
"localhost",
|
||||||
"127.0.0.1",
|
"127.0.0.1",
|
||||||
@@ -104,7 +142,7 @@ func LoadConfig() {
|
|||||||
var paths []string
|
var paths []string
|
||||||
for _, root := range []string{filepath.Dir(appExe), cwd} {
|
for _, root := range []string{filepath.Dir(appExe), cwd} {
|
||||||
paths = append(paths,
|
paths = append(paths,
|
||||||
filepath.Join(root),
|
root,
|
||||||
filepath.Join(root, "windows-"+runtime.GOARCH),
|
filepath.Join(root, "windows-"+runtime.GOARCH),
|
||||||
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
|
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
|
||||||
)
|
)
|
||||||
@@ -147,6 +185,10 @@ func LoadConfig() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if nohistory := clean("OLLAMA_NOHISTORY"); nohistory != "" {
|
||||||
|
NoHistory = true
|
||||||
|
}
|
||||||
|
|
||||||
if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" {
|
if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" {
|
||||||
NoPrune = true
|
NoPrune = true
|
||||||
}
|
}
|
||||||
@@ -158,11 +200,17 @@ func LoadConfig() {
|
|||||||
AllowOrigins = append(AllowOrigins,
|
AllowOrigins = append(AllowOrigins,
|
||||||
fmt.Sprintf("http://%s", allowOrigin),
|
fmt.Sprintf("http://%s", allowOrigin),
|
||||||
fmt.Sprintf("https://%s", allowOrigin),
|
fmt.Sprintf("https://%s", allowOrigin),
|
||||||
fmt.Sprintf("http://%s:*", allowOrigin),
|
fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")),
|
||||||
fmt.Sprintf("https://%s:*", allowOrigin),
|
fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AllowOrigins = append(AllowOrigins,
|
||||||
|
"app://*",
|
||||||
|
"file://*",
|
||||||
|
"tauri://*",
|
||||||
|
)
|
||||||
|
|
||||||
maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
|
maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
|
||||||
if maxRunners != "" {
|
if maxRunners != "" {
|
||||||
m, err := strconv.Atoi(maxRunners)
|
m, err := strconv.Atoi(maxRunners)
|
||||||
@@ -181,4 +229,56 @@ func LoadConfig() {
|
|||||||
MaxQueuedRequests = p
|
MaxQueuedRequests = p
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
KeepAlive = clean("OLLAMA_KEEP_ALIVE")
|
||||||
|
|
||||||
|
var err error
|
||||||
|
Host, err = getOllamaHost()
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("invalid setting", "OLLAMA_HOST", Host, "error", err, "using default port", Host.Port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOllamaHost() (*OllamaHost, error) {
|
||||||
|
defaultPort := "11434"
|
||||||
|
|
||||||
|
hostVar := os.Getenv("OLLAMA_HOST")
|
||||||
|
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
|
||||||
|
|
||||||
|
scheme, hostport, ok := strings.Cut(hostVar, "://")
|
||||||
|
switch {
|
||||||
|
case !ok:
|
||||||
|
scheme, hostport = "http", hostVar
|
||||||
|
case scheme == "http":
|
||||||
|
defaultPort = "80"
|
||||||
|
case scheme == "https":
|
||||||
|
defaultPort = "443"
|
||||||
|
}
|
||||||
|
|
||||||
|
// trim trailing slashes
|
||||||
|
hostport = strings.TrimRight(hostport, "/")
|
||||||
|
|
||||||
|
host, port, err := net.SplitHostPort(hostport)
|
||||||
|
if err != nil {
|
||||||
|
host, port = "127.0.0.1", defaultPort
|
||||||
|
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
|
||||||
|
host = ip.String()
|
||||||
|
} else if hostport != "" {
|
||||||
|
host = hostport
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
|
||||||
|
return &OllamaHost{
|
||||||
|
Scheme: scheme,
|
||||||
|
Host: host,
|
||||||
|
Port: defaultPort,
|
||||||
|
}, ErrInvalidHostPort
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OllamaHost{
|
||||||
|
Scheme: scheme,
|
||||||
|
Host: host,
|
||||||
|
Port: port,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
71
envconfig/config_test.go
Normal file
71
envconfig/config_test.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package envconfig
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig(t *testing.T) {
|
||||||
|
Debug = false // Reset whatever was loaded in init()
|
||||||
|
t.Setenv("OLLAMA_DEBUG", "")
|
||||||
|
LoadConfig()
|
||||||
|
require.False(t, Debug)
|
||||||
|
t.Setenv("OLLAMA_DEBUG", "false")
|
||||||
|
LoadConfig()
|
||||||
|
require.False(t, Debug)
|
||||||
|
t.Setenv("OLLAMA_DEBUG", "1")
|
||||||
|
LoadConfig()
|
||||||
|
require.True(t, Debug)
|
||||||
|
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
|
||||||
|
LoadConfig()
|
||||||
|
require.True(t, FlashAttention)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientFromEnvironment(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
value string
|
||||||
|
expect string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
hostTestCases := map[string]*testCase{
|
||||||
|
"empty": {value: "", expect: "127.0.0.1:11434"},
|
||||||
|
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
|
||||||
|
"only port": {value: ":1234", expect: ":1234"},
|
||||||
|
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
|
||||||
|
"hostname": {value: "example.com", expect: "example.com:11434"},
|
||||||
|
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
|
||||||
|
"zero port": {value: ":0", expect: ":0"},
|
||||||
|
"too large port": {value: ":66000", err: ErrInvalidHostPort},
|
||||||
|
"too small port": {value: ":-1", err: ErrInvalidHostPort},
|
||||||
|
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
|
||||||
|
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
|
||||||
|
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
|
||||||
|
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
|
||||||
|
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
|
||||||
|
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
|
||||||
|
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
|
||||||
|
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range hostTestCases {
|
||||||
|
t.Run(k, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", v.value)
|
||||||
|
LoadConfig()
|
||||||
|
|
||||||
|
oh, err := getOllamaHost()
|
||||||
|
if err != v.err {
|
||||||
|
t.Fatalf("expected %s, got %s", v.err, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
host := net.JoinHostPort(oh.Host, oh.Port)
|
||||||
|
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -77,13 +77,21 @@ LOADER_MAPPING = {
|
|||||||
|
|
||||||
|
|
||||||
def load_single_document(file_path: str) -> List[Document]:
|
def load_single_document(file_path: str) -> List[Document]:
|
||||||
ext = "." + file_path.rsplit(".", 1)[-1]
|
if os.path.getsize(file_path) != 0:
|
||||||
if ext in LOADER_MAPPING:
|
filename, ext = os.path.splitext(file_path)
|
||||||
loader_class, loader_args = LOADER_MAPPING[ext]
|
if ext in LOADER_MAPPING:
|
||||||
loader = loader_class(file_path, **loader_args)
|
loader_class, loader_args = LOADER_MAPPING[ext]
|
||||||
return loader.load()
|
try:
|
||||||
|
loader = loader_class(file_path, **loader_args)
|
||||||
|
if loader:
|
||||||
|
return loader.load()
|
||||||
|
except:
|
||||||
|
print(f"Corrupted file {file_path}. Ignoring it.")
|
||||||
|
else:
|
||||||
|
print(f"Unsupported file {file_path}. Ignoring it.")
|
||||||
|
else:
|
||||||
|
print(f"Empty file {file_path}. Ignoring it.")
|
||||||
|
|
||||||
raise ValueError(f"Unsupported file extension '{ext}'")
|
|
||||||
|
|
||||||
def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Document]:
|
def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
@@ -100,7 +108,8 @@ def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Docum
|
|||||||
results = []
|
results = []
|
||||||
with tqdm(total=len(filtered_files), desc='Loading new documents', ncols=80) as pbar:
|
with tqdm(total=len(filtered_files), desc='Loading new documents', ncols=80) as pbar:
|
||||||
for i, docs in enumerate(pool.imap_unordered(load_single_document, filtered_files)):
|
for i, docs in enumerate(pool.imap_unordered(load_single_document, filtered_files)):
|
||||||
results.extend(docs)
|
if docs:
|
||||||
|
results.extend(docs)
|
||||||
pbar.update()
|
pbar.update()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -11,4 +11,5 @@ tabulate==0.9.0
|
|||||||
pandoc==2.3
|
pandoc==2.3
|
||||||
pypandoc==1.11
|
pypandoc==1.11
|
||||||
tqdm==4.66.1
|
tqdm==4.66.1
|
||||||
sentence_transformers==2.2.2
|
sentence_transformers==2.2.2
|
||||||
|
numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
|
||||||
@@ -9,6 +9,7 @@ def chat(messages):
|
|||||||
r = requests.post(
|
r = requests.post(
|
||||||
"http://0.0.0.0:11434/api/chat",
|
"http://0.0.0.0:11434/api/chat",
|
||||||
json={"model": model, "messages": messages, "stream": True},
|
json={"model": model, "messages": messages, "stream": True},
|
||||||
|
stream=True
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
output = ""
|
output = ""
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestHumanNumber(t *testing.T) {
|
func TestHumanNumber(t *testing.T) {
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
input uint64
|
input uint64
|
||||||
expected string
|
expected string
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -16,6 +16,7 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/agnivade/levenshtein v1.1.1
|
||||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||||
github.com/mattn/go-runewidth v0.0.14
|
github.com/mattn/go-runewidth v0.0.14
|
||||||
github.com/nlpodyssey/gopickle v0.3.0
|
github.com/nlpodyssey/gopickle v0.3.0
|
||||||
|
|||||||
6
go.sum
6
go.sum
@@ -4,10 +4,14 @@ dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7
|
|||||||
gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8=
|
gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8=
|
||||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||||
|
github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8=
|
||||||
|
github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo=
|
||||||
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
|
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
|
||||||
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
|
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
|
||||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6ICHXqG5hm0ZW5IHyeEJXoIJSOZeBLmWPNeIQ=
|
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6ICHXqG5hm0ZW5IHyeEJXoIJSOZeBLmWPNeIQ=
|
||||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
|
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
|
||||||
|
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
||||||
|
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
||||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||||
@@ -36,6 +40,8 @@ github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLc
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g=
|
||||||
|
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
|
||||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
// "strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
)
|
)
|
||||||
@@ -65,7 +65,7 @@ func AMDGetGPUInfo() []GpuInfo {
|
|||||||
|
|
||||||
slog.Debug("detected hip devices", "count", count)
|
slog.Debug("detected hip devices", "count", count)
|
||||||
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
|
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
|
||||||
for i := 0; i < count; i++ {
|
for i := range count {
|
||||||
err = hl.HipSetDevice(i)
|
err = hl.HipSetDevice(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("set device", "id", i, "error", err)
|
slog.Warn("set device", "id", i, "error", err)
|
||||||
@@ -108,10 +108,10 @@ func AMDGetGPUInfo() []GpuInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
||||||
//if totalMemory < IGPUMemLimit {
|
if totalMemory < IGPUMemLimit {
|
||||||
// slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
|
slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
|
||||||
// continue
|
continue
|
||||||
//}
|
}
|
||||||
|
|
||||||
// TODO revisit this once ROCm v6 is available on windows.
|
// TODO revisit this once ROCm v6 is available on windows.
|
||||||
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
|
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -80,7 +80,7 @@ func cleanupTmpDirs() {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
pid, err := strconv.Atoi(string(raw))
|
pid, err := strconv.Atoi(string(raw))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if proc, err := os.FindProcess(int(pid)); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
if proc, err := os.FindProcess(pid); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
||||||
// Another running ollama, ignore this tmpdir
|
// Another running ollama, ignore this tmpdir
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,5 +18,4 @@ func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
|||||||
ids = append(ids, info.ID)
|
ids = append(ids, info.ID)
|
||||||
}
|
}
|
||||||
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
100
gpu/gpu.go
100
gpu/gpu.go
@@ -20,14 +20,15 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/server/envconfig"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type handles struct {
|
type handles struct {
|
||||||
deviceCount int
|
deviceCount int
|
||||||
cudart *C.cudart_handle_t
|
cudart *C.cudart_handle_t
|
||||||
nvcuda *C.nvcuda_handle_t
|
nvcuda *C.nvcuda_handle_t
|
||||||
|
oneapi *C.oneapi_handle_t
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -80,6 +81,15 @@ var NvcudaWindowsGlobs = []string{
|
|||||||
"c:\\windows\\system*\\nvcuda.dll",
|
"c:\\windows\\system*\\nvcuda.dll",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var OneapiWindowsGlobs = []string{
|
||||||
|
"c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
|
||||||
|
}
|
||||||
|
|
||||||
|
var OneapiLinuxGlobs = []string{
|
||||||
|
"/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*",
|
||||||
|
"/usr/lib*/libze_intel_gpu.so*",
|
||||||
|
}
|
||||||
|
|
||||||
// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
|
// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
|
||||||
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
||||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||||
@@ -141,6 +151,7 @@ func initGPUHandles() *handles {
|
|||||||
return gpuHandles
|
return gpuHandles
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return gpuHandles
|
return gpuHandles
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,44 +187,46 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
resp := []GpuInfo{}
|
resp := []GpuInfo{}
|
||||||
|
|
||||||
// NVIDIA first
|
// NVIDIA first
|
||||||
for i := 0; i < gpuHandles.deviceCount; i++ {
|
for i := range gpuHandles.deviceCount {
|
||||||
// TODO once we support CPU compilation variants of GPU libraries refine this...
|
// TODO once we support CPU compilation variants of GPU libraries refine this...
|
||||||
if cpuVariant == "" && runtime.GOARCH == "amd64" {
|
if cpuVariant == "" && runtime.GOARCH == "amd64" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
gpuInfo := GpuInfo{
|
if gpuHandles.cudart != nil || gpuHandles.nvcuda != nil {
|
||||||
Library: "cuda",
|
gpuInfo := GpuInfo{
|
||||||
}
|
Library: "cuda",
|
||||||
var driverMajor int
|
}
|
||||||
var driverMinor int
|
var driverMajor int
|
||||||
if gpuHandles.cudart != nil {
|
var driverMinor int
|
||||||
C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
|
if gpuHandles.cudart != nil {
|
||||||
} else {
|
C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
|
||||||
C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo)
|
} else {
|
||||||
driverMajor = int(gpuHandles.nvcuda.driver_major)
|
C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo)
|
||||||
driverMinor = int(gpuHandles.nvcuda.driver_minor)
|
driverMajor = int(gpuHandles.nvcuda.driver_major)
|
||||||
}
|
driverMinor = int(gpuHandles.nvcuda.driver_minor)
|
||||||
if memInfo.err != nil {
|
}
|
||||||
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
if memInfo.err != nil {
|
||||||
C.free(unsafe.Pointer(memInfo.err))
|
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||||
continue
|
C.free(unsafe.Pointer(memInfo.err))
|
||||||
}
|
continue
|
||||||
if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
|
}
|
||||||
slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
|
if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
|
||||||
continue
|
slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
|
||||||
}
|
continue
|
||||||
gpuInfo.TotalMemory = uint64(memInfo.total)
|
}
|
||||||
gpuInfo.FreeMemory = uint64(memInfo.free)
|
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||||
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||||
gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor)
|
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||||
gpuInfo.MinimumMemory = cudaMinimumMemory
|
gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor)
|
||||||
gpuInfo.DependencyPath = depPath
|
gpuInfo.MinimumMemory = cudaMinimumMemory
|
||||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
gpuInfo.DependencyPath = depPath
|
||||||
gpuInfo.DriverMajor = int(driverMajor)
|
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||||
gpuInfo.DriverMinor = int(driverMinor)
|
gpuInfo.DriverMajor = driverMajor
|
||||||
|
gpuInfo.DriverMinor = driverMinor
|
||||||
|
|
||||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||||
resp = append(resp, gpuInfo)
|
resp = append(resp, gpuInfo)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then AMD
|
// Then AMD
|
||||||
@@ -348,6 +361,23 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
|
|||||||
return 0, nil, ""
|
return 0, nil, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
|
||||||
|
var resp C.oneapi_init_resp_t
|
||||||
|
resp.oh.verbose = getVerboseState()
|
||||||
|
for _, libPath := range oneapiLibPaths {
|
||||||
|
lib := C.CString(libPath)
|
||||||
|
defer C.free(unsafe.Pointer(lib))
|
||||||
|
C.oneapi_init(lib, &resp)
|
||||||
|
if resp.err != nil {
|
||||||
|
slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err))
|
||||||
|
C.free(unsafe.Pointer(resp.err))
|
||||||
|
} else {
|
||||||
|
return int(resp.num_devices), &resp.oh, libPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, nil, ""
|
||||||
|
}
|
||||||
|
|
||||||
func getVerboseState() C.uint16_t {
|
func getVerboseState() C.uint16_t {
|
||||||
if envconfig.Debug {
|
if envconfig.Debug {
|
||||||
return C.uint16_t(1)
|
return C.uint16_t(1)
|
||||||
@@ -368,6 +398,8 @@ func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
|||||||
return cudaGetVisibleDevicesEnv(l)
|
return cudaGetVisibleDevicesEnv(l)
|
||||||
case "rocm":
|
case "rocm":
|
||||||
return rocmGetVisibleDevicesEnv(l)
|
return rocmGetVisibleDevicesEnv(l)
|
||||||
|
case "oneapi":
|
||||||
|
return oneapiGetVisibleDevicesEnv(l)
|
||||||
default:
|
default:
|
||||||
slog.Debug("no filter required for library " + l[0].Library)
|
slog.Debug("no filter required for library " + l[0].Library)
|
||||||
return "", ""
|
return "", ""
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ void cpu_check_ram(mem_info_t *resp);
|
|||||||
|
|
||||||
#include "gpu_info_cudart.h"
|
#include "gpu_info_cudart.h"
|
||||||
#include "gpu_info_nvcuda.h"
|
#include "gpu_info_nvcuda.h"
|
||||||
|
#include "gpu_info_oneapi.h"
|
||||||
|
|
||||||
#endif // __GPU_INFO_H__
|
#endif // __GPU_INFO_H__
|
||||||
#endif // __APPLE__
|
#endif // __APPLE__
|
||||||
214
gpu/gpu_info_oneapi.c
Normal file
214
gpu/gpu_info_oneapi.c
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
#ifndef __APPLE__
|
||||||
|
|
||||||
|
#include "gpu_info_oneapi.h"
|
||||||
|
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
|
||||||
|
{
|
||||||
|
ze_result_t ret;
|
||||||
|
resp->err = NULL;
|
||||||
|
const int buflen = 256;
|
||||||
|
char buf[buflen + 1];
|
||||||
|
int i;
|
||||||
|
struct lookup
|
||||||
|
{
|
||||||
|
char *s;
|
||||||
|
void **p;
|
||||||
|
} l[] = {
|
||||||
|
{"zesInit", (void *)&resp->oh.zesInit},
|
||||||
|
{"zesDriverGet", (void *)&resp->oh.zesDriverGet},
|
||||||
|
{"zesDeviceGet", (void *)&resp->oh.zesDeviceGet},
|
||||||
|
{"zesDeviceGetProperties", (void *)&resp->oh.zesDeviceGetProperties},
|
||||||
|
{"zesDeviceEnumMemoryModules",
|
||||||
|
(void *)&resp->oh.zesDeviceEnumMemoryModules},
|
||||||
|
{"zesMemoryGetProperties", (void *)&resp->oh.zesMemoryGetProperties},
|
||||||
|
{"zesMemoryGetState", (void *)&resp->oh.zesMemoryGetState},
|
||||||
|
{NULL, NULL},
|
||||||
|
};
|
||||||
|
|
||||||
|
resp->oh.handle = LOAD_LIBRARY(oneapi_lib_path, RTLD_LAZY);
|
||||||
|
if (!resp->oh.handle)
|
||||||
|
{
|
||||||
|
char *msg = LOAD_ERR();
|
||||||
|
snprintf(buf, buflen,
|
||||||
|
"Unable to load %s library to query for Intel GPUs: %s\n",
|
||||||
|
oneapi_lib_path, msg);
|
||||||
|
free(msg);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO once we've squashed the remaining corner cases remove this log
|
||||||
|
LOG(resp->oh.verbose,
|
||||||
|
"wiring Level-Zero management library functions in %s\n",
|
||||||
|
oneapi_lib_path);
|
||||||
|
|
||||||
|
for (i = 0; l[i].s != NULL; i++)
|
||||||
|
{
|
||||||
|
// TODO once we've squashed the remaining corner cases remove this log
|
||||||
|
LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s);
|
||||||
|
|
||||||
|
*l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s);
|
||||||
|
if (!l[i].p)
|
||||||
|
{
|
||||||
|
resp->oh.handle = NULL;
|
||||||
|
char *msg = LOAD_ERR();
|
||||||
|
LOG(resp->oh.verbose, "dlerr: %s\n", msg);
|
||||||
|
UNLOAD_LIBRARY(resp->oh.handle);
|
||||||
|
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, msg);
|
||||||
|
free(msg);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = (*resp->oh.zesInit)(0);
|
||||||
|
if (ret != ZE_RESULT_SUCCESS)
|
||||||
|
{
|
||||||
|
LOG(resp->oh.verbose, "zesInit err: %d\n", ret);
|
||||||
|
UNLOAD_LIBRARY(resp->oh.handle);
|
||||||
|
resp->oh.handle = NULL;
|
||||||
|
snprintf(buf, buflen, "oneapi vram init failure: %d", ret);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
(*resp->oh.zesDriverGet)(&resp->num_devices, NULL);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
|
||||||
|
{
|
||||||
|
ze_result_t ret;
|
||||||
|
resp->err = NULL;
|
||||||
|
uint64_t totalMem = 0;
|
||||||
|
uint64_t usedMem = 0;
|
||||||
|
const int buflen = 256;
|
||||||
|
char buf[buflen + 1];
|
||||||
|
int i, d, m;
|
||||||
|
|
||||||
|
if (h.handle == NULL)
|
||||||
|
{
|
||||||
|
resp->err = strdup("Level-Zero handle not initialized");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t driversCount = 0;
|
||||||
|
ret = (*h.zesDriverGet)(&driversCount, NULL);
|
||||||
|
if (ret != ZE_RESULT_SUCCESS)
|
||||||
|
{
|
||||||
|
snprintf(buf, buflen, "unable to get driver count: %d", ret);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
LOG(h.verbose, "discovered %d Level-Zero drivers\n", driversCount);
|
||||||
|
|
||||||
|
zes_driver_handle_t *allDrivers =
|
||||||
|
malloc(driversCount * sizeof(zes_driver_handle_t));
|
||||||
|
(*h.zesDriverGet)(&driversCount, allDrivers);
|
||||||
|
|
||||||
|
resp->total = 0;
|
||||||
|
resp->free = 0;
|
||||||
|
|
||||||
|
for (d = 0; d < driversCount; d++)
|
||||||
|
{
|
||||||
|
uint32_t deviceCount = 0;
|
||||||
|
ret = (*h.zesDeviceGet)(allDrivers[d], &deviceCount, NULL);
|
||||||
|
if (ret != ZE_RESULT_SUCCESS)
|
||||||
|
{
|
||||||
|
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
free(allDrivers);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG(h.verbose, "discovered %d Level-Zero devices\n", deviceCount);
|
||||||
|
|
||||||
|
zes_device_handle_t *devices =
|
||||||
|
malloc(deviceCount * sizeof(zes_device_handle_t));
|
||||||
|
(*h.zesDeviceGet)(allDrivers[d], &deviceCount, devices);
|
||||||
|
|
||||||
|
for (i = 0; i < deviceCount; i++)
|
||||||
|
{
|
||||||
|
zes_device_ext_properties_t ext_props;
|
||||||
|
ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
|
||||||
|
ext_props.pNext = NULL;
|
||||||
|
|
||||||
|
zes_device_properties_t props;
|
||||||
|
props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
|
||||||
|
props.pNext = &ext_props;
|
||||||
|
|
||||||
|
ret = (*h.zesDeviceGetProperties)(devices[i], &props);
|
||||||
|
if (ret != ZE_RESULT_SUCCESS)
|
||||||
|
{
|
||||||
|
snprintf(buf, buflen, "unable to get device properties: %d", ret);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
free(allDrivers);
|
||||||
|
free(devices);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (h.verbose)
|
||||||
|
{
|
||||||
|
// When in verbose mode, report more information about
|
||||||
|
// the card we discover.
|
||||||
|
LOG(h.verbose, "[%d] oneAPI device name: %s\n", i,
|
||||||
|
props.modelName);
|
||||||
|
LOG(h.verbose, "[%d] oneAPI brand: %s\n", i,
|
||||||
|
props.brandName);
|
||||||
|
LOG(h.verbose, "[%d] oneAPI vendor: %s\n", i,
|
||||||
|
props.vendorName);
|
||||||
|
LOG(h.verbose, "[%d] oneAPI S/N: %s\n", i,
|
||||||
|
props.serialNumber);
|
||||||
|
LOG(h.verbose, "[%d] oneAPI board number: %s\n", i,
|
||||||
|
props.boardNumber);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t memCount = 0;
|
||||||
|
ret = (*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, NULL);
|
||||||
|
if (ret != ZE_RESULT_SUCCESS)
|
||||||
|
{
|
||||||
|
snprintf(buf, buflen,
|
||||||
|
"unable to enumerate Level-Zero memory modules: %d", ret);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
free(allDrivers);
|
||||||
|
free(devices);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
|
||||||
|
|
||||||
|
zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
|
||||||
|
(*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, mems);
|
||||||
|
|
||||||
|
for (m = 0; m < memCount; m++)
|
||||||
|
{
|
||||||
|
zes_mem_state_t state;
|
||||||
|
state.stype = ZES_STRUCTURE_TYPE_MEM_STATE;
|
||||||
|
state.pNext = NULL;
|
||||||
|
ret = (*h.zesMemoryGetState)(mems[m], &state);
|
||||||
|
if (ret != ZE_RESULT_SUCCESS)
|
||||||
|
{
|
||||||
|
snprintf(buf, buflen, "unable to get memory state: %d", ret);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
free(allDrivers);
|
||||||
|
free(devices);
|
||||||
|
free(mems);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
resp->total += state.size;
|
||||||
|
resp->free += state.free;
|
||||||
|
}
|
||||||
|
|
||||||
|
free(mems);
|
||||||
|
}
|
||||||
|
|
||||||
|
free(devices);
|
||||||
|
}
|
||||||
|
|
||||||
|
free(allDrivers);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // __APPLE__
|
||||||
211
gpu/gpu_info_oneapi.h
Normal file
211
gpu/gpu_info_oneapi.h
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
#ifndef __APPLE__
|
||||||
|
#ifndef __GPU_INFO_ONEAPI_H__
|
||||||
|
#define __GPU_INFO_ONEAPI_H__
|
||||||
|
#include "gpu_info.h"
|
||||||
|
|
||||||
|
#define ZE_MAX_DEVICE_NAME 256
|
||||||
|
#define ZE_MAX_DEVICE_UUID_SIZE 16
|
||||||
|
#define ZES_STRING_PROPERTY_SIZE 64
|
||||||
|
#define ZE_BIT(_i) (1 << _i)
|
||||||
|
|
||||||
|
// Just enough typedef's to dlopen/dlsym for memory information
|
||||||
|
typedef enum ze_result_t
|
||||||
|
{
|
||||||
|
ZE_RESULT_SUCCESS = 0,
|
||||||
|
// Other values omitted for now...
|
||||||
|
} ze_result_t;
|
||||||
|
|
||||||
|
typedef uint8_t ze_bool_t;
|
||||||
|
typedef struct _zes_driver_handle_t *zes_driver_handle_t;
|
||||||
|
typedef struct _zes_device_handle_t *zes_device_handle_t;
|
||||||
|
typedef struct _zes_mem_handle_t *zes_mem_handle_t;
|
||||||
|
|
||||||
|
typedef enum _ze_structure_type_t
|
||||||
|
{
|
||||||
|
ZE_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
|
||||||
|
} ze_structure_type_t;
|
||||||
|
|
||||||
|
typedef enum _zes_structure_type_t
|
||||||
|
{
|
||||||
|
ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES = 0x1,
|
||||||
|
ZES_STRUCTURE_TYPE_MEM_PROPERTIES = 0xb,
|
||||||
|
ZES_STRUCTURE_TYPE_MEM_STATE = 0x1e,
|
||||||
|
ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES = 0x2d,
|
||||||
|
ZES_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
|
||||||
|
} zes_structure_type_t;
|
||||||
|
|
||||||
|
typedef enum _zes_mem_type_t
|
||||||
|
{
|
||||||
|
ZES_MEM_TYPE_FORCE_UINT32 = 0x7fffffff
|
||||||
|
} zes_mem_type_t;
|
||||||
|
|
||||||
|
typedef enum _zes_mem_loc_t
|
||||||
|
{
|
||||||
|
ZES_MEM_LOC_SYSTEM = 0,
|
||||||
|
ZES_MEM_LOC_DEVICE = 1,
|
||||||
|
ZES_MEM_LOC_FORCE_UINT32 = 0x7fffffff
|
||||||
|
} zes_mem_loc_t;
|
||||||
|
|
||||||
|
typedef enum _zes_mem_health_t
|
||||||
|
{
|
||||||
|
ZES_MEM_HEALTH_FORCE_UINT32 = 0x7fffffff
|
||||||
|
} zes_mem_health_t;
|
||||||
|
|
||||||
|
typedef struct _ze_device_uuid_t
|
||||||
|
{
|
||||||
|
uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
|
||||||
|
} ze_device_uuid_t;
|
||||||
|
|
||||||
|
typedef struct _zes_uuid_t
|
||||||
|
{
|
||||||
|
uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
|
||||||
|
} zes_uuid_t;
|
||||||
|
|
||||||
|
typedef enum _ze_device_type_t
|
||||||
|
{
|
||||||
|
ZE_DEVICE_TYPE_GPU = 1,
|
||||||
|
ZE_DEVICE_TYPE_CPU = 2,
|
||||||
|
ZE_DEVICE_TYPE_FPGA = 3,
|
||||||
|
ZE_DEVICE_TYPE_MCA = 4,
|
||||||
|
ZE_DEVICE_TYPE_VPU = 5,
|
||||||
|
ZE_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff
|
||||||
|
} ze_device_type_t;
|
||||||
|
|
||||||
|
typedef enum _zes_device_type_t
|
||||||
|
{
|
||||||
|
ZES_DEVICE_TYPE_GPU = 1,
|
||||||
|
ZES_DEVICE_TYPE_CPU = 2,
|
||||||
|
ZES_DEVICE_TYPE_FPGA = 3,
|
||||||
|
ZES_DEVICE_TYPE_MCA = 4,
|
||||||
|
ZES_DEVICE_TYPE_VPU = 5,
|
||||||
|
ZES_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff
|
||||||
|
} zes_device_type_t;
|
||||||
|
|
||||||
|
typedef uint32_t ze_device_property_flags_t;
|
||||||
|
typedef enum _ze_device_property_flag_t
|
||||||
|
{
|
||||||
|
ZE_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
|
||||||
|
ZE_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
|
||||||
|
ZE_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
|
||||||
|
ZE_DEVICE_PROPERTY_FLAG_ONDEMANDPAGING = ZE_BIT(3),
|
||||||
|
ZE_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff
|
||||||
|
} ze_device_property_flag_t;
|
||||||
|
|
||||||
|
typedef uint32_t zes_device_property_flags_t;
|
||||||
|
typedef enum _zes_device_property_flag_t
|
||||||
|
{
|
||||||
|
ZES_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
|
||||||
|
ZES_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
|
||||||
|
ZES_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
|
||||||
|
ZES_DEVICE_PROPERTY_FLAG_ONDEMANDPAGING = ZE_BIT(3),
|
||||||
|
ZES_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff
|
||||||
|
} zes_device_property_flag_t;
|
||||||
|
|
||||||
|
typedef struct _ze_device_properties_t
|
||||||
|
{
|
||||||
|
ze_structure_type_t stype;
|
||||||
|
void *pNext;
|
||||||
|
ze_device_type_t type;
|
||||||
|
uint32_t vendorId;
|
||||||
|
uint32_t deviceId;
|
||||||
|
ze_device_property_flags_t flags;
|
||||||
|
uint32_t subdeviceId;
|
||||||
|
uint32_t coreClockRate;
|
||||||
|
uint64_t maxMemAllocSize;
|
||||||
|
uint32_t maxHardwareContexts;
|
||||||
|
uint32_t maxCommandQueuePriority;
|
||||||
|
uint32_t numThreadsPerEU;
|
||||||
|
uint32_t physicalEUSimdWidth;
|
||||||
|
uint32_t numEUsPerSubslice;
|
||||||
|
uint32_t numSubslicesPerSlice;
|
||||||
|
uint32_t numSlices;
|
||||||
|
uint64_t timerResolution;
|
||||||
|
uint32_t timestampValidBits;
|
||||||
|
uint32_t kernelTimestampValidBits;
|
||||||
|
ze_device_uuid_t uuid;
|
||||||
|
char name[ZE_MAX_DEVICE_NAME];
|
||||||
|
} ze_device_properties_t;
|
||||||
|
|
||||||
|
typedef struct _zes_device_properties_t
|
||||||
|
{
|
||||||
|
zes_structure_type_t stype;
|
||||||
|
void *pNext;
|
||||||
|
ze_device_properties_t core;
|
||||||
|
uint32_t numSubdevices;
|
||||||
|
char serialNumber[ZES_STRING_PROPERTY_SIZE];
|
||||||
|
char boardNumber[ZES_STRING_PROPERTY_SIZE];
|
||||||
|
char brandName[ZES_STRING_PROPERTY_SIZE];
|
||||||
|
char modelName[ZES_STRING_PROPERTY_SIZE];
|
||||||
|
char vendorName[ZES_STRING_PROPERTY_SIZE];
|
||||||
|
char driverVersion[ZES_STRING_PROPERTY_SIZE];
|
||||||
|
} zes_device_properties_t;
|
||||||
|
|
||||||
|
typedef struct _zes_device_ext_properties_t
|
||||||
|
{
|
||||||
|
zes_structure_type_t stype;
|
||||||
|
void *pNext;
|
||||||
|
zes_uuid_t uuid;
|
||||||
|
zes_device_type_t type;
|
||||||
|
zes_device_property_flags_t flags;
|
||||||
|
} zes_device_ext_properties_t;
|
||||||
|
|
||||||
|
typedef struct _zes_mem_properties_t
|
||||||
|
{
|
||||||
|
zes_structure_type_t stype;
|
||||||
|
void *pNext;
|
||||||
|
zes_mem_type_t type;
|
||||||
|
ze_bool_t onSubdevice;
|
||||||
|
uint32_t subdeviceId;
|
||||||
|
zes_mem_loc_t location;
|
||||||
|
uint64_t physicalSize;
|
||||||
|
int32_t busWidth;
|
||||||
|
int32_t numChannels;
|
||||||
|
} zes_mem_properties_t;
|
||||||
|
|
||||||
|
typedef struct _zes_mem_state_t
|
||||||
|
{
|
||||||
|
zes_structure_type_t stype;
|
||||||
|
const void *pNext;
|
||||||
|
zes_mem_health_t health;
|
||||||
|
uint64_t free;
|
||||||
|
uint64_t size;
|
||||||
|
} zes_mem_state_t;
|
||||||
|
|
||||||
|
typedef struct oneapi_handle
|
||||||
|
{
|
||||||
|
void *handle;
|
||||||
|
uint16_t verbose;
|
||||||
|
ze_result_t (*zesInit)(int);
|
||||||
|
ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers);
|
||||||
|
ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount,
|
||||||
|
zes_device_handle_t *phDevices);
|
||||||
|
ze_result_t (*zesDeviceGetProperties)(zes_device_handle_t hDevice,
|
||||||
|
zes_device_properties_t *pProperties);
|
||||||
|
ze_result_t (*zesDeviceEnumMemoryModules)(zes_device_handle_t hDevice,
|
||||||
|
uint32_t *pCount,
|
||||||
|
zes_mem_handle_t *phMemory);
|
||||||
|
ze_result_t (*zesMemoryGetProperties)(zes_mem_handle_t hMemory,
|
||||||
|
zes_mem_properties_t *pProperties);
|
||||||
|
ze_result_t (*zesMemoryGetState)(zes_mem_handle_t hMemory,
|
||||||
|
zes_mem_state_t *pState);
|
||||||
|
|
||||||
|
} oneapi_handle_t;
|
||||||
|
|
||||||
|
typedef struct oneapi_init_resp
|
||||||
|
{
|
||||||
|
char *err; // If err is non-null handle is invalid
|
||||||
|
int num_devices;
|
||||||
|
oneapi_handle_t oh;
|
||||||
|
} oneapi_init_resp_t;
|
||||||
|
|
||||||
|
typedef struct oneapi_version_resp
|
||||||
|
{
|
||||||
|
ze_result_t status;
|
||||||
|
char *str; // Contains version or error string if status != 0
|
||||||
|
} oneapi_version_resp_t;
|
||||||
|
|
||||||
|
void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp);
|
||||||
|
void oneapi_check_vram(oneapi_handle_t rh, mem_info_t *resp);
|
||||||
|
|
||||||
|
#endif // __GPU_INFO_INTEL_H__
|
||||||
|
#endif // __APPLE__
|
||||||
21
gpu/gpu_oneapi.go
Normal file
21
gpu/gpu_oneapi.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
//go:build linux || windows
|
||||||
|
|
||||||
|
package gpu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func oneapiGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||||
|
ids := []string{}
|
||||||
|
for _, info := range gpuInfo {
|
||||||
|
if info.Library != "oneapi" {
|
||||||
|
// TODO shouldn't happen if things are wired correctly...
|
||||||
|
slog.Debug("oneapiGetVisibleDevicesEnv skipping over non-sycl device", "library", info.Library)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ids = append(ids, info.ID)
|
||||||
|
}
|
||||||
|
return "ONEAPI_DEVICE_SELECTOR", "level_zero:" + strings.Join(ids, ",")
|
||||||
|
}
|
||||||
@@ -5,11 +5,12 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBasicGetGPUInfo(t *testing.T) {
|
func TestBasicGetGPUInfo(t *testing.T) {
|
||||||
info := GetGPUInfo()
|
info := GetGPUInfo()
|
||||||
assert.Greater(t, len(info), 0)
|
assert.NotEmpty(t, len(info))
|
||||||
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
|
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
|
||||||
if info[0].Library != "cpu" {
|
if info[0].Library != "cpu" {
|
||||||
assert.Greater(t, info[0].TotalMemory, uint64(0))
|
assert.Greater(t, info[0].TotalMemory, uint64(0))
|
||||||
@@ -19,7 +20,7 @@ func TestBasicGetGPUInfo(t *testing.T) {
|
|||||||
|
|
||||||
func TestCPUMemInfo(t *testing.T) {
|
func TestCPUMemInfo(t *testing.T) {
|
||||||
info, err := GetCPUMem()
|
info, err := GetCPUMem()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "darwin":
|
case "darwin":
|
||||||
t.Skip("CPU memory not populated on darwin")
|
t.Skip("CPU memory not populated on darwin")
|
||||||
|
|||||||
116
llm/ext_server/server.cpp
vendored
116
llm/ext_server/server.cpp
vendored
@@ -140,7 +140,6 @@ struct server_slot {
|
|||||||
std::vector<llama_token> cache_tokens;
|
std::vector<llama_token> cache_tokens;
|
||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
|
||||||
bool infill = false;
|
|
||||||
bool embedding = false;
|
bool embedding = false;
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
bool truncated = false;
|
bool truncated = false;
|
||||||
@@ -187,7 +186,6 @@ struct server_slot {
|
|||||||
n_past = 0;
|
n_past = 0;
|
||||||
n_sent_text = 0;
|
n_sent_text = 0;
|
||||||
n_sent_token_probs = 0;
|
n_sent_token_probs = 0;
|
||||||
infill = false;
|
|
||||||
ga_i = 0;
|
ga_i = 0;
|
||||||
n_past_se = 0;
|
n_past_se = 0;
|
||||||
|
|
||||||
@@ -361,7 +359,6 @@ struct llama_server_context
|
|||||||
|
|
||||||
// slots / clients
|
// slots / clients
|
||||||
std::vector<server_slot> slots;
|
std::vector<server_slot> slots;
|
||||||
json default_generation_settings_for_props;
|
|
||||||
|
|
||||||
llama_server_queue queue_tasks;
|
llama_server_queue queue_tasks;
|
||||||
llama_server_response queue_results;
|
llama_server_response queue_results;
|
||||||
@@ -485,9 +482,6 @@ struct llama_server_context
|
|||||||
slots.push_back(slot);
|
slots.push_back(slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
default_generation_settings_for_props = get_formated_generation(slots.front());
|
|
||||||
default_generation_settings_for_props["seed"] = -1;
|
|
||||||
|
|
||||||
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
|
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -586,7 +580,7 @@ struct llama_server_context
|
|||||||
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||||
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||||
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
||||||
slot->params.seed = json_value(data, "seed", default_params.seed);
|
slot->sparams.seed = json_value(data, "seed", default_params.seed);
|
||||||
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||||
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||||
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||||
@@ -600,16 +594,6 @@ struct llama_server_context
|
|||||||
slot->params.n_predict = slot->n_predict;
|
slot->params.n_predict = slot->n_predict;
|
||||||
}
|
}
|
||||||
|
|
||||||
// infill
|
|
||||||
if (data.count("input_prefix") != 0)
|
|
||||||
{
|
|
||||||
slot->params.input_prefix = data["input_prefix"];
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
slot->params.input_prefix = "";
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data.count("input_suffix") != 0)
|
if (data.count("input_suffix") != 0)
|
||||||
{
|
{
|
||||||
slot->params.input_suffix = data["input_suffix"];
|
slot->params.input_suffix = data["input_suffix"];
|
||||||
@@ -823,7 +807,6 @@ struct llama_server_context
|
|||||||
llama_sampling_free(slot->ctx_sampling);
|
llama_sampling_free(slot->ctx_sampling);
|
||||||
}
|
}
|
||||||
slot->ctx_sampling = llama_sampling_init(slot->sparams);
|
slot->ctx_sampling = llama_sampling_init(slot->sparams);
|
||||||
llama_set_rng_seed(ctx, slot->params.seed);
|
|
||||||
slot->command = LOAD_PROMPT;
|
slot->command = LOAD_PROMPT;
|
||||||
|
|
||||||
all_slots_are_idle = false;
|
all_slots_are_idle = false;
|
||||||
@@ -847,7 +830,7 @@ struct llama_server_context
|
|||||||
system_tokens.clear();
|
system_tokens.clear();
|
||||||
|
|
||||||
if (!system_prompt.empty()) {
|
if (!system_prompt.empty()) {
|
||||||
system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
|
system_tokens = ::llama_tokenize(ctx, system_prompt, true);
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
@@ -897,15 +880,6 @@ struct llama_server_context
|
|||||||
system_need_update = true;
|
system_need_update = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void system_prompt_process(const json &sys_props) {
|
|
||||||
system_prompt = sys_props.value("prompt", "");
|
|
||||||
name_user = sys_props.value("anti_prompt", "");
|
|
||||||
name_assistant = sys_props.value("assistant_name", "");
|
|
||||||
|
|
||||||
|
|
||||||
system_prompt_notify();
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t find_stopping_strings(const std::string &text, const size_t last_token_size,
|
static size_t find_stopping_strings(const std::string &text, const size_t last_token_size,
|
||||||
const stop_type type, server_slot &slot)
|
const stop_type type, server_slot &slot)
|
||||||
{
|
{
|
||||||
@@ -1263,13 +1237,12 @@ struct llama_server_context
|
|||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
|
void request_completion(int task_id, json data, bool embedding, int multitask_id)
|
||||||
{
|
{
|
||||||
task_server task;
|
task_server task;
|
||||||
task.id = task_id;
|
task.id = task_id;
|
||||||
task.target_id = 0;
|
task.target_id = 0;
|
||||||
task.data = std::move(data);
|
task.data = std::move(data);
|
||||||
task.infill_mode = infill;
|
|
||||||
task.embedding_mode = embedding;
|
task.embedding_mode = embedding;
|
||||||
task.type = TASK_TYPE_COMPLETION;
|
task.type = TASK_TYPE_COMPLETION;
|
||||||
task.multitask_id = multitask_id;
|
task.multitask_id = multitask_id;
|
||||||
@@ -1415,8 +1388,8 @@ struct llama_server_context
|
|||||||
json subtask_data = multiprompt_task.data;
|
json subtask_data = multiprompt_task.data;
|
||||||
subtask_data["prompt"] = subtask_data["prompt"][i];
|
subtask_data["prompt"] = subtask_data["prompt"][i];
|
||||||
|
|
||||||
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
// subtasks inherit everything else (embedding mode, etc.)
|
||||||
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
|
request_completion(subtask_ids[i], subtask_data, multiprompt_task.embedding_mode, multitask_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1434,26 +1407,8 @@ struct llama_server_context
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (task.data.contains("system_prompt"))
|
|
||||||
{
|
|
||||||
if (!all_slots_are_idle) {
|
|
||||||
send_error(task, "system prompt can only be updated when all slots are idle");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
system_prompt_process(task.data["system_prompt"]);
|
|
||||||
|
|
||||||
// reset cache_tokens for all slots
|
|
||||||
for (server_slot &slot : slots)
|
|
||||||
{
|
|
||||||
slot.cache_tokens.clear();
|
|
||||||
slot.n_past = 0;
|
|
||||||
slot.n_past_se = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
slot->reset();
|
slot->reset();
|
||||||
|
|
||||||
slot->infill = task.infill_mode;
|
|
||||||
slot->embedding = task.embedding_mode;
|
slot->embedding = task.embedding_mode;
|
||||||
slot->task_id = task.id;
|
slot->task_id = task.id;
|
||||||
slot->multitask_id = task.multitask_id;
|
slot->multitask_id = task.multitask_id;
|
||||||
@@ -1679,8 +1634,7 @@ struct llama_server_context
|
|||||||
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty();
|
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty();
|
||||||
|
|
||||||
// empty prompt passed -> release the slot and send empty response
|
// empty prompt passed -> release the slot and send empty response
|
||||||
// note: infill mode allows empty prompt
|
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt)
|
||||||
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt && !slot.infill)
|
|
||||||
{
|
{
|
||||||
slot.release();
|
slot.release();
|
||||||
slot.print_timings();
|
slot.print_timings();
|
||||||
@@ -1697,33 +1651,7 @@ struct llama_server_context
|
|||||||
slot.t_start_process_prompt = ggml_time_us();
|
slot.t_start_process_prompt = ggml_time_us();
|
||||||
slot.t_start_genereration = 0;
|
slot.t_start_genereration = 0;
|
||||||
|
|
||||||
if (slot.infill)
|
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
||||||
{
|
|
||||||
bool suff_rm_leading_spc = true;
|
|
||||||
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1)
|
|
||||||
{
|
|
||||||
params.input_suffix.erase(0, 1);
|
|
||||||
suff_rm_leading_spc = false;
|
|
||||||
}
|
|
||||||
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
|
|
||||||
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
|
|
||||||
|
|
||||||
const int space_token = 29871; // TODO: this should not be hardcoded
|
|
||||||
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
|
|
||||||
suffix_tokens.erase(suffix_tokens.begin());
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
|
||||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
|
|
||||||
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
|
|
||||||
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
|
|
||||||
prefix_tokens.push_back(llama_token_middle(model));
|
|
||||||
prompt_tokens = prefix_tokens;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
slot.n_prompt_tokens = prompt_tokens.size();
|
slot.n_prompt_tokens = prompt_tokens.size();
|
||||||
|
|
||||||
@@ -2130,8 +2058,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
static void server_params_parse(int argc, char **argv, server_params &sparams,
|
static void server_params_parse(int argc, char **argv, server_params &sparams, gpt_params ¶ms)
|
||||||
gpt_params ¶ms, llama_server_context& llama)
|
|
||||||
{
|
{
|
||||||
gpt_params default_params;
|
gpt_params default_params;
|
||||||
server_params default_sparams;
|
server_params default_sparams;
|
||||||
@@ -2546,27 +2473,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||||||
}
|
}
|
||||||
params.n_predict = std::stoi(argv[i]);
|
params.n_predict = std::stoi(argv[i]);
|
||||||
}
|
}
|
||||||
else if (arg == "-spf" || arg == "--system-prompt-file")
|
|
||||||
{
|
|
||||||
if (++i >= argc)
|
|
||||||
{
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
std::ifstream file(argv[i]);
|
|
||||||
if (!file) {
|
|
||||||
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
std::string systm_content;
|
|
||||||
std::copy(
|
|
||||||
std::istreambuf_iterator<char>(file),
|
|
||||||
std::istreambuf_iterator<char>(),
|
|
||||||
std::back_inserter(systm_content)
|
|
||||||
);
|
|
||||||
llama.system_prompt_process(json::parse(systm_content));
|
|
||||||
}
|
|
||||||
else if (arg == "-ctk" || arg == "--cache-type-k") {
|
else if (arg == "-ctk" || arg == "--cache-type-k") {
|
||||||
params.cache_type_k = argv[++i];
|
params.cache_type_k = argv[++i];
|
||||||
}
|
}
|
||||||
@@ -2818,7 +2724,7 @@ int main(int argc, char **argv) {
|
|||||||
// struct that contains llama context and inference
|
// struct that contains llama context and inference
|
||||||
llama_server_context llama;
|
llama_server_context llama;
|
||||||
|
|
||||||
server_params_parse(argc, argv, sparams, params, llama);
|
server_params_parse(argc, argv, sparams, params);
|
||||||
|
|
||||||
if (params.model_alias == "unknown")
|
if (params.model_alias == "unknown")
|
||||||
{
|
{
|
||||||
@@ -3150,7 +3056,7 @@ int main(int argc, char **argv) {
|
|||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
const int task_id = llama.queue_tasks.get_new_id();
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
llama.queue_results.add_waiting_task_id(task_id);
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
llama.request_completion(task_id, data, false, false, -1);
|
llama.request_completion(task_id, data, false, -1);
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
std::string completion_text;
|
std::string completion_text;
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
@@ -3272,7 +3178,7 @@ int main(int argc, char **argv) {
|
|||||||
// create and queue the task
|
// create and queue the task
|
||||||
const int task_id = llama.queue_tasks.get_new_id();
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
llama.queue_results.add_waiting_task_id(task_id);
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
|
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);
|
||||||
|
|
||||||
// get the result
|
// get the result
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
|
|||||||
@@ -32,42 +32,43 @@ case "${GOARCH}" in
|
|||||||
echo "Building static library"
|
echo "Building static library"
|
||||||
build
|
build
|
||||||
|
|
||||||
|
if [ -z "$OLLAMA_SKIP_CPU_GENERATE" ]; then
|
||||||
|
#
|
||||||
|
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
||||||
|
#
|
||||||
|
init_vars
|
||||||
|
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||||
|
BUILD_DIR="../build/darwin/${ARCH}/cpu"
|
||||||
|
echo "Building LCD CPU"
|
||||||
|
build
|
||||||
|
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||||
|
compress
|
||||||
|
|
||||||
#
|
#
|
||||||
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
|
||||||
#
|
# Approximately 400% faster than LCD on same CPU
|
||||||
init_vars
|
#
|
||||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
init_vars
|
||||||
BUILD_DIR="../build/darwin/${ARCH}/cpu"
|
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||||
echo "Building LCD CPU"
|
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
|
||||||
build
|
echo "Building AVX CPU"
|
||||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
build
|
||||||
compress
|
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||||
|
compress
|
||||||
|
|
||||||
#
|
#
|
||||||
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
|
# ~2013 CPU Dynamic library
|
||||||
# Approximately 400% faster than LCD on same CPU
|
# Approximately 10% faster than AVX on same CPU
|
||||||
#
|
#
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
|
||||||
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
|
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
|
||||||
echo "Building AVX CPU"
|
echo "Building AVX2 CPU"
|
||||||
build
|
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
|
||||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
build
|
||||||
compress
|
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||||
|
compress
|
||||||
#
|
fi
|
||||||
# ~2013 CPU Dynamic library
|
|
||||||
# Approximately 10% faster than AVX on same CPU
|
|
||||||
#
|
|
||||||
init_vars
|
|
||||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
|
|
||||||
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
|
|
||||||
echo "Building AVX2 CPU"
|
|
||||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
|
|
||||||
build
|
|
||||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
|
||||||
compress
|
|
||||||
;;
|
;;
|
||||||
"arm64")
|
"arm64")
|
||||||
|
|
||||||
@@ -79,13 +80,15 @@ case "${GOARCH}" in
|
|||||||
echo "Building static library"
|
echo "Building static library"
|
||||||
build
|
build
|
||||||
|
|
||||||
init_vars
|
if [ -z "$OLLAMA_SKIP_METAL_GENERATE" ]; then
|
||||||
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
|
init_vars
|
||||||
BUILD_DIR="../build/darwin/${ARCH}/metal"
|
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
|
||||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
|
BUILD_DIR="../build/darwin/${ARCH}/metal"
|
||||||
build
|
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
|
||||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
build
|
||||||
compress
|
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||||
|
compress
|
||||||
|
fi
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
echo "GOARCH must be set"
|
echo "GOARCH must be set"
|
||||||
|
|||||||
@@ -215,6 +215,36 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then
|
|||||||
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ -z "${ONEAPI_ROOT}" ]; then
|
||||||
|
# Try the default location in case it exists
|
||||||
|
ONEAPI_ROOT=/opt/intel/oneapi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "${OLLAMA_SKIP_ONEAPI_GENERATE}" -a -d "${ONEAPI_ROOT}" ]; then
|
||||||
|
echo "OneAPI libraries detected - building dynamic OneAPI library"
|
||||||
|
init_vars
|
||||||
|
source ${ONEAPI_ROOT}/setvars.sh --force # set up environment variables for oneAPI
|
||||||
|
CC=icx
|
||||||
|
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_SYCL=ON -DLLAMA_SYCL_F16=OFF"
|
||||||
|
BUILD_DIR="../build/linux/${ARCH}/oneapi"
|
||||||
|
EXTRA_LIBS="-fsycl -Wl,-rpath,${ONEAPI_ROOT}/compiler/latest/lib,-rpath,${ONEAPI_ROOT}/mkl/latest/lib,-rpath,${ONEAPI_ROOT}/tbb/latest/lib,-rpath,${ONEAPI_ROOT}/compiler/latest/opt/oclfpga/linux64/lib -lOpenCL -lmkl_core -lmkl_sycl_blas -lmkl_intel_ilp64 -lmkl_tbb_thread -ltbb"
|
||||||
|
DEBUG_FLAGS="" # icx compiles with -O0 if we pass -g, so we must remove it
|
||||||
|
build
|
||||||
|
|
||||||
|
# copy oneAPI dependencies
|
||||||
|
for dep in $(ldd "${BUILD_DIR}/bin/ollama_llama_server" | grep "=>" | cut -f2 -d= | cut -f2 -d' ' | grep -e sycl -e mkl -e tbb); do
|
||||||
|
cp "${dep}" "${BUILD_DIR}/bin/"
|
||||||
|
done
|
||||||
|
cp "${ONEAPI_ROOT}/compiler/latest/lib/libOpenCL.so" "${BUILD_DIR}/bin/"
|
||||||
|
cp "${ONEAPI_ROOT}/compiler/latest/lib/libimf.so" "${BUILD_DIR}/bin/"
|
||||||
|
cp "${ONEAPI_ROOT}/compiler/latest/lib/libintlc.so.5" "${BUILD_DIR}/bin/"
|
||||||
|
cp "${ONEAPI_ROOT}/compiler/latest/lib/libirng.so" "${BUILD_DIR}/bin/"
|
||||||
|
cp "${ONEAPI_ROOT}/compiler/latest/lib/libpi_level_zero.so" "${BUILD_DIR}/bin/"
|
||||||
|
cp "${ONEAPI_ROOT}/compiler/latest/lib/libsvml.so" "${BUILD_DIR}/bin/"
|
||||||
|
cp "${ONEAPI_ROOT}/compiler/latest/lib/libur_loader.so.0" "${BUILD_DIR}/bin/"
|
||||||
|
compress
|
||||||
|
fi
|
||||||
|
|
||||||
if [ -z "${ROCM_PATH}" ]; then
|
if [ -z "${ROCM_PATH}" ]; then
|
||||||
# Try the default location in case it exists
|
# Try the default location in case it exists
|
||||||
ROCM_PATH=/opt/rocm
|
ROCM_PATH=/opt/rocm
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ function amdGPUs {
|
|||||||
"gfx900"
|
"gfx900"
|
||||||
"gfx902"
|
"gfx902"
|
||||||
"gfx904"
|
"gfx904"
|
||||||
|
"gfx90c"
|
||||||
"gfx906:xnack-"
|
"gfx906:xnack-"
|
||||||
"gfx908:xnack-"
|
"gfx908:xnack-"
|
||||||
"gfx90a:xnack+"
|
"gfx90a:xnack+"
|
||||||
@@ -25,6 +26,7 @@ function amdGPUs {
|
|||||||
"gfx1030"
|
"gfx1030"
|
||||||
"gfx1031"
|
"gfx1031"
|
||||||
"gfx1032"
|
"gfx1032"
|
||||||
|
"gfx1033"
|
||||||
"gfx1034"
|
"gfx1034"
|
||||||
"gfx1035"
|
"gfx1035"
|
||||||
"gfx1036"
|
"gfx1036"
|
||||||
@@ -299,6 +301,49 @@ function build_cuda() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function build_oneapi() {
|
||||||
|
if ((-not "${env:OLLAMA_SKIP_ONEAPI_GENERATE}") -and ("${env:ONEAPI_ROOT}")) {
|
||||||
|
# Get oneAPI version
|
||||||
|
$script:ONEAPI_VERSION = icpx --version
|
||||||
|
$script:ONEAPI_VERSION = [regex]::Match($script:ONEAPI_VERSION, '(?<=oneAPI DPC\+\+/C\+\+ Compiler )(?<version>\d+\.\d+\.\d+)').Value
|
||||||
|
if ($null -ne $script:ONEAPI_VERSION) {
|
||||||
|
$script:ONEAPI_VARIANT = "_v" + $script:ONEAPI_VERSION
|
||||||
|
}
|
||||||
|
init_vars
|
||||||
|
$script:buildDir = "../build/windows/${script:ARCH}/oneapi$script:ONEAPI_VARIANT"
|
||||||
|
$script:distDir ="$script:DIST_BASE\oneapi$script:ONEAPI_VARIANT"
|
||||||
|
$script:cmakeDefs += @(
|
||||||
|
"-G", "MinGW Makefiles",
|
||||||
|
"-DLLAMA_SYCL=ON",
|
||||||
|
"-DCMAKE_C_COMPILER=icx",
|
||||||
|
"-DCMAKE_CXX_COMPILER=icx",
|
||||||
|
"-DCMAKE_BUILD_TYPE=Release"
|
||||||
|
)
|
||||||
|
|
||||||
|
Write-Host "Building oneAPI"
|
||||||
|
build
|
||||||
|
# Ninja doesn't prefix with config name
|
||||||
|
if ($null -ne $script:DUMPBIN) {
|
||||||
|
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | Select-String ".dll"
|
||||||
|
}
|
||||||
|
sign
|
||||||
|
install
|
||||||
|
|
||||||
|
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libirngmd.dll" "${script:distDir}"
|
||||||
|
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libmmd.dll" "${script:distDir}"
|
||||||
|
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_level_zero.dll" "${script:distDir}"
|
||||||
|
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_unified_runtime.dll" "${script:distDir}"
|
||||||
|
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_win_proxy_loader.dll" "${script:distDir}"
|
||||||
|
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\svml_dispmd.dll" "${script:distDir}"
|
||||||
|
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\sycl7.dll" "${script:distDir}"
|
||||||
|
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_core.2.dll" "${script:distDir}"
|
||||||
|
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_sycl_blas.4.dll" "${script:distDir}"
|
||||||
|
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_tbb_thread.2.dll" "${script:distDir}"
|
||||||
|
} else {
|
||||||
|
Write-Host "Skipping oneAPI generation step"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function build_rocm() {
|
function build_rocm() {
|
||||||
if ((-not "${env:OLLAMA_SKIP_ROCM_GENERATE}") -and ("${env:HIP_PATH}")) {
|
if ((-not "${env:OLLAMA_SKIP_ROCM_GENERATE}") -and ("${env:HIP_PATH}")) {
|
||||||
$script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
|
$script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
|
||||||
@@ -366,6 +411,7 @@ if ($($args.count) -eq 0) {
|
|||||||
build_cpu_avx
|
build_cpu_avx
|
||||||
build_cpu_avx2
|
build_cpu_avx2
|
||||||
build_cuda
|
build_cuda
|
||||||
|
build_oneapi
|
||||||
build_rocm
|
build_rocm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -81,6 +81,11 @@ func (kv KV) ContextLength() uint64 {
|
|||||||
return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
|
return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (kv KV) ChatTemplate() string {
|
||||||
|
s, _ := kv["tokenizer.chat_template"].(string)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
type Tensors []*Tensor
|
type Tensors []*Tensor
|
||||||
|
|
||||||
func (ts Tensors) Layers() map[string]Layer {
|
func (ts Tensors) Layers() map[string]Layer {
|
||||||
@@ -125,9 +130,9 @@ type Tensor struct {
|
|||||||
|
|
||||||
func (t Tensor) blockSize() uint64 {
|
func (t Tensor) blockSize() uint64 {
|
||||||
switch t.Kind {
|
switch t.Kind {
|
||||||
case 0, 1, 24, 25, 26, 27, 28, 31: // F32, F16, I8, I16, I32, I64, F64, BF16
|
case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
|
||||||
return 1
|
return 1
|
||||||
case 2, 3, 8, 9, 20: // Q4_0, Q4_1, Q8_0, Q8_1, IQ4_NL
|
case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL
|
||||||
return 32
|
return 32
|
||||||
default: // All others
|
default: // All others
|
||||||
return 256
|
return 256
|
||||||
|
|||||||
26
llm/gguf.go
26
llm/gguf.go
@@ -592,8 +592,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dims := 0
|
var dims int
|
||||||
for cnt := 0; cnt < len(tensor.Shape); cnt++ {
|
for cnt := range len(tensor.Shape) {
|
||||||
if tensor.Shape[cnt] > 0 {
|
if tensor.Shape[cnt] > 0 {
|
||||||
dims++
|
dims++
|
||||||
}
|
}
|
||||||
@@ -603,8 +603,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < dims; i++ {
|
for i := range dims {
|
||||||
if err := binary.Write(ws, llm.ByteOrder, uint64(tensor.Shape[dims-1-i])); err != nil {
|
if err := binary.Write(ws, llm.ByteOrder, tensor.Shape[dims-1-i]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -618,22 +618,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
offset, err := ws.Seek(0, io.SeekCurrent)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var alignment int64 = 32
|
var alignment int64 = 32
|
||||||
padding := llm.padding(offset, alignment)
|
|
||||||
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tensor := range tensors {
|
for _, tensor := range tensors {
|
||||||
if _, err := tensor.WriteTo(ws); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
offset, err := ws.Seek(0, io.SeekCurrent)
|
offset, err := ws.Seek(0, io.SeekCurrent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -643,6 +629,10 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
|||||||
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
|
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _, err := tensor.WriteTo(ws); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
Submodule llm/llama.cpp updated: 74f33adf5f...5921b8f089
@@ -5,9 +5,9 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
"github.com/ollama/ollama/server/envconfig"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||||
@@ -103,7 +103,7 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
|
|||||||
}
|
}
|
||||||
|
|
||||||
var layerCount int
|
var layerCount int
|
||||||
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
|
for i := range int(ggml.KV().BlockCount()) {
|
||||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||||
memoryLayer := blk.size()
|
memoryLayer := blk.size()
|
||||||
|
|
||||||
|
|||||||
@@ -1,35 +1,32 @@
|
|||||||
From d02a06f3f45a09255ace8684a66590e06ce44605 Mon Sep 17 00:00:00 2001
|
|
||||||
From: Michael Yang <mxyng@pm.me>
|
|
||||||
Date: Thu, 23 May 2024 11:33:20 -0700
|
|
||||||
Subject: [PATCH] default pretokenizer on unrecognized type
|
|
||||||
|
|
||||||
---
|
|
||||||
llama.cpp | 5 +----
|
|
||||||
1 file changed, 1 insertion(+), 4 deletions(-)
|
|
||||||
|
|
||||||
diff --git a/llama.cpp b/llama.cpp
|
diff --git a/llama.cpp b/llama.cpp
|
||||||
index 15c66077..af1aede3 100644
|
index 40d2ec2c..74f3ee9c 100644
|
||||||
--- a/llama.cpp
|
--- a/llama.cpp
|
||||||
+++ b/llama.cpp
|
+++ b/llama.cpp
|
||||||
@@ -4504,9 +4504,6 @@ static void llm_load_vocab(
|
@@ -4642,16 +4642,7 @@ static void llm_load_vocab(
|
||||||
LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
|
|
||||||
LLAMA_LOG_WARN("%s: \n", __func__);
|
// for now, only BPE models have pre-tokenizers
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
|
||||||
- } else if (
|
- if (tokenizer_pre.empty()) {
|
||||||
- tokenizer_pre == "default") {
|
- LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
|
||||||
|
- LLAMA_LOG_WARN("%s: \n", __func__);
|
||||||
|
- LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
|
||||||
|
- LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__);
|
||||||
|
- LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__);
|
||||||
|
- LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
|
||||||
|
- LLAMA_LOG_WARN("%s: \n", __func__);
|
||||||
- vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
- vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
|
- } else if (
|
||||||
|
+ if (
|
||||||
|
tokenizer_pre == "default") {
|
||||||
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "llama3" ||
|
@@ -4703,7 +4694,8 @@ static void llm_load_vocab(
|
||||||
tokenizer_pre == "llama-v3" ||
|
tokenizer_pre == "smaug-bpe") {
|
||||||
@@ -4553,7 +4550,7 @@ static void llm_load_vocab(
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG;
|
||||||
tokenizer_pre == "dbrx") {
|
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX;
|
|
||||||
} else {
|
} else {
|
||||||
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||||
|
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
||||||
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
--
|
|
||||||
2.45.1
|
|
||||||
|
|
||||||
|
|||||||
13
llm/patches/06-qwen2.diff
Normal file
13
llm/patches/06-qwen2.diff
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
diff --git a/llama.cpp b/llama.cpp
|
||||||
|
index 40d2ec2c..f34eb79a 100644
|
||||||
|
--- a/llama.cpp
|
||||||
|
+++ b/llama.cpp
|
||||||
|
@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||||
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
|
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
|
||||||
|
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
|
||||||
|
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
||||||
|
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
||||||
|
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||||
@@ -10,9 +10,9 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ import (
|
|||||||
"golang.org/x/sync/semaphore"
|
"golang.org/x/sync/semaphore"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
"github.com/ollama/ollama/server/envconfig"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type LlamaServer interface {
|
type LlamaServer interface {
|
||||||
@@ -85,7 +85,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
var systemMemory uint64
|
var systemMemory uint64
|
||||||
gpuCount := len(gpus)
|
gpuCount := len(gpus)
|
||||||
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
|
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
|
||||||
|
|
||||||
// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
|
// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
|
||||||
|
|
||||||
cpuRunner = serverForCpu()
|
cpuRunner = serverForCpu()
|
||||||
@@ -104,21 +103,22 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
var layers int
|
var layers int
|
||||||
layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
|
layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||||
|
|
||||||
if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
|
switch {
|
||||||
|
case gpus[0].Library == "metal" && estimatedVRAM > systemMemory:
|
||||||
// disable partial offloading when model is greater than total system memory as this
|
// disable partial offloading when model is greater than total system memory as this
|
||||||
// can lead to locking up the system
|
// can lead to locking up the system
|
||||||
opts.NumGPU = 0
|
opts.NumGPU = 0
|
||||||
} else if gpus[0].Library != "metal" && layers == 0 {
|
case gpus[0].Library != "metal" && layers == 0:
|
||||||
// Don't bother loading into the GPU if no layers can fit
|
// Don't bother loading into the GPU if no layers can fit
|
||||||
cpuRunner = serverForCpu()
|
cpuRunner = serverForCpu()
|
||||||
gpuCount = 0
|
gpuCount = 0
|
||||||
} else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
|
case opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu":
|
||||||
opts.NumGPU = layers
|
opts.NumGPU = layers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loop through potential servers
|
// Loop through potential servers
|
||||||
finalErr := fmt.Errorf("no suitable llama servers found")
|
finalErr := errors.New("no suitable llama servers found")
|
||||||
|
|
||||||
if len(adapters) > 1 {
|
if len(adapters) > 1 {
|
||||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
||||||
@@ -189,35 +189,38 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
params = append(params, "--memory-f32")
|
params = append(params, "--memory-f32")
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.UseMLock {
|
flashAttnEnabled := envconfig.FlashAttention
|
||||||
params = append(params, "--mlock")
|
|
||||||
|
for _, g := range gpus {
|
||||||
|
// only cuda (compute capability 7+) and metal support flash attention
|
||||||
|
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
|
||||||
|
flashAttnEnabled = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// mmap has issues with partial offloading on metal
|
||||||
|
if g.Library == "metal" &&
|
||||||
|
uint64(opts.NumGPU) > 0 &&
|
||||||
|
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
||||||
|
opts.UseMMap = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if flashAttnEnabled {
|
||||||
|
params = append(params, "--flash-attn")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !opts.UseMMap {
|
if !opts.UseMMap {
|
||||||
params = append(params, "--no-mmap")
|
params = append(params, "--no-mmap")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if opts.UseMLock {
|
||||||
|
params = append(params, "--mlock")
|
||||||
|
}
|
||||||
|
|
||||||
if opts.UseNUMA {
|
if opts.UseNUMA {
|
||||||
params = append(params, "--numa")
|
params = append(params, "--numa")
|
||||||
}
|
}
|
||||||
|
|
||||||
flashAttnEnabled := envconfig.FlashAttention
|
|
||||||
|
|
||||||
// partial offloading does not support flash attention
|
|
||||||
if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
|
||||||
flashAttnEnabled = false
|
|
||||||
}
|
|
||||||
|
|
||||||
// only cuda (compute capability 7+) and metal support flash attention
|
|
||||||
for _, g := range gpus {
|
|
||||||
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
|
|
||||||
flashAttnEnabled = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if flashAttnEnabled {
|
|
||||||
params = append(params, "--flash-attn")
|
|
||||||
}
|
|
||||||
|
|
||||||
numParallel := envconfig.NumParallel
|
numParallel := envconfig.NumParallel
|
||||||
|
|
||||||
// TODO (jmorganca): multimodal models don't support parallel yet
|
// TODO (jmorganca): multimodal models don't support parallel yet
|
||||||
@@ -229,7 +232,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
|
|
||||||
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
|
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
|
||||||
|
|
||||||
for i := 0; i < len(servers); i++ {
|
for i := range len(servers) {
|
||||||
dir := availableServers[servers[i]]
|
dir := availableServers[servers[i]]
|
||||||
if dir == "" {
|
if dir == "" {
|
||||||
// Shouldn't happen
|
// Shouldn't happen
|
||||||
@@ -243,7 +246,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
gpuCount = 0
|
gpuCount = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find an availableServers port, retry on each iterration in case the failure was a port conflict race
|
// Find an availableServers port, retry on each iteration in case the failure was a port conflict race
|
||||||
port := 0
|
port := 0
|
||||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||||
var l *net.TCPListener
|
var l *net.TCPListener
|
||||||
@@ -281,7 +284,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
|
|
||||||
server := filepath.Join(dir, "ollama_llama_server")
|
server := filepath.Join(dir, "ollama_llama_server")
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
server = server + ".exe"
|
server += ".exe"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detect tmp cleaners wiping out the file
|
// Detect tmp cleaners wiping out the file
|
||||||
@@ -312,7 +315,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
s.cmd.Stdout = os.Stdout
|
s.cmd.Stdout = os.Stdout
|
||||||
s.cmd.Stderr = s.status
|
s.cmd.Stderr = s.status
|
||||||
|
|
||||||
visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
|
visibleDevicesEnv, visibleDevicesEnvVal := gpus.GetVisibleDevicesEnv()
|
||||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||||
|
|
||||||
// Update or add the path and visible devices variable with our adjusted version
|
// Update or add the path and visible devices variable with our adjusted version
|
||||||
@@ -456,7 +459,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
|||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, context.DeadlineExceeded) {
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
return ServerStatusNotResponding, fmt.Errorf("server not responding")
|
return ServerStatusNotResponding, errors.New("server not responding")
|
||||||
}
|
}
|
||||||
return ServerStatusError, fmt.Errorf("health resp: %w", err)
|
return ServerStatusError, fmt.Errorf("health resp: %w", err)
|
||||||
}
|
}
|
||||||
@@ -519,16 +522,18 @@ func (s *llmServer) Ping(ctx context.Context) error {
|
|||||||
|
|
||||||
func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
stallDuration := 60 * time.Second
|
stallDuration := 5 * time.Minute // If no progress happens
|
||||||
stallTimer := time.Now().Add(stallDuration) // give up if we stall for
|
finalLoadDuration := 5 * time.Minute // After we hit 100%, give the runner more time to come online
|
||||||
|
stallTimer := time.Now().Add(stallDuration) // give up if we stall
|
||||||
|
|
||||||
slog.Info("waiting for llama runner to start responding")
|
slog.Info("waiting for llama runner to start responding")
|
||||||
var lastStatus ServerStatus = -1
|
var lastStatus ServerStatus = -1
|
||||||
|
fullyLoaded := false
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
slog.Info("context expired before server started")
|
slog.Warn("client connection closed before server finished loading, aborting load")
|
||||||
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
|
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
|
||||||
case err := <-s.done:
|
case err := <-s.done:
|
||||||
msg := ""
|
msg := ""
|
||||||
@@ -572,6 +577,10 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
|||||||
if priorProgress != s.loadProgress {
|
if priorProgress != s.loadProgress {
|
||||||
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
|
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
|
||||||
stallTimer = time.Now().Add(stallDuration)
|
stallTimer = time.Now().Add(stallDuration)
|
||||||
|
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
|
||||||
|
slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
|
||||||
|
stallTimer = time.Now().Add(finalLoadDuration)
|
||||||
|
fullyLoaded = true
|
||||||
}
|
}
|
||||||
time.Sleep(time.Millisecond * 250)
|
time.Sleep(time.Millisecond * 250)
|
||||||
continue
|
continue
|
||||||
@@ -597,7 +606,7 @@ array ::=
|
|||||||
|
|
||||||
string ::=
|
string ::=
|
||||||
"\"" (
|
"\"" (
|
||||||
[^"\\] |
|
[^"\\\x7F\x00-\x1F] |
|
||||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||||
)* "\"" ws
|
)* "\"" ws
|
||||||
|
|
||||||
@@ -756,7 +765,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
|
|
||||||
var c completion
|
var c completion
|
||||||
if err := json.Unmarshal(evt, &c); err != nil {
|
if err := json.Unmarshal(evt, &c); err != nil {
|
||||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
|
|||||||
@@ -245,7 +245,6 @@ func (w *writer) writeResponse(data []byte) (int, error) {
|
|||||||
d, err := json.Marshal(toChunk(w.id, chatResponse))
|
d, err := json.Marshal(toChunk(w.id, chatResponse))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|||||||
@@ -3,12 +3,15 @@ package parser
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode"
|
"unicode/utf16"
|
||||||
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
type File struct {
|
type File struct {
|
||||||
@@ -69,33 +72,31 @@ func ParseFile(r io.Reader) (*File, error) {
|
|||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
var role string
|
var role string
|
||||||
|
|
||||||
var lineCount int
|
|
||||||
var linePos int
|
|
||||||
|
|
||||||
var utf16 bool
|
|
||||||
|
|
||||||
var f File
|
var f File
|
||||||
|
|
||||||
br := bufio.NewReader(r)
|
br := bufio.NewReader(r)
|
||||||
for {
|
|
||||||
r, _, err := br.ReadRune()
|
var sc scannerDecoder = utf8ScannerDecoder{}
|
||||||
if errors.Is(err, io.EOF) {
|
if bom, err := br.Peek(2); err != nil {
|
||||||
break
|
slog.Warn("error reading byte-order mark", "error", err)
|
||||||
} else if err != nil {
|
} else if bytes.Equal(bom, []byte{0xFE, 0xFF}) {
|
||||||
|
sc = utf16ScannerDecoder{binary.LittleEndian}
|
||||||
|
//nolint:errcheck
|
||||||
|
br.Discard(2)
|
||||||
|
} else if bytes.Equal(bom, []byte{0xFF, 0xFE}) {
|
||||||
|
sc = utf16ScannerDecoder{binary.BigEndian}
|
||||||
|
//nolint:errcheck
|
||||||
|
br.Discard(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(br)
|
||||||
|
scanner.Split(sc.ScanBytes)
|
||||||
|
for scanner.Scan() {
|
||||||
|
r, err := sc.DecodeRune(scanner.Bytes())
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// the utf16 byte order mark will be read as "unreadable" by ReadRune()
|
|
||||||
if isUnreadable(r) && lineCount == 0 && linePos == 0 {
|
|
||||||
utf16 = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// skip the second byte if we're reading utf16
|
|
||||||
if utf16 && r == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
next, r, err := parseRuneForState(r, curr)
|
next, r, err := parseRuneForState(r, curr)
|
||||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
return nil, fmt.Errorf("%w: %s", err, b.String())
|
return nil, fmt.Errorf("%w: %s", err, b.String())
|
||||||
@@ -103,13 +104,6 @@ func ParseFile(r io.Reader) (*File, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if isNewline(r) {
|
|
||||||
lineCount++
|
|
||||||
linePos = 0
|
|
||||||
} else {
|
|
||||||
linePos++
|
|
||||||
}
|
|
||||||
|
|
||||||
// process the state transition, some transitions need to be intercepted and redirected
|
// process the state transition, some transitions need to be intercepted and redirected
|
||||||
if next != curr {
|
if next != curr {
|
||||||
switch curr {
|
switch curr {
|
||||||
@@ -309,10 +303,6 @@ func isNewline(r rune) bool {
|
|||||||
return r == '\r' || r == '\n'
|
return r == '\r' || r == '\n'
|
||||||
}
|
}
|
||||||
|
|
||||||
func isUnreadable(r rune) bool {
|
|
||||||
return r == unicode.ReplacementChar
|
|
||||||
}
|
|
||||||
|
|
||||||
func isValidMessageRole(role string) bool {
|
func isValidMessageRole(role string) bool {
|
||||||
return role == "system" || role == "user" || role == "assistant"
|
return role == "system" || role == "user" || role == "assistant"
|
||||||
}
|
}
|
||||||
@@ -325,3 +315,39 @@ func isValidCommand(cmd string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type scannerDecoder interface {
|
||||||
|
ScanBytes(data []byte, atEOF bool) (advance int, token []byte, err error)
|
||||||
|
DecodeRune([]byte) (rune, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type utf8ScannerDecoder struct{}
|
||||||
|
|
||||||
|
func (utf8ScannerDecoder) ScanBytes(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
return scanBytesN(data, 1, atEOF)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (utf8ScannerDecoder) DecodeRune(data []byte) (rune, error) {
|
||||||
|
r, _ := utf8.DecodeRune(data)
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type utf16ScannerDecoder struct {
|
||||||
|
binary.ByteOrder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (utf16ScannerDecoder) ScanBytes(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
return scanBytesN(data, 2, atEOF)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e utf16ScannerDecoder) DecodeRune(data []byte) (rune, error) {
|
||||||
|
return utf16.Decode([]uint16{e.ByteOrder.Uint16(data)})[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanBytesN(data []byte, n int, atEOF bool) (int, []byte, error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, data[:n], nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"unicode/utf16"
|
"unicode/utf16"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseFileFile(t *testing.T) {
|
func TestParseFileFile(t *testing.T) {
|
||||||
@@ -25,7 +26,7 @@ TEMPLATE template1
|
|||||||
reader := strings.NewReader(input)
|
reader := strings.NewReader(input)
|
||||||
|
|
||||||
modelfile, err := ParseFile(reader)
|
modelfile, err := ParseFile(reader)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expectedCommands := []Command{
|
expectedCommands := []Command{
|
||||||
{Name: "model", Args: "model1"},
|
{Name: "model", Args: "model1"},
|
||||||
@@ -88,7 +89,7 @@ func TestParseFileFrom(t *testing.T) {
|
|||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
modelfile, err := ParseFile(strings.NewReader(c.input))
|
modelfile, err := ParseFile(strings.NewReader(c.input))
|
||||||
assert.ErrorIs(t, err, c.err)
|
require.ErrorIs(t, err, c.err)
|
||||||
if modelfile != nil {
|
if modelfile != nil {
|
||||||
assert.Equal(t, c.expected, modelfile.Commands)
|
assert.Equal(t, c.expected, modelfile.Commands)
|
||||||
}
|
}
|
||||||
@@ -105,7 +106,7 @@ PARAMETER param1
|
|||||||
reader := strings.NewReader(input)
|
reader := strings.NewReader(input)
|
||||||
|
|
||||||
_, err := ParseFile(reader)
|
_, err := ParseFile(reader)
|
||||||
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
|
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseFileBadCommand(t *testing.T) {
|
func TestParseFileBadCommand(t *testing.T) {
|
||||||
@@ -114,8 +115,7 @@ FROM foo
|
|||||||
BADCOMMAND param1 value1
|
BADCOMMAND param1 value1
|
||||||
`
|
`
|
||||||
_, err := ParseFile(strings.NewReader(input))
|
_, err := ParseFile(strings.NewReader(input))
|
||||||
assert.ErrorIs(t, err, errInvalidCommand)
|
require.ErrorIs(t, err, errInvalidCommand)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseFileMessages(t *testing.T) {
|
func TestParseFileMessages(t *testing.T) {
|
||||||
@@ -201,7 +201,7 @@ MESSAGE system`,
|
|||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
modelfile, err := ParseFile(strings.NewReader(c.input))
|
modelfile, err := ParseFile(strings.NewReader(c.input))
|
||||||
assert.ErrorIs(t, err, c.err)
|
require.ErrorIs(t, err, c.err)
|
||||||
if modelfile != nil {
|
if modelfile != nil {
|
||||||
assert.Equal(t, c.expected, modelfile.Commands)
|
assert.Equal(t, c.expected, modelfile.Commands)
|
||||||
}
|
}
|
||||||
@@ -355,7 +355,7 @@ TEMPLATE """
|
|||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
modelfile, err := ParseFile(strings.NewReader(c.multiline))
|
modelfile, err := ParseFile(strings.NewReader(c.multiline))
|
||||||
assert.ErrorIs(t, err, c.err)
|
require.ErrorIs(t, err, c.err)
|
||||||
if modelfile != nil {
|
if modelfile != nil {
|
||||||
assert.Equal(t, c.expected, modelfile.Commands)
|
assert.Equal(t, c.expected, modelfile.Commands)
|
||||||
}
|
}
|
||||||
@@ -413,7 +413,7 @@ func TestParseFileParameters(t *testing.T) {
|
|||||||
fmt.Fprintln(&b, "FROM foo")
|
fmt.Fprintln(&b, "FROM foo")
|
||||||
fmt.Fprintln(&b, "PARAMETER", k)
|
fmt.Fprintln(&b, "PARAMETER", k)
|
||||||
modelfile, err := ParseFile(&b)
|
modelfile, err := ParseFile(&b)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, []Command{
|
assert.Equal(t, []Command{
|
||||||
{Name: "model", Args: "foo"},
|
{Name: "model", Args: "foo"},
|
||||||
@@ -442,7 +442,7 @@ FROM foo
|
|||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
modelfile, err := ParseFile(strings.NewReader(c.input))
|
modelfile, err := ParseFile(strings.NewReader(c.input))
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, c.expected, modelfile.Commands)
|
assert.Equal(t, c.expected, modelfile.Commands)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -501,15 +501,14 @@ SYSTEM ""
|
|||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
modelfile, err := ParseFile(strings.NewReader(c))
|
modelfile, err := ParseFile(strings.NewReader(c))
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
modelfile2, err := ParseFile(strings.NewReader(modelfile.String()))
|
modelfile2, err := ParseFile(strings.NewReader(modelfile.String()))
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, modelfile, modelfile2)
|
assert.Equal(t, modelfile, modelfile2)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseFileUTF16ParseFile(t *testing.T) {
|
func TestParseFileUTF16ParseFile(t *testing.T) {
|
||||||
@@ -522,10 +521,10 @@ SYSTEM You are a utf16 file.
|
|||||||
utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...))
|
utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...))
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
err := binary.Write(buf, binary.LittleEndian, utf16File)
|
err := binary.Write(buf, binary.LittleEndian, utf16File)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual, err := ParseFile(buf)
|
actual, err := ParseFile(buf)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected := []Command{
|
expected := []Command{
|
||||||
{Name: "model", Args: "bob"},
|
{Name: "model", Args: "bob"},
|
||||||
@@ -539,9 +538,9 @@ SYSTEM You are a utf16 file.
|
|||||||
// simulate a utf16 be file
|
// simulate a utf16 be file
|
||||||
buf = new(bytes.Buffer)
|
buf = new(bytes.Buffer)
|
||||||
err = binary.Write(buf, binary.BigEndian, utf16File)
|
err = binary.Write(buf, binary.BigEndian, utf16File)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual, err = ParseFile(buf)
|
actual, err = ParseFile(buf)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, expected, actual.Commands)
|
assert.Equal(t, expected, actual.Commands)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ func (p *Progress) StopAndClear() bool {
|
|||||||
stopped := p.stop()
|
stopped := p.stop()
|
||||||
if stopped {
|
if stopped {
|
||||||
// clear all progress lines
|
// clear all progress lines
|
||||||
for i := 0; i < p.pos; i++ {
|
for i := range p.pos {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
fmt.Fprint(p.w, "\033[A")
|
fmt.Fprint(p.w, "\033[A")
|
||||||
}
|
}
|
||||||
@@ -85,7 +85,7 @@ func (p *Progress) render() {
|
|||||||
defer fmt.Fprint(p.w, "\033[?25h")
|
defer fmt.Fprint(p.w, "\033[?25h")
|
||||||
|
|
||||||
// clear already rendered progress lines
|
// clear already rendered progress lines
|
||||||
for i := 0; i < p.pos; i++ {
|
for i := range p.pos {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
fmt.Fprint(p.w, "\033[A")
|
fmt.Fprint(p.w, "\033[A")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,16 +5,20 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/emirpasic/gods/lists/arraylist"
|
"github.com/emirpasic/gods/lists/arraylist"
|
||||||
|
"github.com/mattn/go-runewidth"
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Buffer struct {
|
type Buffer struct {
|
||||||
Pos int
|
DisplayPos int
|
||||||
Buf *arraylist.List
|
Pos int
|
||||||
Prompt *Prompt
|
Buf *arraylist.List
|
||||||
LineWidth int
|
//LineHasSpace is an arraylist of bools to keep track of whether a line has a space at the end
|
||||||
Width int
|
LineHasSpace *arraylist.List
|
||||||
Height int
|
Prompt *Prompt
|
||||||
|
LineWidth int
|
||||||
|
Width int
|
||||||
|
Height int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBuffer(prompt *Prompt) (*Buffer, error) {
|
func NewBuffer(prompt *Prompt) (*Buffer, error) {
|
||||||
@@ -27,25 +31,56 @@ func NewBuffer(prompt *Prompt) (*Buffer, error) {
|
|||||||
lwidth := width - len(prompt.prompt())
|
lwidth := width - len(prompt.prompt())
|
||||||
|
|
||||||
b := &Buffer{
|
b := &Buffer{
|
||||||
Pos: 0,
|
DisplayPos: 0,
|
||||||
Buf: arraylist.New(),
|
Pos: 0,
|
||||||
Prompt: prompt,
|
Buf: arraylist.New(),
|
||||||
Width: width,
|
LineHasSpace: arraylist.New(),
|
||||||
Height: height,
|
Prompt: prompt,
|
||||||
LineWidth: lwidth,
|
Width: width,
|
||||||
|
Height: height,
|
||||||
|
LineWidth: lwidth,
|
||||||
}
|
}
|
||||||
|
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) GetLineSpacing(line int) bool {
|
||||||
|
hasSpace, _ := b.LineHasSpace.Get(line)
|
||||||
|
|
||||||
|
if hasSpace == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return hasSpace.(bool)
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Buffer) MoveLeft() {
|
func (b *Buffer) MoveLeft() {
|
||||||
if b.Pos > 0 {
|
if b.Pos > 0 {
|
||||||
if b.Pos%b.LineWidth == 0 {
|
//asserts that we retrieve a rune
|
||||||
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
if e, ok := b.Buf.Get(b.Pos - 1); ok {
|
||||||
} else {
|
if r, ok := e.(rune); ok {
|
||||||
fmt.Print(CursorLeft)
|
rLength := runewidth.RuneWidth(r)
|
||||||
|
|
||||||
|
if b.DisplayPos%b.LineWidth == 0 {
|
||||||
|
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
||||||
|
if rLength == 2 {
|
||||||
|
fmt.Print(CursorLeft)
|
||||||
|
}
|
||||||
|
|
||||||
|
line := b.DisplayPos/b.LineWidth - 1
|
||||||
|
hasSpace := b.GetLineSpacing(line)
|
||||||
|
if hasSpace {
|
||||||
|
b.DisplayPos -= 1
|
||||||
|
fmt.Print(CursorLeft)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Print(cursorLeftN(rLength))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Pos -= 1
|
||||||
|
b.DisplayPos -= rLength
|
||||||
|
}
|
||||||
}
|
}
|
||||||
b.Pos -= 1
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,18 +106,32 @@ func (b *Buffer) MoveLeftWord() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) MoveRight() {
|
func (b *Buffer) MoveRight() {
|
||||||
if b.Pos < b.Size() {
|
if b.Pos < b.Buf.Size() {
|
||||||
b.Pos += 1
|
if e, ok := b.Buf.Get(b.Pos); ok {
|
||||||
if b.Pos%b.LineWidth == 0 {
|
if r, ok := e.(rune); ok {
|
||||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
rLength := runewidth.RuneWidth(r)
|
||||||
} else {
|
b.Pos += 1
|
||||||
fmt.Print(CursorRight)
|
hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth)
|
||||||
|
b.DisplayPos += rLength
|
||||||
|
|
||||||
|
if b.DisplayPos%b.LineWidth == 0 {
|
||||||
|
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
||||||
|
} else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace {
|
||||||
|
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())+rLength))
|
||||||
|
b.DisplayPos += 1
|
||||||
|
} else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace {
|
||||||
|
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
||||||
|
b.DisplayPos += 1
|
||||||
|
} else {
|
||||||
|
fmt.Print(cursorRightN(rLength))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) MoveRightWord() {
|
func (b *Buffer) MoveRightWord() {
|
||||||
if b.Pos < b.Size() {
|
if b.Pos < b.Buf.Size() {
|
||||||
for {
|
for {
|
||||||
b.MoveRight()
|
b.MoveRight()
|
||||||
v, _ := b.Buf.Get(b.Pos)
|
v, _ := b.Buf.Get(b.Pos)
|
||||||
@@ -90,7 +139,7 @@ func (b *Buffer) MoveRightWord() {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if b.Pos == b.Size() {
|
if b.Pos == b.Buf.Size() {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -99,89 +148,200 @@ func (b *Buffer) MoveRightWord() {
|
|||||||
|
|
||||||
func (b *Buffer) MoveToStart() {
|
func (b *Buffer) MoveToStart() {
|
||||||
if b.Pos > 0 {
|
if b.Pos > 0 {
|
||||||
currLine := b.Pos / b.LineWidth
|
currLine := b.DisplayPos / b.LineWidth
|
||||||
if currLine > 0 {
|
if currLine > 0 {
|
||||||
for cnt := 0; cnt < currLine; cnt++ {
|
for range currLine {
|
||||||
fmt.Print(CursorUp)
|
fmt.Print(CursorUp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
||||||
b.Pos = 0
|
b.Pos = 0
|
||||||
|
b.DisplayPos = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) MoveToEnd() {
|
func (b *Buffer) MoveToEnd() {
|
||||||
if b.Pos < b.Size() {
|
if b.Pos < b.Buf.Size() {
|
||||||
currLine := b.Pos / b.LineWidth
|
currLine := b.DisplayPos / b.LineWidth
|
||||||
totalLines := b.Size() / b.LineWidth
|
totalLines := b.DisplaySize() / b.LineWidth
|
||||||
if currLine < totalLines {
|
if currLine < totalLines {
|
||||||
for cnt := 0; cnt < totalLines-currLine; cnt++ {
|
for range totalLines - currLine {
|
||||||
fmt.Print(CursorDown)
|
fmt.Print(CursorDown)
|
||||||
}
|
}
|
||||||
remainder := b.Size() % b.LineWidth
|
remainder := b.DisplaySize() % b.LineWidth
|
||||||
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())+remainder))
|
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())+remainder))
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(cursorRightN(b.Size() - b.Pos))
|
fmt.Print(cursorRightN(b.DisplaySize() - b.DisplayPos))
|
||||||
}
|
}
|
||||||
|
|
||||||
b.Pos = b.Size()
|
b.Pos = b.Buf.Size()
|
||||||
|
b.DisplayPos = b.DisplaySize()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) Size() int {
|
func (b *Buffer) DisplaySize() int {
|
||||||
return b.Buf.Size()
|
sum := 0
|
||||||
|
for i := range b.Buf.Size() {
|
||||||
|
if e, ok := b.Buf.Get(i); ok {
|
||||||
|
if r, ok := e.(rune); ok {
|
||||||
|
sum += runewidth.RuneWidth(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sum
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) Add(r rune) {
|
func (b *Buffer) Add(r rune) {
|
||||||
if b.Pos == b.Buf.Size() {
|
if b.Pos == b.Buf.Size() {
|
||||||
fmt.Printf("%c", r)
|
b.AddChar(r, false)
|
||||||
b.Buf.Add(r)
|
} else {
|
||||||
b.Pos += 1
|
b.AddChar(r, true)
|
||||||
if b.Pos > 0 && b.Pos%b.LineWidth == 0 {
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) AddChar(r rune, insert bool) {
|
||||||
|
rLength := runewidth.RuneWidth(r)
|
||||||
|
b.DisplayPos += rLength
|
||||||
|
|
||||||
|
if b.Pos > 0 {
|
||||||
|
if b.DisplayPos%b.LineWidth == 0 {
|
||||||
|
fmt.Printf("%c", r)
|
||||||
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
||||||
|
|
||||||
|
if insert {
|
||||||
|
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth-1, false)
|
||||||
|
} else {
|
||||||
|
b.LineHasSpace.Add(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// this case occurs when a double-width rune crosses the line boundary
|
||||||
|
} else if b.DisplayPos%b.LineWidth < (b.DisplayPos-rLength)%b.LineWidth {
|
||||||
|
if insert {
|
||||||
|
fmt.Print(ClearToEOL)
|
||||||
|
}
|
||||||
|
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
||||||
|
b.DisplayPos += 1
|
||||||
|
fmt.Printf("%c", r)
|
||||||
|
|
||||||
|
if insert {
|
||||||
|
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth-1, true)
|
||||||
|
} else {
|
||||||
|
b.LineHasSpace.Add(true)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Printf("%c", r)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf("%c", r)
|
fmt.Printf("%c", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if insert {
|
||||||
b.Buf.Insert(b.Pos, r)
|
b.Buf.Insert(b.Pos, r)
|
||||||
b.Pos += 1
|
} else {
|
||||||
if b.Pos > 0 && b.Pos%b.LineWidth == 0 {
|
b.Buf.Add(r)
|
||||||
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
}
|
||||||
}
|
|
||||||
|
b.Pos += 1
|
||||||
|
|
||||||
|
if insert {
|
||||||
b.drawRemaining()
|
b.drawRemaining()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) countRemainingLineWidth(place int) int {
|
||||||
|
var sum int
|
||||||
|
counter := -1
|
||||||
|
var prevLen int
|
||||||
|
|
||||||
|
for place <= b.LineWidth {
|
||||||
|
counter += 1
|
||||||
|
sum += prevLen
|
||||||
|
if e, ok := b.Buf.Get(b.Pos + counter); ok {
|
||||||
|
if r, ok := e.(rune); ok {
|
||||||
|
place += runewidth.RuneWidth(r)
|
||||||
|
prevLen = len(string(r))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Buffer) drawRemaining() {
|
func (b *Buffer) drawRemaining() {
|
||||||
var place int
|
var place int
|
||||||
remainingText := b.StringN(b.Pos)
|
remainingText := b.StringN(b.Pos)
|
||||||
if b.Pos > 0 {
|
if b.Pos > 0 {
|
||||||
place = b.Pos % b.LineWidth
|
place = b.DisplayPos % b.LineWidth
|
||||||
}
|
}
|
||||||
fmt.Print(CursorHide)
|
fmt.Print(CursorHide)
|
||||||
|
|
||||||
// render the rest of the current line
|
// render the rest of the current line
|
||||||
currLine := remainingText[:min(b.LineWidth-place, len(remainingText))]
|
currLineLength := b.countRemainingLineWidth(place)
|
||||||
|
|
||||||
|
currLine := remainingText[:min(currLineLength, len(remainingText))]
|
||||||
|
currLineSpace := runewidth.StringWidth(currLine)
|
||||||
|
remLength := runewidth.StringWidth(remainingText)
|
||||||
|
|
||||||
if len(currLine) > 0 {
|
if len(currLine) > 0 {
|
||||||
fmt.Printf(ClearToEOL + currLine)
|
fmt.Printf(ClearToEOL + currLine)
|
||||||
fmt.Print(cursorLeftN(len(currLine)))
|
fmt.Print(cursorLeftN(currLineSpace))
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(ClearToEOL)
|
fmt.Print(ClearToEOL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if currLineSpace != b.LineWidth-place && currLineSpace != remLength {
|
||||||
|
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth, true)
|
||||||
|
} else if currLineSpace != b.LineWidth-place {
|
||||||
|
b.LineHasSpace.Remove(b.DisplayPos / b.LineWidth)
|
||||||
|
} else {
|
||||||
|
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (b.DisplayPos+currLineSpace)%b.LineWidth == 0 && currLine == remainingText {
|
||||||
|
fmt.Print(cursorRightN(currLineSpace))
|
||||||
|
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
||||||
|
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width-currLineSpace))
|
||||||
|
}
|
||||||
|
|
||||||
// render the other lines
|
// render the other lines
|
||||||
if len(remainingText) > len(currLine) {
|
if remLength > currLineSpace {
|
||||||
remaining := []rune(remainingText[len(currLine):])
|
remaining := (remainingText[len(currLine):])
|
||||||
var totalLines int
|
var totalLines int
|
||||||
for i, c := range remaining {
|
var displayLength int
|
||||||
if i%b.LineWidth == 0 {
|
var lineLength int = currLineSpace
|
||||||
|
|
||||||
|
for _, c := range remaining {
|
||||||
|
if displayLength == 0 || (displayLength+runewidth.RuneWidth(c))%b.LineWidth < displayLength%b.LineWidth {
|
||||||
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
||||||
totalLines += 1
|
totalLines += 1
|
||||||
|
|
||||||
|
if displayLength != 0 {
|
||||||
|
if lineLength == b.LineWidth {
|
||||||
|
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth+totalLines-1, false)
|
||||||
|
} else {
|
||||||
|
b.LineHasSpace.Set(b.DisplayPos/b.LineWidth+totalLines-1, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lineLength = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
displayLength += runewidth.RuneWidth(c)
|
||||||
|
lineLength += runewidth.RuneWidth(c)
|
||||||
fmt.Printf("%c", c)
|
fmt.Printf("%c", c)
|
||||||
}
|
}
|
||||||
fmt.Print(ClearToEOL)
|
fmt.Print(ClearToEOL)
|
||||||
fmt.Print(cursorUpN(totalLines))
|
fmt.Print(cursorUpN(totalLines))
|
||||||
fmt.Printf(CursorBOL + cursorRightN(b.Width-len(currLine)))
|
fmt.Printf(CursorBOL + cursorRightN(b.Width-currLineSpace))
|
||||||
|
|
||||||
|
hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth)
|
||||||
|
|
||||||
|
if hasSpace && b.DisplayPos%b.LineWidth != b.LineWidth-1 {
|
||||||
|
fmt.Print(CursorLeft)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Print(CursorShow)
|
fmt.Print(CursorShow)
|
||||||
@@ -189,46 +349,81 @@ func (b *Buffer) drawRemaining() {
|
|||||||
|
|
||||||
func (b *Buffer) Remove() {
|
func (b *Buffer) Remove() {
|
||||||
if b.Buf.Size() > 0 && b.Pos > 0 {
|
if b.Buf.Size() > 0 && b.Pos > 0 {
|
||||||
if b.Pos%b.LineWidth == 0 {
|
if e, ok := b.Buf.Get(b.Pos - 1); ok {
|
||||||
// if the user backspaces over the word boundary, do this magic to clear the line
|
if r, ok := e.(rune); ok {
|
||||||
// and move to the end of the previous line
|
rLength := runewidth.RuneWidth(r)
|
||||||
fmt.Printf(CursorBOL + ClearToEOL)
|
hasSpace := b.GetLineSpacing(b.DisplayPos/b.LineWidth - 1)
|
||||||
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width) + " " + CursorLeft)
|
|
||||||
} else {
|
|
||||||
fmt.Printf(CursorLeft + " " + CursorLeft)
|
|
||||||
}
|
|
||||||
|
|
||||||
var eraseExtraLine bool
|
if b.DisplayPos%b.LineWidth == 0 {
|
||||||
if (b.Size()-1)%b.LineWidth == 0 {
|
// if the user backspaces over the word boundary, do this magic to clear the line
|
||||||
eraseExtraLine = true
|
// and move to the end of the previous line
|
||||||
}
|
fmt.Printf(CursorBOL + ClearToEOL)
|
||||||
|
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
||||||
|
|
||||||
b.Pos -= 1
|
if b.DisplaySize()%b.LineWidth < (b.DisplaySize()-rLength)%b.LineWidth {
|
||||||
b.Buf.Remove(b.Pos)
|
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
|
||||||
|
}
|
||||||
|
|
||||||
if b.Pos < b.Size() {
|
if hasSpace {
|
||||||
b.drawRemaining()
|
b.DisplayPos -= 1
|
||||||
// this erases a line which is left over when backspacing in the middle of a line and there
|
fmt.Print(CursorLeft)
|
||||||
// are trailing characters which go over the line width boundary
|
}
|
||||||
if eraseExtraLine {
|
|
||||||
remainingLines := (b.Size() - b.Pos) / b.LineWidth
|
if rLength == 2 {
|
||||||
fmt.Printf(cursorDownN(remainingLines+1) + CursorBOL + ClearToEOL)
|
fmt.Print(CursorLeft + " " + cursorLeftN(2))
|
||||||
place := b.Pos % b.LineWidth
|
} else {
|
||||||
fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.prompt())))
|
fmt.Print(" " + CursorLeft)
|
||||||
|
}
|
||||||
|
} else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace {
|
||||||
|
fmt.Printf(CursorBOL + ClearToEOL)
|
||||||
|
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
||||||
|
|
||||||
|
if b.Pos == b.Buf.Size() {
|
||||||
|
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
|
||||||
|
}
|
||||||
|
b.DisplayPos -= 1
|
||||||
|
} else {
|
||||||
|
fmt.Print(cursorLeftN(rLength))
|
||||||
|
for range rLength {
|
||||||
|
fmt.Print(" ")
|
||||||
|
}
|
||||||
|
fmt.Print(cursorLeftN(rLength))
|
||||||
|
}
|
||||||
|
|
||||||
|
var eraseExtraLine bool
|
||||||
|
if (b.DisplaySize()-1)%b.LineWidth == 0 || (rLength == 2 && ((b.DisplaySize()-2)%b.LineWidth == 0)) || b.DisplaySize()%b.LineWidth == 0 {
|
||||||
|
eraseExtraLine = true
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Pos -= 1
|
||||||
|
b.DisplayPos -= rLength
|
||||||
|
b.Buf.Remove(b.Pos)
|
||||||
|
|
||||||
|
if b.Pos < b.Buf.Size() {
|
||||||
|
b.drawRemaining()
|
||||||
|
// this erases a line which is left over when backspacing in the middle of a line and there
|
||||||
|
// are trailing characters which go over the line width boundary
|
||||||
|
if eraseExtraLine {
|
||||||
|
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
|
||||||
|
fmt.Printf(cursorDownN(remainingLines+1) + CursorBOL + ClearToEOL)
|
||||||
|
place := b.DisplayPos % b.LineWidth
|
||||||
|
fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.prompt())))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) Delete() {
|
func (b *Buffer) Delete() {
|
||||||
if b.Size() > 0 && b.Pos < b.Size() {
|
if b.Buf.Size() > 0 && b.Pos < b.Buf.Size() {
|
||||||
b.Buf.Remove(b.Pos)
|
b.Buf.Remove(b.Pos)
|
||||||
b.drawRemaining()
|
b.drawRemaining()
|
||||||
if b.Size()%b.LineWidth == 0 {
|
if b.DisplaySize()%b.LineWidth == 0 {
|
||||||
if b.Pos != b.Size() {
|
if b.DisplayPos != b.DisplaySize() {
|
||||||
remainingLines := (b.Size() - b.Pos) / b.LineWidth
|
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
|
||||||
fmt.Printf(cursorDownN(remainingLines) + CursorBOL + ClearToEOL)
|
fmt.Printf(cursorDownN(remainingLines) + CursorBOL + ClearToEOL)
|
||||||
place := b.Pos % b.LineWidth
|
place := b.DisplayPos % b.LineWidth
|
||||||
fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.prompt())))
|
fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.prompt())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -244,9 +439,9 @@ func (b *Buffer) DeleteBefore() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) DeleteRemaining() {
|
func (b *Buffer) DeleteRemaining() {
|
||||||
if b.Size() > 0 && b.Pos < b.Size() {
|
if b.DisplaySize() > 0 && b.Pos < b.DisplaySize() {
|
||||||
charsToDel := b.Size() - b.Pos
|
charsToDel := b.Buf.Size() - b.Pos
|
||||||
for cnt := 0; cnt < charsToDel; cnt++ {
|
for range charsToDel {
|
||||||
b.Delete()
|
b.Delete()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -281,14 +476,16 @@ func (b *Buffer) ClearScreen() {
|
|||||||
ph := b.Prompt.placeholder()
|
ph := b.Prompt.placeholder()
|
||||||
fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault)
|
fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault)
|
||||||
} else {
|
} else {
|
||||||
currPos := b.Pos
|
currPos := b.DisplayPos
|
||||||
|
currIndex := b.Pos
|
||||||
b.Pos = 0
|
b.Pos = 0
|
||||||
|
b.DisplayPos = 0
|
||||||
b.drawRemaining()
|
b.drawRemaining()
|
||||||
fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.prompt())))
|
fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.prompt())))
|
||||||
if currPos > 0 {
|
if currPos > 0 {
|
||||||
targetLine := currPos / b.LineWidth
|
targetLine := currPos / b.LineWidth
|
||||||
if targetLine > 0 {
|
if targetLine > 0 {
|
||||||
for cnt := 0; cnt < targetLine; cnt++ {
|
for range targetLine {
|
||||||
fmt.Print(CursorDown)
|
fmt.Print(CursorDown)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -300,7 +497,8 @@ func (b *Buffer) ClearScreen() {
|
|||||||
fmt.Printf(CursorBOL + b.Prompt.AltPrompt)
|
fmt.Printf(CursorBOL + b.Prompt.AltPrompt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
b.Pos = currPos
|
b.Pos = currIndex
|
||||||
|
b.DisplayPos = currPos
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -309,9 +507,20 @@ func (b *Buffer) IsEmpty() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) Replace(r []rune) {
|
func (b *Buffer) Replace(r []rune) {
|
||||||
|
b.DisplayPos = 0
|
||||||
b.Pos = 0
|
b.Pos = 0
|
||||||
|
lineNums := b.DisplaySize() / b.LineWidth
|
||||||
|
|
||||||
b.Buf.Clear()
|
b.Buf.Clear()
|
||||||
fmt.Printf(ClearLine + CursorBOL + b.Prompt.prompt())
|
|
||||||
|
fmt.Printf(CursorBOL + ClearToEOL)
|
||||||
|
|
||||||
|
for range lineNums {
|
||||||
|
fmt.Print(CursorUp + CursorBOL + ClearToEOL)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf(CursorBOL + b.Prompt.prompt())
|
||||||
|
|
||||||
for _, c := range r {
|
for _, c := range r {
|
||||||
b.Add(c)
|
b.Add(c)
|
||||||
}
|
}
|
||||||
@@ -328,7 +537,7 @@ func (b *Buffer) StringN(n int) string {
|
|||||||
func (b *Buffer) StringNM(n, m int) string {
|
func (b *Buffer) StringNM(n, m int) string {
|
||||||
var s string
|
var s string
|
||||||
if m == 0 {
|
if m == 0 {
|
||||||
m = b.Size()
|
m = b.Buf.Size()
|
||||||
}
|
}
|
||||||
for cnt := n; cnt < m; cnt++ {
|
for cnt := n; cnt < m; cnt++ {
|
||||||
c, _ := b.Buf.Get(cnt)
|
c, _ := b.Buf.Get(cnt)
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ func (h *History) Add(l []rune) {
|
|||||||
func (h *History) Compact() {
|
func (h *History) Compact() {
|
||||||
s := h.Buf.Size()
|
s := h.Buf.Size()
|
||||||
if s > h.Limit {
|
if s > h.Limit {
|
||||||
for cnt := 0; cnt < s-h.Limit; cnt++ {
|
for range s - h.Limit {
|
||||||
h.Buf.Remove(0)
|
h.Buf.Remove(0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -139,7 +139,7 @@ func (h *History) Save() error {
|
|||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
buf := bufio.NewWriter(f)
|
buf := bufio.NewWriter(f)
|
||||||
for cnt := 0; cnt < h.Size(); cnt++ {
|
for cnt := range h.Size() {
|
||||||
v, _ := h.Buf.Get(cnt)
|
v, _ := h.Buf.Get(cnt)
|
||||||
line, _ := v.([]rune)
|
line, _ := v.([]rune)
|
||||||
if _, err := buf.WriteString(string(line) + "\n"); err != nil {
|
if _, err := buf.WriteString(string(line) + "\n"); err != nil {
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"syscall"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Prompt struct {
|
type Prompt struct {
|
||||||
@@ -63,7 +62,7 @@ func New(prompt Prompt) (*Instance, error) {
|
|||||||
|
|
||||||
func (i *Instance) Readline() (string, error) {
|
func (i *Instance) Readline() (string, error) {
|
||||||
if !i.Terminal.rawmode {
|
if !i.Terminal.rawmode {
|
||||||
fd := int(syscall.Stdin)
|
fd := os.Stdin.Fd()
|
||||||
termios, err := SetRawMode(fd)
|
termios, err := SetRawMode(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -80,8 +79,8 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
fmt.Print(prompt)
|
fmt.Print(prompt)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
fd := int(syscall.Stdin)
|
fd := os.Stdin.Fd()
|
||||||
// nolint: errcheck
|
//nolint:errcheck
|
||||||
UnsetRawMode(fd, i.Terminal.termios)
|
UnsetRawMode(fd, i.Terminal.termios)
|
||||||
i.Terminal.rawmode = false
|
i.Terminal.rawmode = false
|
||||||
}()
|
}()
|
||||||
@@ -136,7 +135,7 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
buf.MoveRight()
|
buf.MoveRight()
|
||||||
case CharBracketedPaste:
|
case CharBracketedPaste:
|
||||||
var code string
|
var code string
|
||||||
for cnt := 0; cnt < 3; cnt++ {
|
for range 3 {
|
||||||
r, err = i.Terminal.Read()
|
r, err = i.Terminal.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", io.EOF
|
return "", io.EOF
|
||||||
@@ -150,7 +149,7 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
i.Pasting = false
|
i.Pasting = false
|
||||||
}
|
}
|
||||||
case KeyDel:
|
case KeyDel:
|
||||||
if buf.Size() > 0 {
|
if buf.DisplaySize() > 0 {
|
||||||
buf.Delete()
|
buf.Delete()
|
||||||
}
|
}
|
||||||
metaDel = true
|
metaDel = true
|
||||||
@@ -198,11 +197,11 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
buf.Remove()
|
buf.Remove()
|
||||||
case CharTab:
|
case CharTab:
|
||||||
// todo: convert back to real tabs
|
// todo: convert back to real tabs
|
||||||
for cnt := 0; cnt < 8; cnt++ {
|
for range 8 {
|
||||||
buf.Add(' ')
|
buf.Add(' ')
|
||||||
}
|
}
|
||||||
case CharDelete:
|
case CharDelete:
|
||||||
if buf.Size() > 0 {
|
if buf.DisplaySize() > 0 {
|
||||||
buf.Delete()
|
buf.Delete()
|
||||||
} else {
|
} else {
|
||||||
return "", io.EOF
|
return "", io.EOF
|
||||||
@@ -216,7 +215,7 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
case CharCtrlW:
|
case CharCtrlW:
|
||||||
buf.DeleteWord()
|
buf.DeleteWord()
|
||||||
case CharCtrlZ:
|
case CharCtrlZ:
|
||||||
fd := int(syscall.Stdin)
|
fd := os.Stdin.Fd()
|
||||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||||
case CharEnter, CharCtrlJ:
|
case CharEnter, CharCtrlJ:
|
||||||
output := buf.String()
|
output := buf.String()
|
||||||
@@ -248,7 +247,7 @@ func (i *Instance) HistoryDisable() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewTerminal() (*Terminal, error) {
|
func NewTerminal() (*Terminal, error) {
|
||||||
fd := int(syscall.Stdin)
|
fd := os.Stdin.Fd()
|
||||||
termios, err := SetRawMode(fd)
|
termios, err := SetRawMode(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
func handleCharCtrlZ(fd int, termios any) (string, error) {
|
func handleCharCtrlZ(fd uintptr, termios any) (string, error) {
|
||||||
t := termios.(*Termios)
|
t := termios.(*Termios)
|
||||||
if err := UnsetRawMode(fd, t); err != nil {
|
if err := UnsetRawMode(fd, t); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package readline
|
package readline
|
||||||
|
|
||||||
func handleCharCtrlZ(fd int, state any) (string, error) {
|
func handleCharCtrlZ(fd uintptr, state any) (string, error) {
|
||||||
// not supported
|
// not supported
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
|
|
||||||
type Termios syscall.Termios
|
type Termios syscall.Termios
|
||||||
|
|
||||||
func SetRawMode(fd int) (*Termios, error) {
|
func SetRawMode(fd uintptr) (*Termios, error) {
|
||||||
termios, err := getTermios(fd)
|
termios, err := getTermios(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -25,13 +25,13 @@ func SetRawMode(fd int) (*Termios, error) {
|
|||||||
return termios, setTermios(fd, &newTermios)
|
return termios, setTermios(fd, &newTermios)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnsetRawMode(fd int, termios any) error {
|
func UnsetRawMode(fd uintptr, termios any) error {
|
||||||
t := termios.(*Termios)
|
t := termios.(*Termios)
|
||||||
return setTermios(fd, t)
|
return setTermios(fd, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsTerminal returns true if the given file descriptor is a terminal.
|
// IsTerminal returns true if the given file descriptor is a terminal.
|
||||||
func IsTerminal(fd int) bool {
|
func IsTerminal(fd uintptr) bool {
|
||||||
_, err := getTermios(fd)
|
_, err := getTermios(fd)
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,17 +7,17 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getTermios(fd int) (*Termios, error) {
|
func getTermios(fd uintptr) (*Termios, error) {
|
||||||
termios := new(Termios)
|
termios := new(Termios)
|
||||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return termios, nil
|
return termios, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setTermios(fd int, termios *Termios) error {
|
func setTermios(fd uintptr, termios *Termios) error {
|
||||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,17 +10,17 @@ import (
|
|||||||
const tcgets = 0x5401
|
const tcgets = 0x5401
|
||||||
const tcsets = 0x5402
|
const tcsets = 0x5402
|
||||||
|
|
||||||
func getTermios(fd int) (*Termios, error) {
|
func getTermios(fd uintptr) (*Termios, error) {
|
||||||
termios := new(Termios)
|
termios := new(Termios)
|
||||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return termios, nil
|
return termios, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setTermios(fd int, termios *Termios) error {
|
func setTermios(fd uintptr, termios *Termios) error {
|
||||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ type State struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// IsTerminal checks if the given file descriptor is associated with a terminal
|
// IsTerminal checks if the given file descriptor is associated with a terminal
|
||||||
func IsTerminal(fd int) bool {
|
func IsTerminal(fd uintptr) bool {
|
||||||
var st uint32
|
var st uint32
|
||||||
err := windows.GetConsoleMode(windows.Handle(fd), &st)
|
err := windows.GetConsoleMode(windows.Handle(fd), &st)
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetRawMode(fd int) (*State, error) {
|
func SetRawMode(fd uintptr) (*State, error) {
|
||||||
var st uint32
|
var st uint32
|
||||||
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
|
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -32,7 +32,7 @@ func SetRawMode(fd int) (*State, error) {
|
|||||||
return &State{st}, nil
|
return &State{st}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnsetRawMode(fd int, state any) error {
|
func UnsetRawMode(fd uintptr, state any) error {
|
||||||
s := state.(*State)
|
s := state.(*State)
|
||||||
return windows.SetConsoleMode(windows.Handle(fd), s.mode)
|
return windows.SetConsoleMode(windows.Handle(fd), s.mode)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,9 +33,11 @@ case "$ARCH" in
|
|||||||
*) error "Unsupported architecture: $ARCH" ;;
|
*) error "Unsupported architecture: $ARCH" ;;
|
||||||
esac
|
esac
|
||||||
|
|
||||||
|
IS_WSL2=false
|
||||||
|
|
||||||
KERN=$(uname -r)
|
KERN=$(uname -r)
|
||||||
case "$KERN" in
|
case "$KERN" in
|
||||||
*icrosoft*WSL2 | *icrosoft*wsl2) ;;
|
*icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;;
|
||||||
*icrosoft) error "Microsoft WSL1 is not currently supported. Please upgrade to WSL2 with 'wsl --set-version <distro> 2'" ;;
|
*icrosoft) error "Microsoft WSL1 is not currently supported. Please upgrade to WSL2 with 'wsl --set-version <distro> 2'" ;;
|
||||||
*) ;;
|
*) ;;
|
||||||
esac
|
esac
|
||||||
@@ -72,7 +74,7 @@ status "Installing ollama to $BINDIR..."
|
|||||||
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
||||||
$SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama
|
$SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama
|
||||||
|
|
||||||
install_success() {
|
install_success() {
|
||||||
status 'The Ollama API is now available at 127.0.0.1:11434.'
|
status 'The Ollama API is now available at 127.0.0.1:11434.'
|
||||||
status 'Install complete. Run "ollama" from the command line.'
|
status 'Install complete. Run "ollama" from the command line.'
|
||||||
}
|
}
|
||||||
@@ -131,6 +133,17 @@ if available systemctl; then
|
|||||||
configure_systemd
|
configure_systemd
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# WSL2 only supports GPUs via nvidia passthrough
|
||||||
|
# so check for nvidia-smi to determine if GPU is available
|
||||||
|
if [ "$IS_WSL2" = true ]; then
|
||||||
|
if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
|
||||||
|
status "Nvidia GPU detected."
|
||||||
|
fi
|
||||||
|
install_success
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install GPU dependencies on Linux
|
||||||
if ! available lspci && ! available lshw; then
|
if ! available lspci && ! available lshw; then
|
||||||
warning "Unable to detect NVIDIA/AMD GPU. Install lspci or lshw to automatically detect and install GPU dependencies."
|
warning "Unable to detect NVIDIA/AMD GPU. Install lspci or lshw to automatically detect and install GPU dependencies."
|
||||||
exit 0
|
exit 0
|
||||||
@@ -139,12 +152,12 @@ fi
|
|||||||
check_gpu() {
|
check_gpu() {
|
||||||
# Look for devices based on vendor ID for NVIDIA and AMD
|
# Look for devices based on vendor ID for NVIDIA and AMD
|
||||||
case $1 in
|
case $1 in
|
||||||
lspci)
|
lspci)
|
||||||
case $2 in
|
case $2 in
|
||||||
nvidia) available lspci && lspci -d '10de:' | grep -q 'NVIDIA' || return 1 ;;
|
nvidia) available lspci && lspci -d '10de:' | grep -q 'NVIDIA' || return 1 ;;
|
||||||
amdgpu) available lspci && lspci -d '1002:' | grep -q 'AMD' || return 1 ;;
|
amdgpu) available lspci && lspci -d '1002:' | grep -q 'AMD' || return 1 ;;
|
||||||
esac ;;
|
esac ;;
|
||||||
lshw)
|
lshw)
|
||||||
case $2 in
|
case $2 in
|
||||||
nvidia) available lshw && $SUDO lshw -c display -numeric | grep -q 'vendor: .* \[10DE\]' || return 1 ;;
|
nvidia) available lshw && $SUDO lshw -c display -numeric | grep -q 'vendor: .* \[10DE\]' || return 1 ;;
|
||||||
amdgpu) available lshw && $SUDO lshw -c display -numeric | grep -q 'vendor: .* \[1002\]' || return 1 ;;
|
amdgpu) available lshw && $SUDO lshw -c display -numeric | grep -q 'vendor: .* \[1002\]' || return 1 ;;
|
||||||
@@ -181,7 +194,7 @@ if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
|
|||||||
curl --fail --show-error --location --progress-bar "https://ollama.com/download/ollama-linux-amd64-rocm.tgz${VER_PARAM}" \
|
curl --fail --show-error --location --progress-bar "https://ollama.com/download/ollama-linux-amd64-rocm.tgz${VER_PARAM}" \
|
||||||
| $SUDO tar zx --owner ollama --group ollama -C /usr/share/ollama/lib/rocm .
|
| $SUDO tar zx --owner ollama --group ollama -C /usr/share/ollama/lib/rocm .
|
||||||
install_success
|
install_success
|
||||||
status "AMD GPU dependencies installed."
|
status "AMD GPU ready."
|
||||||
exit 0
|
exit 0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@@ -274,7 +287,7 @@ if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\
|
|||||||
esac
|
esac
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if ! lsmod | grep -q nvidia; then
|
if ! lsmod | grep -q nvidia || ! lsmod | grep -q nvidia_uvm; then
|
||||||
KERNEL_RELEASE="$(uname -r)"
|
KERNEL_RELEASE="$(uname -r)"
|
||||||
case $OS_NAME in
|
case $OS_NAME in
|
||||||
rocky) $SUDO $PACKAGE_MANAGER -y install kernel-devel kernel-headers ;;
|
rocky) $SUDO $PACKAGE_MANAGER -y install kernel-devel kernel-headers ;;
|
||||||
@@ -295,7 +308,19 @@ if ! lsmod | grep -q nvidia; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
$SUDO modprobe nvidia
|
$SUDO modprobe nvidia
|
||||||
|
$SUDO modprobe nvidia_uvm
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# make sure the NVIDIA modules are loaded on boot with nvidia-persistenced
|
||||||
|
if command -v nvidia-persistenced > /dev/null 2>&1; then
|
||||||
|
$SUDO touch /etc/modules-load.d/nvidia.conf
|
||||||
|
MODULES="nvidia nvidia-uvm"
|
||||||
|
for MODULE in $MODULES; do
|
||||||
|
if ! grep -qxF "$MODULE" /etc/modules-load.d/nvidia.conf; then
|
||||||
|
echo "$MODULE" | sudo tee -a /etc/modules-load.d/nvidia.conf > /dev/null
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
status "NVIDIA CUDA drivers installed."
|
status "NVIDIA GPU ready."
|
||||||
|
install_success
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size)
|
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed)
|
||||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
// rollback progress
|
// rollback progress
|
||||||
b.Completed.Add(-n)
|
b.Completed.Add(-n)
|
||||||
@@ -340,17 +340,17 @@ type downloadOpts struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||||
func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
|
||||||
fp, err := GetBlobsPath(opts.digest)
|
fp, err := GetBlobsPath(opts.digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
fi, err := os.Stat(fp)
|
fi, err := os.Stat(fp)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, os.ErrNotExist):
|
case errors.Is(err, os.ErrNotExist):
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return false, err
|
||||||
default:
|
default:
|
||||||
opts.fn(api.ProgressResponse{
|
opts.fn(api.ProgressResponse{
|
||||||
Status: fmt.Sprintf("pulling %s", opts.digest[7:19]),
|
Status: fmt.Sprintf("pulling %s", opts.digest[7:19]),
|
||||||
@@ -359,7 +359,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
|||||||
Completed: fi.Size(),
|
Completed: fi.Size(),
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
|
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
|
||||||
@@ -369,12 +369,12 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
|||||||
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
|
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
|
||||||
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
|
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
|
||||||
blobDownloadManager.Delete(opts.digest)
|
blobDownloadManager.Delete(opts.digest)
|
||||||
return err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint: contextcheck
|
//nolint:contextcheck
|
||||||
go download.Run(context.Background(), requestURL, opts.regOpts)
|
go download.Run(context.Background(), requestURL, opts.regOpts)
|
||||||
}
|
}
|
||||||
|
|
||||||
return download.Wait(ctx, opts.fn)
|
return false, download.Wait(ctx, opts.fn)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
package envconfig
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestConfig(t *testing.T) {
|
|
||||||
Debug = false // Reset whatever was loaded in init()
|
|
||||||
t.Setenv("OLLAMA_DEBUG", "")
|
|
||||||
LoadConfig()
|
|
||||||
require.False(t, Debug)
|
|
||||||
t.Setenv("OLLAMA_DEBUG", "false")
|
|
||||||
LoadConfig()
|
|
||||||
require.False(t, Debug)
|
|
||||||
t.Setenv("OLLAMA_DEBUG", "1")
|
|
||||||
LoadConfig()
|
|
||||||
require.True(t, Debug)
|
|
||||||
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
|
|
||||||
LoadConfig()
|
|
||||||
require.True(t, FlashAttention)
|
|
||||||
}
|
|
||||||
103
server/images.go
103
server/images.go
@@ -18,17 +18,16 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/auth"
|
"github.com/ollama/ollama/auth"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/server/envconfig"
|
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@@ -315,7 +314,7 @@ func realpath(rel, from string) string {
|
|||||||
return abspath
|
return abspath
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) {
|
func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) {
|
||||||
config := ConfigV2{
|
config := ConfigV2{
|
||||||
OS: "linux",
|
OS: "linux",
|
||||||
Architecture: "amd64",
|
Architecture: "amd64",
|
||||||
@@ -333,7 +332,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
|
|||||||
|
|
||||||
switch c.Name {
|
switch c.Name {
|
||||||
case "model", "adapter":
|
case "model", "adapter":
|
||||||
var baseLayers []*layerWithGGML
|
var baseLayers []*layerGGML
|
||||||
if name := model.ParseName(c.Args); name.IsValid() {
|
if name := model.ParseName(c.Args); name.IsValid() {
|
||||||
baseLayers, err = parseFromModel(ctx, name, fn)
|
baseLayers, err = parseFromModel(ctx, name, fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -440,19 +439,27 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
|
|||||||
layers = append(layers, baseLayer.Layer)
|
layers = append(layers, baseLayer.Layer)
|
||||||
}
|
}
|
||||||
case "license", "template", "system":
|
case "license", "template", "system":
|
||||||
|
if c.Name != "license" {
|
||||||
|
// replace
|
||||||
|
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
||||||
|
if layer.MediaType != mediatype {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := layer.Remove(); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
blob := strings.NewReader(c.Args)
|
blob := strings.NewReader(c.Args)
|
||||||
layer, err := NewLayer(blob, mediatype)
|
layer, err := NewLayer(blob, mediatype)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Name != "license" {
|
|
||||||
// replace
|
|
||||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
|
||||||
return layer.MediaType == mediatype
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
layers = append(layers, layer)
|
layers = append(layers, layer)
|
||||||
case "message":
|
case "message":
|
||||||
role, content, ok := strings.Cut(c.Args, ": ")
|
role, content, ok := strings.Cut(c.Args, ": ")
|
||||||
@@ -571,26 +578,15 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unref := make(map[string]struct{})
|
old, _ := ParseNamedManifest(name)
|
||||||
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
|
|
||||||
for _, layer := range manifest.Layers {
|
|
||||||
if !slices.Contains(digests, layer.Digest) {
|
|
||||||
unref[layer.Digest] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if manifest.Config.Digest != layer.Digest {
|
|
||||||
unref[manifest.Config.Digest] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||||
if err := WriteManifest(name, layer, layers); err != nil {
|
if err := WriteManifest(name, layer, layers); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !envconfig.NoPrune {
|
if !envconfig.NoPrune && old != nil {
|
||||||
if err := deleteUnusedLayers(nil, unref); err != nil {
|
if err := old.RemoveLayers(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -662,7 +658,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{})
|
|||||||
// save (i.e. delete from the deleteMap) any files used in other manifests
|
// save (i.e. delete from the deleteMap) any files used in other manifests
|
||||||
manifest, _, err := GetManifest(fmp)
|
manifest, _, err := GetManifest(fmp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// nolint: nilerr
|
//nolint:nilerr
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -771,37 +767,6 @@ func PruneDirectory(path string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteModel(name string) error {
|
|
||||||
mp := ParseModelPath(name)
|
|
||||||
manifest, _, err := GetManifest(mp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
deleteMap := make(map[string]struct{})
|
|
||||||
for _, layer := range manifest.Layers {
|
|
||||||
deleteMap[layer.Digest] = struct{}{}
|
|
||||||
}
|
|
||||||
deleteMap[manifest.Config.Digest] = struct{}{}
|
|
||||||
|
|
||||||
err = deleteUnusedLayers(&mp, deleteMap)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fp, err := mp.GetManifestPath()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = os.Remove(fp)
|
|
||||||
if err != nil {
|
|
||||||
slog.Info(fmt.Sprintf("couldn't remove manifest file '%s': %v", fp, err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
mp := ParseModelPath(name)
|
mp := ParseModelPath(name)
|
||||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||||
@@ -888,23 +853,27 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
layers = append(layers, manifest.Layers...)
|
layers = append(layers, manifest.Layers...)
|
||||||
layers = append(layers, manifest.Config)
|
layers = append(layers, manifest.Config)
|
||||||
|
|
||||||
|
skipVerify := make(map[string]bool)
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
if err := downloadBlob(
|
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
||||||
ctx,
|
mp: mp,
|
||||||
downloadOpts{
|
digest: layer.Digest,
|
||||||
mp: mp,
|
regOpts: regOpts,
|
||||||
digest: layer.Digest,
|
fn: fn,
|
||||||
regOpts: regOpts,
|
})
|
||||||
fn: fn,
|
if err != nil {
|
||||||
}); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
skipVerify[layer.Digest] = cacheHit
|
||||||
delete(deleteMap, layer.Digest)
|
delete(deleteMap, layer.Digest)
|
||||||
}
|
}
|
||||||
delete(deleteMap, manifest.Config.Digest)
|
delete(deleteMap, manifest.Config.Digest)
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
|
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
|
if skipVerify[layer.Digest] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if err := verifyBlob(layer.Digest); err != nil {
|
if err := verifyBlob(layer.Digest); err != nil {
|
||||||
if errors.Is(err, errDigestMismatch) {
|
if errors.Is(err, errDigestMismatch) {
|
||||||
// something went wrong, delete the blob
|
// something went wrong, delete the blob
|
||||||
@@ -1019,7 +988,7 @@ func getTokenSubject(token string) string {
|
|||||||
|
|
||||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||||
anonymous := true // access will default to anonymous if no user is found associated with the public key
|
anonymous := true // access will default to anonymous if no user is found associated with the public key
|
||||||
for i := 0; i < 2; i++ {
|
for range 2 {
|
||||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, context.Canceled) {
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
|||||||
@@ -88,3 +88,26 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
|||||||
|
|
||||||
return os.Open(blob)
|
return os.Open(blob)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *Layer) Remove() error {
|
||||||
|
ms, err := Manifests()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range ms {
|
||||||
|
for _, layer := range append(m.Layers, m.Config) {
|
||||||
|
if layer.Digest == l.Digest {
|
||||||
|
// something is using this layer
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
blob, err := GetBlobsPath(l.Digest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.Remove(blob)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
@@ -14,7 +15,10 @@ import (
|
|||||||
|
|
||||||
type Manifest struct {
|
type Manifest struct {
|
||||||
ManifestV2
|
ManifestV2
|
||||||
Digest string `json:"-"`
|
|
||||||
|
filepath string
|
||||||
|
fi os.FileInfo
|
||||||
|
digest string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manifest) Size() (size int64) {
|
func (m *Manifest) Size() (size int64) {
|
||||||
@@ -25,9 +29,34 @@ func (m *Manifest) Size() (size int64) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseNamedManifest(name model.Name) (*Manifest, error) {
|
func (m *Manifest) Remove() error {
|
||||||
if !name.IsFullyQualified() {
|
if err := os.Remove(m.filepath); err != nil {
|
||||||
return nil, model.Unqualified(name)
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
manifests, err := GetManifestPath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return PruneDirectory(manifests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manifest) RemoveLayers() error {
|
||||||
|
for _, layer := range append(m.Layers, m.Config) {
|
||||||
|
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
|
||||||
|
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||||
|
if !n.IsFullyQualified() {
|
||||||
|
return nil, model.Unqualified(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
manifests, err := GetManifestPath()
|
manifests, err := GetManifestPath()
|
||||||
@@ -35,45 +64,101 @@ func ParseNamedManifest(name model.Name) (*Manifest, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var manifest ManifestV2
|
p := filepath.Join(manifests, n.Filepath())
|
||||||
manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath()))
|
|
||||||
|
var m ManifestV2
|
||||||
|
f, err := os.Open(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
fi, err := f.Stat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sha256sum := sha256.New()
|
sha256sum := sha256.New()
|
||||||
if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil {
|
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Manifest{
|
return &Manifest{
|
||||||
ManifestV2: manifest,
|
ManifestV2: m,
|
||||||
Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
|
filepath: p,
|
||||||
|
fi: fi,
|
||||||
|
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteManifest(name string, config *Layer, layers []*Layer) error {
|
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
||||||
manifest := ManifestV2{
|
manifests, err := GetManifestPath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p := filepath.Join(manifests, name.Filepath())
|
||||||
|
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Create(p)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
m := ManifestV2{
|
||||||
SchemaVersion: 2,
|
SchemaVersion: 2,
|
||||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||||
Config: config,
|
Config: config,
|
||||||
Layers: layers,
|
Layers: layers,
|
||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
return json.NewEncoder(f).Encode(m)
|
||||||
if err := json.NewEncoder(&b).Encode(manifest); err != nil {
|
}
|
||||||
return err
|
|
||||||
}
|
func Manifests() (map[model.Name]*Manifest, error) {
|
||||||
|
manifests, err := GetManifestPath()
|
||||||
modelpath := ParseModelPath(name)
|
if err != nil {
|
||||||
manifestPath, err := modelpath.GetManifestPath()
|
return nil, err
|
||||||
if err != nil {
|
}
|
||||||
return err
|
|
||||||
}
|
// TODO(mxyng): use something less brittle
|
||||||
|
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
|
||||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return os.WriteFile(manifestPath, b.Bytes(), 0o644)
|
ms := make(map[model.Name]*Manifest)
|
||||||
|
for _, match := range matches {
|
||||||
|
fi, err := os.Stat(match)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fi.IsDir() {
|
||||||
|
rel, err := filepath.Rel(manifests, match)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("bad filepath", "path", match, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
n := model.ParseNameFromFilepath(rel)
|
||||||
|
if !n.IsValid() {
|
||||||
|
slog.Warn("bad manifest name", "path", rel, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := ParseNamedManifest(n)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("bad manifest", "name", n, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ms[n] = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ms, nil
|
||||||
}
|
}
|
||||||
|
|||||||
150
server/manifest_test.go
Normal file
150
server/manifest_test.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createManifest(t *testing.T, path, name string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
p := filepath.Join(path, "manifests", name)
|
||||||
|
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Create(p)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifests(t *testing.T) {
|
||||||
|
cases := map[string]struct {
|
||||||
|
ps []string
|
||||||
|
wantValidCount int
|
||||||
|
wantInvalidCount int
|
||||||
|
}{
|
||||||
|
"empty": {},
|
||||||
|
"single": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("host", "namespace", "model", "tag"),
|
||||||
|
},
|
||||||
|
wantValidCount: 1,
|
||||||
|
},
|
||||||
|
"multiple": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "latest"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"),
|
||||||
|
},
|
||||||
|
wantValidCount: 15,
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("host", "namespace", "model", "tag"),
|
||||||
|
filepath.Join("host", "namespace", "model", ".hidden"),
|
||||||
|
},
|
||||||
|
wantValidCount: 1,
|
||||||
|
wantInvalidCount: 1,
|
||||||
|
},
|
||||||
|
"subdir": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("host", "namespace", "model", "tag", "one"),
|
||||||
|
filepath.Join("host", "namespace", "model", "tag", "another", "one"),
|
||||||
|
},
|
||||||
|
wantInvalidCount: 2,
|
||||||
|
},
|
||||||
|
"upper tag": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("host", "namespace", "model", "TAG"),
|
||||||
|
},
|
||||||
|
wantValidCount: 1,
|
||||||
|
},
|
||||||
|
"upper model": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("host", "namespace", "MODEL", "tag"),
|
||||||
|
},
|
||||||
|
wantValidCount: 1,
|
||||||
|
},
|
||||||
|
"upper namespace": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("host", "NAMESPACE", "model", "tag"),
|
||||||
|
},
|
||||||
|
wantValidCount: 1,
|
||||||
|
},
|
||||||
|
"upper host": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("HOST", "namespace", "model", "tag"),
|
||||||
|
},
|
||||||
|
wantValidCount: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, wants := range cases {
|
||||||
|
t.Run(n, func(t *testing.T) {
|
||||||
|
d := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", d)
|
||||||
|
|
||||||
|
for _, p := range wants.ps {
|
||||||
|
createManifest(t, d, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
ms, err := Manifests()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ns []model.Name
|
||||||
|
for k := range ms {
|
||||||
|
ns = append(ns, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotValidCount, gotInvalidCount int
|
||||||
|
for _, p := range wants.ps {
|
||||||
|
n := model.ParseNameFromFilepath(p)
|
||||||
|
if n.IsValid() {
|
||||||
|
gotValidCount++
|
||||||
|
} else {
|
||||||
|
gotInvalidCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
if !n.IsValid() && slices.Contains(ns, n) {
|
||||||
|
t.Errorf("unexpected invalid name: %s", p)
|
||||||
|
} else if n.IsValid() && !slices.Contains(ns, n) {
|
||||||
|
t.Errorf("missing valid name: %s", p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotValidCount != wants.wantValidCount {
|
||||||
|
t.Errorf("got valid count %d, want %d", gotValidCount, wants.wantValidCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotInvalidCount != wants.wantInvalidCount {
|
||||||
|
t.Errorf("got invalid count %d, want %d", gotInvalidCount, wants.wantInvalidCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -14,27 +15,26 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/convert"
|
"github.com/ollama/ollama/convert"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
|
"github.com/ollama/ollama/templates"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
var intermediateBlobs map[string]string = make(map[string]string)
|
var intermediateBlobs map[string]string = make(map[string]string)
|
||||||
|
|
||||||
type layerWithGGML struct {
|
type layerGGML struct {
|
||||||
*Layer
|
*Layer
|
||||||
*llm.GGML
|
*llm.GGML
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
|
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||||
modelpath := ParseModelPath(name.String())
|
m, err := ParseNamedManifest(name)
|
||||||
manifest, _, err := GetManifest(modelpath)
|
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, os.ErrNotExist):
|
case errors.Is(err, os.ErrNotExist):
|
||||||
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
|
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
modelpath = ParseModelPath(name.String())
|
m, err = ParseNamedManifest(name)
|
||||||
manifest, _, err = GetManifest(modelpath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -42,8 +42,8 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, layer := range manifest.Layers {
|
for _, layer := range m.Layers {
|
||||||
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
|
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -68,17 +68,16 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
layers = append(layers, &layerWithGGML{layer, ggml})
|
layers = append(layers, &layerGGML{layer, ggml})
|
||||||
default:
|
default:
|
||||||
layers = append(layers, &layerWithGGML{layer, nil})
|
layers = append(layers, &layerGGML{layer, nil})
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
|
func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||||
stat, err := file.Stat()
|
stat, err := file.Stat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -182,13 +181,13 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
layers = append(layers, &layerWithGGML{layer, ggml})
|
layers = append(layers, &layerGGML{layer, ggml})
|
||||||
|
|
||||||
intermediateBlobs[digest] = layer.Digest
|
intermediateBlobs[digest] = layer.Digest
|
||||||
return layers, nil
|
return detectChatTemplate(layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
|
func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||||
sr := io.NewSectionReader(file, 0, 512)
|
sr := io.NewSectionReader(file, 0, 512)
|
||||||
contentType, err := detectContentType(sr)
|
contentType, err := detectContentType(sr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -230,10 +229,30 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
layers = append(layers, &layerWithGGML{layer, ggml})
|
layers = append(layers, &layerGGML{layer, ggml})
|
||||||
offset = n
|
offset = n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return detectChatTemplate(layers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||||
|
for _, layer := range layers {
|
||||||
|
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
||||||
|
if t, err := templates.NamedTemplate(s); err != nil {
|
||||||
|
slog.Debug("template detection", "error", err)
|
||||||
|
} else {
|
||||||
|
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl.status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||||
|
layers = append(layers, &layerGGML{tmpl, nil})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,12 +6,13 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetBlobsPath(t *testing.T) {
|
func TestGetBlobsPath(t *testing.T) {
|
||||||
// GetBlobsPath expects an actual directory to exist
|
// GetBlobsPath expects an actual directory to exist
|
||||||
dir, err := os.MkdirTemp("", "ollama-test")
|
dir, err := os.MkdirTemp("", "ollama-test")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer os.RemoveAll(dir)
|
defer os.RemoveAll(dir)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -63,7 +64,7 @@ func TestGetBlobsPath(t *testing.T) {
|
|||||||
|
|
||||||
got, err := GetBlobsPath(tc.digest)
|
got, err := GetBlobsPath(tc.digest)
|
||||||
|
|
||||||
assert.ErrorIs(t, tc.err, err, tc.name)
|
require.ErrorIs(t, tc.err, err, tc.name)
|
||||||
assert.Equal(t, tc.expected, got, tc.name)
|
assert.Equal(t, tc.expected, got, tc.name)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
220
server/routes.go
220
server/routes.go
@@ -16,6 +16,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -23,14 +24,13 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/server/envconfig"
|
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@@ -77,7 +77,6 @@ func isSupportedImageType(image []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GenerateHandler(c *gin.Context) {
|
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
var req api.GenerateRequest
|
var req api.GenerateRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
@@ -315,10 +314,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getDefaultSessionDuration() time.Duration {
|
func getDefaultSessionDuration() time.Duration {
|
||||||
if t, exists := os.LookupEnv("OLLAMA_KEEP_ALIVE"); exists {
|
if envconfig.KeepAlive != "" {
|
||||||
v, err := strconv.Atoi(t)
|
v, err := strconv.Atoi(envconfig.KeepAlive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
d, err := time.ParseDuration(t)
|
d, err := time.ParseDuration(envconfig.KeepAlive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return defaultSessionDuration
|
return defaultSessionDuration
|
||||||
}
|
}
|
||||||
@@ -421,13 +420,14 @@ func (s *Server) PullModelHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var model string
|
name := model.ParseName(cmp.Or(req.Model, req.Name))
|
||||||
if req.Model != "" {
|
if !name.IsValid() {
|
||||||
model = req.Model
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
|
||||||
} else if req.Name != "" {
|
return
|
||||||
model = req.Name
|
}
|
||||||
} else {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
if err := checkNameExists(name); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -445,7 +445,7 @@ func (s *Server) PullModelHandler(c *gin.Context) {
|
|||||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := PullModel(ctx, model, regOpts, fn); err != nil {
|
if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -507,9 +507,24 @@ func (s *Server) PushModelHandler(c *gin.Context) {
|
|||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkNameExists(name model.Name) error {
|
||||||
|
names, err := Manifests()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for n := range names {
|
||||||
|
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
|
||||||
|
return fmt.Errorf("a model with that name already exists")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) CreateModelHandler(c *gin.Context) {
|
func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||||
var req api.CreateRequest
|
var r api.CreateRequest
|
||||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@@ -517,30 +532,35 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
name := model.ParseName(cmp.Or(req.Model, req.Name))
|
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||||
if !name.IsValid() {
|
if !name.IsValid() {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Path == "" && req.Modelfile == "" {
|
if err := checkNameExists(name); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Path == "" && r.Modelfile == "" {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var r io.Reader = strings.NewReader(req.Modelfile)
|
var sr io.Reader = strings.NewReader(r.Modelfile)
|
||||||
if req.Path != "" && req.Modelfile == "" {
|
if r.Path != "" && r.Modelfile == "" {
|
||||||
f, err := os.Open(req.Path)
|
f, err := os.Open(r.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
r = f
|
sr = f
|
||||||
}
|
}
|
||||||
|
|
||||||
modelfile, err := parser.ParseFile(r)
|
f, err := parser.ParseFile(sr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -556,17 +576,13 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
|||||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
quantization := req.Quantization
|
quantization := cmp.Or(r.Quantize, r.Quantization)
|
||||||
if req.Quantize != "" {
|
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
|
||||||
quantization = req.Quantize
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(quantization), modelfile, fn); err != nil {
|
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if req.Stream != nil && !*req.Stream {
|
if r.Stream != nil && !*r.Stream {
|
||||||
waitForStream(c, ch)
|
waitForStream(c, ch)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -575,48 +591,36 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) DeleteModelHandler(c *gin.Context) {
|
func (s *Server) DeleteModelHandler(c *gin.Context) {
|
||||||
var req api.DeleteRequest
|
var r api.DeleteRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
|
||||||
switch {
|
|
||||||
case errors.Is(err, io.EOF):
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
return
|
return
|
||||||
case err != nil:
|
} else if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var model string
|
n := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||||
if req.Model != "" {
|
if !n.IsValid() {
|
||||||
model = req.Model
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
|
||||||
} else if req.Name != "" {
|
|
||||||
model = req.Name
|
|
||||||
} else {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DeleteModel(model); err != nil {
|
m, err := ParseNamedManifest(n)
|
||||||
if os.IsNotExist(err) {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
|
|
||||||
} else {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
manifestsPath, err := GetManifestPath()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := PruneDirectory(manifestsPath); err != nil {
|
if err := m.Remove(); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, nil)
|
if err := m.RemoveLayers(); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ShowModelHandler(c *gin.Context) {
|
func (s *Server) ShowModelHandler(c *gin.Context) {
|
||||||
@@ -720,75 +724,45 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ListModelsHandler(c *gin.Context) {
|
func (s *Server) ListModelsHandler(c *gin.Context) {
|
||||||
manifests, err := GetManifestPath()
|
ms, err := Manifests()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
models := []api.ModelResponse{}
|
models := []api.ListModelResponse{}
|
||||||
if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error {
|
for n, m := range ms {
|
||||||
if !info.IsDir() {
|
f, err := m.Config.Open()
|
||||||
rel, err := filepath.Rel(manifests, path)
|
if err != nil {
|
||||||
if err != nil {
|
slog.Warn("bad manifest filepath", "name", n, "error", err)
|
||||||
return err
|
continue
|
||||||
}
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil {
|
var cf ConfigV2
|
||||||
return err
|
if err := json.NewDecoder(f).Decode(&cf); err != nil {
|
||||||
} else if hidden {
|
slog.Warn("bad manifest config", "name", n, "error", err)
|
||||||
return nil
|
continue
|
||||||
}
|
|
||||||
|
|
||||||
n := model.ParseNameFromFilepath(rel)
|
|
||||||
if !n.IsValid() {
|
|
||||||
slog.Warn("bad manifest filepath", "path", rel)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := ParseNamedManifest(n)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("bad manifest", "name", n, "error", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
f, err := m.Config.Open()
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("bad manifest config filepath", "name", n, "error", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
var c ConfigV2
|
|
||||||
if err := json.NewDecoder(f).Decode(&c); err != nil {
|
|
||||||
slog.Warn("bad manifest config", "name", n, "error", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// tag should never be masked
|
|
||||||
models = append(models, api.ModelResponse{
|
|
||||||
Model: n.DisplayShortest(),
|
|
||||||
Name: n.DisplayShortest(),
|
|
||||||
Size: m.Size(),
|
|
||||||
Digest: m.Digest,
|
|
||||||
ModifiedAt: info.ModTime(),
|
|
||||||
Details: api.ModelDetails{
|
|
||||||
Format: c.ModelFormat,
|
|
||||||
Family: c.ModelFamily,
|
|
||||||
Families: c.ModelFamilies,
|
|
||||||
ParameterSize: c.ModelType,
|
|
||||||
QuantizationLevel: c.FileType,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
// tag should never be masked
|
||||||
}); err != nil {
|
models = append(models, api.ListModelResponse{
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
Model: n.DisplayShortest(),
|
||||||
return
|
Name: n.DisplayShortest(),
|
||||||
|
Size: m.Size(),
|
||||||
|
Digest: m.digest,
|
||||||
|
ModifiedAt: m.fi.ModTime(),
|
||||||
|
Details: api.ModelDetails{
|
||||||
|
Format: cf.ModelFormat,
|
||||||
|
Family: cf.ModelFamily,
|
||||||
|
Families: cf.ModelFamilies,
|
||||||
|
ParameterSize: cf.ModelType,
|
||||||
|
QuantizationLevel: cf.FileType,
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
|
slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
|
||||||
// most recently modified first
|
// most recently modified first
|
||||||
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
|
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
|
||||||
})
|
})
|
||||||
@@ -818,6 +792,11 @@ func (s *Server) CopyModelHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := checkNameExists(dst); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
|
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@@ -963,7 +942,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if allowedHost(host) {
|
if allowedHost(host) {
|
||||||
if c.Request.Method == "OPTIONS" {
|
if c.Request.Method == http.MethodOptions {
|
||||||
c.AbortWithStatus(http.StatusNoContent)
|
c.AbortWithStatus(http.StatusNoContent)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -981,6 +960,10 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|||||||
config.AllowWildcard = true
|
config.AllowWildcard = true
|
||||||
config.AllowBrowserExtensions = true
|
config.AllowBrowserExtensions = true
|
||||||
config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
|
config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
|
||||||
|
openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"}
|
||||||
|
for _, prop := range openAIProperties {
|
||||||
|
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
|
||||||
|
}
|
||||||
config.AllowOrigins = envconfig.AllowOrigins
|
config.AllowOrigins = envconfig.AllowOrigins
|
||||||
|
|
||||||
r := gin.Default()
|
r := gin.Default()
|
||||||
@@ -1025,7 +1008,7 @@ func Serve(ln net.Listener) error {
|
|||||||
level = slog.LevelDebug
|
level = slog.LevelDebug
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("server config", "env", envconfig.AsMap())
|
slog.Info("server config", "env", envconfig.Values())
|
||||||
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||||
Level: level,
|
Level: level,
|
||||||
AddSource: true,
|
AddSource: true,
|
||||||
@@ -1160,7 +1143,7 @@ func streamResponse(c *gin.Context, ch chan any) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ProcessHandler(c *gin.Context) {
|
func (s *Server) ProcessHandler(c *gin.Context) {
|
||||||
models := []api.ModelResponse{}
|
models := []api.ProcessModelResponse{}
|
||||||
|
|
||||||
for _, v := range s.sched.loaded {
|
for _, v := range s.sched.loaded {
|
||||||
model := v.model
|
model := v.model
|
||||||
@@ -1172,7 +1155,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
|||||||
QuantizationLevel: model.Config.FileType,
|
QuantizationLevel: model.Config.FileType,
|
||||||
}
|
}
|
||||||
|
|
||||||
mr := api.ModelResponse{
|
mr := api.ProcessModelResponse{
|
||||||
Model: model.ShortName,
|
Model: model.ShortName,
|
||||||
Name: model.ShortName,
|
Name: model.ShortName,
|
||||||
Size: int64(v.estimatedTotal),
|
Size: int64(v.estimatedTotal),
|
||||||
@@ -1192,7 +1175,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
|||||||
models = append(models, mr)
|
models = append(models, mr)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, api.ListResponse{Models: models})
|
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||||
@@ -1327,7 +1310,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
|
||||||
fn := func(r llm.CompletionResponse) {
|
fn := func(r llm.CompletionResponse) {
|
||||||
|
|
||||||
resp := api.ChatResponse{
|
resp := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
|
|||||||
560
server/routes_create_test.go
Normal file
560
server/routes_create_test.go
Normal file
@@ -0,0 +1,560 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
var stream bool = false
|
||||||
|
|
||||||
|
func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
f, err := os.CreateTemp(t.TempDir(), "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if err := llm.NewGGUFV3(binary.LittleEndian).Encode(f, kv, ti); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
type responseRecorder struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
http.CloseNotifier
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRecorder() *responseRecorder {
|
||||||
|
return &responseRecorder{
|
||||||
|
ResponseRecorder: httptest.NewRecorder(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *responseRecorder) CloseNotify() <-chan bool {
|
||||||
|
return make(chan bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createRequest(t *testing.T, fn func(*gin.Context), body any) *httptest.ResponseRecorder {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
w := NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(body); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request = &http.Request{
|
||||||
|
Body: io.NopCloser(&b),
|
||||||
|
}
|
||||||
|
|
||||||
|
fn(c)
|
||||||
|
return w.ResponseRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkFileExists(t *testing.T, p string, expect []string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
actual, err := filepath.Glob(p)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Equal(actual, expect) {
|
||||||
|
t.Fatalf("expected slices to be equal %v", actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFromBin(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
|
||||||
|
var s Server
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFromModel(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test2",
|
||||||
|
Modelfile: "FROM test",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateRemovesLayers(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-b507b9c2f6ca642bffcd06665ea7c91f235fd32daeefdf875a0f938db05fb315"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-bc80b03733773e0728011b2f4adf34c458b400e1aad48cb28d61170f3a2ad2d6"),
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateUnsetsSystem(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nSYSTEM Say hi!", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-8585df945d1069bc78b79bd10bb73ba07fbc29b0f5479a31a601c0d12731416e"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-f29e82a8284dbdf5910b1555580ff60b04238b8da9d5e51159ada67a4d0d5851"),
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nSYSTEM \"\"", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-67d4b8d106af2a5b100a46e9bdc038c71eef2a35c9abac784092654212f97cf5"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"),
|
||||||
|
})
|
||||||
|
|
||||||
|
bts, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(bts) != "" {
|
||||||
|
t.Fatalf("expected empty string, actual %s", string(bts))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateMergeParameters(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nPARAMETER temperature 1\nPARAMETER top_k 10\nPARAMETER stop USER:\nPARAMETER stop ASSISTANT:", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
})
|
||||||
|
|
||||||
|
// in order to merge parameters, the second model must be created FROM the first
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test2",
|
||||||
|
Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-4cd9d4ba6b734d9b4cbd1e5caa60374c00722e993fce5e1e2d15a33698f71187"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-e29a7b3c47287a2489c895d21fe413c20f859a85d20e749492f52a838e36e1ba"),
|
||||||
|
})
|
||||||
|
|
||||||
|
actual, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e29a7b3c47287a2489c895d21fe413c20f859a85d20e749492f52a838e36e1ba"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect, err := json.Marshal(map[string]any{"temperature": 0.6, "top_k": 10, "top_p": 0.7, "stop": []string{"USER:", "ASSISTANT:"}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(bytes.TrimSpace(expect), bytes.TrimSpace(actual)) {
|
||||||
|
t.Errorf("expected %s, actual %s", string(expect), string(actual))
|
||||||
|
}
|
||||||
|
|
||||||
|
// slices are replaced
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test2",
|
||||||
|
Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7\nPARAMETER stop <|endoftext|>",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-257aa726584f24970a4f240765e75a7169bfbe7f4966c1f04513d6b6c860583a"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
})
|
||||||
|
|
||||||
|
actual, err = os.ReadFile(filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect, err = json.Marshal(map[string]any{"temperature": 0.6, "top_k": 10, "top_p": 0.7, "stop": []string{"<|endoftext|>"}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(bytes.TrimSpace(expect), bytes.TrimSpace(actual)) {
|
||||||
|
t.Errorf("expected %s, actual %s", string(expect), string(actual))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateReplacesMessages(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nMESSAGE assistant \"What is my purpose?\"\nMESSAGE user \"You run tests.\"\nMESSAGE assistant \"Oh, my god.\"", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"),
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test2",
|
||||||
|
Modelfile: "FROM test\nMESSAGE assistant \"You're a test, Harry.\"\nMESSAGE user \"I-I'm a what?\"\nMESSAGE assistant \"A test. And a thumping good one at that, I'd wager.\"",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-4f48b25fe9969564c82f58eb1cedbdff6484cc0baf474bc6c2a9b37c8da3362a"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a60ecc9da299ec7ede453f99236e5577fd125e143689b646d9f0ddc9971bf4db"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"),
|
||||||
|
})
|
||||||
|
|
||||||
|
type message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Open(filepath.Join(p, "blobs", "sha256-a60ecc9da299ec7ede453f99236e5577fd125e143689b646d9f0ddc9971bf4db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
var actual []message
|
||||||
|
if err := json.NewDecoder(f).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect := []message{
|
||||||
|
{Role: "assistant", Content: "You're a test, Harry."},
|
||||||
|
{Role: "user", Content: "I-I'm a what?"},
|
||||||
|
{Role: "assistant", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Equal(actual, expect) {
|
||||||
|
t.Errorf("expected %s, actual %s", expect, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateTemplateSystem(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}\nSYSTEM Say hello!\nTEMPLATE {{ .System }} {{ .Prompt }}\nSYSTEM Say bye!", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-2b5e330885117c82f3fd75169ea323e141070a2947c11ddb9f79ee0b01c589c1"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-4c5f51faac758fecaff8db42f0b7382891a4d0c0bb885f7b86be88c814a7cc86"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
|
||||||
|
})
|
||||||
|
|
||||||
|
template, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(template) != "{{ .System }} {{ .Prompt }}" {
|
||||||
|
t.Errorf("expected \"{{ .System }} {{ .Prompt }}\", actual %s", template)
|
||||||
|
}
|
||||||
|
|
||||||
|
system, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-4c5f51faac758fecaff8db42f0b7382891a4d0c0bb885f7b86be88c814a7cc86"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(system) != "Say bye!" {
|
||||||
|
t.Errorf("expected \"Say bye!\", actual %s", system)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateLicenses(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-79a39c37536ddee29cbadd5d5e2dcba8ed7f03e431f626ff38432c1c866bb7e2"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-e5dcffe836b6ec8a58e492419b550e65fb8cbdc308503979e5dacb33ac7ea3b7"),
|
||||||
|
})
|
||||||
|
|
||||||
|
mit, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e5dcffe836b6ec8a58e492419b550e65fb8cbdc308503979e5dacb33ac7ea3b7"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(mit) != "MIT" {
|
||||||
|
t.Errorf("expected MIT, actual %s", mit)
|
||||||
|
}
|
||||||
|
|
||||||
|
apache, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(apache) != "Apache-2.0" {
|
||||||
|
t.Errorf("expected Apache-2.0, actual %s", apache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateDetectTemplate(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
t.Run("matched", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||||
|
"tokenizer.chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
|
||||||
|
}, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-2f8e594e6f34b1b4d36a246628eeb3365ce442303d656f1fcc69e821722acea0"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unmatched", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
104
server/routes_delete_test.go
Normal file
104
server/routes_delete_test.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDelete(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test2",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)),
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test"})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test2"})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteDuplicateLayers(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
n := model.ParseName("test")
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(&ConfigV2{}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a manifest with duplicate layers
|
||||||
|
if err := WriteManifest(n, config, []*Layer{config}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test"})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||||
|
}
|
||||||
61
server/routes_list_test.go
Normal file
61
server/routes_list_test.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestList(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
||||||
|
expectNames := []string{
|
||||||
|
"mistral:7b-instruct-q4_0",
|
||||||
|
"zephyr:7b-beta-q5_K_M",
|
||||||
|
"apple/OpenELM:latest",
|
||||||
|
"boreas:2b-code-v1.5-q6_K",
|
||||||
|
"notus:7b-v1-IQ2_S",
|
||||||
|
// TODO: host:port currently fails on windows (#4107)
|
||||||
|
// "localhost:5000/library/eurus:700b-v0.5-iq3_XXS",
|
||||||
|
"mynamespace/apeliotes:latest",
|
||||||
|
"myhost/mynamespace/lips:code",
|
||||||
|
}
|
||||||
|
|
||||||
|
var s Server
|
||||||
|
for _, n := range expectNames {
|
||||||
|
createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: n,
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
w := createRequest(t, s.ListModelsHandler, nil)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.ListResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Models) != len(expectNames) {
|
||||||
|
t.Fatalf("expected %d models, actual %d", len(expectNames), len(resp.Models))
|
||||||
|
}
|
||||||
|
|
||||||
|
actualNames := make([]string, len(resp.Models))
|
||||||
|
for i, m := range resp.Models {
|
||||||
|
actualNames[i] = m.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.Sort(actualNames)
|
||||||
|
slices.Sort(expectNames)
|
||||||
|
|
||||||
|
if !slices.Equal(actualNames, expectNames) {
|
||||||
|
t.Fatalf("expected slices to be equal %v", actualNames)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,12 +15,36 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func createTestFile(t *testing.T, name string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
f, err := os.CreateTemp(t.TempDir(), name)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = binary.Write(f, binary.LittleEndian, uint32(3))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = binary.Write(f, binary.LittleEndian, uint64(0))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = binary.Write(f, binary.LittleEndian, uint64(0))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return f.Name()
|
||||||
|
}
|
||||||
|
|
||||||
func Test_Routes(t *testing.T) {
|
func Test_Routes(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
Name string
|
Name string
|
||||||
@@ -30,39 +54,19 @@ func Test_Routes(t *testing.T) {
|
|||||||
Expected func(t *testing.T, resp *http.Response)
|
Expected func(t *testing.T, resp *http.Response)
|
||||||
}
|
}
|
||||||
|
|
||||||
createTestFile := func(t *testing.T, name string) string {
|
createTestModel := func(t *testing.T, name string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), name)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
err = binary.Write(f, binary.LittleEndian, uint32(3))
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
err = binary.Write(f, binary.LittleEndian, uint64(0))
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
err = binary.Write(f, binary.LittleEndian, uint64(0))
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
return f.Name()
|
|
||||||
}
|
|
||||||
|
|
||||||
createTestModel := func(t *testing.T, name string) {
|
|
||||||
fname := createTestFile(t, "ollama-model")
|
fname := createTestFile(t, "ollama-model")
|
||||||
|
|
||||||
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
|
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
|
||||||
modelfile, err := parser.ParseFile(r)
|
modelfile, err := parser.ParseFile(r)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
fn := func(resp api.ProgressResponse) {
|
fn := func(resp api.ProgressResponse) {
|
||||||
t.Logf("Status: %s", resp.Status)
|
t.Logf("Status: %s", resp.Status)
|
||||||
}
|
}
|
||||||
err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
|
err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
@@ -74,9 +78,9 @@ func Test_Routes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
assert.Equal(t, contentType, "application/json; charset=utf-8")
|
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body))
|
assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -86,17 +90,17 @@ func Test_Routes(t *testing.T) {
|
|||||||
Path: "/api/tags",
|
Path: "/api/tags",
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
assert.Equal(t, contentType, "application/json; charset=utf-8")
|
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var modelList api.ListResponse
|
var modelList api.ListResponse
|
||||||
|
|
||||||
err = json.Unmarshal(body, &modelList)
|
err = json.Unmarshal(body, &modelList)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.NotNil(t, modelList.Models)
|
assert.NotNil(t, modelList.Models)
|
||||||
assert.Equal(t, 0, len(modelList.Models))
|
assert.Empty(t, len(modelList.Models))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -108,16 +112,18 @@ func Test_Routes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
assert.Equal(t, contentType, "application/json; charset=utf-8")
|
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotContains(t, string(body), "expires_at")
|
||||||
|
|
||||||
var modelList api.ListResponse
|
var modelList api.ListResponse
|
||||||
err = json.Unmarshal(body, &modelList)
|
err = json.Unmarshal(body, &modelList)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, 1, len(modelList.Models))
|
assert.Len(t, modelList.Models, 1)
|
||||||
assert.Equal(t, modelList.Models[0].Name, "test-model:latest")
|
assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -134,7 +140,7 @@ func Test_Routes(t *testing.T) {
|
|||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(createReq)
|
jsonData, err := json.Marshal(createReq)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||||
},
|
},
|
||||||
@@ -142,11 +148,11 @@ func Test_Routes(t *testing.T) {
|
|||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
assert.Equal(t, "application/json", contentType)
|
assert.Equal(t, "application/json", contentType)
|
||||||
_, err := io.ReadAll(resp.Body)
|
_, err := io.ReadAll(resp.Body)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, resp.StatusCode, 200)
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
model, err := GetModel("t-bone")
|
model, err := GetModel("t-bone")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "t-bone:latest", model.ShortName)
|
assert.Equal(t, "t-bone:latest", model.ShortName)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -161,13 +167,13 @@ func Test_Routes(t *testing.T) {
|
|||||||
Destination: "beefsteak",
|
Destination: "beefsteak",
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(copyReq)
|
jsonData, err := json.Marshal(copyReq)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
model, err := GetModel("beefsteak")
|
model, err := GetModel("beefsteak")
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "beefsteak:latest", model.ShortName)
|
assert.Equal(t, "beefsteak:latest", model.ShortName)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -179,18 +185,18 @@ func Test_Routes(t *testing.T) {
|
|||||||
createTestModel(t, "show-model")
|
createTestModel(t, "show-model")
|
||||||
showReq := api.ShowRequest{Model: "show-model"}
|
showReq := api.ShowRequest{Model: "show-model"}
|
||||||
jsonData, err := json.Marshal(showReq)
|
jsonData, err := json.Marshal(showReq)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
assert.Equal(t, contentType, "application/json; charset=utf-8")
|
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var showResp api.ShowResponse
|
var showResp api.ShowResponse
|
||||||
err = json.Unmarshal(body, &showResp)
|
err = json.Unmarshal(body, &showResp)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var params []string
|
var params []string
|
||||||
paramsSplit := strings.Split(showResp.Parameters, "\n")
|
paramsSplit := strings.Split(showResp.Parameters, "\n")
|
||||||
@@ -209,26 +215,26 @@ func Test_Routes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
||||||
s := &Server{}
|
s := &Server{}
|
||||||
router := s.GenerateRoutes()
|
router := s.GenerateRoutes()
|
||||||
|
|
||||||
httpSrv := httptest.NewServer(router)
|
httpSrv := httptest.NewServer(router)
|
||||||
t.Cleanup(httpSrv.Close)
|
t.Cleanup(httpSrv.Close)
|
||||||
|
|
||||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
u := httpSrv.URL + tc.Path
|
u := httpSrv.URL + tc.Path
|
||||||
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
|
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if tc.Setup != nil {
|
if tc.Setup != nil {
|
||||||
tc.Setup(t, req)
|
tc.Setup(t, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := httpSrv.Client().Do(req)
|
resp, err := httpSrv.Client().Do(req)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if tc.Expected != nil {
|
if tc.Expected != nil {
|
||||||
@@ -237,3 +243,82 @@ func Test_Routes(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCase(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
||||||
|
cases := []string{
|
||||||
|
"mistral",
|
||||||
|
"llama3:latest",
|
||||||
|
"library/phi3:q4_0",
|
||||||
|
"registry.ollama.ai/library/gemma:q5_K_M",
|
||||||
|
// TODO: host:port currently fails on windows (#4107)
|
||||||
|
// "localhost:5000/alice/bob:latest",
|
||||||
|
}
|
||||||
|
|
||||||
|
var s Server
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt, func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: tt,
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200 got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("create", func(t *testing.T) {
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: strings.ToUpper(tt),
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status 500 got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(w.Body.Bytes(), expect) {
|
||||||
|
t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pull", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.PullModelHandler, api.PullRequest{
|
||||||
|
Name: strings.ToUpper(tt),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status 500 got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(w.Body.Bytes(), expect) {
|
||||||
|
t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("copy", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CopyModelHandler, api.CopyRequest{
|
||||||
|
Source: tt,
|
||||||
|
Destination: strings.ToUpper(tt),
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status 500 got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(w.Body.Bytes(), expect) {
|
||||||
|
t.Fatalf("expected error %s got %s", expect, w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,17 +7,17 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/server/envconfig"
|
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type LlmRequest struct {
|
type LlmRequest struct {
|
||||||
@@ -66,7 +66,7 @@ func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options,
|
|||||||
opts.NumCtx = 4
|
opts.NumCtx = 4
|
||||||
}
|
}
|
||||||
|
|
||||||
opts.NumCtx = opts.NumCtx * envconfig.NumParallel
|
opts.NumCtx *= envconfig.NumParallel
|
||||||
|
|
||||||
req := &LlmRequest{
|
req := &LlmRequest{
|
||||||
ctx: c,
|
ctx: c,
|
||||||
@@ -370,7 +370,6 @@ func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
|
|||||||
r.refMu.Lock()
|
r.refMu.Lock()
|
||||||
gpuIDs := make([]string, 0, len(r.gpus))
|
gpuIDs := make([]string, 0, len(r.gpus))
|
||||||
if r.llama != nil {
|
if r.llama != nil {
|
||||||
|
|
||||||
// TODO this should be broken down by GPU instead of assuming uniform spread
|
// TODO this should be broken down by GPU instead of assuming uniform spread
|
||||||
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
|
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
|
||||||
for _, gpu := range r.gpus {
|
for _, gpu := range r.gpus {
|
||||||
@@ -529,7 +528,6 @@ func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return finished
|
return finished
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ByDuration []*runnerRef
|
type ByDuration []*runnerRef
|
||||||
|
|||||||
@@ -12,11 +12,10 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/app/lifecycle"
|
"github.com/ollama/ollama/app/lifecycle"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/server/envconfig"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,10 +52,10 @@ func TestLoad(t *testing.T) {
|
|||||||
}
|
}
|
||||||
gpus := gpu.GpuInfoList{}
|
gpus := gpu.GpuInfoList{}
|
||||||
s.load(req, ggml, gpus)
|
s.load(req, ggml, gpus)
|
||||||
require.Len(t, req.successCh, 0)
|
require.Empty(t, req.successCh)
|
||||||
require.Len(t, req.errCh, 1)
|
require.Len(t, req.errCh, 1)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 0)
|
require.Empty(t, s.loaded)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
err := <-req.errCh
|
err := <-req.errCh
|
||||||
require.Contains(t, err.Error(), "this model may be incompatible")
|
require.Contains(t, err.Error(), "this model may be incompatible")
|
||||||
@@ -113,7 +112,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
gguf := llm.NewGGUFV3(binary.LittleEndian)
|
gguf := llm.NewGGUFV3(binary.LittleEndian)
|
||||||
@@ -131,7 +130,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
|||||||
}, []llm.Tensor{
|
}, []llm.Tensor{
|
||||||
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
fname := f.Name()
|
fname := f.Name()
|
||||||
model := &Model{Name: modelName, ModelPath: fname}
|
model := &Model{Name: modelName, ModelPath: fname}
|
||||||
@@ -190,8 +189,8 @@ func TestRequests(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case resp := <-scenario1a.req.successCh:
|
case resp := <-scenario1a.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, scenario1a.req.errCh, 0)
|
require.Empty(t, scenario1a.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
@@ -203,8 +202,8 @@ func TestRequests(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case resp := <-scenario1b.req.successCh:
|
case resp := <-scenario1b.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, scenario1b.req.errCh, 0)
|
require.Empty(t, scenario1b.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
@@ -221,8 +220,8 @@ func TestRequests(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case resp := <-scenario2a.req.successCh:
|
case resp := <-scenario2a.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario2a.srv)
|
require.Equal(t, resp.llama, scenario2a.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, scenario2a.req.errCh, 0)
|
require.Empty(t, scenario2a.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
@@ -237,8 +236,8 @@ func TestRequests(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case resp := <-scenario3a.req.successCh:
|
case resp := <-scenario3a.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3a.srv)
|
require.Equal(t, resp.llama, scenario3a.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, scenario3a.req.errCh, 0)
|
require.Empty(t, scenario3a.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
@@ -253,8 +252,8 @@ func TestRequests(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case resp := <-scenario3b.req.successCh:
|
case resp := <-scenario3b.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3b.srv)
|
require.Equal(t, resp.llama, scenario3b.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, scenario3b.req.errCh, 0)
|
require.Empty(t, scenario3b.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
@@ -269,8 +268,8 @@ func TestRequests(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case resp := <-scenario3c.req.successCh:
|
case resp := <-scenario3c.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3c.srv)
|
require.Equal(t, resp.llama, scenario3c.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, scenario3c.req.errCh, 0)
|
require.Empty(t, scenario3c.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
@@ -296,8 +295,8 @@ func TestRequests(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case resp := <-scenario3d.req.successCh:
|
case resp := <-scenario3d.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3d.srv)
|
require.Equal(t, resp.llama, scenario3d.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, scenario3d.req.errCh, 0)
|
require.Empty(t, scenario3d.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
@@ -332,7 +331,7 @@ func TestGetRunner(t *testing.T) {
|
|||||||
slog.Info("scenario1b")
|
slog.Info("scenario1b")
|
||||||
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
|
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
require.Len(t, successCh1b, 0)
|
require.Empty(t, successCh1b)
|
||||||
require.Len(t, errCh1b, 1)
|
require.Len(t, errCh1b, 1)
|
||||||
err := <-errCh1b
|
err := <-errCh1b
|
||||||
require.Contains(t, err.Error(), "server busy")
|
require.Contains(t, err.Error(), "server busy")
|
||||||
@@ -340,8 +339,8 @@ func TestGetRunner(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case resp := <-successCh1a:
|
case resp := <-successCh1a:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, errCh1a, 0)
|
require.Empty(t, errCh1a)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Errorf("timeout")
|
||||||
}
|
}
|
||||||
@@ -355,9 +354,9 @@ func TestGetRunner(t *testing.T) {
|
|||||||
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
|
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
|
||||||
// Starts in pending channel, then should be quickly processsed to return an error
|
// Starts in pending channel, then should be quickly processsed to return an error
|
||||||
time.Sleep(5 * time.Millisecond)
|
time.Sleep(5 * time.Millisecond)
|
||||||
require.Len(t, successCh1c, 0)
|
require.Empty(t, successCh1c)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 0)
|
require.Empty(t, s.loaded)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
require.Len(t, errCh1c, 1)
|
require.Len(t, errCh1c, 1)
|
||||||
err = <-errCh1c
|
err = <-errCh1c
|
||||||
@@ -386,8 +385,8 @@ func TestPrematureExpired(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case resp := <-successCh1a:
|
case resp := <-successCh1a:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
require.Len(t, s.pendingReqCh, 0)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Len(t, errCh1a, 0)
|
require.Empty(t, errCh1a)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 1)
|
require.Len(t, s.loaded, 1)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
@@ -401,9 +400,9 @@ func TestPrematureExpired(t *testing.T) {
|
|||||||
time.Sleep(20 * time.Millisecond)
|
time.Sleep(20 * time.Millisecond)
|
||||||
require.LessOrEqual(t, len(s.finishedReqCh), 1)
|
require.LessOrEqual(t, len(s.finishedReqCh), 1)
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
require.Len(t, s.finishedReqCh, 0)
|
require.Empty(t, s.finishedReqCh)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 0)
|
require.Empty(t, s.loaded)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
// also shouldn't happen in real life
|
// also shouldn't happen in real life
|
||||||
@@ -487,7 +486,6 @@ func TestFindRunnerToUnload(t *testing.T) {
|
|||||||
r2.refCount = 1
|
r2.refCount = 1
|
||||||
resp = s.findRunnerToUnload()
|
resp = s.findRunnerToUnload()
|
||||||
require.Equal(t, r1, resp)
|
require.Equal(t, r1, resp)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNeedsReload(t *testing.T) {
|
func TestNeedsReload(t *testing.T) {
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
|||||||
case requestURL := <-b.nextURL:
|
case requestURL := <-b.nextURL:
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
var err error
|
var err error
|
||||||
for try := 0; try < maxRetries; try++ {
|
for try := range maxRetries {
|
||||||
err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
|
err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, context.Canceled):
|
case errors.Is(err, context.Canceled):
|
||||||
@@ -190,7 +190,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
|||||||
headers.Set("Content-Type", "application/octet-stream")
|
headers.Set("Content-Type", "application/octet-stream")
|
||||||
headers.Set("Content-Length", "0")
|
headers.Set("Content-Length", "0")
|
||||||
|
|
||||||
for try := 0; try < maxRetries; try++ {
|
for try := range maxRetries {
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
|
resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
@@ -253,7 +253,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// retry uploading to the redirect URL
|
// retry uploading to the redirect URL
|
||||||
for try := 0; try < maxRetries; try++ {
|
for try := range maxRetries {
|
||||||
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
|
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, context.Canceled):
|
case errors.Is(err, context.Canceled):
|
||||||
@@ -391,7 +391,7 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryO
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint: contextcheck
|
//nolint:contextcheck
|
||||||
go upload.Run(context.Background(), opts)
|
go upload.Run(context.Background(), opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
1
templates/alfred.gotmpl
Normal file
1
templates/alfred.gotmpl
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
|
||||||
7
templates/alpaca.gotmpl
Normal file
7
templates/alpaca.gotmpl
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{{ if .System }}{{ .System }}
|
||||||
|
|
||||||
|
{{ end }}{{ if .Prompt }}### Instruction:
|
||||||
|
{{ .Prompt }}
|
||||||
|
|
||||||
|
{{ end }}### Response:
|
||||||
|
{{ .Response }}
|
||||||
6
templates/chatml.gotmpl
Normal file
6
templates/chatml.gotmpl
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{{ if .System }}<|im_start|>system
|
||||||
|
{{ .System }}<|im_end|>
|
||||||
|
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||||
|
{{ .Prompt }}<|im_end|>
|
||||||
|
{{ end }}<|im_start|>assistant
|
||||||
|
{{ .Response }}<|im_end|>
|
||||||
5
templates/chatqa.gotmpl
Normal file
5
templates/chatqa.gotmpl
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{{ if .System }}System: {{ .System }}
|
||||||
|
|
||||||
|
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
||||||
|
|
||||||
|
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}
|
||||||
8
templates/codellama-70b-instruct.gotmpl
Normal file
8
templates/codellama-70b-instruct.gotmpl
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{{ if .System }} Source: system
|
||||||
|
|
||||||
|
{{ .System }} <step>{{ end }} Source: user
|
||||||
|
|
||||||
|
{{ .Prompt }} <step> Source: assistant
|
||||||
|
Destination: user
|
||||||
|
|
||||||
|
{{ .Response }}<step>
|
||||||
3
templates/falcon-instruct.gotmpl
Normal file
3
templates/falcon-instruct.gotmpl
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{{ if .System }}{{ .System }}
|
||||||
|
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
||||||
|
{{ end }}Assistant: {{ .Response }}
|
||||||
4
templates/gemma-instruct.gotmpl
Normal file
4
templates/gemma-instruct.gotmpl
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
{{ .Response }}<end_of_turn>
|
||||||
9
templates/granite-instruct.gotmpl
Normal file
9
templates/granite-instruct.gotmpl
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
{{ if .System }}
|
||||||
|
System:
|
||||||
|
{{ .System }}
|
||||||
|
|
||||||
|
{{ end }}{{ if .Prompt }}Question:
|
||||||
|
{{ .Prompt }}
|
||||||
|
|
||||||
|
{{ end }}Answer:
|
||||||
|
{{ .Response }}
|
||||||
138
templates/index.json
Normal file
138
templates/index.json
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}",
|
||||||
|
"name": "chatml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"name": "chatml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
|
||||||
|
"name": "zephyr"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}",
|
||||||
|
"name": "chatml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
|
||||||
|
"name": "openchat"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{{bos_token}}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"name": "chatml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"name": "chatml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"name": "chatml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"name": "chatml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
|
||||||
|
"name": "zephyr"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
||||||
|
"name": "mistral-instruct"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'### Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response\n'}}",
|
||||||
|
"name": "starcoder2-instruct"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}",
|
||||||
|
"name": "llama2-chat"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '<s>' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\n\n ' + message['content'] | trim %}{{ content + ' <step> ' }}{% endfor %}{{'Source: assistant\nDestination: user\n\n '}}",
|
||||||
|
"name": "codellama-70b-instruct"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
||||||
|
"name": "mistral-instruct"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'system' %}\n{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|im_start|>assistant\n' + message['content'] + '<|im_end|>' }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|im_start|>assistant' }}\n{% endif %}\n{% endfor %}",
|
||||||
|
"name": "chatml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"name": "chatml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"name": "chatml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif 'system' not in messages[0]['role'] %}{% set loop_messages = messages %}{% set system_message = 'You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks \u2014 remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER\\'S QUERY.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% if system_message != false %}{{ '<|im_start|>system\n' + system_message | trim + '<|im_end|>\n'}}{% endif %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% else %}{{ '\n' + '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% endif %}{% if (add_generation_prompt == true and loop.last) %}{{ '\n' + '<|im_start|>' + 'assistant' + '\n' }}{% endif %}{% endfor %}",
|
||||||
|
"name": "chatml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}",
|
||||||
|
"name": "alpaca"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
|
||||||
|
"name": "chatqa"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
|
"name": "gemma-instruct"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
||||||
|
"name": "llama3-instruct"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'Question:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'system' %}\n{{ 'System:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Answer:\n' + message['content'] + '\n\n' }}{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Answer:\n' }}{% endif %}{% endfor %}",
|
||||||
|
"name": "granite-instruct"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'@@ Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'@@ Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'@@ Response\n'}}",
|
||||||
|
"name": "magicoder"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<start_user>' + message['content'].strip() + '<end_message>' }}{% elif message['role'] == 'system' %}{{ '<start_system>' + message['content'].strip() + '<end_message>' }}{% elif message['role'] == 'assistant' %}{{ '<start_assistant>' + message['content'] + '<end_message>' }}{% else %}{{ raise_exception('Only system, user and assistant roles are supported.') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<start_assistant>' }}{% endif %}{% endfor %}",
|
||||||
|
"name": "alfred"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
|
||||||
|
"name": "llama2-chat"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
||||||
|
"name": "phi-3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
||||||
|
"name": "phi-3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
|
||||||
|
"name": "phi-3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{{ bos_token }}{%- if messages[0]['role'] == 'system' -%}{% set loop_messages = messages[1:] %}{%- else -%}{% set loop_messages = messages %}{% endif %}System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context.\n\n{% for message in loop_messages %}{%- if message['role'] == 'user' -%}User: {{ message['content'].strip() + '\n\n' }}{%- else -%}Assistant: {{ message['content'].strip() + '\n\n' }}{%- endif %}{% if loop.last and message['role'] == 'user' %}Assistant:{% endif %}{% endfor %}",
|
||||||
|
"name": "chatqa"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'User: \n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ 'System: ' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ 'Falcon:\n' + message['content']}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Falcon:' }}\n{% endif %}\n{% endfor %}",
|
||||||
|
"name": "falcon-instruct"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% for message in messages %}{% if not loop.first %}{{ '\n' }}{% endif %}{% if message['role'] == 'system' %}{{ 'System: ' }}{% elif message['role'] == 'user' %}{{ 'User: ' }}{% elif message['role'] == 'assistant' %}{{ 'Falcon: ' }}{% endif %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ '\n' + 'Falcon:' }}{% endif %}",
|
||||||
|
"name": "falcon-instruct"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"template": "{% for message in messages %}{% if message['role'] == 'system' %}{% if message['content']%}{{'### System:\n' + message['content']+'\n\n'}}{% endif %}{% elif message['role'] == 'user' %}{{'### User:\n' + message['content']+'\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Assistant:\n' + message['content']}}{% endif %}{% if loop.last and add_generation_prompt %}{{ '### Assistant:\n' }}{% endif %}{% endfor %}",
|
||||||
|
"name": "solar-instruct"
|
||||||
|
}
|
||||||
|
]
|
||||||
3
templates/llama2-chat.gotmpl
Normal file
3
templates/llama2-chat.gotmpl
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[INST] <<SYS>>{{ .System }}<</SYS>>
|
||||||
|
|
||||||
|
{{ .Prompt }} [/INST] {{ .Response }}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user