mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-22 14:53:56 +00:00
Compare commits
44 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
28832df4bd | ||
|
|
8200c371ae | ||
|
|
0a8d6ea86d | ||
|
|
8e1050f366 | ||
|
|
eda8a32a09 | ||
|
|
a0a40aa20c | ||
|
|
2697d7f5aa | ||
|
|
1f32276178 | ||
|
|
4c4fe3f87f | ||
|
|
feedf49c71 | ||
|
|
8b00a415ab | ||
|
|
b7d38e2ccd | ||
|
|
01b80e9ffc | ||
|
|
bd5e432630 | ||
|
|
aec77d6a05 | ||
|
|
6ffb5cb017 | ||
|
|
f7e3b9190f | ||
|
|
980dd15f81 | ||
|
|
01d544d373 | ||
|
|
1dc3ef3aa9 | ||
|
|
8aac22438e | ||
|
|
15c2d8fe14 | ||
|
|
25906d72d1 | ||
|
|
023451ce47 | ||
|
|
9b53e39d8e | ||
|
|
97fae2df95 | ||
|
|
160d9d4900 | ||
|
|
d4e6407464 | ||
|
|
b7f7d8cd15 | ||
|
|
2fa1db4345 | ||
|
|
71b0945fc6 | ||
|
|
5bca2e60a7 | ||
|
|
67472e0e89 | ||
|
|
e9aa5117c4 | ||
|
|
2473bdba5e | ||
|
|
7d1c0047fa | ||
|
|
7b61eba471 | ||
|
|
7edaf6e7e8 | ||
|
|
97ec8cfd4e | ||
|
|
ce67706037 | ||
|
|
04210aa6dd | ||
|
|
43f9d92008 | ||
|
|
ed6c8bfe57 | ||
|
|
df3802a65f |
3
.gitattributes
vendored
3
.gitattributes
vendored
@@ -1,2 +1,3 @@
|
|||||||
llm/ext_server/* linguist-vendored
|
llm/ext_server/* linguist-vendored
|
||||||
* text eol=lf
|
* text=auto
|
||||||
|
*.go text eol=lf
|
||||||
|
|||||||
10
.github/workflows/release.yaml
vendored
10
.github/workflows/release.yaml
vendored
@@ -31,7 +31,7 @@ jobs:
|
|||||||
security set-keychain-settings -lut 3600 build.keychain
|
security set-keychain-settings -lut 3600 build.keychain
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- name: Build Darwin
|
- name: Build Darwin
|
||||||
env:
|
env:
|
||||||
@@ -87,7 +87,7 @@ jobs:
|
|||||||
write-host "plugin installed"
|
write-host "plugin installed"
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- run: go get ./...
|
- run: go get ./...
|
||||||
- run: |
|
- run: |
|
||||||
@@ -141,7 +141,7 @@ jobs:
|
|||||||
write-host "plugin installed"
|
write-host "plugin installed"
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- name: 'Install ROCm'
|
- name: 'Install ROCm'
|
||||||
run: |
|
run: |
|
||||||
@@ -218,7 +218,7 @@ jobs:
|
|||||||
write-host "plugin installed"
|
write-host "plugin installed"
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- name: 'Install CUDA'
|
- name: 'Install CUDA'
|
||||||
run: |
|
run: |
|
||||||
@@ -306,7 +306,7 @@ jobs:
|
|||||||
write-host "plugin installed"
|
write-host "plugin installed"
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- run: go get
|
- run: go get
|
||||||
- uses: actions/download-artifact@v4
|
- uses: actions/download-artifact@v4
|
||||||
|
|||||||
10
.github/workflows/test.yaml
vendored
10
.github/workflows/test.yaml
vendored
@@ -63,7 +63,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- run: go get ./...
|
- run: go get ./...
|
||||||
- run: |
|
- run: |
|
||||||
@@ -163,7 +163,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- name: 'Install ROCm'
|
- name: 'Install ROCm'
|
||||||
run: |
|
run: |
|
||||||
@@ -200,7 +200,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- name: 'Install CUDA'
|
- name: 'Install CUDA'
|
||||||
run: |
|
run: |
|
||||||
@@ -255,7 +255,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: false
|
cache: false
|
||||||
- run: |
|
- run: |
|
||||||
case ${{ matrix.arch }} in
|
case ${{ matrix.arch }} in
|
||||||
@@ -297,7 +297,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- run: |
|
- run: |
|
||||||
case ${{ matrix.arch }} in
|
case ${{ matrix.arch }} in
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ linters:
|
|||||||
- nosprintfhostport
|
- nosprintfhostport
|
||||||
- staticcheck
|
- staticcheck
|
||||||
- tenv
|
- tenv
|
||||||
- testifylint
|
|
||||||
- unconvert
|
- unconvert
|
||||||
- unused
|
- unused
|
||||||
- usestdlibvars
|
- usestdlibvars
|
||||||
|
|||||||
37
CONTRIBUTING.md
Normal file
37
CONTRIBUTING.md
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# Contributing to Ollama
|
||||||
|
|
||||||
|
Thank you for your interest in contributing to Ollama! Here are a few guidelines to help get you started.
|
||||||
|
|
||||||
|
## Set up
|
||||||
|
|
||||||
|
See the [development documentation](./docs/development.md) for instructions on how to build and run Ollama locally.
|
||||||
|
|
||||||
|
## Pull requests
|
||||||
|
|
||||||
|
### Ideal issues
|
||||||
|
|
||||||
|
* [Bugs](https://github.com/ollama/ollama/issues?q=is%3Aissue+is%3Aopen+label%3Abug): issues where Ollama stops working or where it results in an unexpected error.
|
||||||
|
* [Performance](https://github.com/ollama/ollama/issues?q=is%3Aissue+is%3Aopen+label%3Aperformance): issues to make Ollama faster at model inference, downloading or uploading.
|
||||||
|
* [Security](https://github.com/ollama/ollama/blob/main/SECURITY.md): issues that could lead to a security vulnerability. As mentioned in [SECURITY.md](https://github.com/ollama/ollama/blob/main/SECURITY.md), please do not disclose security vulnerabilities publicly.
|
||||||
|
|
||||||
|
### Issues that are harder to review
|
||||||
|
|
||||||
|
* New features: new features (e.g. API fields, environment variables) add surface area to Ollama and make it harder to maintain in the long run as they cannot be removed without potentially breaking users in the future.
|
||||||
|
* Refactoring: large code improvements are important, but can be harder or take longer to review and merge.
|
||||||
|
* Documentation: small updates to fill in or dorrect missing documentation is helpful, however large documentation additions can be hard to maintain over time.
|
||||||
|
|
||||||
|
### Issues that may not be accepted
|
||||||
|
|
||||||
|
* Changes that break backwards compatibility in Ollama's API (including the OpenAI-compatible API)
|
||||||
|
* Changes that add significant friction to the user experience
|
||||||
|
* Changes that create a large future maintenance burden for maintainers and contributors
|
||||||
|
|
||||||
|
### Best practices
|
||||||
|
|
||||||
|
* Commit messages: please leave both a title and a description in your commit messages. The title should be a short summary of the changes, with a leading word that explains the section of the code being changed (e.g. `api: fix parsing of prompt field`) . In the description, leave a short 2-3 sentences that explain more about the change and its impact.
|
||||||
|
* Tests: please add test coverage to changes where possible.
|
||||||
|
* Minimize dependencies: avoid adding new dependencies unless absolutely necessary.
|
||||||
|
|
||||||
|
## Need help?
|
||||||
|
|
||||||
|
If you need help with anything, feel free to reach out to us on our [Discord server](https://discord.gg/ollama).
|
||||||
@@ -343,6 +343,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [tlm](https://github.com/yusufcanb/tlm)
|
- [tlm](https://github.com/yusufcanb/tlm)
|
||||||
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
|
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
|
||||||
- [gollama](https://github.com/sammcj/gollama)
|
- [gollama](https://github.com/sammcj/gollama)
|
||||||
|
- [Ollama eBook Summary](https://github.com/cognitivetech/ollama-ebook-summary/)
|
||||||
|
|
||||||
### Database
|
### Database
|
||||||
|
|
||||||
|
|||||||
@@ -298,7 +298,7 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
|
|||||||
return &lr, nil
|
return &lr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// List running models.
|
// ListRunning lists running models.
|
||||||
func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
|
func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
|
||||||
var lr ProcessResponse
|
var lr ProcessResponse
|
||||||
if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
|
if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
|
||||||
@@ -333,7 +333,7 @@ func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, err
|
|||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hearbeat checks if the server has started and is responsive; if yes, it
|
// Heartbeat checks if the server has started and is responsive; if yes, it
|
||||||
// returns nil, otherwise an error.
|
// returns nil, otherwise an error.
|
||||||
func (c *Client) Heartbeat(ctx context.Context) error {
|
func (c *Client) Heartbeat(ctx context.Context) error {
|
||||||
if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
|
if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
updatAvailableMenuID = 1
|
updateAvailableMenuID = 1
|
||||||
updateMenuID = updatAvailableMenuID + 1
|
updateMenuID = updateAvailableMenuID + 1
|
||||||
separatorMenuID = updateMenuID + 1
|
separatorMenuID = updateMenuID + 1
|
||||||
diagLogsMenuID = separatorMenuID + 1
|
diagLogsMenuID = separatorMenuID + 1
|
||||||
diagSeparatorMenuID = diagLogsMenuID + 1
|
diagSeparatorMenuID = diagLogsMenuID + 1
|
||||||
@@ -35,7 +35,7 @@ func (t *winTray) initMenus() error {
|
|||||||
func (t *winTray) UpdateAvailable(ver string) error {
|
func (t *winTray) UpdateAvailable(ver string) error {
|
||||||
if !t.updateNotified {
|
if !t.updateNotified {
|
||||||
slog.Debug("updating menu and sending notification for new update")
|
slog.Debug("updating menu and sending notification for new update")
|
||||||
if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
|
if err := t.addOrUpdateMenuItem(updateAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
|
||||||
return fmt.Errorf("unable to create menu entries %w", err)
|
return fmt.Errorf("unable to create menu entries %w", err)
|
||||||
}
|
}
|
||||||
if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil {
|
if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil {
|
||||||
|
|||||||
47
cmd/cmd.go
47
cmd/cmd.go
@@ -22,6 +22,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -78,6 +79,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
status := "transferring model data"
|
status := "transferring model data"
|
||||||
spinner := progress.NewSpinner(status)
|
spinner := progress.NewSpinner(status)
|
||||||
p.Add(status, spinner)
|
p.Add(status, spinner)
|
||||||
|
defer p.Stop()
|
||||||
|
|
||||||
for i := range modelfile.Commands {
|
for i := range modelfile.Commands {
|
||||||
switch modelfile.Commands[i].Name {
|
switch modelfile.Commands[i].Name {
|
||||||
@@ -112,7 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
path = tempfile
|
path = tempfile
|
||||||
}
|
}
|
||||||
|
|
||||||
digest, err := createBlob(cmd, client, path)
|
digest, err := createBlob(cmd, client, path, spinner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -263,13 +265,20 @@ func tempZipFiles(path string) (string, error) {
|
|||||||
return tempfile.Name(), nil
|
return tempfile.Name(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
|
func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
|
||||||
bin, err := os.Open(path)
|
bin, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
defer bin.Close()
|
defer bin.Close()
|
||||||
|
|
||||||
|
// Get file info to retrieve the size
|
||||||
|
fileInfo, err := bin.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
fileSize := fileInfo.Size()
|
||||||
|
|
||||||
hash := sha256.New()
|
hash := sha256.New()
|
||||||
if _, err := io.Copy(hash, bin); err != nil {
|
if _, err := io.Copy(hash, bin); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -279,13 +288,43 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var pw progressWriter
|
||||||
|
status := "transferring model data 0%"
|
||||||
|
spinner.SetMessage(status)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(60 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n.Load()/fileSize)))
|
||||||
|
case <-done:
|
||||||
|
spinner.SetMessage("transferring model data 100%")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
||||||
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
|
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return digest, nil
|
return digest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type progressWriter struct {
|
||||||
|
n atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *progressWriter) Write(p []byte) (n int, err error) {
|
||||||
|
w.n.Add(int64(len(p)))
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
interactive := true
|
interactive := true
|
||||||
|
|
||||||
@@ -1086,7 +1125,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
func RunServer(_ *cobra.Command, _ []string) error {
|
||||||
if err := initializeKeypair(); err != nil {
|
if err := initializeKeypair(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,6 +27,10 @@ func (Parameters) KV(t *Tokenizer) llm.KV {
|
|||||||
"tokenizer.ggml.token_type": t.Vocabulary.Types,
|
"tokenizer.ggml.token_type": t.Vocabulary.Types,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(t.Merges) > 0 {
|
||||||
|
kv["tokenizer.ggml.merges"] = t.Merges
|
||||||
|
}
|
||||||
|
|
||||||
if t.Template != "" {
|
if t.Template != "" {
|
||||||
kv["tokenizer.chat_template"] = t.Template
|
kv["tokenizer.chat_template"] = t.Template
|
||||||
}
|
}
|
||||||
@@ -89,6 +93,8 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
conv = &mixtral{}
|
conv = &mixtral{}
|
||||||
case "GemmaForCausalLM":
|
case "GemmaForCausalLM":
|
||||||
conv = &gemma{}
|
conv = &gemma{}
|
||||||
|
case "Phi3ForCausalLM":
|
||||||
|
conv = &phi3{}
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported architecture")
|
return errors.New("unsupported architecture")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -90,10 +90,6 @@ func (p *llama) KV(t *Tokenizer) llm.KV {
|
|||||||
kv["llama.attention.value_length"] = p.HeadDim
|
kv["llama.attention.value_length"] = p.HeadDim
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(t.Merges) > 0 {
|
|
||||||
kv["tokenizer.ggml.merges"] = t.Merges
|
|
||||||
}
|
|
||||||
|
|
||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
125
convert/convert_phi3.go
Normal file
125
convert/convert_phi3.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type phi3 struct {
|
||||||
|
Parameters
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
NLayers uint32 `json:"n_layers"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
NEmbd uint32 `json:"n_embd"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NHead uint32 `json:"n_head"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
NHeadKV uint32 `json:"n_head_kv"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RopeScaling struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
LongFactor ropeFactor `json:"long_factor"`
|
||||||
|
ShortFactor ropeFactor `json:"short_factor"`
|
||||||
|
} `json:"rope_scaling"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
NPositions uint32 `json:"n_positions"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Converter = (*phi3)(nil)
|
||||||
|
|
||||||
|
func (p *phi3) KV(t *Tokenizer) llm.KV {
|
||||||
|
kv := p.Parameters.KV(t)
|
||||||
|
kv["general.architecture"] = "phi3"
|
||||||
|
kv["general.name"] = "phi3"
|
||||||
|
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
kv["phi3.embedding_length"] = cmp.Or(p.HiddenSize, p.NEmbd)
|
||||||
|
kv["phi3.feed_forward_length"] = p.IntermediateSize
|
||||||
|
kv["phi3.block_count"] = cmp.Or(p.NumHiddenLayers, p.NLayers)
|
||||||
|
kv["phi3.attention.head_count"] = cmp.Or(p.NumAttentionHeads, p.NHead)
|
||||||
|
kv["phi3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NHeadKV)
|
||||||
|
kv["phi3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||||
|
kv["phi3.rope.dimension_count"] = p.HiddenSize / cmp.Or(p.NumAttentionHeads, p.NHead)
|
||||||
|
kv["phi3.rope.freq_base"] = p.RopeTheta
|
||||||
|
kv["phi3.rope.scaling.original_context_length"] = p.OriginalMaxPositionEmbeddings
|
||||||
|
kv["phi3.attention.sliding_window"] = p.SlidingWindow
|
||||||
|
|
||||||
|
scale := float64(p.MaxPositionEmbeddings) / float64(p.OriginalMaxPositionEmbeddings)
|
||||||
|
|
||||||
|
switch p.RopeScaling.Type {
|
||||||
|
case "":
|
||||||
|
// no scaling
|
||||||
|
case "su", "longrope":
|
||||||
|
kv["phi3.rope.scaling.attn_factor"] = float32(max(math.Sqrt(1+math.Log(scale)/math.Log(float64(p.OriginalMaxPositionEmbeddings))), 1.0))
|
||||||
|
case "yarn":
|
||||||
|
kv["phi3.rope.scaling.attn_factor"] = float32(max(0.1*math.Log(scale)+1.0, 1.0))
|
||||||
|
default:
|
||||||
|
panic("unknown rope scaling type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *phi3) Tensors(ts []Tensor) []llm.Tensor {
|
||||||
|
var addRopeFactors sync.Once
|
||||||
|
|
||||||
|
out := make([]llm.Tensor, 0, len(ts)+2)
|
||||||
|
for _, t := range ts {
|
||||||
|
name := p.tensorName(t.Name())
|
||||||
|
if strings.HasPrefix(name, "blk.0.") {
|
||||||
|
addRopeFactors.Do(func() {
|
||||||
|
out = append(out, llm.Tensor{
|
||||||
|
Name: "rope_factors_long.weight",
|
||||||
|
Kind: 0,
|
||||||
|
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
|
||||||
|
WriterTo: p.RopeScaling.LongFactor,
|
||||||
|
}, llm.Tensor{
|
||||||
|
Name: "rope_factors_short.weight",
|
||||||
|
Kind: 0,
|
||||||
|
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
|
||||||
|
WriterTo: p.RopeScaling.ShortFactor,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, llm.Tensor{
|
||||||
|
Name: name,
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *phi3) tensorName(n string) string {
|
||||||
|
return strings.NewReplacer(
|
||||||
|
"lm_head", "output",
|
||||||
|
"model.embed_tokens", "token_embd",
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
"model.layers", "blk",
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"self_attn.qkv_proj", "attn_qkv",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.gate_up_proj", "ffn_up",
|
||||||
|
"post_attention_layernorm", "ffn_norm",
|
||||||
|
).Replace(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ropeFactor []float32
|
||||||
|
|
||||||
|
func (r ropeFactor) WriteTo(w io.Writer) (int64, error) {
|
||||||
|
err := binary.Write(w, binary.LittleEndian, r)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
@@ -65,6 +65,8 @@ func TestConvertFull(t *testing.T) {
|
|||||||
"Mistral-7B-Instruct-v0.2",
|
"Mistral-7B-Instruct-v0.2",
|
||||||
"Mixtral-8x7B-Instruct-v0.1",
|
"Mixtral-8x7B-Instruct-v0.1",
|
||||||
"gemma-2b-it",
|
"gemma-2b-it",
|
||||||
|
// microsoft/Phi-3-mini-128-instruct@d548c233192db00165d842bf8edff054bb3212f8
|
||||||
|
"Phi-3-mini-128k-instruct",
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range cases {
|
for i := range cases {
|
||||||
|
|||||||
225
convert/testdata/Phi-3-mini-128k-instruct.json
vendored
Normal file
225
convert/testdata/Phi-3-mini-128k-instruct.json
vendored
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
{
|
||||||
|
"general.architecture": "phi3",
|
||||||
|
"general.file_type": "1",
|
||||||
|
"general.quantization_version": "2",
|
||||||
|
"phi3.block_count": "32",
|
||||||
|
"phi3.context_length": "131072",
|
||||||
|
"phi3.embedding_length": "3072",
|
||||||
|
"phi3.feed_forward_length": "8192",
|
||||||
|
"phi3.rope.scaling.original_context_length": "4096",
|
||||||
|
"phi3.rope.dimension_count": "96",
|
||||||
|
"phi3.rope.freq_base": "10000",
|
||||||
|
"phi3.rope.scaling.attn_factor": "1.1902381",
|
||||||
|
"phi3.attention.head_count": "32",
|
||||||
|
"phi3.attention.head_count_kv": "32",
|
||||||
|
"phi3.attention.layer_norm_rms_epsilon": "1e-05",
|
||||||
|
"phi3.attention.sliding_window": "262144",
|
||||||
|
"tokenizer.ggml.model": "llama",
|
||||||
|
"tokenizer.ggml.pre": "default",
|
||||||
|
"tokenizer.ggml.add_bos_token": "false",
|
||||||
|
"tokenizer.ggml.add_eos_token": "false",
|
||||||
|
"tokenizer.ggml.bos_token_id": "1",
|
||||||
|
"tokenizer.ggml.eos_token_id": "32000",
|
||||||
|
"tokenizer.ggml.unknown_token_id": "0",
|
||||||
|
"tokenizer.ggml.padding_token_id": "32000",
|
||||||
|
"tokenizer.ggml.scores": "6e37bcde2adc7e350e87c496eddd7a2124329c1dc66c5bf3ad3997253e4f7a62",
|
||||||
|
"tokenizer.ggml.token_type": "b6ecf55ec64ee67d87750bdb8d757a2c58bf78377e9f4219f5689a6c4dea57ce",
|
||||||
|
"tokenizer.ggml.tokens": "d168da3ddd3eee820916945fcb9baf24dd3cde42f606cffa2d19e7c8a8743918",
|
||||||
|
"blk.0.attn_norm.weight": "216aeb2c9e0c271f899e1ef2a63cceeb8f41e97642e84fada54b1d3c1c11cf25",
|
||||||
|
"blk.0.attn_output.weight": "b597d56f7188ffc1fafc273fadc59d41738cffd677ae98c61a62c3285b3a3099",
|
||||||
|
"blk.0.attn_qkv.weight": "d28a6b44e13f59be5483e4be2bedb544e346168d720aca27f47d1a5a722be91e",
|
||||||
|
"blk.0.ffn_down.weight": "4a691370e5a61fcbbf540fbcbf4c0f1d15dec0364528c0e916d0744f6262b63b",
|
||||||
|
"blk.0.ffn_norm.weight": "0c00af2b4a3128bec64a0cbb1084b042fdbe13d9ad0d03bd577f9449dfead338",
|
||||||
|
"blk.0.ffn_up.weight": "b32b52f790c1c083bfb8a3126dc1111cfeeb28dc8c584a930a1e5334cb176bf4",
|
||||||
|
"blk.1.attn_norm.weight": "68748011503c6c029e8e69a84a8e5a89338f378769627b6dbf7f93d715c292e1",
|
||||||
|
"blk.1.attn_output.weight": "2267344add13b048ca59e4377c86dc512be8046a57156901fa32a20fa74e4ee0",
|
||||||
|
"blk.1.attn_qkv.weight": "9109d2e3d7a2eacfda5226587b8be124a3bf44b972da7ebb17aa15795897eacc",
|
||||||
|
"blk.1.ffn_down.weight": "d675df4df4dd039c0c339ad6445d39eddd2004db6bf35bed6314c7497245a633",
|
||||||
|
"blk.1.ffn_norm.weight": "3b5767ae977bc8baaa06b06efdbea193b6b3ba605ce76d77a76ce317e935500c",
|
||||||
|
"blk.1.ffn_up.weight": "80dfd6d9d234b00334c89b8e0a02f81899c2efd377321c34ba5ba51a5f61b5ff",
|
||||||
|
"blk.2.attn_norm.weight": "6a6743b057e5088f145bc179e92c9bfb41163e7295d7b81c62e23dd89d2b59c4",
|
||||||
|
"blk.2.attn_output.weight": "bc5491ea54e0db81462d7d9b7d25cbdda380c2db8de041bd1c4ab7b76a1d19c3",
|
||||||
|
"blk.2.attn_qkv.weight": "a61287a9852e2f5aca9c100b471d98398b2913a3497c743de3c70ec9ddd7087f",
|
||||||
|
"blk.2.ffn_down.weight": "4fddcc382c8dceeab027fe43d8d44e67edb5e8ce4b9a1b7f773c87770380ade1",
|
||||||
|
"blk.2.ffn_norm.weight": "07e05f82b3f63f711db3b684ca79aed25c0657917e66f88af47348a82065c227",
|
||||||
|
"blk.2.ffn_up.weight": "4835a682ef1826c12df01ae7663fc45f9c82bc8e64b665f13fb7da8e201ec0fb",
|
||||||
|
"blk.3.attn_norm.weight": "f22aba7c03999ba7136f39cda747a39715e498699dc1716cd97fc5dfc58d1b1c",
|
||||||
|
"blk.3.attn_output.weight": "53b579855366fd786c5126b2b30aac4d583ca7bda56833c4865f5cadb5c18c6d",
|
||||||
|
"blk.3.attn_qkv.weight": "bb56aba78158123140fcea59c69ac562ca208f6d3086819417cdad8c50f333ad",
|
||||||
|
"blk.3.ffn_down.weight": "97280897a7cd86db2830c004bccc5bc094f50e293baded0189159a2019145a6e",
|
||||||
|
"blk.3.ffn_norm.weight": "10a8c99f8b57a960e8e0a1133c4a26f9148403d1b9bff2eff114917de996f3b5",
|
||||||
|
"blk.3.ffn_up.weight": "7324046c915e75d621b2043597a245a428d8eea31869135e6257a861491d8dcc",
|
||||||
|
"blk.4.attn_norm.weight": "507d8e164de94646edbfe33def8e8fbf7c9a6ee3fbaedb5000f72d9f51ec5e36",
|
||||||
|
"blk.4.attn_output.weight": "bbb3429e6efa98c150e0fdbf48c16180cbf0d0cbc1b3c253c6c319d78f4593a2",
|
||||||
|
"blk.4.attn_qkv.weight": "b95ee5be0786d3901273d806c339fe6c20e6bfffd2a20672a9f56af80921e8ab",
|
||||||
|
"blk.4.ffn_down.weight": "806bbf91df92a5a22bd5aa1ffb7fc2869f7293ffc7704771c290ecc583b27975",
|
||||||
|
"blk.4.ffn_norm.weight": "cfc2930a81df7aee3a5e7f726a15c1182233e868bf0d9d37f6b6ae6d8c15c234",
|
||||||
|
"blk.4.ffn_up.weight": "c3390c69533de2c8424e8069323ccc5d0c4543111535da04cf2c7d26745576aa",
|
||||||
|
"blk.5.attn_norm.weight": "0d71c4fbcefabbd021569442853d2fe90668b19409ae2805a718a829ca60beab",
|
||||||
|
"blk.5.attn_output.weight": "10ebd93629112bf2df5c30dd0953a4a5e9020306768283181ed426934d47e14f",
|
||||||
|
"blk.5.attn_qkv.weight": "5cb05633369f12d4b00e0ff787736bd846856682115720ebc6cce05270c334f6",
|
||||||
|
"blk.5.ffn_down.weight": "e28bcc5094212eafc7476dbc5b7a520d25b79578cbf4229d698e2655956a80ad",
|
||||||
|
"blk.5.ffn_norm.weight": "b6f2c4cf9f34bb4d59989f96165c14a67dc1e266ad0a6d0fcc49f1add929e6ff",
|
||||||
|
"blk.5.ffn_up.weight": "0f9ef99423cc07ebedc0e9cfa95809f2d7108d910bb4ef97ebc0b0309c440750",
|
||||||
|
"blk.6.attn_norm.weight": "b3edcc47a42218234f7564d7470611b49401a41ae8cd42123f86557c69f5d7f2",
|
||||||
|
"blk.6.attn_output.weight": "eb9b7d257b388bb5b8fe0515e5c6873317239cb94cda236e4b6ada2a6c57c65c",
|
||||||
|
"blk.6.attn_qkv.weight": "eb968081f478c52f07bd9c2761741e982dba33cc4eeadeea3557d391b9ac2106",
|
||||||
|
"blk.6.ffn_down.weight": "1b8588bb7463206290322695577dcfced300895d6e6f4b26966c53a9ae2f0f84",
|
||||||
|
"blk.6.ffn_norm.weight": "1219c04b7770983c77814200eefe743f46d15328ea2b12711e44f8103eab08d3",
|
||||||
|
"blk.6.ffn_up.weight": "197ef287239fec47c55677f0fbb66eaf0644f775bc382de843971730721394f6",
|
||||||
|
"blk.7.attn_norm.weight": "b630ad08c80d564ed1c024384818e9fd3f22a36cd7a14aa96e7e2759a8285099",
|
||||||
|
"blk.7.attn_output.weight": "970255aa750828a47d6b9d399f9612b5bf25aefe7dadbcba41fc416d0d4067c1",
|
||||||
|
"blk.7.attn_qkv.weight": "ebb157c880293e6de8d629f263ba8853ed1dbdc02c311d43432bb8cfbb310739",
|
||||||
|
"blk.7.ffn_down.weight": "24bcd4db4cba844c89f878b81843c373dbbc0675e889d32c5b12e63384a7b670",
|
||||||
|
"blk.7.ffn_norm.weight": "b9c6f71001808ee873ce7db8056e4b53fb4cccec8b7f0f312899b575fae39d39",
|
||||||
|
"blk.7.ffn_up.weight": "979f1828d227455c26015a2a11afe9dd05f2bb97a8ba6b38c8dab3f50e627401",
|
||||||
|
"blk.8.attn_norm.weight": "4e8e347e3775010b7112ee630f2f4f2383be7ff64e6ca6154b9b22566552eaa6",
|
||||||
|
"blk.8.attn_output.weight": "65a44babf44a435a1829945211b3168f9ec78ac3cb7a049a733e93d11f0d6659",
|
||||||
|
"blk.8.attn_qkv.weight": "343ed07671da400b040812a4058482fa38284b5d9af9becfed07417fe26ce747",
|
||||||
|
"blk.8.ffn_down.weight": "7fb7e073e3c2c503c4e9d60efa0988fed7398d900cc003695fe3fffd3e188b82",
|
||||||
|
"blk.8.ffn_norm.weight": "b07c1f655d8593e3892a2cf73f8a0c19ce8e5cb613fafbe7cbd430da8ce4c57d",
|
||||||
|
"blk.8.ffn_up.weight": "8b26e14de54b3fdc2e2d3ea41720f9d9c236a93688c3b7fd7bf43f5fbb327c9b",
|
||||||
|
"blk.9.attn_norm.weight": "46394d408a8e316916177e6aa261de32e137a82d729c0b1800b072f0c38c39b6",
|
||||||
|
"blk.9.attn_output.weight": "d57f3d46107947a7073373a0b35d6ecf7759b5df15406f4a3590a60666af6b16",
|
||||||
|
"blk.9.attn_qkv.weight": "14bb8ace8c5453148f4b536e9f4279c813f31136716947256f5cca333448639c",
|
||||||
|
"blk.9.ffn_down.weight": "2b8d98e2b5ed68338f6e4de43bf7de0c4858cc69103cd5177725f7444eec7694",
|
||||||
|
"blk.9.ffn_norm.weight": "41a499dfd418cc4c6b8c12313f673f7e2cd4a3f9c4065eb6c4feb5eed02fb542",
|
||||||
|
"blk.9.ffn_up.weight": "143aab7533a64b17fbe201490a6f674bc7f0bd370c094500b2e100419073d1c2",
|
||||||
|
"blk.10.attn_norm.weight": "ebb670aafd36816a794347287269d8f1a5b19c1e3c0a1e38023bc19fdba9b073",
|
||||||
|
"blk.10.attn_output.weight": "b5d65bbc0ed5e49fdd9d754bc18163cd042a285024d0cf6f954c503bc8c877cb",
|
||||||
|
"blk.10.attn_qkv.weight": "f06b15bac88da798fa34a62b03eaac0dbe8b846020516603c387541f2d8dd672",
|
||||||
|
"blk.10.ffn_down.weight": "fb091fcd1b4de25d1bea94d1755e255cb02914a030d23e3a234e57b8d46bde6e",
|
||||||
|
"blk.10.ffn_norm.weight": "eb347bdf9c40414af87e13a8e72e40b31f004b50f7cb366f1a219ced60a61355",
|
||||||
|
"blk.10.ffn_up.weight": "ed2d52fc881a173f404fe8a1067862c9856d6c3e0d2e90a330a7aa394e3f84d1",
|
||||||
|
"blk.11.attn_norm.weight": "64e252603cf010a0e502ca39fdf8d0a196a79aec67c0d2bb9213fc0cb80c47d4",
|
||||||
|
"blk.11.attn_output.weight": "228e33e21c69f52efc74fdfc831bc9af271e44b2a29a3dced1d64e667ce36eb5",
|
||||||
|
"blk.11.attn_qkv.weight": "ab9ce6d4ef9e42ee0da3f20a7708a3bbc5e79e967b05fa86ba946a05e2eb63eb",
|
||||||
|
"blk.11.ffn_down.weight": "0ca133b7835c98dc77c25d64e4eb7873778bdb5e4d22d8b80f920f46865b43bd",
|
||||||
|
"blk.11.ffn_norm.weight": "02455741a0dfd161c79aa1ecc381901721f229fdcda5615622a629631fb61cfd",
|
||||||
|
"blk.11.ffn_up.weight": "9fecdcc099fbb8e23c6b1ea9294702a027f4a58d265543ec5e7be79b8f63b354",
|
||||||
|
"blk.12.attn_norm.weight": "783bb459911b1b3609a9b2bdfe272f1670add73b5471da738e07ac47e2e07dfd",
|
||||||
|
"blk.12.attn_output.weight": "1e1a914c9e48b857206ac5a1f7cead994bc1ea91d5d4fff8c834d73f2e38ef5d",
|
||||||
|
"blk.12.attn_qkv.weight": "5953e7185ccb87fb4dae8f9426ec86315d4c7794326e8ab59b3a95d4af2189f0",
|
||||||
|
"blk.12.ffn_down.weight": "a3eecf0f394f86e2cfb48a5940a5c50ca86d71883b2f79fcc642a935fabce0d4",
|
||||||
|
"blk.12.ffn_norm.weight": "0a4272e41373c23bd72f10d2d82930aa3a1480aac75832bfbf01cebf0b86b6a4",
|
||||||
|
"blk.12.ffn_up.weight": "06f42776de3a7ceac3025f26a7a8bd20e062233cce2bdaa2183470dc4b30b87d",
|
||||||
|
"blk.13.attn_norm.weight": "5915da60fb03e201fa649faba780e5fdf1c761c262b206e5415cf83181f65780",
|
||||||
|
"blk.13.attn_output.weight": "4dbf6eab074fa3835fd32bd631a8208e511037d5056d2fd3015735cca7674ef7",
|
||||||
|
"blk.13.attn_qkv.weight": "d3d8339a1c4782d9e73d77fdebe154d3c5b83ac40c9175b3e91a4977d08f876b",
|
||||||
|
"blk.13.ffn_down.weight": "de6772b46a55e1fd42b007637dfbf68b6598e5d5b61622da0935002e1e192d3a",
|
||||||
|
"blk.13.ffn_norm.weight": "5a640ea3b8c7be49c95a58a2327e10d8e8d9d142504bde5c8091613e5b961d7a",
|
||||||
|
"blk.13.ffn_up.weight": "f35e3545e4bd3531b2e843b5efd31dee0c13c807ee6386e65473ba67bbec30d0",
|
||||||
|
"blk.14.attn_norm.weight": "9b34986450b7c98b4927e81e61a816f9e84b1addc7c14926402100037aad6678",
|
||||||
|
"blk.14.attn_output.weight": "155d52efb23d366016d861a251d4d1f4a0c13699188c50d50dba016a0d8bfcd9",
|
||||||
|
"blk.14.attn_qkv.weight": "8e1415084e1f33c73a777f19e752489f4dd312cca047733e5ea643cd4a955e04",
|
||||||
|
"blk.14.ffn_down.weight": "a2a142226b94baa01ccb65bdea2b7418e49085c1d9c3c63e544e3112c58a25da",
|
||||||
|
"blk.14.ffn_norm.weight": "8aecfd9b0ae6affaea31a80c5c9a4a14b31deaa0db7bd8f6da2a64d23447921c",
|
||||||
|
"blk.14.ffn_up.weight": "0c1407237b8c1bd02f193346b5681926fe698a5055eac6a7450451b0f991707c",
|
||||||
|
"blk.15.attn_norm.weight": "e037bd19880bfa83d983200fb0c7866f8ad16c3ff5cc4b4f3a37ca7373870ff6",
|
||||||
|
"blk.15.attn_output.weight": "045fe4fc95cc129a1b92771b179c11b12845c4c088786c607f17bd98857e68e1",
|
||||||
|
"blk.15.attn_qkv.weight": "7621b7559705cab1d4dea1c69f76dbf9dc1c8837a203b656f484703b9c1b70ce",
|
||||||
|
"blk.15.ffn_down.weight": "7e5ac20e290bc60761e1cd972354fde225b7fa861048d44d9a0dd9b046d55f58",
|
||||||
|
"blk.15.ffn_norm.weight": "b6d830d88f1db1825687973c8c2b1a24c6fa84f07af8d0e3ef9c86009baca0b2",
|
||||||
|
"blk.15.ffn_up.weight": "dcda0957cd04fc45476774dba2bbf9aa89d6b05d5ca7b10ae6f73ad2c49b1cd3",
|
||||||
|
"blk.16.attn_norm.weight": "4ee9b70ba15cb2a08240f93990e90f5068c48fceb481f8e2186bec8b7214eb3f",
|
||||||
|
"blk.16.attn_output.weight": "315cfe5536658d2498192b2980eade15b2c9a4ff220e4011911457b1727fa103",
|
||||||
|
"blk.16.attn_qkv.weight": "3c8122e3ad637583b9dcde8ff3a323267d3014bb1f0f9771e5322260ca9ecc8d",
|
||||||
|
"blk.16.ffn_down.weight": "3b5fbebd5ee2b86cad96fb8a9b45a8770d08f82c1c8b74d7061e866f7020a18d",
|
||||||
|
"blk.16.ffn_norm.weight": "ffab69f20bda372de6e5878f0539163e2fc6ba113621ded95705fc3b1465c9f0",
|
||||||
|
"blk.16.ffn_up.weight": "0935ea3d258da42d6258406365f39f58ddaabfe97ea5977580db3635188f24a1",
|
||||||
|
"blk.17.attn_norm.weight": "f030441733f3d147b4a06a1eb4aeb8465c7c24d9c53bf4c48fe7e134d3629803",
|
||||||
|
"blk.17.attn_output.weight": "07a955ef09e8dc766ac0df647d0b2c69f23c4c69a7137654b4aad80303ed0eda",
|
||||||
|
"blk.17.attn_qkv.weight": "1c10688061e21e2fe12ad0cb54bf03895c1f83c3b0df743a42f548b52cbca1b2",
|
||||||
|
"blk.17.ffn_down.weight": "ebb9cc9836f41d88fdae2aa9a4355514e4edaec8d1577ffeb947a35204e77f52",
|
||||||
|
"blk.17.ffn_norm.weight": "50aff44f6528b13db5389f2ddcdb7676244947610bd7ffbff3f881c968c2a0d4",
|
||||||
|
"blk.17.ffn_up.weight": "d716537949582be33bde6b02e38f5a70081c9642a9fb05a61312126718b8d148",
|
||||||
|
"blk.18.attn_norm.weight": "0ea695c4e53d637902f46663a6ee42adc493c36794476acc7dbddaa05b13840d",
|
||||||
|
"blk.18.attn_output.weight": "5fd35b500221a612eb4f4bddf0e9b6b7db4d7733032a75f8802fb2d884647c2e",
|
||||||
|
"blk.18.attn_qkv.weight": "b0da37fd030fe69581f990bf23bfd35467a1bbe558af6de7c0924f6b72e92317",
|
||||||
|
"blk.18.ffn_down.weight": "b355c33f44b328f4bb977567de8f7544db4b005d7a8fbded658518ecf3c5a153",
|
||||||
|
"blk.18.ffn_norm.weight": "58b3fe9094079989a86e0387143259e1cc35952d24dc3df290c4ba6df44f5c51",
|
||||||
|
"blk.18.ffn_up.weight": "2ce530954c342c30ed2ead5353f931960bfae1d278868504c0efb973560fabbe",
|
||||||
|
"blk.19.attn_norm.weight": "533e9aed66feea8f0392aa81f9e293240e1f009a5334253915fb60c2749b615d",
|
||||||
|
"blk.19.attn_output.weight": "84f2d00f98a4113a779d3b5d1c3e7c914eb47784d3ab13b290367c124c2994aa",
|
||||||
|
"blk.19.attn_qkv.weight": "fbe6b9f53b07fa7537d3b3d452d20a9bc666f9fd41ec2091dd28bc2f70fc668f",
|
||||||
|
"blk.19.ffn_down.weight": "b30199e098c8bb3f890183d8b18471e80b62b604729b277ad62488dd71e1206b",
|
||||||
|
"blk.19.ffn_norm.weight": "c81373e41cd340b7badb19f9517c77c4250b4eb9a02dc758b8b49b652487d7ff",
|
||||||
|
"blk.19.ffn_up.weight": "5a5cb083ca7725720e3a890f7fa46354760e8007a8188849a092e305694a75e3",
|
||||||
|
"blk.20.attn_norm.weight": "4953091b4477e354357a8e743ba0a1900633e52f1599ee082a0c9b0b2b5cd978",
|
||||||
|
"blk.20.attn_output.weight": "62d54f7749cd6856097b2632066a322b0296df915fe66f382c5b5981be0d4f23",
|
||||||
|
"blk.20.attn_qkv.weight": "406de9e35b0729ebe902d7a47905cc7fb29a921431ed35dbef0c03e5690a1329",
|
||||||
|
"blk.20.ffn_down.weight": "62fb678b0d1261e19a4903a2b347d67afcc8acff01feb33a687a35a2d1e6f9a5",
|
||||||
|
"blk.20.ffn_norm.weight": "cd9d36b7e71e55c8925b97bb09c28219f182626bcff094878ae39c3db887a14b",
|
||||||
|
"blk.20.ffn_up.weight": "b9276771d79d3e932e73ccc520c3f8476342b9ef312ed2ee1e0da822e6e3ad18",
|
||||||
|
"blk.21.attn_norm.weight": "66d8c8a35e13ce9c2a0e75b670150e2c31484a55c2316df46075312196178ed3",
|
||||||
|
"blk.21.attn_output.weight": "12ab46c9382648f9b3350fdd92a6be6352743d62d6b520d7e2024e0c838588f5",
|
||||||
|
"blk.21.attn_qkv.weight": "a7909676ee1675ca23cd29a5fdd226df8dd9d68f94c6c9bbb51dd9fd38504008",
|
||||||
|
"blk.21.ffn_down.weight": "6fb317279c6542e82f97d5a12a60fac1bd0fa0405154f9fbe265e2fe39bd49cc",
|
||||||
|
"blk.21.ffn_norm.weight": "c0f703eb3ff161b5ba4490d87d8684b8a6c47a8f433e12f418333b9db439010a",
|
||||||
|
"blk.21.ffn_up.weight": "6dbdb80ef0c35e364bbce12d40d5e74c7963c7b55d58d9579567a07ffce7b863",
|
||||||
|
"blk.22.attn_norm.weight": "f94237433bf03d675cb2f655b81ca91a1ce2447bc6b00b13d6b0ccfe2d411eff",
|
||||||
|
"blk.22.attn_output.weight": "e821f95995ce497c01e63ca64f737713b1b65f11df1903e51d444aa516f33f71",
|
||||||
|
"blk.22.attn_qkv.weight": "1b0f717c73afb5eb4c82a1708c4e85c969e8a2a8770d9ddb78b1870a2d8a781e",
|
||||||
|
"blk.22.ffn_down.weight": "0f33f7a3cdc685484be99aa0c03642b0b20850a27d1fddbe054b13a9382f3ccb",
|
||||||
|
"blk.22.ffn_norm.weight": "9df285cf211ddd7df2b36a50489af574755c7d4d98b29a05cd04566ae613c8dc",
|
||||||
|
"blk.22.ffn_up.weight": "63ac300e1efb34041dd0136cf43ea622fac6f0caccce1cd9262f5e08d2cf179c",
|
||||||
|
"blk.23.attn_norm.weight": "5f72d9e88689b4027b28f5f8f26cd3abb03635ceea7ec98a4c91a9fc691f6707",
|
||||||
|
"blk.23.attn_output.weight": "6ecf04ff61125c5fc768f8656497152149373daf321ee9c957e8f7245a1184d1",
|
||||||
|
"blk.23.attn_qkv.weight": "a9d9978806724c2959f2cf386c233831f08e1e933dbf2b32665e788d9d512ea4",
|
||||||
|
"blk.23.ffn_down.weight": "72c7d17886a3da17fa0daa456aa5e877b2ef5b8b403182b870d9ca5ca9c70347",
|
||||||
|
"blk.23.ffn_norm.weight": "971e4b712e3025a13419b5b57d674b5e4ab7f18f74b57b9afc4671623da90c4b",
|
||||||
|
"blk.23.ffn_up.weight": "df2b5c7dbd5834545b815073af0c7355b065124e6d6f0fee78d8fa5b2076dc3e",
|
||||||
|
"blk.24.attn_norm.weight": "c41957c4a79ad3b16f6e11daec1c7f530b9f3f4b618e1e4367c3b67787ac4ab6",
|
||||||
|
"blk.24.attn_output.weight": "ef7d61f5fc88ac6f31bf60cb5f4d2d6b8df42d38825807112361a7224b0dee3b",
|
||||||
|
"blk.24.attn_qkv.weight": "3e6a58fe7d49c90bb6971efbad3371c32256881173ea5aee4b0c296cb206490f",
|
||||||
|
"blk.24.ffn_down.weight": "f43619144047de42fed81dfa495f1815d3cb771330e574043e2b67620819292c",
|
||||||
|
"blk.24.ffn_norm.weight": "5501d4a2a98c8ca6b42e77b53b221dbc08f530f6a067256d787534ec6fe028bd",
|
||||||
|
"blk.24.ffn_up.weight": "d64c8b0e509e2b1118f6000176f8956cacecdbb200c7e95ed93fb78b6e26c84a",
|
||||||
|
"blk.25.attn_norm.weight": "502fa3c302d371f61c5791f4615b73018ffb1daa09b6499b227116581244c5d4",
|
||||||
|
"blk.25.attn_output.weight": "ad8391d4e9c980856f2547aa945b2b6a407a6382158dc1ddd4f08d94ecc24be6",
|
||||||
|
"blk.25.attn_qkv.weight": "42e8983780d4a01a02c54ad23d4df21eea437f119a10af5a9c12a76a42d308c1",
|
||||||
|
"blk.25.ffn_down.weight": "302dd010d4e0ab4eeaee89090409ea0dddeeeed3236415eb8f97c942497eea91",
|
||||||
|
"blk.25.ffn_norm.weight": "fb34c1ee5bca96986c08834df0a0c047ba041c1123ac1f563e9d64312bf82d6a",
|
||||||
|
"blk.25.ffn_up.weight": "10739a8de156816d93c92b935386540bfa976bdbef204f0312960f6fc657582f",
|
||||||
|
"blk.26.attn_norm.weight": "7036c711609128c4e55968ff3681d3043338879a5737efd6c2ac9e1a2a61f1a0",
|
||||||
|
"blk.26.attn_output.weight": "db5db45dead5cb911fa01da59832f121b7c18b2d167bf53741c40819f24d346c",
|
||||||
|
"blk.26.attn_qkv.weight": "cae34c6b7f82ed14348d5ed30a79919c383737c1694a9cb9c0de609d3b0c1d0a",
|
||||||
|
"blk.26.ffn_down.weight": "491ec3a4da9b4f49f8ebc6be658ce397a9b801ae9fb35e82177e47808c65e5d0",
|
||||||
|
"blk.26.ffn_norm.weight": "fd7059d75d7f0e5288511ddeeb0f772eb3cae3ccfe4226b877015834edc3c386",
|
||||||
|
"blk.26.ffn_up.weight": "ea1ee1274c56458ce056d2205e5bb6e5422ce4cb0ad58006b8141749b97a0c39",
|
||||||
|
"blk.27.attn_norm.weight": "cc362c9a937609265052cd38544af17a1a7448cea086d4c801139e1fc865832d",
|
||||||
|
"blk.27.attn_output.weight": "ba757a81dabde9cb1b069d1bb616fe79649a1724f756567ec61caed1304fe6cf",
|
||||||
|
"blk.27.attn_qkv.weight": "1ab8d7d02d87756c12c2275636823aa5ede3d683178225c4cac4bd892c319bd4",
|
||||||
|
"blk.27.ffn_down.weight": "deb1c711c8a66acf4dcd2d088e1548f8e08f296f755e4067d6557fa55afde88c",
|
||||||
|
"blk.27.ffn_norm.weight": "fc6242d8cb8a4a37a8ddb7e41e7e60a63d4a89edf36acb35df052f10b9c91ece",
|
||||||
|
"blk.27.ffn_up.weight": "8df39b09c4801f343aca78f2918a1f6db78c8c55e591eda4c69eadb74c26e180",
|
||||||
|
"blk.28.attn_norm.weight": "75b539308f77e3cefdc6d98484d8b5cbf0538f0c2869a77b7373a145a18bc850",
|
||||||
|
"blk.28.attn_output.weight": "ae128940eb60a6d2e121762ef4b3e9dcf9eb3e105b249507fa7f12de0e19822c",
|
||||||
|
"blk.28.attn_qkv.weight": "bdda781c288e9326c240e33905f8e621b6a2ad902e620739d34f93fcd6f933de",
|
||||||
|
"blk.28.ffn_down.weight": "f1d6e6d1c286b1138bfd7e53fe477f399ae93bc2c04e35416f84218ed7247965",
|
||||||
|
"blk.28.ffn_norm.weight": "3f837ce82c8b9bde0d61d08b6f5fe5574886ea5328dbdc53f2929f18da8b4087",
|
||||||
|
"blk.28.ffn_up.weight": "2af027002e31d1b6cfedbdb30a2b9d7213f3aa691167c353913adfd48fda31e4",
|
||||||
|
"blk.29.attn_norm.weight": "61e8003b5329462ffe0fe172f2b160260de006aed858332d49d75504b6b6aa7a",
|
||||||
|
"blk.29.attn_output.weight": "ca44542a72a37476dc73dbdcc01f5b7497cb3ebc4ea230a55c9634ccd8e56ad4",
|
||||||
|
"blk.29.attn_qkv.weight": "abb3d9d6abe57872ae3daa51935d43264093ded5ce63b49d1e280ee5758be0e4",
|
||||||
|
"blk.29.ffn_down.weight": "6764b895fce881df097489c263446f0106de36217997660c15984b3ee22a5a06",
|
||||||
|
"blk.29.ffn_norm.weight": "89e03e9a33fc0e6e31ba9f0c2bd7c5734a118c5602bb90148793e08a80e8d0ae",
|
||||||
|
"blk.29.ffn_up.weight": "fa7ad57a84954f4121653152efed1a871d8adb20a1ea9086e3e849ce359d7d2e",
|
||||||
|
"blk.30.attn_norm.weight": "91a697aca1e42af54f806a20211031c3369e8d0bd58df1b0147fe24954e1f5a4",
|
||||||
|
"blk.30.attn_output.weight": "36063fcf766c89ac75be56f688cc63cefe5f2c733fbf4378ea9956ad386fa148",
|
||||||
|
"blk.30.attn_qkv.weight": "2cacd1161f1121a2c0b979930134f4666f73fb8d7237b3b0659ae091b15955a6",
|
||||||
|
"blk.30.ffn_down.weight": "9f3fcb6217100595850c05dc98f9ab2a263afdb6ab28df2fcb08aeff512057d7",
|
||||||
|
"blk.30.ffn_norm.weight": "6c600bc1fc7de39d4f8917b81fc7d1d5ed2a9b56492234c13a4bd6028c30d880",
|
||||||
|
"blk.30.ffn_up.weight": "73cabd1bb011956b2689ea3338bb76642ef3a57c197377d666d2ab5f56317668",
|
||||||
|
"blk.31.attn_norm.weight": "72d3e1cc771380645fa75a899858c95f39857a4f3f1ed60fe1578df383b8bc53",
|
||||||
|
"blk.31.attn_output.weight": "40089cdd29994dc19a1d89fa15902a89cfeca3540f12dc9bf4d00ef82506e456",
|
||||||
|
"blk.31.attn_qkv.weight": "1d0bb40e9258071ae14290a53c619a8e331dda07354d2a02ef45766c029ae5e4",
|
||||||
|
"blk.31.ffn_down.weight": "8defa0e06335b793fa8be03883f0a322d6c5b33f52c69c943c35c60d16e42c0a",
|
||||||
|
"blk.31.ffn_norm.weight": "33c55d9d0c496ccfb130361fe131649346e098abaaac39c0519507e5d846721d",
|
||||||
|
"blk.31.ffn_up.weight": "599f6503f61c692c1f82001973d35119f9688db5e6be9d9c298411491c93f09b",
|
||||||
|
"output.weight": "14b8dc662bfa3308ebb2e102c562d8e52c15670e538f20f3216a9c310ca9dd41",
|
||||||
|
"output_norm.weight": "7f2294ba94ce65681df6c7ddd8698799199b9d77dc83c10bdad5c3999f0fdb82",
|
||||||
|
"rope_factors_long.weight": "e34d378664e354652c38f47d10dafb0498ccc2fb042d39ff7fef768146fff22b",
|
||||||
|
"rope_factors_short.weight": "9379146a4988f373d362fe47b06c75e7fe7c54aa4dc9558758df79b7a87471fd",
|
||||||
|
"token_embd.weight": "19a03c1fb5ac0baee93b0a7d8b0f26e9a9b011e229b694afc50ebfc13d84f8bf"
|
||||||
|
}
|
||||||
@@ -16,7 +16,9 @@ If the model being imported is one of these architectures, it can be imported di
|
|||||||
|
|
||||||
- LlamaForCausalLM
|
- LlamaForCausalLM
|
||||||
- MistralForCausalLM
|
- MistralForCausalLM
|
||||||
|
- MixtralForCausalLM
|
||||||
- GemmaForCausalLM
|
- GemmaForCausalLM
|
||||||
|
- Phi3ForCausalLM
|
||||||
|
|
||||||
```dockerfile
|
```dockerfile
|
||||||
FROM /path/to/safetensors/directory
|
FROM /path/to/safetensors/directory
|
||||||
|
|||||||
@@ -182,7 +182,6 @@ curl http://localhost:11434/v1/embeddings \
|
|||||||
- [x] Reproducible outputs
|
- [x] Reproducible outputs
|
||||||
- [x] Vision
|
- [x] Vision
|
||||||
- [x] Tools (streaming support coming soon)
|
- [x] Tools (streaming support coming soon)
|
||||||
- [ ] Vision
|
|
||||||
- [ ] Logprobs
|
- [ ] Logprobs
|
||||||
|
|
||||||
#### Supported request fields
|
#### Supported request fields
|
||||||
|
|||||||
@@ -112,15 +112,9 @@ Keep the following tips and best practices in mind when working with Go template
|
|||||||
ChatML is a popular template format. It can be used for models such as Databrick's DBRX, Intel's Neural Chat, and Microsoft's Orca 2.
|
ChatML is a popular template format. It can be used for models such as Databrick's DBRX, Intel's Neural Chat, and Microsoft's Orca 2.
|
||||||
|
|
||||||
```gotmpl
|
```gotmpl
|
||||||
{{- if .System }}<|im_start|>system
|
|
||||||
{{ .System }}<|im_end|>
|
|
||||||
{{ end }}
|
|
||||||
{{- range .Messages }}<|im_start|>{{ .Role }}
|
{{- range .Messages }}<|im_start|>{{ .Role }}
|
||||||
{{ .Content }}<|im_end|>
|
{{ .Content }}<|im_end|>
|
||||||
{{ end }}<|im_start|>assistant
|
{{ end }}<|im_start|>assistant
|
||||||
{{ else }}
|
|
||||||
{{ if .System }}<|im_start|>system
|
|
||||||
{{ .System }}<|im_end|>
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Example Tools
|
### Example Tools
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -1,6 +1,6 @@
|
|||||||
module github.com/ollama/ollama
|
module github.com/ollama/ollama
|
||||||
|
|
||||||
go 1.22.0
|
go 1.22.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/containerd/console v1.0.3
|
github.com/containerd/console v1.0.3
|
||||||
|
|||||||
@@ -49,13 +49,9 @@ func PayloadsDir() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Track our pid so we can clean up orphaned tmpdirs
|
// Track our pid so we can clean up orphaned tmpdirs
|
||||||
pidFilePath := filepath.Join(tmpDir, "ollama.pid")
|
n := filepath.Join(tmpDir, "ollama.pid")
|
||||||
pidFile, err := os.OpenFile(pidFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.ModePerm)
|
if err := os.WriteFile(n, []byte(strconv.Itoa(os.Getpid())), 0o644); err != nil {
|
||||||
if err != nil {
|
return "", fmt.Errorf("failed to write pid file %s: %w", n, err)
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if _, err := pidFile.Write([]byte(strconv.Itoa(os.Getpid()))); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// We create a distinct subdirectory for payloads within the tmpdir
|
// We create a distinct subdirectory for payloads within the tmpdir
|
||||||
@@ -67,37 +63,44 @@ func PayloadsDir() (string, error) {
|
|||||||
|
|
||||||
// Best effort to clean up prior tmpdirs
|
// Best effort to clean up prior tmpdirs
|
||||||
func cleanupTmpDirs() {
|
func cleanupTmpDirs() {
|
||||||
dirs, err := filepath.Glob(filepath.Join(os.TempDir(), "ollama*"))
|
matches, err := filepath.Glob(filepath.Join(os.TempDir(), "ollama*", "ollama.pid"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, d := range dirs {
|
|
||||||
info, err := os.Stat(d)
|
for _, match := range matches {
|
||||||
if err != nil || !info.IsDir() {
|
raw, err := os.ReadFile(match)
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
slog.Debug("not a ollama runtime directory, skipping", "path", match)
|
||||||
continue
|
continue
|
||||||
}
|
} else if err != nil {
|
||||||
raw, err := os.ReadFile(filepath.Join(d, "ollama.pid"))
|
slog.Warn("could not read ollama.pid, skipping", "path", match, "error", err)
|
||||||
if err != nil {
|
|
||||||
slog.Warn("failed to read ollama.pid", "path", d, "error", err)
|
|
||||||
// No pid, ignore this tmpdir
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
pid, err := strconv.Atoi(string(raw))
|
pid, err := strconv.Atoi(string(raw))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("failed to parse pid", "path", d, "error", err)
|
slog.Warn("invalid pid, skipping", "path", match, "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
proc, err := os.FindProcess(pid)
|
p, err := os.FindProcess(pid)
|
||||||
if err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
if err == nil && !errors.Is(p.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
||||||
slog.Warn("found running ollama", "pid", pid, "path", d)
|
slog.Warn("process still running, skipping", "pid", pid, "path", match)
|
||||||
// Another running ollama, ignore this tmpdir
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.Remove(d); err != nil {
|
if err := os.Remove(match); err != nil {
|
||||||
slog.Warn("unable to cleanup stale tmpdir", "path", d, "error", err)
|
slog.Warn("could not cleanup stale pidfile", "path", match, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
runners := filepath.Join(filepath.Dir(match), "runners")
|
||||||
|
if err := os.RemoveAll(runners); err != nil {
|
||||||
|
slog.Warn("could not cleanup stale runners", "path", runners, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Remove(filepath.Dir(match)); err != nil {
|
||||||
|
slog.Warn("could not cleanup stale tmpdir", "path", filepath.Dir(match), "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -305,6 +305,8 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
// Intel
|
// Intel
|
||||||
if envconfig.IntelGPU() {
|
if envconfig.IntelGPU() {
|
||||||
oHandles = initOneAPIHandles()
|
oHandles = initOneAPIHandles()
|
||||||
|
if oHandles != nil && oHandles.oneapi != nil {
|
||||||
|
|
||||||
// On windows we bundle the oneapi library one level above the runner dir
|
// On windows we bundle the oneapi library one level above the runner dir
|
||||||
depPath = ""
|
depPath = ""
|
||||||
if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
|
if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
|
||||||
@@ -340,6 +342,7 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
rocmGPUs = AMDGetGPUInfo()
|
rocmGPUs = AMDGetGPUInfo()
|
||||||
bootstrapped = true
|
bootstrapped = true
|
||||||
|
|||||||
38
llm/ext_server/server.cpp
vendored
38
llm/ext_server/server.cpp
vendored
@@ -1223,9 +1223,7 @@ struct llama_server_context
|
|||||||
|
|
||||||
res.result_json = json
|
res.result_json = json
|
||||||
{
|
{
|
||||||
{"id", res.id},
|
|
||||||
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
||||||
{"timings", slot.get_formated_timings()},
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3194,41 +3192,17 @@ int main(int argc, char **argv) {
|
|||||||
prompt = "";
|
prompt = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (prompt.size() == 1) {
|
|
||||||
prompt = prompt[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
json responses;
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
{
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
const int id_task = llama.queue_tasks.get_new_id();
|
llama.request_completion(task_id, {{"prompt", prompt}}, true, -1);
|
||||||
llama.queue_results.add_waiting_task_id(id_task);
|
|
||||||
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
|
|
||||||
|
|
||||||
// get the result
|
// get the result
|
||||||
task_result result = llama.queue_results.recv(id_task);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
llama.queue_results.remove_waiting_task_id(id_task);
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
if (result.error) {
|
|
||||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
|
||||||
}
|
|
||||||
|
|
||||||
responses = result.result_json.value("results", std::vector<json>{result.result_json});
|
|
||||||
std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) {
|
|
||||||
return a["id"] < b["id"];
|
|
||||||
});
|
|
||||||
|
|
||||||
json embeddings = json::array();
|
|
||||||
|
|
||||||
int prompt_n = 0;
|
|
||||||
for (auto & elem : responses) {
|
|
||||||
embeddings.push_back(elem.at("embedding"));
|
|
||||||
prompt_n += elem.at("timings").at("prompt_n").get<int>();
|
|
||||||
}
|
|
||||||
|
|
||||||
// send the result
|
// send the result
|
||||||
json embedding_res = json{{"embedding", embeddings}, {"prompt_n", prompt_n}};
|
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||||
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
||||||
|
|||||||
@@ -157,6 +157,14 @@ type Tensor struct {
|
|||||||
io.WriterTo `json:"-"`
|
io.WriterTo `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t Tensor) block() (n int) {
|
||||||
|
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (t Tensor) blockSize() uint64 {
|
func (t Tensor) blockSize() uint64 {
|
||||||
switch t.Kind {
|
switch t.Kind {
|
||||||
case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
|
case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
|
||||||
|
|||||||
15
llm/gguf.go
15
llm/gguf.go
@@ -532,15 +532,14 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slices.SortFunc(ts, func(a, b Tensor) int {
|
slices.SortStableFunc(ts, func(a, b Tensor) int {
|
||||||
var i, j int
|
if i, j := a.block(), b.block(); i < 0 && j > 0 {
|
||||||
if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 {
|
return 1
|
||||||
return cmp.Compare(a.Name, b.Name)
|
} else if i > 0 && j < 0 {
|
||||||
} else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 {
|
return -1
|
||||||
return cmp.Compare(a.Name, b.Name)
|
} else {
|
||||||
}
|
|
||||||
|
|
||||||
return cmp.Compare(i, j)
|
return cmp.Compare(i, j)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
var s uint64
|
var s uint64
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ type LlamaServer interface {
|
|||||||
Ping(ctx context.Context) error
|
Ping(ctx context.Context) error
|
||||||
WaitUntilRunning(ctx context.Context) error
|
WaitUntilRunning(ctx context.Context) error
|
||||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||||
Embed(ctx context.Context, input []string) (*EmbedResponse, error)
|
Embedding(ctx context.Context, input string) ([]float32, error)
|
||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
@@ -125,8 +125,9 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// On linux, over-allocating CPU memory will almost always result in an error
|
// On linux and windows, over-allocating CPU memory will almost always result in an error
|
||||||
if runtime.GOOS == "linux" {
|
// Darwin has fully dynamic swap so has no direct concept of free swap space
|
||||||
|
if runtime.GOOS != "darwin" {
|
||||||
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
|
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
|
||||||
available := systemFreeMemory + systemSwapFreeMemory
|
available := systemFreeMemory + systemSwapFreeMemory
|
||||||
if systemMemoryRequired > available {
|
if systemMemoryRequired > available {
|
||||||
@@ -882,24 +883,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbedRequest struct {
|
type EmbeddingRequest struct {
|
||||||
Content []string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbedResponse struct {
|
type EmbeddingResponse struct {
|
||||||
Embedding [][]float32 `json:"embedding"`
|
Embedding []float32 `json:"embedding"`
|
||||||
PromptEvalCount int `json:"prompt_n"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
|
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||||
// each input will use a slot, so we need to acquire the semaphore for
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
// the number of inputs up to numParallel
|
|
||||||
slots := int64(min(len(input), s.numParallel))
|
|
||||||
if err := s.sem.Acquire(ctx, slots); err != nil {
|
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer s.sem.Release(slots)
|
defer s.sem.Release(1)
|
||||||
|
|
||||||
// Make sure the server is ready
|
// Make sure the server is ready
|
||||||
status, err := s.getServerStatusRetry(ctx)
|
status, err := s.getServerStatusRetry(ctx)
|
||||||
@@ -909,18 +906,18 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
|
|||||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(EmbedRequest{Content: input})
|
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating embed request: %w", err)
|
return nil, fmt.Errorf("error creating embed request: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := http.DefaultClient.Do(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("do embedding request: %w", err)
|
return nil, fmt.Errorf("do embedding request: %w", err)
|
||||||
}
|
}
|
||||||
@@ -936,12 +933,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
|
|||||||
return nil, fmt.Errorf("%s", body)
|
return nil, fmt.Errorf("%s", body)
|
||||||
}
|
}
|
||||||
|
|
||||||
var e EmbedResponse
|
var e EmbeddingResponse
|
||||||
if err := json.Unmarshal(body, &e); err != nil {
|
if err := json.Unmarshal(body, &e); err != nil {
|
||||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &e, nil
|
return e.Embedding, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizeRequest struct {
|
type TokenizeRequest struct {
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ var errorPrefixes = []string{
|
|||||||
"cudaMalloc failed",
|
"cudaMalloc failed",
|
||||||
"\"ERR\"",
|
"\"ERR\"",
|
||||||
"error loading model",
|
"error loading model",
|
||||||
|
"GGML_ASSERT",
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StatusWriter) Write(b []byte) (int, error) {
|
func (w *StatusWriter) Write(b []byte) (int, error) {
|
||||||
|
|||||||
@@ -7,12 +7,12 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
@@ -20,14 +20,9 @@ import (
|
|||||||
const (
|
const (
|
||||||
prefix = `data:image/jpeg;base64,`
|
prefix = `data:image/jpeg;base64,`
|
||||||
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
imageURL = prefix + image
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func prepareRequest(req *http.Request, body any) {
|
var False = false
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
}
|
|
||||||
|
|
||||||
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
@@ -43,134 +38,136 @@ func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
|||||||
|
|
||||||
func TestChatMiddleware(t *testing.T) {
|
func TestChatMiddleware(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
Name string
|
name string
|
||||||
Setup func(t *testing.T, req *http.Request)
|
body string
|
||||||
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
|
req api.ChatRequest
|
||||||
|
err ErrorResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
var capturedRequest *api.ChatRequest
|
var capturedRequest *api.ChatRequest
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "chat handler",
|
name: "chat handler",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := ChatCompletionRequest{
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
Messages: []api.Message{
|
||||||
}
|
{
|
||||||
prepareRequest(req, body)
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
},
|
||||||
if resp.Code != http.StatusOK {
|
Options: map[string]any{
|
||||||
t.Fatalf("expected 200, got %d", resp.Code)
|
"temperature": 1.0,
|
||||||
}
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
if req.Messages[0].Role != "user" {
|
Stream: &False,
|
||||||
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Messages[0].Content != "Hello" {
|
|
||||||
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "chat handler with image content",
|
name: "chat handler with image content",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := ChatCompletionRequest{
|
"model": "test-model",
|
||||||
Model: "test-model",
|
"messages": [
|
||||||
Messages: []Message{
|
|
||||||
{
|
{
|
||||||
Role: "user", Content: []map[string]any{
|
"role": "user",
|
||||||
{"type": "text", "text": "Hello"},
|
"content": [
|
||||||
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "` + prefix + image + `"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Images: []api.ImageData{
|
||||||
|
func() []byte {
|
||||||
|
img, _ := base64.StdEncoding.DecodeString(image)
|
||||||
|
return img
|
||||||
|
}(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
Options: map[string]any{
|
||||||
prepareRequest(req, body)
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
Stream: &False,
|
||||||
if resp.Code != http.StatusOK {
|
|
||||||
t.Fatalf("expected 200, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Messages[0].Role != "user" {
|
|
||||||
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Messages[0].Content != "Hello" {
|
|
||||||
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
|
||||||
|
|
||||||
if req.Messages[1].Role != "user" {
|
|
||||||
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(req.Messages[1].Images[0], img) {
|
|
||||||
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "chat handler with tools",
|
name: "chat handler with tools",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := ChatCompletionRequest{
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||||
|
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
Messages: []Message{
|
Messages: []api.Message{
|
||||||
{Role: "user", Content: "What's the weather like in Paris Today?"},
|
{
|
||||||
{Role: "assistant", ToolCalls: []ToolCall{{
|
Role: "user",
|
||||||
ID: "id",
|
Content: "What's the weather like in Paris Today?",
|
||||||
Type: "function",
|
},
|
||||||
Function: struct {
|
{
|
||||||
Name string `json:"name"`
|
Role: "assistant",
|
||||||
Arguments string `json:"arguments"`
|
ToolCalls: []api.ToolCall{
|
||||||
}{
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
|
Arguments: map[string]interface{}{
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius",
|
||||||
},
|
},
|
||||||
}}},
|
|
||||||
},
|
},
|
||||||
}
|
|
||||||
prepareRequest(req, body)
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
},
|
||||||
if resp.Code != 200 {
|
},
|
||||||
t.Fatalf("expected 200, got %d", resp.Code)
|
},
|
||||||
}
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
|
|
||||||
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
|
|
||||||
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
|
|
||||||
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
Name: "chat handler error forwarding",
|
name: "chat handler error forwarding",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := ChatCompletionRequest{
|
"model": "test-model",
|
||||||
Model: "test-model",
|
"messages": [
|
||||||
Messages: []Message{{Role: "user", Content: 2}},
|
{"role": "user", "content": 2}
|
||||||
}
|
]
|
||||||
prepareRequest(req, body)
|
}`,
|
||||||
|
err: ErrorResponse{
|
||||||
|
Error: Error{
|
||||||
|
Message: "invalid message content type: float64",
|
||||||
|
Type: "invalid_request_error",
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
|
||||||
if resp.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(resp.Body.String(), "invalid message content type") {
|
|
||||||
t.Fatalf("error was not forwarded")
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -185,16 +182,26 @@ func TestChatMiddleware(t *testing.T) {
|
|||||||
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
|
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
tc.Setup(t, req)
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
tc.Expected(t, capturedRequest, resp)
|
var errResp ErrorResponse
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||||
|
t.Fatal("requests did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tc.err, errResp) {
|
||||||
|
t.Fatal("errors did not match")
|
||||||
|
}
|
||||||
capturedRequest = nil
|
capturedRequest = nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -202,71 +209,52 @@ func TestChatMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
func TestCompletionsMiddleware(t *testing.T) {
|
func TestCompletionsMiddleware(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
Name string
|
name string
|
||||||
Setup func(t *testing.T, req *http.Request)
|
body string
|
||||||
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
|
req api.GenerateRequest
|
||||||
|
err ErrorResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
var capturedRequest *api.GenerateRequest
|
var capturedRequest *api.GenerateRequest
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "completions handler",
|
name: "completions handler",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
temp := float32(0.8)
|
"model": "test-model",
|
||||||
body := CompletionRequest{
|
"prompt": "Hello",
|
||||||
|
"temperature": 0.8,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
Prompt: "Hello",
|
Prompt: "Hello",
|
||||||
Temperature: &temp,
|
Options: map[string]any{
|
||||||
Stop: []string{"\n", "stop"},
|
"frequency_penalty": 0.0,
|
||||||
Suffix: "suffix",
|
"presence_penalty": 0.0,
|
||||||
}
|
"temperature": 1.6,
|
||||||
prepareRequest(req, body)
|
"top_p": 1.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
Suffix: "suffix",
|
||||||
if req.Prompt != "Hello" {
|
Stream: &False,
|
||||||
t.Fatalf("expected 'Hello', got %s", req.Prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Options["temperature"] != 1.6 {
|
|
||||||
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
|
|
||||||
}
|
|
||||||
|
|
||||||
stopTokens, ok := req.Options["stop"].([]any)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected stop tokens to be a list")
|
|
||||||
}
|
|
||||||
|
|
||||||
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
|
||||||
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Suffix != "suffix" {
|
|
||||||
t.Fatalf("expected 'suffix', got %s", req.Suffix)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "completions handler error forwarding",
|
name: "completions handler error forwarding",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := CompletionRequest{
|
"model": "test-model",
|
||||||
Model: "test-model",
|
"prompt": "Hello",
|
||||||
Prompt: "Hello",
|
"temperature": null,
|
||||||
Temperature: nil,
|
"stop": [1, 2],
|
||||||
Stop: []int{1, 2},
|
"suffix": "suffix"
|
||||||
Suffix: "suffix",
|
}`,
|
||||||
}
|
err: ErrorResponse{
|
||||||
prepareRequest(req, body)
|
Error: Error{
|
||||||
|
Message: "invalid type for 'stop' field: float64",
|
||||||
|
Type: "invalid_request_error",
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
|
||||||
if resp.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
|
|
||||||
t.Fatalf("error was not forwarded")
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -281,15 +269,27 @@ func TestCompletionsMiddleware(t *testing.T) {
|
|||||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
|
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
tc.Setup(t, req)
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
tc.Expected(t, capturedRequest, resp)
|
var errResp ErrorResponse
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||||
|
t.Fatal("requests did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tc.err, errResp) {
|
||||||
|
t.Fatal("errors did not match")
|
||||||
|
}
|
||||||
|
|
||||||
capturedRequest = nil
|
capturedRequest = nil
|
||||||
})
|
})
|
||||||
@@ -298,78 +298,47 @@ func TestCompletionsMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
func TestEmbeddingsMiddleware(t *testing.T) {
|
func TestEmbeddingsMiddleware(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
Name string
|
name string
|
||||||
Setup func(t *testing.T, req *http.Request)
|
body string
|
||||||
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
|
req api.EmbedRequest
|
||||||
|
err ErrorResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
var capturedRequest *api.EmbedRequest
|
var capturedRequest *api.EmbedRequest
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "embed handler single input",
|
name: "embed handler single input",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := EmbedRequest{
|
"input": "Hello",
|
||||||
|
"model": "test-model"
|
||||||
|
}`,
|
||||||
|
req: api.EmbedRequest{
|
||||||
Input: "Hello",
|
Input: "Hello",
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
}
|
|
||||||
prepareRequest(req, body)
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
|
||||||
if req.Input != "Hello" {
|
|
||||||
t.Fatalf("expected 'Hello', got %s", req.Input)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Model != "test-model" {
|
|
||||||
t.Fatalf("expected 'test-model', got %s", req.Model)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "embed handler batch input",
|
name: "embed handler batch input",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := EmbedRequest{
|
"input": ["Hello", "World"],
|
||||||
Input: []string{"Hello", "World"},
|
"model": "test-model"
|
||||||
|
}`,
|
||||||
|
req: api.EmbedRequest{
|
||||||
|
Input: []any{"Hello", "World"},
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
}
|
|
||||||
prepareRequest(req, body)
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
|
||||||
input, ok := req.Input.([]any)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected input to be a list")
|
|
||||||
}
|
|
||||||
|
|
||||||
if input[0].(string) != "Hello" {
|
|
||||||
t.Fatalf("expected 'Hello', got %s", input[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if input[1].(string) != "World" {
|
|
||||||
t.Fatalf("expected 'World', got %s", input[1])
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Model != "test-model" {
|
|
||||||
t.Fatalf("expected 'test-model', got %s", req.Model)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "embed handler error forwarding",
|
name: "embed handler error forwarding",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := EmbedRequest{
|
"model": "test-model"
|
||||||
Model: "test-model",
|
}`,
|
||||||
}
|
err: ErrorResponse{
|
||||||
prepareRequest(req, body)
|
Error: Error{
|
||||||
|
Message: "invalid input",
|
||||||
|
Type: "invalid_request_error",
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
|
||||||
if resp.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(resp.Body.String(), "invalid input") {
|
|
||||||
t.Fatalf("error was not forwarded")
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -384,116 +353,167 @@ func TestEmbeddingsMiddleware(t *testing.T) {
|
|||||||
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
|
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
tc.Setup(t, req)
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
tc.Expected(t, capturedRequest, resp)
|
var errResp ErrorResponse
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||||
|
t.Fatal("requests did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tc.err, errResp) {
|
||||||
|
t.Fatal("errors did not match")
|
||||||
|
}
|
||||||
|
|
||||||
capturedRequest = nil
|
capturedRequest = nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMiddlewareResponses(t *testing.T) {
|
func TestListMiddleware(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
Name string
|
name string
|
||||||
Method string
|
endpoint func(c *gin.Context)
|
||||||
Path string
|
resp string
|
||||||
TestPath string
|
|
||||||
Handler func() gin.HandlerFunc
|
|
||||||
Endpoint func(c *gin.Context)
|
|
||||||
Setup func(t *testing.T, req *http.Request)
|
|
||||||
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "list handler",
|
name: "list handler",
|
||||||
Method: http.MethodGet,
|
endpoint: func(c *gin.Context) {
|
||||||
Path: "/api/tags",
|
|
||||||
TestPath: "/api/tags",
|
|
||||||
Handler: ListMiddleware,
|
|
||||||
Endpoint: func(c *gin.Context) {
|
|
||||||
c.JSON(http.StatusOK, api.ListResponse{
|
c.JSON(http.StatusOK, api.ListResponse{
|
||||||
Models: []api.ListModelResponse{
|
Models: []api.ListModelResponse{
|
||||||
{
|
{
|
||||||
Name: "Test Model",
|
Name: "test-model",
|
||||||
|
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
resp: `{
|
||||||
var listResp ListCompletion
|
"object": "list",
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
"data": [
|
||||||
t.Fatal(err)
|
{
|
||||||
|
"id": "test-model",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "library"
|
||||||
}
|
}
|
||||||
|
]
|
||||||
if listResp.Object != "list" {
|
}`,
|
||||||
t.Fatalf("expected list, got %s", listResp.Object)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(listResp.Data) != 1 {
|
|
||||||
t.Fatalf("expected 1, got %d", len(listResp.Data))
|
|
||||||
}
|
|
||||||
|
|
||||||
if listResp.Data[0].Id != "Test Model" {
|
|
||||||
t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "retrieve model",
|
name: "list handler empty output",
|
||||||
Method: http.MethodGet,
|
endpoint: func(c *gin.Context) {
|
||||||
Path: "/api/show/:model",
|
c.JSON(http.StatusOK, api.ListResponse{})
|
||||||
TestPath: "/api/show/test-model",
|
|
||||||
Handler: RetrieveMiddleware,
|
|
||||||
Endpoint: func(c *gin.Context) {
|
|
||||||
c.JSON(http.StatusOK, api.ShowResponse{
|
|
||||||
ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
|
|
||||||
})
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
||||||
var retrieveResp Model
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if retrieveResp.Object != "model" {
|
|
||||||
t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
|
|
||||||
}
|
|
||||||
|
|
||||||
if retrieveResp.Id != "test-model" {
|
|
||||||
t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
resp: `{
|
||||||
|
"object": "list",
|
||||||
|
"data": null
|
||||||
|
}`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
router := gin.New()
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
router := gin.New()
|
||||||
router = gin.New()
|
router.Use(ListMiddleware())
|
||||||
router.Use(tc.Handler())
|
router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
|
||||||
router.Handle(tc.Method, tc.Path, tc.Endpoint)
|
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
|
||||||
req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
|
|
||||||
|
|
||||||
if tc.Setup != nil {
|
|
||||||
tc.Setup(t, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, resp.Code)
|
var expected, actual map[string]any
|
||||||
|
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
tc.Expected(t, resp)
|
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||||
})
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(expected, actual) {
|
||||||
|
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrieveMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
endpoint func(c *gin.Context)
|
||||||
|
resp string
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "retrieve handler",
|
||||||
|
endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.ShowResponse{
|
||||||
|
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
||||||
|
})
|
||||||
|
},
|
||||||
|
resp: `{
|
||||||
|
"id":"test-model",
|
||||||
|
"object":"model",
|
||||||
|
"created":1686935002,
|
||||||
|
"owned_by":"library"}
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve handler error forwarding",
|
||||||
|
endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
|
||||||
|
},
|
||||||
|
resp: `{
|
||||||
|
"error": {
|
||||||
|
"code": null,
|
||||||
|
"message": "model not found",
|
||||||
|
"param": null,
|
||||||
|
"type": "api_error"
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(RetrieveMiddleware())
|
||||||
|
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
var expected, actual map[string]any
|
||||||
|
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(expected, actual) {
|
||||||
|
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,11 +3,12 @@ package progress
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Spinner struct {
|
type Spinner struct {
|
||||||
message string
|
message atomic.Value
|
||||||
messageWidth int
|
messageWidth int
|
||||||
|
|
||||||
parts []string
|
parts []string
|
||||||
@@ -21,20 +22,25 @@ type Spinner struct {
|
|||||||
|
|
||||||
func NewSpinner(message string) *Spinner {
|
func NewSpinner(message string) *Spinner {
|
||||||
s := &Spinner{
|
s := &Spinner{
|
||||||
message: message,
|
|
||||||
parts: []string{
|
parts: []string{
|
||||||
"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏",
|
"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏",
|
||||||
},
|
},
|
||||||
started: time.Now(),
|
started: time.Now(),
|
||||||
}
|
}
|
||||||
|
s.SetMessage(message)
|
||||||
go s.start()
|
go s.start()
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Spinner) SetMessage(message string) {
|
||||||
|
s.message.Store(message)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Spinner) String() string {
|
func (s *Spinner) String() string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
if len(s.message) > 0 {
|
|
||||||
message := strings.TrimSpace(s.message)
|
if message, ok := s.message.Load().(string); ok && len(message) > 0 {
|
||||||
|
message := strings.TrimSpace(message)
|
||||||
if s.messageWidth > 0 && len(message) > s.messageWidth {
|
if s.messageWidth > 0 && len(message) > s.messageWidth {
|
||||||
message = message[:s.messageWidth]
|
message = message[:s.messageWidth]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func (b *Buffer) MoveLeft() {
|
|||||||
rLength := runewidth.RuneWidth(r)
|
rLength := runewidth.RuneWidth(r)
|
||||||
|
|
||||||
if b.DisplayPos%b.LineWidth == 0 {
|
if b.DisplayPos%b.LineWidth == 0 {
|
||||||
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
fmt.Print(CursorUp + CursorBOL + CursorRightN(b.Width))
|
||||||
if rLength == 2 {
|
if rLength == 2 {
|
||||||
fmt.Print(CursorLeft)
|
fmt.Print(CursorLeft)
|
||||||
}
|
}
|
||||||
@@ -74,7 +74,7 @@ func (b *Buffer) MoveLeft() {
|
|||||||
fmt.Print(CursorLeft)
|
fmt.Print(CursorLeft)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(cursorLeftN(rLength))
|
fmt.Print(CursorLeftN(rLength))
|
||||||
}
|
}
|
||||||
|
|
||||||
b.Pos -= 1
|
b.Pos -= 1
|
||||||
@@ -115,15 +115,15 @@ func (b *Buffer) MoveRight() {
|
|||||||
b.DisplayPos += rLength
|
b.DisplayPos += rLength
|
||||||
|
|
||||||
if b.DisplayPos%b.LineWidth == 0 {
|
if b.DisplayPos%b.LineWidth == 0 {
|
||||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt())))
|
||||||
} else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace {
|
} else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace {
|
||||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())+rLength))
|
fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt())+rLength))
|
||||||
b.DisplayPos += 1
|
b.DisplayPos += 1
|
||||||
} else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace {
|
} else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace {
|
||||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt())))
|
||||||
b.DisplayPos += 1
|
b.DisplayPos += 1
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(cursorRightN(rLength))
|
fmt.Print(CursorRightN(rLength))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -154,7 +154,7 @@ func (b *Buffer) MoveToStart() {
|
|||||||
fmt.Print(CursorUp)
|
fmt.Print(CursorUp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
fmt.Print(CursorBOL + CursorRightN(len(b.Prompt.prompt())))
|
||||||
b.Pos = 0
|
b.Pos = 0
|
||||||
b.DisplayPos = 0
|
b.DisplayPos = 0
|
||||||
}
|
}
|
||||||
@@ -169,9 +169,9 @@ func (b *Buffer) MoveToEnd() {
|
|||||||
fmt.Print(CursorDown)
|
fmt.Print(CursorDown)
|
||||||
}
|
}
|
||||||
remainder := b.DisplaySize() % b.LineWidth
|
remainder := b.DisplaySize() % b.LineWidth
|
||||||
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())+remainder))
|
fmt.Print(CursorBOL + CursorRightN(len(b.Prompt.prompt())+remainder))
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(cursorRightN(b.DisplaySize() - b.DisplayPos))
|
fmt.Print(CursorRightN(b.DisplaySize() - b.DisplayPos))
|
||||||
}
|
}
|
||||||
|
|
||||||
b.Pos = b.Buf.Size()
|
b.Pos = b.Buf.Size()
|
||||||
@@ -286,8 +286,7 @@ func (b *Buffer) drawRemaining() {
|
|||||||
remLength := runewidth.StringWidth(remainingText)
|
remLength := runewidth.StringWidth(remainingText)
|
||||||
|
|
||||||
if len(currLine) > 0 {
|
if len(currLine) > 0 {
|
||||||
fmt.Printf(ClearToEOL + currLine)
|
fmt.Print(ClearToEOL + currLine + CursorLeftN(currLineSpace))
|
||||||
fmt.Print(cursorLeftN(currLineSpace))
|
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(ClearToEOL)
|
fmt.Print(ClearToEOL)
|
||||||
}
|
}
|
||||||
@@ -301,9 +300,9 @@ func (b *Buffer) drawRemaining() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (b.DisplayPos+currLineSpace)%b.LineWidth == 0 && currLine == remainingText {
|
if (b.DisplayPos+currLineSpace)%b.LineWidth == 0 && currLine == remainingText {
|
||||||
fmt.Print(cursorRightN(currLineSpace))
|
fmt.Print(CursorRightN(currLineSpace))
|
||||||
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
||||||
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width-currLineSpace))
|
fmt.Print(CursorUp + CursorBOL + CursorRightN(b.Width-currLineSpace))
|
||||||
}
|
}
|
||||||
|
|
||||||
// render the other lines
|
// render the other lines
|
||||||
@@ -333,9 +332,7 @@ func (b *Buffer) drawRemaining() {
|
|||||||
lineLength += runewidth.RuneWidth(c)
|
lineLength += runewidth.RuneWidth(c)
|
||||||
fmt.Printf("%c", c)
|
fmt.Printf("%c", c)
|
||||||
}
|
}
|
||||||
fmt.Print(ClearToEOL)
|
fmt.Print(ClearToEOL + CursorUpN(totalLines) + CursorBOL + CursorRightN(b.Width-currLineSpace))
|
||||||
fmt.Print(cursorUpN(totalLines))
|
|
||||||
fmt.Printf(CursorBOL + cursorRightN(b.Width-currLineSpace))
|
|
||||||
|
|
||||||
hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth)
|
hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth)
|
||||||
|
|
||||||
@@ -357,8 +354,7 @@ func (b *Buffer) Remove() {
|
|||||||
if b.DisplayPos%b.LineWidth == 0 {
|
if b.DisplayPos%b.LineWidth == 0 {
|
||||||
// if the user backspaces over the word boundary, do this magic to clear the line
|
// if the user backspaces over the word boundary, do this magic to clear the line
|
||||||
// and move to the end of the previous line
|
// and move to the end of the previous line
|
||||||
fmt.Printf(CursorBOL + ClearToEOL)
|
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + CursorRightN(b.Width))
|
||||||
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
|
||||||
|
|
||||||
if b.DisplaySize()%b.LineWidth < (b.DisplaySize()-rLength)%b.LineWidth {
|
if b.DisplaySize()%b.LineWidth < (b.DisplaySize()-rLength)%b.LineWidth {
|
||||||
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
|
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
|
||||||
@@ -370,24 +366,23 @@ func (b *Buffer) Remove() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rLength == 2 {
|
if rLength == 2 {
|
||||||
fmt.Print(CursorLeft + " " + cursorLeftN(2))
|
fmt.Print(CursorLeft + " " + CursorLeftN(2))
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(" " + CursorLeft)
|
fmt.Print(" " + CursorLeft)
|
||||||
}
|
}
|
||||||
} else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace {
|
} else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace {
|
||||||
fmt.Printf(CursorBOL + ClearToEOL)
|
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + CursorRightN(b.Width))
|
||||||
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
|
||||||
|
|
||||||
if b.Pos == b.Buf.Size() {
|
if b.Pos == b.Buf.Size() {
|
||||||
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
|
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
|
||||||
}
|
}
|
||||||
b.DisplayPos -= 1
|
b.DisplayPos -= 1
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(cursorLeftN(rLength))
|
fmt.Print(CursorLeftN(rLength))
|
||||||
for range rLength {
|
for range rLength {
|
||||||
fmt.Print(" ")
|
fmt.Print(" ")
|
||||||
}
|
}
|
||||||
fmt.Print(cursorLeftN(rLength))
|
fmt.Print(CursorLeftN(rLength))
|
||||||
}
|
}
|
||||||
|
|
||||||
var eraseExtraLine bool
|
var eraseExtraLine bool
|
||||||
@@ -405,9 +400,9 @@ func (b *Buffer) Remove() {
|
|||||||
// are trailing characters which go over the line width boundary
|
// are trailing characters which go over the line width boundary
|
||||||
if eraseExtraLine {
|
if eraseExtraLine {
|
||||||
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
|
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
|
||||||
fmt.Printf(cursorDownN(remainingLines+1) + CursorBOL + ClearToEOL)
|
fmt.Print(CursorDownN(remainingLines+1) + CursorBOL + ClearToEOL)
|
||||||
place := b.DisplayPos % b.LineWidth
|
place := b.DisplayPos % b.LineWidth
|
||||||
fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.prompt())))
|
fmt.Print(CursorUpN(remainingLines+1) + CursorRightN(place+len(b.Prompt.prompt())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -422,9 +417,9 @@ func (b *Buffer) Delete() {
|
|||||||
if b.DisplaySize()%b.LineWidth == 0 {
|
if b.DisplaySize()%b.LineWidth == 0 {
|
||||||
if b.DisplayPos != b.DisplaySize() {
|
if b.DisplayPos != b.DisplaySize() {
|
||||||
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
|
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
|
||||||
fmt.Printf(cursorDownN(remainingLines) + CursorBOL + ClearToEOL)
|
fmt.Print(CursorDownN(remainingLines) + CursorBOL + ClearToEOL)
|
||||||
place := b.DisplayPos % b.LineWidth
|
place := b.DisplayPos % b.LineWidth
|
||||||
fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.prompt())))
|
fmt.Print(CursorUpN(remainingLines) + CursorRightN(place+len(b.Prompt.prompt())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -471,17 +466,17 @@ func (b *Buffer) DeleteWord() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) ClearScreen() {
|
func (b *Buffer) ClearScreen() {
|
||||||
fmt.Printf(ClearScreen + CursorReset + b.Prompt.prompt())
|
fmt.Print(ClearScreen + CursorReset + b.Prompt.prompt())
|
||||||
if b.IsEmpty() {
|
if b.IsEmpty() {
|
||||||
ph := b.Prompt.placeholder()
|
ph := b.Prompt.placeholder()
|
||||||
fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault)
|
fmt.Print(ColorGrey + ph + CursorLeftN(len(ph)) + ColorDefault)
|
||||||
} else {
|
} else {
|
||||||
currPos := b.DisplayPos
|
currPos := b.DisplayPos
|
||||||
currIndex := b.Pos
|
currIndex := b.Pos
|
||||||
b.Pos = 0
|
b.Pos = 0
|
||||||
b.DisplayPos = 0
|
b.DisplayPos = 0
|
||||||
b.drawRemaining()
|
b.drawRemaining()
|
||||||
fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.prompt())))
|
fmt.Print(CursorReset + CursorRightN(len(b.Prompt.prompt())))
|
||||||
if currPos > 0 {
|
if currPos > 0 {
|
||||||
targetLine := currPos / b.LineWidth
|
targetLine := currPos / b.LineWidth
|
||||||
if targetLine > 0 {
|
if targetLine > 0 {
|
||||||
@@ -491,10 +486,10 @@ func (b *Buffer) ClearScreen() {
|
|||||||
}
|
}
|
||||||
remainder := currPos % b.LineWidth
|
remainder := currPos % b.LineWidth
|
||||||
if remainder > 0 {
|
if remainder > 0 {
|
||||||
fmt.Print(cursorRightN(remainder))
|
fmt.Print(CursorRightN(remainder))
|
||||||
}
|
}
|
||||||
if currPos%b.LineWidth == 0 {
|
if currPos%b.LineWidth == 0 {
|
||||||
fmt.Printf(CursorBOL + b.Prompt.AltPrompt)
|
fmt.Print(CursorBOL + b.Prompt.AltPrompt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
b.Pos = currIndex
|
b.Pos = currIndex
|
||||||
@@ -513,13 +508,13 @@ func (b *Buffer) Replace(r []rune) {
|
|||||||
|
|
||||||
b.Buf.Clear()
|
b.Buf.Clear()
|
||||||
|
|
||||||
fmt.Printf(CursorBOL + ClearToEOL)
|
fmt.Print(CursorBOL + ClearToEOL)
|
||||||
|
|
||||||
for range lineNums {
|
for range lineNums {
|
||||||
fmt.Print(CursorUp + CursorBOL + ClearToEOL)
|
fmt.Print(CursorUp + CursorBOL + ClearToEOL)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf(CursorBOL + b.Prompt.prompt())
|
fmt.Print(CursorBOL + b.Prompt.prompt())
|
||||||
|
|
||||||
for _, c := range r {
|
for _, c := range r {
|
||||||
b.Add(c)
|
b.Add(c)
|
||||||
@@ -545,19 +540,3 @@ func (b *Buffer) StringNM(n, m int) string {
|
|||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func cursorLeftN(n int) string {
|
|
||||||
return fmt.Sprintf(CursorLeftN, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func cursorRightN(n int) string {
|
|
||||||
return fmt.Sprintf(CursorRightN, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func cursorUpN(n int) string {
|
|
||||||
return fmt.Sprintf(CursorUpN, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func cursorDownN(n int) string {
|
|
||||||
return fmt.Sprintf(CursorDownN, n)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
showPlaceholder := !i.Pasting || i.Prompt.UseAlt
|
showPlaceholder := !i.Pasting || i.Prompt.UseAlt
|
||||||
if buf.IsEmpty() && showPlaceholder {
|
if buf.IsEmpty() && showPlaceholder {
|
||||||
ph := i.Prompt.placeholder()
|
ph := i.Prompt.placeholder()
|
||||||
fmt.Printf(ColorGrey + ph + fmt.Sprintf(CursorLeftN, len(ph)) + ColorDefault)
|
fmt.Print(ColorGrey + ph + CursorLeftN(len(ph)) + ColorDefault)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := i.Terminal.Read()
|
r, err := i.Terminal.Read()
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package readline
|
package readline
|
||||||
|
|
||||||
|
import "strconv"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CharNull = 0
|
CharNull = 0
|
||||||
CharLineStart = 1
|
CharLineStart = 1
|
||||||
@@ -41,34 +43,49 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CursorUp = "\033[1A"
|
Esc = "\x1b"
|
||||||
CursorDown = "\033[1B"
|
|
||||||
CursorRight = "\033[1C"
|
|
||||||
CursorLeft = "\033[1D"
|
|
||||||
|
|
||||||
CursorSave = "\033[s"
|
CursorSave = Esc + "[s"
|
||||||
CursorRestore = "\033[u"
|
CursorRestore = Esc + "[u"
|
||||||
|
|
||||||
CursorUpN = "\033[%dA"
|
CursorEOL = Esc + "[E"
|
||||||
CursorDownN = "\033[%dB"
|
CursorBOL = Esc + "[1G"
|
||||||
CursorRightN = "\033[%dC"
|
CursorHide = Esc + "[?25l"
|
||||||
CursorLeftN = "\033[%dD"
|
CursorShow = Esc + "[?25h"
|
||||||
|
|
||||||
CursorEOL = "\033[E"
|
ClearToEOL = Esc + "[K"
|
||||||
CursorBOL = "\033[1G"
|
ClearLine = Esc + "[2K"
|
||||||
CursorHide = "\033[?25l"
|
ClearScreen = Esc + "[2J"
|
||||||
CursorShow = "\033[?25h"
|
CursorReset = Esc + "[0;0f"
|
||||||
|
|
||||||
ClearToEOL = "\033[K"
|
ColorGrey = Esc + "[38;5;245m"
|
||||||
ClearLine = "\033[2K"
|
ColorDefault = Esc + "[0m"
|
||||||
ClearScreen = "\033[2J"
|
|
||||||
CursorReset = "\033[0;0f"
|
|
||||||
|
|
||||||
ColorGrey = "\033[38;5;245m"
|
StartBracketedPaste = Esc + "[?2004h"
|
||||||
ColorDefault = "\033[0m"
|
EndBracketedPaste = Esc + "[?2004l"
|
||||||
|
)
|
||||||
|
|
||||||
StartBracketedPaste = "\033[?2004h"
|
func CursorUpN(n int) string {
|
||||||
EndBracketedPaste = "\033[?2004l"
|
return Esc + "[" + strconv.Itoa(n) + "A"
|
||||||
|
}
|
||||||
|
|
||||||
|
func CursorDownN(n int) string {
|
||||||
|
return Esc + "[" + strconv.Itoa(n) + "B"
|
||||||
|
}
|
||||||
|
|
||||||
|
func CursorRightN(n int) string {
|
||||||
|
return Esc + "[" + strconv.Itoa(n) + "C"
|
||||||
|
}
|
||||||
|
|
||||||
|
func CursorLeftN(n int) string {
|
||||||
|
return Esc + "[" + strconv.Itoa(n) + "D"
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
CursorUp = CursorUpN(1)
|
||||||
|
CursorDown = CursorDownN(1)
|
||||||
|
CursorRight = CursorRightN(1)
|
||||||
|
CursorLeft = CursorLeftN(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -209,15 +209,15 @@ install_cuda_driver_yum() {
|
|||||||
case $PACKAGE_MANAGER in
|
case $PACKAGE_MANAGER in
|
||||||
yum)
|
yum)
|
||||||
$SUDO $PACKAGE_MANAGER -y install yum-utils
|
$SUDO $PACKAGE_MANAGER -y install yum-utils
|
||||||
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
|
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo" >/dev/null ; then
|
||||||
$SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
|
$SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo
|
||||||
else
|
else
|
||||||
error $CUDA_REPO_ERR_MSG
|
error $CUDA_REPO_ERR_MSG
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
dnf)
|
dnf)
|
||||||
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
|
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo" >/dev/null ; then
|
||||||
$SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
|
$SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo
|
||||||
else
|
else
|
||||||
error $CUDA_REPO_ERR_MSG
|
error $CUDA_REPO_ERR_MSG
|
||||||
fi
|
fi
|
||||||
@@ -245,8 +245,8 @@ install_cuda_driver_yum() {
|
|||||||
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian
|
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian
|
||||||
install_cuda_driver_apt() {
|
install_cuda_driver_apt() {
|
||||||
status 'Installing NVIDIA repository...'
|
status 'Installing NVIDIA repository...'
|
||||||
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb" >/dev/null ; then
|
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-keyring_1.1-1_all.deb" >/dev/null ; then
|
||||||
curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb
|
curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-keyring_1.1-1_all.deb
|
||||||
else
|
else
|
||||||
error $CUDA_REPO_ERR_MSG
|
error $CUDA_REPO_ERR_MSG
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
numDownloadParts = 64
|
numDownloadParts = 16
|
||||||
minDownloadPartSize int64 = 100 * format.MegaByte
|
minDownloadPartSize int64 = 100 * format.MegaByte
|
||||||
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
||||||
)
|
)
|
||||||
@@ -216,9 +216,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
if err := setSparse(file); err != nil {
|
setSparse(file)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = file.Truncate(b.Total)
|
_ = file.Truncate(b.Total)
|
||||||
|
|
||||||
@@ -235,7 +233,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||||||
|
|
||||||
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||||
if len(via) > 10 {
|
if len(via) > 10 {
|
||||||
return errors.New("maxium redirects exceeded (10) for directURL")
|
return errors.New("maximum redirects exceeded (10) for directURL")
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the hostname is the same, allow the redirect
|
// if the hostname is the same, allow the redirect
|
||||||
|
|||||||
@@ -373,7 +373,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
|||||||
var messages []*api.Message
|
var messages []*api.Message
|
||||||
parameters := make(map[string]any)
|
parameters := make(map[string]any)
|
||||||
|
|
||||||
var layers []*Layer
|
var layers []Layer
|
||||||
for _, c := range modelfile.Commands {
|
for _, c := range modelfile.Commands {
|
||||||
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
||||||
|
|
||||||
@@ -499,7 +499,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
|||||||
|
|
||||||
if c.Name != "license" {
|
if c.Name != "license" {
|
||||||
// replace
|
// replace
|
||||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
layers = slices.DeleteFunc(layers, func(layer Layer) bool {
|
||||||
if layer.MediaType != mediatype {
|
if layer.MediaType != mediatype {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -545,7 +545,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err2 error
|
var err2 error
|
||||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
layers = slices.DeleteFunc(layers, func(layer Layer) bool {
|
||||||
switch layer.MediaType {
|
switch layer.MediaType {
|
||||||
case "application/vnd.ollama.image.message":
|
case "application/vnd.ollama.image.message":
|
||||||
// if there are new messages, remove the inherited ones
|
// if there are new messages, remove the inherited ones
|
||||||
@@ -625,12 +625,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, layer := range append(layers, layer) {
|
for _, layer := range append(layers, configLayer) {
|
||||||
if layer.status != "" {
|
if layer.status != "" {
|
||||||
fn(api.ProgressResponse{Status: layer.status})
|
fn(api.ProgressResponse{Status: layer.status})
|
||||||
}
|
}
|
||||||
@@ -639,7 +639,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
|||||||
old, _ := ParseNamedManifest(name)
|
old, _ := ParseNamedManifest(name)
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||||
if err := WriteManifest(name, layer, layers); err != nil {
|
if err := WriteManifest(name, configLayer, layers); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -839,10 +839,10 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var layers []*Layer
|
var layers []Layer
|
||||||
layers = append(layers, manifest.Layers...)
|
layers = append(layers, manifest.Layers...)
|
||||||
if manifest.Config.Digest != "" {
|
if manifest.Config.Digest != "" {
|
||||||
layers = append(layers, &manifest.Config)
|
layers = append(layers, manifest.Config)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
@@ -911,10 +911,10 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
return fmt.Errorf("pull model manifest: %s", err)
|
return fmt.Errorf("pull model manifest: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var layers []*Layer
|
var layers []Layer
|
||||||
layers = append(layers, manifest.Layers...)
|
layers = append(layers, manifest.Layers...)
|
||||||
if manifest.Config.Digest != "" {
|
if manifest.Config.Digest != "" {
|
||||||
layers = append(layers, &manifest.Config)
|
layers = append(layers, manifest.Config)
|
||||||
}
|
}
|
||||||
|
|
||||||
skipVerify := make(map[string]bool)
|
skipVerify := make(map[string]bool)
|
||||||
|
|||||||
@@ -16,15 +16,15 @@ type Layer struct {
|
|||||||
status string
|
status string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
|
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||||
blobs, err := GetBlobsPath("")
|
blobs, err := GetBlobsPath("")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
temp, err := os.CreateTemp(blobs, "sha256-")
|
temp, err := os.CreateTemp(blobs, "sha256-")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
defer temp.Close()
|
defer temp.Close()
|
||||||
defer os.Remove(temp.Name())
|
defer os.Remove(temp.Name())
|
||||||
@@ -32,28 +32,28 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
|
|||||||
sha256sum := sha256.New()
|
sha256sum := sha256.New()
|
||||||
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
|
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := temp.Close(); err != nil {
|
if err := temp.Close(); err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
|
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
|
||||||
blob, err := GetBlobsPath(digest)
|
blob, err := GetBlobsPath(digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
status := "using existing layer"
|
status := "using existing layer"
|
||||||
if _, err := os.Stat(blob); err != nil {
|
if _, err := os.Stat(blob); err != nil {
|
||||||
status = "creating new layer"
|
status = "creating new layer"
|
||||||
if err := os.Rename(temp.Name(), blob); err != nil {
|
if err := os.Rename(temp.Name(), blob); err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Layer{
|
return Layer{
|
||||||
MediaType: mediatype,
|
MediaType: mediatype,
|
||||||
Digest: digest,
|
Digest: digest,
|
||||||
Size: n,
|
Size: n,
|
||||||
@@ -61,22 +61,22 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
|
func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||||
if digest == "" {
|
if digest == "" {
|
||||||
return nil, errors.New("creating new layer from layer with empty digest")
|
return Layer{}, errors.New("creating new layer from layer with empty digest")
|
||||||
}
|
}
|
||||||
|
|
||||||
blob, err := GetBlobsPath(digest)
|
blob, err := GetBlobsPath(digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
fi, err := os.Stat(blob)
|
fi, err := os.Stat(blob)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Layer{
|
return Layer{
|
||||||
MediaType: mediatype,
|
MediaType: mediatype,
|
||||||
Digest: digest,
|
Digest: digest,
|
||||||
Size: fi.Size(),
|
Size: fi.Size(),
|
||||||
@@ -109,7 +109,7 @@ func (l *Layer) Remove() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, m := range ms {
|
for _, m := range ms {
|
||||||
for _, layer := range append(m.Layers, &m.Config) {
|
for _, layer := range append(m.Layers, m.Config) {
|
||||||
if layer.Digest == l.Digest {
|
if layer.Digest == l.Digest {
|
||||||
// something is using this layer
|
// something is using this layer
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ type Manifest struct {
|
|||||||
SchemaVersion int `json:"schemaVersion"`
|
SchemaVersion int `json:"schemaVersion"`
|
||||||
MediaType string `json:"mediaType"`
|
MediaType string `json:"mediaType"`
|
||||||
Config Layer `json:"config"`
|
Config Layer `json:"config"`
|
||||||
Layers []*Layer `json:"layers"`
|
Layers []Layer `json:"layers"`
|
||||||
|
|
||||||
filepath string
|
filepath string
|
||||||
fi os.FileInfo
|
fi os.FileInfo
|
||||||
@@ -25,7 +25,7 @@ type Manifest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manifest) Size() (size int64) {
|
func (m *Manifest) Size() (size int64) {
|
||||||
for _, layer := range append(m.Layers, &m.Config) {
|
for _, layer := range append(m.Layers, m.Config) {
|
||||||
size += layer.Size
|
size += layer.Size
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ func (m *Manifest) Remove() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manifest) RemoveLayers() error {
|
func (m *Manifest) RemoveLayers() error {
|
||||||
for _, layer := range append(m.Layers, &m.Config) {
|
for _, layer := range append(m.Layers, m.Config) {
|
||||||
if layer.Digest != "" {
|
if layer.Digest != "" {
|
||||||
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
|
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
|
||||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||||
@@ -95,7 +95,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
|||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||||
manifests, err := GetManifestPath()
|
manifests, err := GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -115,7 +115,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
|||||||
m := Manifest{
|
m := Manifest{
|
||||||
SchemaVersion: 2,
|
SchemaVersion: 2,
|
||||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||||
Config: *config,
|
Config: config,
|
||||||
Layers: layers,
|
Layers: layers,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ import (
|
|||||||
var intermediateBlobs map[string]string = make(map[string]string)
|
var intermediateBlobs map[string]string = make(map[string]string)
|
||||||
|
|
||||||
type layerGGML struct {
|
type layerGGML struct {
|
||||||
*Layer
|
Layer
|
||||||
*llm.GGML
|
*llm.GGML
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,10 +176,21 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
|
|||||||
mediatype = "application/vnd.ollama.image.projector"
|
mediatype = "application/vnd.ollama.image.projector"
|
||||||
}
|
}
|
||||||
|
|
||||||
layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype)
|
var layer Layer
|
||||||
|
if digest != "" && n == stat.Size() && offset == 0 {
|
||||||
|
layer, err = NewLayerFromLayer(digest, mediatype, file.Name())
|
||||||
|
if err != nil {
|
||||||
|
slog.Debug("could not create new layer from layer", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
|
||||||
|
if layer.Digest == "" {
|
||||||
|
layer, err = NewLayer(io.NewSectionReader(file, offset, n), mediatype)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
layers = append(layers, &layerGGML{layer, ggml})
|
layers = append(layers, &layerGGML{layer, ggml})
|
||||||
offset = n
|
offset = n
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -11,6 +13,7 @@ import (
|
|||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -133,3 +136,82 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseFromFileFromLayer(t *testing.T) {
|
||||||
|
tempModels := t.TempDir()
|
||||||
|
|
||||||
|
file, err := os.CreateTemp(tempModels, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
if err := llm.WriteGGUF(file, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil {
|
||||||
|
t.Fatalf("failed to write gguf: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := file.Seek(0, io.SeekStart); err != nil {
|
||||||
|
t.Fatalf("failed to seek to start: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
layers, err := parseFromFile(context.Background(), file, "", func(api.ProgressResponse) {})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse from file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(layers) != 1 {
|
||||||
|
t.Fatalf("got %d != want 1", len(layers))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := file.Seek(0, io.SeekStart); err != nil {
|
||||||
|
t.Fatalf("failed to seek to start: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
layers2, err := parseFromFile(context.Background(), file, layers[0].Digest, func(api.ProgressResponse) {})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse from file: %v", err)
|
||||||
|
}
|
||||||
|
if len(layers2) != 1 {
|
||||||
|
t.Fatalf("got %d != want 1", len(layers2))
|
||||||
|
}
|
||||||
|
|
||||||
|
if layers[0].Digest != layers2[0].Digest {
|
||||||
|
t.Fatalf("got %s != want %s", layers[0].Digest, layers2[0].Digest)
|
||||||
|
}
|
||||||
|
|
||||||
|
if layers[0].Size != layers2[0].Size {
|
||||||
|
t.Fatalf("got %d != want %d", layers[0].Size, layers2[0].Size)
|
||||||
|
}
|
||||||
|
|
||||||
|
if layers[0].MediaType != layers2[0].MediaType {
|
||||||
|
t.Fatalf("got %v != want %v", layers[0].MediaType, layers2[0].MediaType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseLayerFromCopy(t *testing.T) {
|
||||||
|
tempModels := t.TempDir()
|
||||||
|
|
||||||
|
file2, err := os.CreateTemp(tempModels, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open file: %v", err)
|
||||||
|
}
|
||||||
|
defer file2.Close()
|
||||||
|
|
||||||
|
for range 5 {
|
||||||
|
if err := llm.WriteGGUF(file2, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil {
|
||||||
|
t.Fatalf("failed to write gguf: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := file2.Seek(0, io.SeekStart); err != nil {
|
||||||
|
t.Fatalf("failed to seek to start: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
layers, err := parseFromFile(context.Background(), file2, "", func(api.ProgressResponse) {})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse from file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(layers) != 5 {
|
||||||
|
t.Fatalf("got %d != want 5", len(layers))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
@@ -323,13 +324,10 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
input = append(input, v.(string))
|
input = append(input, v.(string))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
if req.Input != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(input) == 0 {
|
|
||||||
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||||
@@ -340,12 +338,18 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
|
if len(input) == 0 {
|
||||||
|
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
kvData, err := getKVData(m.ModelPath, false)
|
kvData, err := getKVData(m.ModelPath, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var count int
|
||||||
for i, s := range input {
|
for i, s := range input {
|
||||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -368,25 +372,36 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
count += len(tokens)
|
||||||
|
|
||||||
input[i] = s
|
input[i] = s
|
||||||
}
|
}
|
||||||
embeddings, err := r.Embed(c.Request.Context(), input)
|
|
||||||
|
var g errgroup.Group
|
||||||
|
embeddings := make([][]float32, len(input))
|
||||||
|
for i, text := range input {
|
||||||
|
g.Go(func() error {
|
||||||
|
embedding, err := r.Embedding(c.Request.Context(), text)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("embedding generation failed", "error", err)
|
return err
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
}
|
||||||
return
|
embeddings[i] = normalize(embedding)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, e := range embeddings.Embedding {
|
if err := g.Wait(); err != nil {
|
||||||
embeddings.Embedding[i] = normalize(e)
|
slog.Error("embedding generation failed", "error", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := api.EmbedResponse{
|
resp := api.EmbedResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
Embeddings: embeddings.Embedding,
|
Embeddings: embeddings,
|
||||||
TotalDuration: time.Since(checkpointStart),
|
TotalDuration: time.Since(checkpointStart),
|
||||||
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||||
PromptEvalCount: embeddings.PromptEvalCount,
|
PromptEvalCount: count,
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
@@ -430,21 +445,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
|
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
embedding := make([]float64, len(embeddings.Embedding[0]))
|
var e []float64
|
||||||
|
for _, v := range embedding {
|
||||||
for i, v := range embeddings.Embedding[0] {
|
e = append(e, float64(v))
|
||||||
embedding[i] = float64(v)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := api.EmbeddingResponse{
|
resp := api.EmbeddingResponse{
|
||||||
Embedding: embedding,
|
Embedding: e,
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create a manifest with duplicate layers
|
// create a manifest with duplicate layers
|
||||||
if err := WriteManifest(n, config, []*Layer{config}); err != nil {
|
if err := WriteManifest(n, config, []Layer{config}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -272,76 +272,6 @@ func Test_Routes(t *testing.T) {
|
|||||||
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Name: "Embed Handler Empty Input",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/embed",
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
embedReq := api.EmbedRequest{
|
|
||||||
Model: "t-bone",
|
|
||||||
Input: "",
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(embedReq)
|
|
||||||
require.NoError(t, err)
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
|
||||||
if contentType != "application/json; charset=utf-8" {
|
|
||||||
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
|
|
||||||
}
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var embedResp api.EmbedResponse
|
|
||||||
err = json.Unmarshal(body, &embedResp)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if embedResp.Model != "t-bone" {
|
|
||||||
t.Fatalf("expected model t-bone, got %s", embedResp.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
if embedResp.Embeddings == nil {
|
|
||||||
t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(embedResp.Embeddings) != 0 {
|
|
||||||
t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "Embed Handler Invalid Input",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/embed",
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
embedReq := api.EmbedRequest{
|
|
||||||
Model: "t-bone",
|
|
||||||
Input: 2,
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(embedReq)
|
|
||||||
require.NoError(t, err)
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
|
||||||
if contentType != "application/json; charset=utf-8" {
|
|
||||||
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
|
|
||||||
}
|
|
||||||
_, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected status code 400, got %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|||||||
@@ -418,7 +418,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
|
|||||||
// some older models are not compatible with newer versions of llama.cpp
|
// some older models are not compatible with newer versions of llama.cpp
|
||||||
// show a generalized compatibility error until there is a better way to
|
// show a generalized compatibility error until there is a better way to
|
||||||
// check for model compatibility
|
// check for model compatibility
|
||||||
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
|
if errors.Is(err, llm.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
|
||||||
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
|
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
|
||||||
}
|
}
|
||||||
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
|
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
|
||||||
|
|||||||
@@ -708,8 +708,8 @@ type mockLlm struct {
|
|||||||
pingResp error
|
pingResp error
|
||||||
waitResp error
|
waitResp error
|
||||||
completionResp error
|
completionResp error
|
||||||
embedResp *llm.EmbedResponse
|
embeddingResp []float32
|
||||||
embedRespErr error
|
embeddingRespErr error
|
||||||
tokenizeResp []int
|
tokenizeResp []int
|
||||||
tokenizeRespErr error
|
tokenizeRespErr error
|
||||||
detokenizeResp string
|
detokenizeResp string
|
||||||
@@ -727,8 +727,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
|
|||||||
return s.completionResp
|
return s.completionResp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) {
|
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||||
return s.embedResp, s.embedRespErr
|
return s.embeddingResp, s.embeddingRespErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
|
|||||||
@@ -4,6 +4,5 @@ package server
|
|||||||
|
|
||||||
import "os"
|
import "os"
|
||||||
|
|
||||||
func setSparse(file *os.File) error {
|
func setSparse(*os.File) {
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ import (
|
|||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setSparse(file *os.File) error {
|
func setSparse(file *os.File) {
|
||||||
return windows.DeviceIoControl(
|
// exFat (and other FS types) don't support sparse files, so ignore errors
|
||||||
|
windows.DeviceIoControl( //nolint:errcheck
|
||||||
windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE,
|
windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE,
|
||||||
nil, 0,
|
nil, 0,
|
||||||
nil, 0,
|
nil, 0,
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ import (
|
|||||||
var blobUploadManager sync.Map
|
var blobUploadManager sync.Map
|
||||||
|
|
||||||
type blobUpload struct {
|
type blobUpload struct {
|
||||||
*Layer
|
Layer
|
||||||
|
|
||||||
Total int64
|
Total int64
|
||||||
Completed atomic.Int64
|
Completed atomic.Int64
|
||||||
@@ -362,7 +362,7 @@ func (p *progressWriter) Rollback() {
|
|||||||
p.written = 0
|
p.written = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
requestURL := mp.BaseURL()
|
requestURL := mp.BaseURL()
|
||||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
||||||
|
|
||||||
|
|||||||
@@ -219,7 +219,7 @@ func (n Name) String() string {
|
|||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisplayShort returns a short string version of the name.
|
// DisplayShortest returns a short string version of the name.
|
||||||
func (n Name) DisplayShortest() string {
|
func (n Name) DisplayShortest() string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user