diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 049a97ed..dbb6c2fd 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -269,9 +269,9 @@ jobs: mkdir -p llm/build/darwin/$ARCH/stub/bin touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server if: ${{ startsWith(matrix.os, 'macos-') }} - - uses: golangci/golangci-lint-action@v4 + - uses: golangci/golangci-lint-action@v6 with: - args: --timeout 8m0s -v + args: --timeout 8m0s -v ${{ startsWith(matrix.os, 'windows-') && '' || '--disable gofmt --disable goimports' }} test: strategy: matrix: diff --git a/.golangci.yaml b/.golangci.yaml index 7dec49de..cfe06e07 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -9,9 +9,26 @@ linters: - contextcheck - exportloopref - gocheckcompilerdirectives - # FIXME: for some reason this errors on windows + # conditionally enable this on linux/macos # - gofmt # - goimports + - intrange - misspell - nilerr + - nolintlint + - nosprintfhostport + - testifylint + - unconvert - unused + - wastedassign + - whitespace + - usestdlibvars +severity: + default-severity: error + rules: + - linters: + - gofmt + - goimports + - intrange + - usestdlibvars + severity: info diff --git a/README.md b/README.md index 124db1c7..10623107 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [![Discord](https://dcbadge.vercel.app/api/server/ollama?style=flat&compact=true)](https://discord.gg/ollama) -Get up and running with large language models locally. +Get up and running with large language models. ### macOS @@ -303,6 +303,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends) - [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama) - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS) +- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) ### Terminal @@ -325,6 +326,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [ShellOracle](https://github.com/djcopley/ShellOracle) - [tlm](https://github.com/yusufcanb/tlm) - [podman-ollama](https://github.com/ericcurtin/podman-ollama) +- [gollama](https://github.com/sammcj/gollama) ### Database @@ -364,6 +366,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama) - [PromptingTools.jl](https://github.com/svilupp/PromptingTools.jl) with an [example](https://svilupp.github.io/PromptingTools.jl/dev/examples/working_with_ollama) - [LlamaScript](https://github.com/Project-Llama/llamascript) + ### Mobile - [Enchanted](https://github.com/AugustDev/enchanted) @@ -397,6 +400,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation) - [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities. -### Supported backends +### Supported backends + - [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov. diff --git a/api/client.go b/api/client.go index d50b397d..dc099e95 100644 --- a/api/client.go +++ b/api/client.go @@ -355,8 +355,8 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) { } // List running models. -func (c *Client) ListRunning(ctx context.Context) (*ListResponse, error) { - var lr ListResponse +func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) { + var lr ProcessResponse if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil { return nil, err } diff --git a/api/types.go b/api/types.go index 4195a7c5..caf2ad70 100644 --- a/api/types.go +++ b/api/types.go @@ -282,19 +282,33 @@ type PushRequest struct { // ListResponse is the response from [Client.List]. type ListResponse struct { - Models []ModelResponse `json:"models"` + Models []ListModelResponse `json:"models"` } -// ModelResponse is a single model description in [ListResponse]. -type ModelResponse struct { +// ProcessResponse is the response from [Client.Process]. +type ProcessResponse struct { + Models []ProcessModelResponse `json:"models"` +} + +// ListModelResponse is a single model description in [ListResponse]. +type ListModelResponse struct { Name string `json:"name"` Model string `json:"model"` - ModifiedAt time.Time `json:"modified_at,omitempty"` + ModifiedAt time.Time `json:"modified_at"` Size int64 `json:"size"` Digest string `json:"digest"` Details ModelDetails `json:"details,omitempty"` - ExpiresAt time.Time `json:"expires_at,omitempty"` - SizeVRAM int64 `json:"size_vram,omitempty"` +} + +// ProcessModelResponse is a single model description in [ProcessResponse]. +type ProcessModelResponse struct { + Name string `json:"name"` + Model string `json:"model"` + Size int64 `json:"size"` + Digest string `json:"digest"` + Details ModelDetails `json:"details,omitempty"` + ExpiresAt time.Time `json:"expires_at"` + SizeVRAM int64 `json:"size_vram"` } type TokenResponse struct { @@ -306,7 +320,7 @@ type GenerateResponse struct { // Model is the model name that generated the response. Model string `json:"model"` - //CreatedAt is the timestamp of the response. + // CreatedAt is the timestamp of the response. CreatedAt time.Time `json:"created_at"` // Response is the textual response itself. diff --git a/api/types_test.go b/api/types_test.go index cfe1331f..211385c7 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -72,13 +72,13 @@ func TestDurationMarshalUnmarshal(t *testing.T) { }, { "positive duration", - time.Duration(42 * time.Second), - time.Duration(42 * time.Second), + 42 * time.Second, + 42 * time.Second, }, { "another positive duration", - time.Duration(42 * time.Minute), - time.Duration(42 * time.Minute), + 42 * time.Minute, + 42 * time.Minute, }, { "zero duration", diff --git a/app/lifecycle/paths.go b/app/lifecycle/paths.go index e4f2dbd9..fe07bce1 100644 --- a/app/lifecycle/paths.go +++ b/app/lifecycle/paths.go @@ -69,7 +69,6 @@ func init() { slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err)) } } - } else if runtime.GOOS == "darwin" { // TODO AppName += ".app" diff --git a/app/lifecycle/server.go b/app/lifecycle/server.go index 3c11edb8..0152ccd1 100644 --- a/app/lifecycle/server.go +++ b/app/lifecycle/server.go @@ -15,7 +15,7 @@ import ( ) func getCLIFullPath(command string) string { - cmdPath := "" + var cmdPath string appExe, err := os.Executable() if err == nil { cmdPath = filepath.Join(filepath.Dir(appExe), command) @@ -65,7 +65,6 @@ func start(ctx context.Context, command string) (*exec.Cmd, error) { if err != nil { if !errors.Is(err, os.ErrNotExist) { return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err) - } if err := os.MkdirAll(logDir, 0o755); err != nil { diff --git a/app/lifecycle/server_windows.go b/app/lifecycle/server_windows.go index cd4244ff..5f9fe124 100644 --- a/app/lifecycle/server_windows.go +++ b/app/lifecycle/server_windows.go @@ -24,7 +24,8 @@ func terminate(cmd *exec.Cmd) error { if err != nil { return err } - defer dll.Release() // nolint: errcheck + //nolint:errcheck + defer dll.Release() pid := cmd.Process.Pid @@ -73,7 +74,8 @@ func isProcessExited(pid int) (bool, error) { if err != nil { return false, fmt.Errorf("failed to open process: %v", err) } - defer windows.CloseHandle(hProcess) // nolint: errcheck + //nolint:errcheck + defer windows.CloseHandle(hProcess) var exitCode uint32 err = windows.GetExitCodeProcess(hProcess, &exitCode) diff --git a/app/lifecycle/updater.go b/app/lifecycle/updater.go index 243bbf22..b6d95330 100644 --- a/app/lifecycle/updater.go +++ b/app/lifecycle/updater.go @@ -78,7 +78,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) { } defer resp.Body.Close() - if resp.StatusCode == 204 { + if resp.StatusCode == http.StatusNoContent { slog.Debug("check update response 204 (current version is up to date)") return false, updateResp } @@ -87,7 +87,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) { slog.Warn(fmt.Sprintf("failed to read body response: %s", err)) } - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { slog.Info(fmt.Sprintf("check update error %d - %.96s", resp.StatusCode, string(body))) return false, updateResp } @@ -114,7 +114,7 @@ func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error { if err != nil { return fmt.Errorf("error checking update: %w", err) } - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode) } resp.Body.Close() diff --git a/app/store/store.go b/app/store/store.go index 13a75a60..b743e8a8 100644 --- a/app/store/store.go +++ b/app/store/store.go @@ -29,7 +29,6 @@ func GetID() string { initStore() } return store.ID - } func GetFirstTimeRun() bool { diff --git a/app/tray/wintray/eventloop.go b/app/tray/wintray/eventloop.go index a0af9787..0f944894 100644 --- a/app/tray/wintray/eventloop.go +++ b/app/tray/wintray/eventloop.go @@ -47,7 +47,6 @@ func nativeLoop() { default: pTranslateMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck pDispatchMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck - } } } @@ -160,8 +159,8 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui lResult, _, _ = pDefWindowProc.Call( uintptr(hWnd), uintptr(message), - uintptr(wParam), - uintptr(lParam), + wParam, + lParam, ) } return diff --git a/app/tray/wintray/tray.go b/app/tray/wintray/tray.go index 69d4487d..027ec5a5 100644 --- a/app/tray/wintray/tray.go +++ b/app/tray/wintray/tray.go @@ -186,7 +186,7 @@ func (t *winTray) initInstance() error { t.muNID.Lock() defer t.muNID.Unlock() t.nid = ¬ifyIconData{ - Wnd: windows.Handle(t.window), + Wnd: t.window, ID: 100, Flags: NIF_MESSAGE, CallbackMessage: t.wmSystrayMessage, @@ -197,7 +197,6 @@ func (t *winTray) initInstance() error { } func (t *winTray) createMenu() error { - menuHandle, _, err := pCreatePopupMenu.Call() if menuHandle == 0 { return err @@ -246,7 +245,7 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title mi := menuItemInfo{ Mask: MIIM_FTYPE | MIIM_STRING | MIIM_ID | MIIM_STATE, Type: MFT_STRING, - ID: uint32(menuItemId), + ID: menuItemId, TypeData: titlePtr, Cch: uint32(len(title)), } @@ -302,11 +301,10 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title } func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error { - mi := menuItemInfo{ Mask: MIIM_FTYPE | MIIM_ID | MIIM_STATE, Type: MFT_SEPARATOR, - ID: uint32(menuItemId), + ID: menuItemId, } mi.Size = uint32(unsafe.Sizeof(mi)) @@ -426,7 +424,6 @@ func iconBytesToFilePath(iconBytes []byte) (string, error) { // Loads an image from file and shows it in tray. // Shell_NotifyIcon: https://msdn.microsoft.com/en-us/library/windows/desktop/bb762159(v=vs.85).aspx func (t *winTray) setIcon(src string) error { - h, err := t.loadIconFrom(src) if err != nil { return err @@ -444,7 +441,6 @@ func (t *winTray) setIcon(src string) error { // Loads an image from file to be shown in tray or menu item. // LoadImage: https://msdn.microsoft.com/en-us/library/windows/desktop/ms648045(v=vs.85).aspx func (t *winTray) loadIconFrom(src string) (windows.Handle, error) { - // Save and reuse handles of loaded images t.muLoadedImages.RLock() h, ok := t.loadedImages[src] diff --git a/cmd/cmd.go b/cmd/cmd.go index b285f83c..b5747543 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -20,6 +20,7 @@ import ( "path/filepath" "regexp" "runtime" + "slices" "strings" "syscall" "time" @@ -29,7 +30,6 @@ import ( "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" "golang.org/x/crypto/ssh" - "golang.org/x/exp/slices" "golang.org/x/term" "github.com/ollama/ollama/api" @@ -746,7 +746,6 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState) if wordWrap && termWidth >= 10 { for _, ch := range content { if state.lineLength+1 > termWidth-5 { - if runewidth.StringWidth(state.wordBuffer) > termWidth-10 { fmt.Printf("%s%c", state.wordBuffer, ch) state.wordBuffer = "" @@ -1030,24 +1029,6 @@ func initializeKeypair() error { return nil } -//nolint:unused -func waitForServer(ctx context.Context, client *api.Client) error { - // wait for the server to start - timeout := time.After(5 * time.Second) - tick := time.Tick(500 * time.Millisecond) - for { - select { - case <-timeout: - return errors.New("timed out waiting for server to start") - case <-tick: - if err := client.Heartbeat(ctx); err == nil { - return nil // server has started - } - } - } - -} - func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { client, err := api.ClientFromEnvironment() if err != nil { diff --git a/cmd/interactive.go b/cmd/interactive.go index c055df0e..80a91547 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -8,11 +8,11 @@ import ( "os" "path/filepath" "regexp" + "slices" "sort" "strings" "github.com/spf13/cobra" - "golang.org/x/exp/slices" "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" diff --git a/cmd/interactive_test.go b/cmd/interactive_test.go index 8eedf729..d9af01eb 100644 --- a/cmd/interactive_test.go +++ b/cmd/interactive_test.go @@ -6,6 +6,7 @@ import ( "text/template" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ollama/ollama/api" ) @@ -85,11 +86,11 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark.""" ` tmpl, err := template.New("").Parse(expectedModelfile) - assert.Nil(t, err) + require.NoError(t, err) var buf bytes.Buffer err = tmpl.Execute(&buf, opts) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, buf.String(), mf) opts.ParentModel = "horseshark" @@ -107,10 +108,10 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark.""" ` tmpl, err = template.New("").Parse(expectedModelfile) - assert.Nil(t, err) + require.NoError(t, err) var parentBuf bytes.Buffer err = tmpl.Execute(&parentBuf, opts) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, parentBuf.String(), mf) } diff --git a/cmd/start.go b/cmd/start.go new file mode 100644 index 00000000..0c4eed08 --- /dev/null +++ b/cmd/start.go @@ -0,0 +1,27 @@ +//go:build darwin || windows + +package cmd + +import ( + "context" + "errors" + "time" + + "github.com/ollama/ollama/api" +) + +func waitForServer(ctx context.Context, client *api.Client) error { + // wait for the server to start + timeout := time.After(5 * time.Second) + tick := time.Tick(500 * time.Millisecond) + for { + select { + case <-timeout: + return errors.New("timed out waiting for server to start") + case <-tick: + if err := client.Heartbeat(ctx); err == nil { + return nil // server has started + } + } + } +} diff --git a/convert/convert.go b/convert/convert.go index e71a0ff3..103de457 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -189,7 +189,7 @@ func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) { if params.VocabSize > len(v.Tokens) { missingTokens := params.VocabSize - len(v.Tokens) slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens)) - for cnt := 0; cnt < missingTokens; cnt++ { + for cnt := range missingTokens { v.Tokens = append(v.Tokens, fmt.Sprintf("", cnt+1)) v.Scores = append(v.Scores, -1) v.Types = append(v.Types, tokenTypeUserDefined) diff --git a/convert/gemma.go b/convert/gemma.go index 9dc406e0..d01ffedf 100644 --- a/convert/gemma.go +++ b/convert/gemma.go @@ -35,7 +35,6 @@ func addOnes(data []float32, vectorSize int) ([]float32, error) { f32s = append(f32s, t...) } - return f32s, nil } diff --git a/convert/llama.go b/convert/llama.go index 7853c4cf..b4211b02 100644 --- a/convert/llama.go +++ b/convert/llama.go @@ -119,11 +119,12 @@ func llamaRepack(name string, params *Params, data []float32, shape []uint64) ([ } var heads int - if strings.HasSuffix(name, "attn_q.weight") { + switch { + case strings.HasSuffix(name, "attn_q.weight"): heads = params.AttentionHeads - } else if strings.HasSuffix(name, "attn_k.weight") { + case strings.HasSuffix(name, "attn_k.weight"): heads = cmp.Or(params.KeyValHeads, params.AttentionHeads) - } else { + default: return nil, fmt.Errorf("unknown tensor name: %s", name) } diff --git a/convert/safetensors.go b/convert/safetensors.go index 69270b87..f45687f1 100644 --- a/convert/safetensors.go +++ b/convert/safetensors.go @@ -120,7 +120,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params) Name: name, Kind: kind, Offset: offset, - Shape: shape[:], + Shape: shape, } t.WriterTo = safetensorWriterTo{ diff --git a/convert/tokenizer.go b/convert/tokenizer.go index efeb5491..fd6df5f5 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -85,11 +85,8 @@ func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, e sha256sum := sha256.New() for _, pt := range t.PreTokenizer.PreTokenizers { - switch pt.Type { - case "Split": - if pt.Pattern.Regex != "" { - sha256sum.Write([]byte(pt.Pattern.Regex)) - } + if pt.Type == "Split" && pt.Pattern.Regex != "" { + sha256sum.Write([]byte(pt.Pattern.Regex)) } } diff --git a/convert/torch.go b/convert/torch.go index b7ae0f76..55414adc 100644 --- a/convert/torch.go +++ b/convert/torch.go @@ -88,7 +88,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, Name: ggufName, Kind: kind, Offset: offset, // calculate the offset - Shape: shape[:], + Shape: shape, } tensor.WriterTo = torchWriterTo{ @@ -104,7 +104,6 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, } return tensors, nil - } func getAltParams(dirpath string) (*Params, error) { diff --git a/docs/api.md b/docs/api.md index 0f11c388..64bfbed8 100644 --- a/docs/api.md +++ b/docs/api.md @@ -12,6 +12,7 @@ - [Pull a Model](#pull-a-model) - [Push a Model](#push-a-model) - [Generate Embeddings](#generate-embeddings) +- [List Running Models](#list-running-models) ## Conventions @@ -1035,3 +1036,48 @@ curl http://localhost:11434/api/embeddings -d '{ ] } ``` + +## List Running Models +```shell +GET /api/ps +``` + +List models that are currently loaded into memory. + +\* If a model is loaded completely into system memory, `size_vram` is omitted from the response. + +#### Examples + +### Request +```shell +curl http://localhost:11434/api/ps +``` + +#### Response + +A single JSON object will be returned. + +```json +{ + "models": [ + { + "name": "mistral:latest", + "model": "mistral:latest", + "size": 5137025024, + "digest": "2ae6f6dd7a3dd734790bbbf58b8909a606e0e7e97e94b7604e0aa7ae4490e6d8", + "details": { + "parent_model": "", + "format": "gguf", + "family": "llama", + "families": [ + "llama" + ], + "parameter_size": "7.2B", + "quantization_level": "Q4_0" + }, + "expires_at": "2024-06-04T14:38:31.83753-07:00", + "size_vram": 5137025024 + } + ] +} +``` \ No newline at end of file diff --git a/docs/tutorials/langchainpy.md b/docs/tutorials/langchainpy.md index 9a1bca0d..06543a07 100644 --- a/docs/tutorials/langchainpy.md +++ b/docs/tutorials/langchainpy.md @@ -45,7 +45,7 @@ all_splits = text_splitter.split_documents(data) ``` It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install chromadb` - +We also need to pull embedding model: `ollama pull nomic-embed-text` ```python from langchain.embeddings import OllamaEmbeddings from langchain.vectorstores import Chroma @@ -68,7 +68,8 @@ The next thing is to send the question and the relevant parts of the docs to the ```python from langchain.chains import RetrievalQA qachain=RetrievalQA.from_chain_type(ollama, retriever=vectorstore.as_retriever()) -qachain.invoke({"query": question}) +res = qachain.invoke({"query": question}) +print(res['result']) ``` The answer received from this chain was: diff --git a/envconfig/config.go b/envconfig/config.go index d6699451..ae4e9939 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -3,6 +3,7 @@ package envconfig import ( "fmt" "log/slog" + "net" "os" "path/filepath" "runtime" @@ -126,7 +127,7 @@ func LoadConfig() { var paths []string for _, root := range []string{filepath.Dir(appExe), cwd} { paths = append(paths, - filepath.Join(root), + root, filepath.Join(root, "windows-"+runtime.GOARCH), filepath.Join(root, "dist", "windows-"+runtime.GOARCH), ) @@ -184,11 +185,17 @@ func LoadConfig() { AllowOrigins = append(AllowOrigins, fmt.Sprintf("http://%s", allowOrigin), fmt.Sprintf("https://%s", allowOrigin), - fmt.Sprintf("http://%s:*", allowOrigin), - fmt.Sprintf("https://%s:*", allowOrigin), + fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")), + fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")), ) } + AllowOrigins = append(AllowOrigins, + "app://*", + "file://*", + "tauri://*", + ) + maxRunners := clean("OLLAMA_MAX_LOADED_MODELS") if maxRunners != "" { m, err := strconv.Atoi(maxRunners) diff --git a/format/format_test.go b/format/format_test.go index 1d73c80b..bff32780 100644 --- a/format/format_test.go +++ b/format/format_test.go @@ -5,7 +5,6 @@ import ( ) func TestHumanNumber(t *testing.T) { - type testCase struct { input uint64 expected string diff --git a/go.mod b/go.mod index 2f3d4ca3..6807b9b4 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( ) require ( + github.com/agnivade/levenshtein v1.1.1 github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/mattn/go-runewidth v0.0.14 github.com/nlpodyssey/gopickle v0.3.0 diff --git a/go.sum b/go.sum index 9e1baebe..926ed26d 100644 --- a/go.sum +++ b/go.sum @@ -4,10 +4,14 @@ dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7 gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= +github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6ICHXqG5hm0ZW5IHyeEJXoIJSOZeBLmWPNeIQ= github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs= +github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= +github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= @@ -36,6 +40,8 @@ github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLc github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g= +github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= diff --git a/gpu/amd_windows.go b/gpu/amd_windows.go index 66979c9b..290f8677 100644 --- a/gpu/amd_windows.go +++ b/gpu/amd_windows.go @@ -65,7 +65,7 @@ func AMDGetGPUInfo() []GpuInfo { slog.Debug("detected hip devices", "count", count) // TODO how to determine the underlying device ID when visible devices is causing this to subset? - for i := 0; i < count; i++ { + for i := range count { err = hl.HipSetDevice(i) if err != nil { slog.Warn("set device", "id", i, "error", err) diff --git a/gpu/assets.go b/gpu/assets.go index e3fbe47c..f2adcf3e 100644 --- a/gpu/assets.go +++ b/gpu/assets.go @@ -80,7 +80,7 @@ func cleanupTmpDirs() { if err == nil { pid, err := strconv.Atoi(string(raw)) if err == nil { - if proc, err := os.FindProcess(int(pid)); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) { + if proc, err := os.FindProcess(pid); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) { // Another running ollama, ignore this tmpdir continue } diff --git a/gpu/cuda_common.go b/gpu/cuda_common.go index 03c1a25b..c90a644c 100644 --- a/gpu/cuda_common.go +++ b/gpu/cuda_common.go @@ -18,5 +18,4 @@ func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { ids = append(ids, info.ID) } return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",") - } diff --git a/gpu/gpu.go b/gpu/gpu.go index 03e16702..73ef1358 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -187,7 +187,7 @@ func GetGPUInfo() GpuInfoList { resp := []GpuInfo{} // NVIDIA first - for i := 0; i < gpuHandles.deviceCount; i++ { + for i := range gpuHandles.deviceCount { // TODO once we support CPU compilation variants of GPU libraries refine this... if cpuVariant == "" && runtime.GOARCH == "amd64" { continue @@ -221,8 +221,8 @@ func GetGPUInfo() GpuInfoList { gpuInfo.MinimumMemory = cudaMinimumMemory gpuInfo.DependencyPath = depPath gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - gpuInfo.DriverMajor = int(driverMajor) - gpuInfo.DriverMinor = int(driverMinor) + gpuInfo.DriverMajor = driverMajor + gpuInfo.DriverMinor = driverMinor // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... resp = append(resp, gpuInfo) diff --git a/gpu/gpu_test.go b/gpu/gpu_test.go index a28cbe8c..46d3201e 100644 --- a/gpu/gpu_test.go +++ b/gpu/gpu_test.go @@ -5,11 +5,12 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBasicGetGPUInfo(t *testing.T) { info := GetGPUInfo() - assert.Greater(t, len(info), 0) + assert.NotEmpty(t, len(info)) assert.Contains(t, "cuda rocm cpu metal", info[0].Library) if info[0].Library != "cpu" { assert.Greater(t, info[0].TotalMemory, uint64(0)) @@ -19,7 +20,7 @@ func TestBasicGetGPUInfo(t *testing.T) { func TestCPUMemInfo(t *testing.T) { info, err := GetCPUMem() - assert.NoError(t, err) + require.NoError(t, err) switch runtime.GOOS { case "darwin": t.Skip("CPU memory not populated on darwin") diff --git a/llm/ggml.go b/llm/ggml.go index 878800f3..645447d5 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -81,6 +81,11 @@ func (kv KV) ContextLength() uint64 { return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture())) } +func (kv KV) ChatTemplate() string { + s, _ := kv["tokenizer.chat_template"].(string) + return s +} + type Tensors []*Tensor func (ts Tensors) Layers() map[string]Layer { diff --git a/llm/gguf.go b/llm/gguf.go index 0ba48f76..ca7e340d 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -592,8 +592,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error { return err } - dims := 0 - for cnt := 0; cnt < len(tensor.Shape); cnt++ { + var dims int + for cnt := range len(tensor.Shape) { if tensor.Shape[cnt] > 0 { dims++ } @@ -603,8 +603,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error { return err } - for i := 0; i < dims; i++ { - if err := binary.Write(ws, llm.ByteOrder, uint64(tensor.Shape[dims-1-i])); err != nil { + for i := range dims { + if err := binary.Write(ws, llm.ByteOrder, tensor.Shape[dims-1-i]); err != nil { return err } } diff --git a/llm/memory.go b/llm/memory.go index ff64baf1..1c2e476b 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -5,9 +5,9 @@ import ( "log/slog" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" - "github.com/ollama/ollama/envconfig" ) // This algorithm looks for a complete fit to determine if we need to unload other models @@ -103,7 +103,7 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts } var layerCount int - for i := 0; i < int(ggml.KV().BlockCount()); i++ { + for i := range int(ggml.KV().BlockCount()) { if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok { memoryLayer := blk.size() diff --git a/llm/patches/06-qwen2.diff b/llm/patches/06-qwen2.diff new file mode 100644 index 00000000..d7b0c155 --- /dev/null +++ b/llm/patches/06-qwen2.diff @@ -0,0 +1,13 @@ +diff --git a/llama.cpp b/llama.cpp +index 40d2ec2c..f34eb79a 100644 +--- a/llama.cpp ++++ b/llama.cpp +@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv( + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + +- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { ++ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); diff --git a/llm/payload.go b/llm/payload.go index abe3d263..a025ee34 100644 --- a/llm/payload.go +++ b/llm/payload.go @@ -10,9 +10,9 @@ import ( "os" "path/filepath" "runtime" + "slices" "strings" - "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "github.com/ollama/ollama/gpu" diff --git a/llm/server.go b/llm/server.go index 3af8a329..6cb01fa0 100644 --- a/llm/server.go +++ b/llm/server.go @@ -85,7 +85,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr var systemMemory uint64 gpuCount := len(gpus) if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 { - // TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner cpuRunner = serverForCpu() @@ -104,21 +103,22 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr var layers int layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts) - if gpus[0].Library == "metal" && estimatedVRAM > systemMemory { + switch { + case gpus[0].Library == "metal" && estimatedVRAM > systemMemory: // disable partial offloading when model is greater than total system memory as this // can lead to locking up the system opts.NumGPU = 0 - } else if gpus[0].Library != "metal" && layers == 0 { + case gpus[0].Library != "metal" && layers == 0: // Don't bother loading into the GPU if no layers can fit cpuRunner = serverForCpu() gpuCount = 0 - } else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" { + case opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu": opts.NumGPU = layers } } // Loop through potential servers - finalErr := fmt.Errorf("no suitable llama servers found") + finalErr := errors.New("no suitable llama servers found") if len(adapters) > 1 { return nil, errors.New("ollama supports only one lora adapter, but multiple were provided") @@ -232,7 +232,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr params = append(params, "--parallel", fmt.Sprintf("%d", numParallel)) - for i := 0; i < len(servers); i++ { + for i := range len(servers) { dir := availableServers[servers[i]] if dir == "" { // Shouldn't happen @@ -284,7 +284,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr server := filepath.Join(dir, "ollama_llama_server") if runtime.GOOS == "windows" { - server = server + ".exe" + server += ".exe" } // Detect tmp cleaners wiping out the file @@ -315,7 +315,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr s.cmd.Stdout = os.Stdout s.cmd.Stderr = s.status - visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv() + visibleDevicesEnv, visibleDevicesEnvVal := gpus.GetVisibleDevicesEnv() pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) // Update or add the path and visible devices variable with our adjusted version @@ -459,7 +459,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { resp, err := http.DefaultClient.Do(req) if err != nil { if errors.Is(err, context.DeadlineExceeded) { - return ServerStatusNotResponding, fmt.Errorf("server not responding") + return ServerStatusNotResponding, errors.New("server not responding") } return ServerStatusError, fmt.Errorf("health resp: %w", err) } diff --git a/openai/openai.go b/openai/openai.go index 7ce29e9f..310051a5 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -245,7 +245,6 @@ func (w *writer) writeResponse(data []byte) (int, error) { d, err := json.Marshal(toChunk(w.id, chatResponse)) if err != nil { return 0, err - } w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") diff --git a/parser/parser_test.go b/parser/parser_test.go index 21223cb1..55660590 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -10,6 +10,7 @@ import ( "unicode/utf16" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestParseFileFile(t *testing.T) { @@ -25,7 +26,7 @@ TEMPLATE template1 reader := strings.NewReader(input) modelfile, err := ParseFile(reader) - assert.NoError(t, err) + require.NoError(t, err) expectedCommands := []Command{ {Name: "model", Args: "model1"}, @@ -88,7 +89,7 @@ func TestParseFileFrom(t *testing.T) { for _, c := range cases { t.Run("", func(t *testing.T) { modelfile, err := ParseFile(strings.NewReader(c.input)) - assert.ErrorIs(t, err, c.err) + require.ErrorIs(t, err, c.err) if modelfile != nil { assert.Equal(t, c.expected, modelfile.Commands) } @@ -105,7 +106,7 @@ PARAMETER param1 reader := strings.NewReader(input) _, err := ParseFile(reader) - assert.ErrorIs(t, err, io.ErrUnexpectedEOF) + require.ErrorIs(t, err, io.ErrUnexpectedEOF) } func TestParseFileBadCommand(t *testing.T) { @@ -114,8 +115,7 @@ FROM foo BADCOMMAND param1 value1 ` _, err := ParseFile(strings.NewReader(input)) - assert.ErrorIs(t, err, errInvalidCommand) - + require.ErrorIs(t, err, errInvalidCommand) } func TestParseFileMessages(t *testing.T) { @@ -201,7 +201,7 @@ MESSAGE system`, for _, c := range cases { t.Run("", func(t *testing.T) { modelfile, err := ParseFile(strings.NewReader(c.input)) - assert.ErrorIs(t, err, c.err) + require.ErrorIs(t, err, c.err) if modelfile != nil { assert.Equal(t, c.expected, modelfile.Commands) } @@ -355,7 +355,7 @@ TEMPLATE """ for _, c := range cases { t.Run("", func(t *testing.T) { modelfile, err := ParseFile(strings.NewReader(c.multiline)) - assert.ErrorIs(t, err, c.err) + require.ErrorIs(t, err, c.err) if modelfile != nil { assert.Equal(t, c.expected, modelfile.Commands) } @@ -413,7 +413,7 @@ func TestParseFileParameters(t *testing.T) { fmt.Fprintln(&b, "FROM foo") fmt.Fprintln(&b, "PARAMETER", k) modelfile, err := ParseFile(&b) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, []Command{ {Name: "model", Args: "foo"}, @@ -442,7 +442,7 @@ FROM foo for _, c := range cases { t.Run("", func(t *testing.T) { modelfile, err := ParseFile(strings.NewReader(c.input)) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, c.expected, modelfile.Commands) }) } @@ -501,15 +501,14 @@ SYSTEM "" for _, c := range cases { t.Run("", func(t *testing.T) { modelfile, err := ParseFile(strings.NewReader(c)) - assert.NoError(t, err) + require.NoError(t, err) modelfile2, err := ParseFile(strings.NewReader(modelfile.String())) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, modelfile, modelfile2) }) } - } func TestParseFileUTF16ParseFile(t *testing.T) { @@ -522,10 +521,10 @@ SYSTEM You are a utf16 file. utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...)) buf := new(bytes.Buffer) err := binary.Write(buf, binary.LittleEndian, utf16File) - assert.NoError(t, err) + require.NoError(t, err) actual, err := ParseFile(buf) - assert.NoError(t, err) + require.NoError(t, err) expected := []Command{ {Name: "model", Args: "bob"}, @@ -539,9 +538,9 @@ SYSTEM You are a utf16 file. // simulate a utf16 be file buf = new(bytes.Buffer) err = binary.Write(buf, binary.BigEndian, utf16File) - assert.NoError(t, err) + require.NoError(t, err) actual, err = ParseFile(buf) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, expected, actual.Commands) } diff --git a/progress/progress.go b/progress/progress.go index 556ba00f..102830a8 100644 --- a/progress/progress.go +++ b/progress/progress.go @@ -59,7 +59,7 @@ func (p *Progress) StopAndClear() bool { stopped := p.stop() if stopped { // clear all progress lines - for i := 0; i < p.pos; i++ { + for i := range p.pos { if i > 0 { fmt.Fprint(p.w, "\033[A") } @@ -85,7 +85,7 @@ func (p *Progress) render() { defer fmt.Fprint(p.w, "\033[?25h") // clear already rendered progress lines - for i := 0; i < p.pos; i++ { + for i := range p.pos { if i > 0 { fmt.Fprint(p.w, "\033[A") } diff --git a/readline/buffer.go b/readline/buffer.go index 2c3bfec9..b7cf9b13 100644 --- a/readline/buffer.go +++ b/readline/buffer.go @@ -52,7 +52,6 @@ func (b *Buffer) GetLineSpacing(line int) bool { } return hasSpace.(bool) - } func (b *Buffer) MoveLeft() { @@ -117,15 +116,12 @@ func (b *Buffer) MoveRight() { if b.DisplayPos%b.LineWidth == 0 { fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt()))) - } else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace { fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())+rLength)) b.DisplayPos += 1 - } else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace { fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt()))) b.DisplayPos += 1 - } else { fmt.Print(cursorRightN(rLength)) } @@ -154,7 +150,7 @@ func (b *Buffer) MoveToStart() { if b.Pos > 0 { currLine := b.DisplayPos / b.LineWidth if currLine > 0 { - for cnt := 0; cnt < currLine; cnt++ { + for range currLine { fmt.Print(CursorUp) } } @@ -169,7 +165,7 @@ func (b *Buffer) MoveToEnd() { currLine := b.DisplayPos / b.LineWidth totalLines := b.DisplaySize() / b.LineWidth if currLine < totalLines { - for cnt := 0; cnt < totalLines-currLine; cnt++ { + for range totalLines - currLine { fmt.Print(CursorDown) } remainder := b.DisplaySize() % b.LineWidth @@ -185,7 +181,7 @@ func (b *Buffer) MoveToEnd() { func (b *Buffer) DisplaySize() int { sum := 0 - for i := 0; i < b.Buf.Size(); i++ { + for i := range b.Buf.Size() { if e, ok := b.Buf.Get(i); ok { if r, ok := e.(rune); ok { sum += runewidth.RuneWidth(r) @@ -197,7 +193,6 @@ func (b *Buffer) DisplaySize() int { } func (b *Buffer) Add(r rune) { - if b.Pos == b.Buf.Size() { b.AddChar(r, false) } else { @@ -210,7 +205,6 @@ func (b *Buffer) AddChar(r rune, insert bool) { b.DisplayPos += rLength if b.Pos > 0 { - if b.DisplayPos%b.LineWidth == 0 { fmt.Printf("%c", r) fmt.Printf("\n%s", b.Prompt.AltPrompt) @@ -235,7 +229,6 @@ func (b *Buffer) AddChar(r rune, insert bool) { } else { b.LineHasSpace.Add(true) } - } else { fmt.Printf("%c", r) } @@ -356,7 +349,6 @@ func (b *Buffer) drawRemaining() { func (b *Buffer) Remove() { if b.Buf.Size() > 0 && b.Pos > 0 { - if e, ok := b.Buf.Get(b.Pos - 1); ok { if r, ok := e.(rune); ok { rLength := runewidth.RuneWidth(r) @@ -382,7 +374,6 @@ func (b *Buffer) Remove() { } else { fmt.Print(" " + CursorLeft) } - } else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace { fmt.Printf(CursorBOL + ClearToEOL) fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width)) @@ -391,10 +382,9 @@ func (b *Buffer) Remove() { b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1) } b.DisplayPos -= 1 - } else { fmt.Print(cursorLeftN(rLength)) - for i := 0; i < rLength; i++ { + for range rLength { fmt.Print(" ") } fmt.Print(cursorLeftN(rLength)) @@ -451,7 +441,7 @@ func (b *Buffer) DeleteBefore() { func (b *Buffer) DeleteRemaining() { if b.DisplaySize() > 0 && b.Pos < b.DisplaySize() { charsToDel := b.Buf.Size() - b.Pos - for cnt := 0; cnt < charsToDel; cnt++ { + for range charsToDel { b.Delete() } } @@ -495,7 +485,7 @@ func (b *Buffer) ClearScreen() { if currPos > 0 { targetLine := currPos / b.LineWidth if targetLine > 0 { - for cnt := 0; cnt < targetLine; cnt++ { + for range targetLine { fmt.Print(CursorDown) } } @@ -525,7 +515,7 @@ func (b *Buffer) Replace(r []rune) { fmt.Printf(CursorBOL + ClearToEOL) - for i := 0; i < lineNums; i++ { + for range lineNums { fmt.Print(CursorUp + CursorBOL + ClearToEOL) } diff --git a/readline/history.go b/readline/history.go index 670a0722..9c6b930b 100644 --- a/readline/history.go +++ b/readline/history.go @@ -91,7 +91,7 @@ func (h *History) Add(l []rune) { func (h *History) Compact() { s := h.Buf.Size() if s > h.Limit { - for cnt := 0; cnt < s-h.Limit; cnt++ { + for range s - h.Limit { h.Buf.Remove(0) } } @@ -139,7 +139,7 @@ func (h *History) Save() error { defer f.Close() buf := bufio.NewWriter(f) - for cnt := 0; cnt < h.Size(); cnt++ { + for cnt := range h.Size() { v, _ := h.Buf.Get(cnt) line, _ := v.([]rune) if _, err := buf.WriteString(string(line) + "\n"); err != nil { diff --git a/readline/readline.go b/readline/readline.go index ee461ae4..e90a5e01 100644 --- a/readline/readline.go +++ b/readline/readline.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "os" - "syscall" ) type Prompt struct { @@ -63,7 +62,7 @@ func New(prompt Prompt) (*Instance, error) { func (i *Instance) Readline() (string, error) { if !i.Terminal.rawmode { - fd := int(syscall.Stdin) + fd := os.Stdin.Fd() termios, err := SetRawMode(fd) if err != nil { return "", err @@ -80,8 +79,8 @@ func (i *Instance) Readline() (string, error) { fmt.Print(prompt) defer func() { - fd := int(syscall.Stdin) - // nolint: errcheck + fd := os.Stdin.Fd() + //nolint:errcheck UnsetRawMode(fd, i.Terminal.termios) i.Terminal.rawmode = false }() @@ -136,7 +135,7 @@ func (i *Instance) Readline() (string, error) { buf.MoveRight() case CharBracketedPaste: var code string - for cnt := 0; cnt < 3; cnt++ { + for range 3 { r, err = i.Terminal.Read() if err != nil { return "", io.EOF @@ -198,7 +197,7 @@ func (i *Instance) Readline() (string, error) { buf.Remove() case CharTab: // todo: convert back to real tabs - for cnt := 0; cnt < 8; cnt++ { + for range 8 { buf.Add(' ') } case CharDelete: @@ -216,7 +215,7 @@ func (i *Instance) Readline() (string, error) { case CharCtrlW: buf.DeleteWord() case CharCtrlZ: - fd := int(syscall.Stdin) + fd := os.Stdin.Fd() return handleCharCtrlZ(fd, i.Terminal.termios) case CharEnter, CharCtrlJ: output := buf.String() @@ -248,7 +247,7 @@ func (i *Instance) HistoryDisable() { } func NewTerminal() (*Terminal, error) { - fd := int(syscall.Stdin) + fd := os.Stdin.Fd() termios, err := SetRawMode(fd) if err != nil { return nil, err diff --git a/readline/readline_unix.go b/readline/readline_unix.go index 76cff8c8..d48b9176 100644 --- a/readline/readline_unix.go +++ b/readline/readline_unix.go @@ -6,7 +6,7 @@ import ( "syscall" ) -func handleCharCtrlZ(fd int, termios any) (string, error) { +func handleCharCtrlZ(fd uintptr, termios any) (string, error) { t := termios.(*Termios) if err := UnsetRawMode(fd, t); err != nil { return "", err diff --git a/readline/readline_windows.go b/readline/readline_windows.go index b4e96b25..a131d0ef 100644 --- a/readline/readline_windows.go +++ b/readline/readline_windows.go @@ -1,6 +1,6 @@ package readline -func handleCharCtrlZ(fd int, state any) (string, error) { +func handleCharCtrlZ(fd uintptr, state any) (string, error) { // not supported return "", nil } diff --git a/readline/term.go b/readline/term.go index 9d747162..5584cd25 100644 --- a/readline/term.go +++ b/readline/term.go @@ -8,7 +8,7 @@ import ( type Termios syscall.Termios -func SetRawMode(fd int) (*Termios, error) { +func SetRawMode(fd uintptr) (*Termios, error) { termios, err := getTermios(fd) if err != nil { return nil, err @@ -25,13 +25,13 @@ func SetRawMode(fd int) (*Termios, error) { return termios, setTermios(fd, &newTermios) } -func UnsetRawMode(fd int, termios any) error { +func UnsetRawMode(fd uintptr, termios any) error { t := termios.(*Termios) return setTermios(fd, t) } // IsTerminal returns true if the given file descriptor is a terminal. -func IsTerminal(fd int) bool { +func IsTerminal(fd uintptr) bool { _, err := getTermios(fd) return err == nil } diff --git a/readline/term_bsd.go b/readline/term_bsd.go index 04b912f0..80bee6b3 100644 --- a/readline/term_bsd.go +++ b/readline/term_bsd.go @@ -7,17 +7,17 @@ import ( "unsafe" ) -func getTermios(fd int) (*Termios, error) { +func getTermios(fd uintptr) (*Termios, error) { termios := new(Termios) - _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0) + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0) if err != 0 { return nil, err } return termios, nil } -func setTermios(fd int, termios *Termios) error { - _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0) +func setTermios(fd uintptr, termios *Termios) error { + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0) if err != 0 { return err } diff --git a/readline/term_linux.go b/readline/term_linux.go index 2d6211dd..e9ed0745 100644 --- a/readline/term_linux.go +++ b/readline/term_linux.go @@ -10,17 +10,17 @@ import ( const tcgets = 0x5401 const tcsets = 0x5402 -func getTermios(fd int) (*Termios, error) { +func getTermios(fd uintptr) (*Termios, error) { termios := new(Termios) - _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0) + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0) if err != 0 { return nil, err } return termios, nil } -func setTermios(fd int, termios *Termios) error { - _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0) +func setTermios(fd uintptr, termios *Termios) error { + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0) if err != 0 { return err } diff --git a/readline/term_windows.go b/readline/term_windows.go index cfdfd672..3b35149b 100644 --- a/readline/term_windows.go +++ b/readline/term_windows.go @@ -9,13 +9,13 @@ type State struct { } // IsTerminal checks if the given file descriptor is associated with a terminal -func IsTerminal(fd int) bool { +func IsTerminal(fd uintptr) bool { var st uint32 err := windows.GetConsoleMode(windows.Handle(fd), &st) return err == nil } -func SetRawMode(fd int) (*State, error) { +func SetRawMode(fd uintptr) (*State, error) { var st uint32 if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil { return nil, err @@ -32,7 +32,7 @@ func SetRawMode(fd int) (*State, error) { return &State{st}, nil } -func UnsetRawMode(fd int, state any) error { +func UnsetRawMode(fd uintptr, state any) error { s := state.(*State) return windows.SetConsoleMode(windows.Handle(fd), s.mode) } diff --git a/server/download.go b/server/download.go index 5a735abf..d93cd3b4 100644 --- a/server/download.go +++ b/server/download.go @@ -340,17 +340,17 @@ type downloadOpts struct { } // downloadBlob downloads a blob from the registry and stores it in the blobs directory -func downloadBlob(ctx context.Context, opts downloadOpts) error { +func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) { fp, err := GetBlobsPath(opts.digest) if err != nil { - return err + return false, err } fi, err := os.Stat(fp) switch { case errors.Is(err, os.ErrNotExist): case err != nil: - return err + return false, err default: opts.fn(api.ProgressResponse{ Status: fmt.Sprintf("pulling %s", opts.digest[7:19]), @@ -359,7 +359,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { Completed: fi.Size(), }) - return nil + return true, nil } data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest}) @@ -369,12 +369,12 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest) if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil { blobDownloadManager.Delete(opts.digest) - return err + return false, err } - // nolint: contextcheck + //nolint:contextcheck go download.Run(context.Background(), requestURL, opts.regOpts) } - return download.Wait(ctx, opts.fn) + return false, download.Wait(ctx, opts.fn) } diff --git a/server/images.go b/server/images.go index 9254671c..32207f20 100644 --- a/server/images.go +++ b/server/images.go @@ -18,17 +18,17 @@ import ( "os" "path/filepath" "runtime" + "slices" "strconv" "strings" - "golang.org/x/exp/slices" - "github.com/ollama/ollama/api" "github.com/ollama/ollama/auth" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/parser" - "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/templates" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -315,7 +315,7 @@ func realpath(rel, from string) string { return abspath } -func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) { +func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) { config := ConfigV2{ OS: "linux", Architecture: "amd64", @@ -435,24 +435,47 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount())) config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String()) config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture()) + + if s := baseLayer.GGML.KV().ChatTemplate(); s != "" { + t, err := templates.NamedTemplate(s) + if err != nil { + return err + } + + layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template") + if err != nil { + return err + } + + layer.status = fmt.Sprintf("using autodetected template %s", t.Name) + layers = append(layers, layer) + } } layers = append(layers, baseLayer.Layer) } case "license", "template", "system": + if c.Name != "license" { + // replace + layers = slices.DeleteFunc(layers, func(layer *Layer) bool { + if layer.MediaType != mediatype { + return false + } + + if err := layer.Remove(); err != nil { + return false + } + + return true + }) + } + blob := strings.NewReader(c.Args) layer, err := NewLayer(blob, mediatype) if err != nil { return err } - if c.Name != "license" { - // replace - layers = slices.DeleteFunc(layers, func(layer *Layer) bool { - return layer.MediaType == mediatype - }) - } - layers = append(layers, layer) case "message": role, content, ok := strings.Cut(c.Args, ": ") @@ -571,26 +594,15 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m } } - unref := make(map[string]struct{}) - if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil { - for _, layer := range manifest.Layers { - if !slices.Contains(digests, layer.Digest) { - unref[layer.Digest] = struct{}{} - } - } - - if manifest.Config.Digest != layer.Digest { - unref[manifest.Config.Digest] = struct{}{} - } - } + old, _ := ParseNamedManifest(name) fn(api.ProgressResponse{Status: "writing manifest"}) if err := WriteManifest(name, layer, layers); err != nil { return err } - if !envconfig.NoPrune { - if err := deleteUnusedLayers(nil, unref); err != nil { + if !envconfig.NoPrune && old != nil { + if err := old.RemoveLayers(); err != nil { return err } } @@ -662,7 +674,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}) // save (i.e. delete from the deleteMap) any files used in other manifests manifest, _, err := GetManifest(fmp) if err != nil { - // nolint: nilerr + //nolint:nilerr return nil } @@ -857,23 +869,27 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu layers = append(layers, manifest.Layers...) layers = append(layers, manifest.Config) + skipVerify := make(map[string]bool) for _, layer := range layers { - if err := downloadBlob( - ctx, - downloadOpts{ - mp: mp, - digest: layer.Digest, - regOpts: regOpts, - fn: fn, - }); err != nil { + cacheHit, err := downloadBlob(ctx, downloadOpts{ + mp: mp, + digest: layer.Digest, + regOpts: regOpts, + fn: fn, + }) + if err != nil { return err } + skipVerify[layer.Digest] = cacheHit delete(deleteMap, layer.Digest) } delete(deleteMap, manifest.Config.Digest) fn(api.ProgressResponse{Status: "verifying sha256 digest"}) for _, layer := range layers { + if skipVerify[layer.Digest] { + continue + } if err := verifyBlob(layer.Digest); err != nil { if errors.Is(err, errDigestMismatch) { // something went wrong, delete the blob @@ -988,7 +1004,7 @@ func getTokenSubject(token string) string { func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) { anonymous := true // access will default to anonymous if no user is found associated with the public key - for i := 0; i < 2; i++ { + for range 2 { resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) if err != nil { if !errors.Is(err, context.Canceled) { diff --git a/server/manifest.go b/server/manifest.go index a5251298..d0675724 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -1,7 +1,6 @@ package server import ( - "bytes" "crypto/sha256" "encoding/json" "fmt" @@ -34,12 +33,6 @@ func (m *Manifest) Remove() error { return err } - for _, layer := range append(m.Layers, m.Config) { - if err := layer.Remove(); err != nil { - return err - } - } - manifests, err := GetManifestPath() if err != nil { return err @@ -48,6 +41,16 @@ func (m *Manifest) Remove() error { return PruneDirectory(manifests) } +func (m *Manifest) RemoveLayers() error { + for _, layer := range append(m.Layers, m.Config) { + if err := layer.Remove(); err != nil { + return err + } + } + + return nil +} + func ParseNamedManifest(n model.Name) (*Manifest, error) { if !n.IsFullyQualified() { return nil, model.Unqualified(n) @@ -85,30 +88,31 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { }, nil } -func WriteManifest(name string, config *Layer, layers []*Layer) error { - manifest := ManifestV2{ +func WriteManifest(name model.Name, config *Layer, layers []*Layer) error { + manifests, err := GetManifestPath() + if err != nil { + return err + } + + p := filepath.Join(manifests, name.Filepath()) + if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil { + return err + } + + f, err := os.Create(p) + if err != nil { + return err + } + defer f.Close() + + m := ManifestV2{ SchemaVersion: 2, MediaType: "application/vnd.docker.distribution.manifest.v2+json", Config: config, Layers: layers, } - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(manifest); err != nil { - return err - } - - modelpath := ParseModelPath(name) - manifestPath, err := modelpath.GetManifestPath() - if err != nil { - return err - } - - if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil { - return err - } - - return os.WriteFile(manifestPath, b.Bytes(), 0o644) + return json.NewEncoder(f).Encode(m) } func Manifests() (map[model.Name]*Manifest, error) { diff --git a/server/model.go b/server/model.go index fcf406f6..ee2ae080 100644 --- a/server/model.go +++ b/server/model.go @@ -25,16 +25,14 @@ type layerWithGGML struct { } func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { - modelpath := ParseModelPath(name.String()) - manifest, _, err := GetManifest(modelpath) + m, err := ParseNamedManifest(name) switch { case errors.Is(err, os.ErrNotExist): if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil { return nil, err } - modelpath = ParseModelPath(name.String()) - manifest, _, err = GetManifest(modelpath) + m, err = ParseNamedManifest(name) if err != nil { return nil, err } @@ -42,8 +40,8 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe return nil, err } - for _, layer := range manifest.Layers { - layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname()) + for _, layer := range m.Layers { + layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest()) if err != nil { return nil, err } @@ -72,7 +70,6 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe default: layers = append(layers, &layerWithGGML{layer, nil}) } - } return layers, nil diff --git a/server/modelpath_test.go b/server/modelpath_test.go index 30741d87..849e0fa7 100644 --- a/server/modelpath_test.go +++ b/server/modelpath_test.go @@ -6,12 +6,13 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGetBlobsPath(t *testing.T) { // GetBlobsPath expects an actual directory to exist dir, err := os.MkdirTemp("", "ollama-test") - assert.Nil(t, err) + require.NoError(t, err) defer os.RemoveAll(dir) tests := []struct { @@ -63,7 +64,7 @@ func TestGetBlobsPath(t *testing.T) { got, err := GetBlobsPath(tc.digest) - assert.ErrorIs(t, tc.err, err, tc.name) + require.ErrorIs(t, tc.err, err, tc.name) assert.Equal(t, tc.expected, got, tc.name) }) } diff --git a/server/routes.go b/server/routes.go index 7a6dfd1f..188fe974 100644 --- a/server/routes.go +++ b/server/routes.go @@ -16,6 +16,7 @@ import ( "os" "os/signal" "path/filepath" + "slices" "strconv" "strings" "syscall" @@ -23,7 +24,6 @@ import ( "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" - "golang.org/x/exp/slices" "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" @@ -77,7 +77,6 @@ func isSupportedImageType(image []byte) bool { } func (s *Server) GenerateHandler(c *gin.Context) { - checkpointStart := time.Now() var req api.GenerateRequest err := c.ShouldBindJSON(&req) @@ -524,8 +523,8 @@ func checkNameExists(name model.Name) error { } func (s *Server) CreateModelHandler(c *gin.Context) { - var req api.CreateRequest - if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { + var r api.CreateRequest + if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return } else if err != nil { @@ -533,7 +532,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) { return } - name := model.ParseName(cmp.Or(req.Model, req.Name)) + name := model.ParseName(cmp.Or(r.Model, r.Name)) if !name.IsValid() { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg}) return @@ -544,24 +543,24 @@ func (s *Server) CreateModelHandler(c *gin.Context) { return } - if req.Path == "" && req.Modelfile == "" { + if r.Path == "" && r.Modelfile == "" { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) return } - var r io.Reader = strings.NewReader(req.Modelfile) - if req.Path != "" && req.Modelfile == "" { - f, err := os.Open(req.Path) + var sr io.Reader = strings.NewReader(r.Modelfile) + if r.Path != "" && r.Modelfile == "" { + f, err := os.Open(r.Path) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)}) return } defer f.Close() - r = f + sr = f } - modelfile, err := parser.ParseFile(r) + f, err := parser.ParseFile(sr) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -577,17 +576,13 @@ func (s *Server) CreateModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - quantization := req.Quantization - if req.Quantize != "" { - quantization = req.Quantize - } - - if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(quantization), modelfile, fn); err != nil { + quantization := cmp.Or(r.Quantize, r.Quantization) + if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() - if req.Stream != nil && !*req.Stream { + if r.Stream != nil && !*r.Stream { waitForStream(c, ch) return } @@ -621,6 +616,11 @@ func (s *Server) DeleteModelHandler(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + + if err := m.RemoveLayers(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } } func (s *Server) ShowModelHandler(c *gin.Context) { @@ -730,7 +730,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) { return } - models := []api.ModelResponse{} + models := []api.ListModelResponse{} for n, m := range ms { f, err := m.Config.Open() if err != nil { @@ -746,7 +746,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) { } // tag should never be masked - models = append(models, api.ModelResponse{ + models = append(models, api.ListModelResponse{ Model: n.DisplayShortest(), Name: n.DisplayShortest(), Size: m.Size(), @@ -762,7 +762,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) { }) } - slices.SortStableFunc(models, func(i, j api.ModelResponse) int { + slices.SortStableFunc(models, func(i, j api.ListModelResponse) int { // most recently modified first return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix()) }) @@ -942,7 +942,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc { } if allowedHost(host) { - if c.Request.Method == "OPTIONS" { + if c.Request.Method == http.MethodOptions { c.AbortWithStatus(http.StatusNoContent) return } @@ -960,6 +960,10 @@ func (s *Server) GenerateRoutes() http.Handler { config.AllowWildcard = true config.AllowBrowserExtensions = true config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"} + openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"} + for _, prop := range openAIProperties { + config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop) + } config.AllowOrigins = envconfig.AllowOrigins r := gin.Default() @@ -1139,7 +1143,7 @@ func streamResponse(c *gin.Context, ch chan any) { } func (s *Server) ProcessHandler(c *gin.Context) { - models := []api.ModelResponse{} + models := []api.ProcessModelResponse{} for _, v := range s.sched.loaded { model := v.model @@ -1151,7 +1155,7 @@ func (s *Server) ProcessHandler(c *gin.Context) { QuantizationLevel: model.Config.FileType, } - mr := api.ModelResponse{ + mr := api.ProcessModelResponse{ Model: model.ShortName, Name: model.ShortName, Size: int64(v.estimatedTotal), @@ -1171,7 +1175,7 @@ func (s *Server) ProcessHandler(c *gin.Context) { models = append(models, mr) } - c.JSON(http.StatusOK, api.ListResponse{Models: models}) + c.JSON(http.StatusOK, api.ProcessResponse{Models: models}) } // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model @@ -1306,7 +1310,6 @@ func (s *Server) ChatHandler(c *gin.Context) { defer close(ch) fn := func(r llm.CompletionResponse) { - resp := api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), diff --git a/server/routes_create_test.go b/server/routes_create_test.go index e5af1ded..19bf19ed 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -158,3 +158,371 @@ func TestCreateFromModel(t *testing.T) { filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"), }) } + +func TestCreateRemovesLayers(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-b507b9c2f6ca642bffcd06665ea7c91f235fd32daeefdf875a0f938db05fb315"), + filepath.Join(p, "blobs", "sha256-bc80b03733773e0728011b2f4adf34c458b400e1aad48cb28d61170f3a2ad2d6"), + }) + + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"), + }) +} + +func TestCreateUnsetsSystem(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nSYSTEM Say hi!", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-8585df945d1069bc78b79bd10bb73ba07fbc29b0f5479a31a601c0d12731416e"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-f29e82a8284dbdf5910b1555580ff60b04238b8da9d5e51159ada67a4d0d5851"), + }) + + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nSYSTEM \"\"", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-67d4b8d106af2a5b100a46e9bdc038c71eef2a35c9abac784092654212f97cf5"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"), + }) + + bts, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")) + if err != nil { + t.Fatal(err) + } + + if string(bts) != "" { + t.Fatalf("expected empty string, actual %s", string(bts)) + } +} + +func TestCreateMergeParameters(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nPARAMETER temperature 1\nPARAMETER top_k 10\nPARAMETER stop USER:\nPARAMETER stop ASSISTANT:", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"), + filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + }) + + // in order to merge parameters, the second model must be created FROM the first + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test2", + Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"), + filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"), + filepath.Join(p, "blobs", "sha256-4cd9d4ba6b734d9b4cbd1e5caa60374c00722e993fce5e1e2d15a33698f71187"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-e29a7b3c47287a2489c895d21fe413c20f859a85d20e749492f52a838e36e1ba"), + }) + + actual, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e29a7b3c47287a2489c895d21fe413c20f859a85d20e749492f52a838e36e1ba")) + if err != nil { + t.Fatal(err) + } + + expect, err := json.Marshal(map[string]any{"temperature": 0.6, "top_k": 10, "top_p": 0.7, "stop": []string{"USER:", "ASSISTANT:"}}) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(bytes.TrimSpace(expect), bytes.TrimSpace(actual)) { + t.Errorf("expected %s, actual %s", string(expect), string(actual)) + } + + // slices are replaced + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test2", + Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7\nPARAMETER stop <|endoftext|>", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"), + filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"), + filepath.Join(p, "blobs", "sha256-257aa726584f24970a4f240765e75a7169bfbe7f4966c1f04513d6b6c860583a"), + filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + }) + + actual, err = os.ReadFile(filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35")) + if err != nil { + t.Fatal(err) + } + + expect, err = json.Marshal(map[string]any{"temperature": 0.6, "top_k": 10, "top_p": 0.7, "stop": []string{"<|endoftext|>"}}) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(bytes.TrimSpace(expect), bytes.TrimSpace(actual)) { + t.Errorf("expected %s, actual %s", string(expect), string(actual)) + } +} + +func TestCreateReplacesMessages(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nMESSAGE assistant \"What is my purpose?\"\nMESSAGE user \"You run tests.\"\nMESSAGE assistant \"Oh, my god.\"", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"), + }) + + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test2", + Modelfile: "FROM test\nMESSAGE assistant \"You're a test, Harry.\"\nMESSAGE user \"I-I'm a what?\"\nMESSAGE assistant \"A test. And a thumping good one at that, I'd wager.\"", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"), + filepath.Join(p, "blobs", "sha256-4f48b25fe9969564c82f58eb1cedbdff6484cc0baf474bc6c2a9b37c8da3362a"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-a60ecc9da299ec7ede453f99236e5577fd125e143689b646d9f0ddc9971bf4db"), + filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"), + }) + + type message struct { + Role string `json:"role"` + Content string `json:"content"` + } + + f, err := os.Open(filepath.Join(p, "blobs", "sha256-a60ecc9da299ec7ede453f99236e5577fd125e143689b646d9f0ddc9971bf4db")) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + var actual []message + if err := json.NewDecoder(f).Decode(&actual); err != nil { + t.Fatal(err) + } + + expect := []message{ + {Role: "assistant", Content: "You're a test, Harry."}, + {Role: "user", Content: "I-I'm a what?"}, + {Role: "assistant", Content: "A test. And a thumping good one at that, I'd wager."}, + } + + if !slices.Equal(actual, expect) { + t.Errorf("expected %s, actual %s", expect, actual) + } +} + +func TestCreateTemplateSystem(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}\nSYSTEM Say hello!\nTEMPLATE {{ .System }} {{ .Prompt }}\nSYSTEM Say bye!", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-2b5e330885117c82f3fd75169ea323e141070a2947c11ddb9f79ee0b01c589c1"), + filepath.Join(p, "blobs", "sha256-4c5f51faac758fecaff8db42f0b7382891a4d0c0bb885f7b86be88c814a7cc86"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"), + }) + + template, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed")) + if err != nil { + t.Fatal(err) + } + + if string(template) != "{{ .System }} {{ .Prompt }}" { + t.Errorf("expected \"{{ .System }} {{ .Prompt }}\", actual %s", template) + } + + system, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-4c5f51faac758fecaff8db42f0b7382891a4d0c0bb885f7b86be88c814a7cc86")) + if err != nil { + t.Fatal(err) + } + + if string(system) != "Say bye!" { + t.Errorf("expected \"Say bye!\", actual %s", system) + } +} + +func TestCreateLicenses(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"), + filepath.Join(p, "blobs", "sha256-79a39c37536ddee29cbadd5d5e2dcba8ed7f03e431f626ff38432c1c866bb7e2"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-e5dcffe836b6ec8a58e492419b550e65fb8cbdc308503979e5dacb33ac7ea3b7"), + }) + + mit, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e5dcffe836b6ec8a58e492419b550e65fb8cbdc308503979e5dacb33ac7ea3b7")) + if err != nil { + t.Fatal(err) + } + + if string(mit) != "MIT" { + t.Errorf("expected MIT, actual %s", mit) + } + + apache, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608")) + if err != nil { + t.Fatal(err) + } + + if string(apache) != "Apache-2.0" { + t.Errorf("expected Apache-2.0, actual %s", apache) + } +} diff --git a/server/routes_test.go b/server/routes_test.go index 74933b1f..91ef625b 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -15,9 +15,11 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ollama/ollama/api" "github.com/ollama/ollama/parser" + "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -25,20 +27,20 @@ func createTestFile(t *testing.T, name string) string { t.Helper() f, err := os.CreateTemp(t.TempDir(), name) - assert.Nil(t, err) + require.NoError(t, err) defer f.Close() err = binary.Write(f, binary.LittleEndian, []byte("GGUF")) - assert.Nil(t, err) + require.NoError(t, err) err = binary.Write(f, binary.LittleEndian, uint32(3)) - assert.Nil(t, err) + require.NoError(t, err) err = binary.Write(f, binary.LittleEndian, uint64(0)) - assert.Nil(t, err) + require.NoError(t, err) err = binary.Write(f, binary.LittleEndian, uint64(0)) - assert.Nil(t, err) + require.NoError(t, err) return f.Name() } @@ -53,16 +55,18 @@ func Test_Routes(t *testing.T) { } createTestModel := func(t *testing.T, name string) { + t.Helper() + fname := createTestFile(t, "ollama-model") r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) modelfile, err := parser.ParseFile(r) - assert.Nil(t, err) + require.NoError(t, err) fn := func(resp api.ProgressResponse) { t.Logf("Status: %s", resp.Status) } - err = CreateModel(context.TODO(), name, "", "", modelfile, fn) - assert.Nil(t, err) + err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn) + require.NoError(t, err) } testCases := []testCase{ @@ -74,9 +78,9 @@ func Test_Routes(t *testing.T) { }, Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, contentType, "application/json; charset=utf-8") + assert.Equal(t, "application/json; charset=utf-8", contentType) body, err := io.ReadAll(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body)) }, }, @@ -86,17 +90,17 @@ func Test_Routes(t *testing.T) { Path: "/api/tags", Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, contentType, "application/json; charset=utf-8") + assert.Equal(t, "application/json; charset=utf-8", contentType) body, err := io.ReadAll(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) var modelList api.ListResponse err = json.Unmarshal(body, &modelList) - assert.Nil(t, err) + require.NoError(t, err) assert.NotNil(t, modelList.Models) - assert.Equal(t, 0, len(modelList.Models)) + assert.Empty(t, len(modelList.Models)) }, }, { @@ -108,16 +112,18 @@ func Test_Routes(t *testing.T) { }, Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, contentType, "application/json; charset=utf-8") + assert.Equal(t, "application/json; charset=utf-8", contentType) body, err := io.ReadAll(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + + assert.NotContains(t, string(body), "expires_at") var modelList api.ListResponse err = json.Unmarshal(body, &modelList) - assert.Nil(t, err) + require.NoError(t, err) - assert.Equal(t, 1, len(modelList.Models)) - assert.Equal(t, modelList.Models[0].Name, "test-model:latest") + assert.Len(t, modelList.Models, 1) + assert.Equal(t, "test-model:latest", modelList.Models[0].Name) }, }, { @@ -134,7 +140,7 @@ func Test_Routes(t *testing.T) { Stream: &stream, } jsonData, err := json.Marshal(createReq) - assert.Nil(t, err) + require.NoError(t, err) req.Body = io.NopCloser(bytes.NewReader(jsonData)) }, @@ -142,11 +148,11 @@ func Test_Routes(t *testing.T) { contentType := resp.Header.Get("Content-Type") assert.Equal(t, "application/json", contentType) _, err := io.ReadAll(resp.Body) - assert.Nil(t, err) - assert.Equal(t, resp.StatusCode, 200) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) model, err := GetModel("t-bone") - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "t-bone:latest", model.ShortName) }, }, @@ -161,13 +167,13 @@ func Test_Routes(t *testing.T) { Destination: "beefsteak", } jsonData, err := json.Marshal(copyReq) - assert.Nil(t, err) + require.NoError(t, err) req.Body = io.NopCloser(bytes.NewReader(jsonData)) }, Expected: func(t *testing.T, resp *http.Response) { model, err := GetModel("beefsteak") - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "beefsteak:latest", model.ShortName) }, }, @@ -179,18 +185,18 @@ func Test_Routes(t *testing.T) { createTestModel(t, "show-model") showReq := api.ShowRequest{Model: "show-model"} jsonData, err := json.Marshal(showReq) - assert.Nil(t, err) + require.NoError(t, err) req.Body = io.NopCloser(bytes.NewReader(jsonData)) }, Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, contentType, "application/json; charset=utf-8") + assert.Equal(t, "application/json; charset=utf-8", contentType) body, err := io.ReadAll(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) var showResp api.ShowResponse err = json.Unmarshal(body, &showResp) - assert.Nil(t, err) + require.NoError(t, err) var params []string paramsSplit := strings.Split(showResp.Parameters, "\n") @@ -221,14 +227,14 @@ func Test_Routes(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { u := httpSrv.URL + tc.Path req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) - assert.Nil(t, err) + require.NoError(t, err) if tc.Setup != nil { tc.Setup(t, req) } resp, err := httpSrv.Client().Do(req) - assert.Nil(t, err) + require.NoError(t, err) defer resp.Body.Close() if tc.Expected != nil { diff --git a/server/sched.go b/server/sched.go index 8c72177f..c36486f7 100644 --- a/server/sched.go +++ b/server/sched.go @@ -7,17 +7,17 @@ import ( "log/slog" "reflect" "runtime" + "slices" "sort" "strings" "sync" "time" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" - "github.com/ollama/ollama/envconfig" - "golang.org/x/exp/slices" ) type LlmRequest struct { @@ -66,7 +66,7 @@ func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, opts.NumCtx = 4 } - opts.NumCtx = opts.NumCtx * envconfig.NumParallel + opts.NumCtx *= envconfig.NumParallel req := &LlmRequest{ ctx: c, @@ -370,7 +370,6 @@ func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) { r.refMu.Lock() gpuIDs := make([]string, 0, len(r.gpus)) if r.llama != nil { - // TODO this should be broken down by GPU instead of assuming uniform spread estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus)) for _, gpu := range r.gpus { @@ -529,7 +528,6 @@ func (runner *runnerRef) waitForVRAMRecovery() chan interface{} { } }() return finished - } type ByDuration []*runnerRef diff --git a/server/sched_test.go b/server/sched_test.go index 3ee1b989..f7dce6d1 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -12,11 +12,10 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/app/lifecycle" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" - "github.com/ollama/ollama/envconfig" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -53,10 +52,10 @@ func TestLoad(t *testing.T) { } gpus := gpu.GpuInfoList{} s.load(req, ggml, gpus) - require.Len(t, req.successCh, 0) + require.Empty(t, req.successCh) require.Len(t, req.errCh, 1) s.loadedMu.Lock() - require.Len(t, s.loaded, 0) + require.Empty(t, s.loaded) s.loadedMu.Unlock() err := <-req.errCh require.Contains(t, err.Error(), "this model may be incompatible") @@ -113,7 +112,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV t.Helper() f, err := os.CreateTemp(t.TempDir(), modelName) - assert.Nil(t, err) + require.NoError(t, err) defer f.Close() gguf := llm.NewGGUFV3(binary.LittleEndian) @@ -131,7 +130,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV }, []llm.Tensor{ {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, }) - assert.Nil(t, err) + require.NoError(t, err) fname := f.Name() model := &Model{Name: modelName, ModelPath: fname} @@ -190,8 +189,8 @@ func TestRequests(t *testing.T) { select { case resp := <-scenario1a.req.successCh: require.Equal(t, resp.llama, scenario1a.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, scenario1a.req.errCh, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario1a.req.errCh) case <-ctx.Done(): t.Errorf("timeout") } @@ -203,8 +202,8 @@ func TestRequests(t *testing.T) { select { case resp := <-scenario1b.req.successCh: require.Equal(t, resp.llama, scenario1a.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, scenario1b.req.errCh, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario1b.req.errCh) case <-ctx.Done(): t.Errorf("timeout") } @@ -221,8 +220,8 @@ func TestRequests(t *testing.T) { select { case resp := <-scenario2a.req.successCh: require.Equal(t, resp.llama, scenario2a.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, scenario2a.req.errCh, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario2a.req.errCh) case <-ctx.Done(): t.Errorf("timeout") } @@ -237,8 +236,8 @@ func TestRequests(t *testing.T) { select { case resp := <-scenario3a.req.successCh: require.Equal(t, resp.llama, scenario3a.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, scenario3a.req.errCh, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario3a.req.errCh) case <-ctx.Done(): t.Errorf("timeout") } @@ -253,8 +252,8 @@ func TestRequests(t *testing.T) { select { case resp := <-scenario3b.req.successCh: require.Equal(t, resp.llama, scenario3b.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, scenario3b.req.errCh, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario3b.req.errCh) case <-ctx.Done(): t.Errorf("timeout") } @@ -269,8 +268,8 @@ func TestRequests(t *testing.T) { select { case resp := <-scenario3c.req.successCh: require.Equal(t, resp.llama, scenario3c.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, scenario3c.req.errCh, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario3c.req.errCh) case <-ctx.Done(): t.Errorf("timeout") } @@ -296,8 +295,8 @@ func TestRequests(t *testing.T) { select { case resp := <-scenario3d.req.successCh: require.Equal(t, resp.llama, scenario3d.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, scenario3d.req.errCh, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario3d.req.errCh) case <-ctx.Done(): t.Errorf("timeout") } @@ -332,7 +331,7 @@ func TestGetRunner(t *testing.T) { slog.Info("scenario1b") successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration) require.Len(t, s.pendingReqCh, 1) - require.Len(t, successCh1b, 0) + require.Empty(t, successCh1b) require.Len(t, errCh1b, 1) err := <-errCh1b require.Contains(t, err.Error(), "server busy") @@ -340,8 +339,8 @@ func TestGetRunner(t *testing.T) { select { case resp := <-successCh1a: require.Equal(t, resp.llama, scenario1a.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, errCh1a, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, errCh1a) case <-ctx.Done(): t.Errorf("timeout") } @@ -355,9 +354,9 @@ func TestGetRunner(t *testing.T) { successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration) // Starts in pending channel, then should be quickly processsed to return an error time.Sleep(5 * time.Millisecond) - require.Len(t, successCh1c, 0) + require.Empty(t, successCh1c) s.loadedMu.Lock() - require.Len(t, s.loaded, 0) + require.Empty(t, s.loaded) s.loadedMu.Unlock() require.Len(t, errCh1c, 1) err = <-errCh1c @@ -386,8 +385,8 @@ func TestPrematureExpired(t *testing.T) { select { case resp := <-successCh1a: require.Equal(t, resp.llama, scenario1a.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, errCh1a, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, errCh1a) s.loadedMu.Lock() require.Len(t, s.loaded, 1) s.loadedMu.Unlock() @@ -401,9 +400,9 @@ func TestPrematureExpired(t *testing.T) { time.Sleep(20 * time.Millisecond) require.LessOrEqual(t, len(s.finishedReqCh), 1) time.Sleep(10 * time.Millisecond) - require.Len(t, s.finishedReqCh, 0) + require.Empty(t, s.finishedReqCh) s.loadedMu.Lock() - require.Len(t, s.loaded, 0) + require.Empty(t, s.loaded) s.loadedMu.Unlock() // also shouldn't happen in real life @@ -487,7 +486,6 @@ func TestFindRunnerToUnload(t *testing.T) { r2.refCount = 1 resp = s.findRunnerToUnload() require.Equal(t, r1, resp) - } func TestNeedsReload(t *testing.T) { diff --git a/server/upload.go b/server/upload.go index 9b52238a..73ce78ce 100644 --- a/server/upload.go +++ b/server/upload.go @@ -146,7 +146,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) { case requestURL := <-b.nextURL: g.Go(func() error { var err error - for try := 0; try < maxRetries; try++ { + for try := range maxRetries { err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts) switch { case errors.Is(err, context.Canceled): @@ -190,7 +190,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) { headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Length", "0") - for try := 0; try < maxRetries; try++ { + for try := range maxRetries { var resp *http.Response resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts) if errors.Is(err, context.Canceled) { @@ -253,7 +253,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL * } // retry uploading to the redirect URL - for try := 0; try < maxRetries; try++ { + for try := range maxRetries { err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil) switch { case errors.Is(err, context.Canceled): @@ -391,7 +391,7 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryO return err } - // nolint: contextcheck + //nolint:contextcheck go upload.Run(context.Background(), opts) } diff --git a/templates/alfred.gotmpl b/templates/alfred.gotmpl new file mode 100644 index 00000000..cecb9d2c --- /dev/null +++ b/templates/alfred.gotmpl @@ -0,0 +1 @@ +{{ if .System }}{{ .System }}{{ end }}{{ if .Prompt }}{{ .Prompt }}{{ end }}{{ .Response }} \ No newline at end of file diff --git a/templates/alpaca.gotmpl b/templates/alpaca.gotmpl new file mode 100644 index 00000000..440d0662 --- /dev/null +++ b/templates/alpaca.gotmpl @@ -0,0 +1,7 @@ +{{ if .System }}{{ .System }} + +{{ end }}{{ if .Prompt }}### Instruction: +{{ .Prompt }} + +{{ end }}### Response: +{{ .Response }} \ No newline at end of file diff --git a/templates/chatml.gotmpl b/templates/chatml.gotmpl new file mode 100644 index 00000000..dcf17285 --- /dev/null +++ b/templates/chatml.gotmpl @@ -0,0 +1,6 @@ +{{ if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ .Response }}<|im_end|> \ No newline at end of file diff --git a/templates/chatqa.gotmpl b/templates/chatqa.gotmpl new file mode 100644 index 00000000..1ede6227 --- /dev/null +++ b/templates/chatqa.gotmpl @@ -0,0 +1,5 @@ +{{ if .System }}System: {{ .System }} + +{{ end }}{{ if .Prompt }}User: {{ .Prompt }} + +{{ end }}Assistant: <|begin_of_text|>{{ .Response }} \ No newline at end of file diff --git a/templates/codellama-70b-instruct.gotmpl b/templates/codellama-70b-instruct.gotmpl new file mode 100644 index 00000000..3196bd6f --- /dev/null +++ b/templates/codellama-70b-instruct.gotmpl @@ -0,0 +1,8 @@ +{{ if .System }} Source: system + + {{ .System }} {{ end }} Source: user + + {{ .Prompt }} Source: assistant +Destination: user + + {{ .Response }} \ No newline at end of file diff --git a/templates/falcon-instruct.gotmpl b/templates/falcon-instruct.gotmpl new file mode 100644 index 00000000..2309a1c5 --- /dev/null +++ b/templates/falcon-instruct.gotmpl @@ -0,0 +1,3 @@ +{{ if .System }}{{ .System }} +{{ end }}{{ if .Prompt }}User: {{ .Prompt }} +{{ end }}Assistant: {{ .Response }} \ No newline at end of file diff --git a/templates/gemma-instruct.gotmpl b/templates/gemma-instruct.gotmpl new file mode 100644 index 00000000..91b9883a --- /dev/null +++ b/templates/gemma-instruct.gotmpl @@ -0,0 +1,4 @@ +user +{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} +model +{{ .Response }} \ No newline at end of file diff --git a/templates/granite-instruct.gotmpl b/templates/granite-instruct.gotmpl new file mode 100644 index 00000000..2ede647f --- /dev/null +++ b/templates/granite-instruct.gotmpl @@ -0,0 +1,9 @@ +{{ if .System }} +System: +{{ .System }} + +{{ end }}{{ if .Prompt }}Question: +{{ .Prompt }} + +{{ end }}Answer: +{{ .Response }} \ No newline at end of file diff --git a/templates/index.json b/templates/index.json new file mode 100644 index 00000000..e2d41893 --- /dev/null +++ b/templates/index.json @@ -0,0 +1,138 @@ +[ + { + "template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}", + "name": "chatml" + }, + { + "template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "name": "chatml" + }, + { + "template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", + "name": "zephyr" + }, + { + "template": "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}", + "name": "chatml" + }, + { + "template": "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", + "name": "openchat" + }, + { + "template": "{{bos_token}}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "name": "chatml" + }, + { + "template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "name": "chatml" + }, + { + "template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "name": "chatml" + }, + { + "template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "name": "chatml" + }, + { + "template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", + "name": "zephyr" + }, + { + "template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + "name": "mistral-instruct" + }, + { + "template": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'### Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response\n'}}", + "name": "starcoder2-instruct" + }, + { + "template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}", + "name": "llama2-chat" + }, + { + "template": "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\n\n ' + message['content'] | trim %}{{ content + ' ' }}{% endfor %}{{'Source: assistant\nDestination: user\n\n '}}", + "name": "codellama-70b-instruct" + }, + { + "template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + "name": "mistral-instruct" + }, + { + "template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'system' %}\n{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|im_start|>assistant\n' + message['content'] + '<|im_end|>' }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|im_start|>assistant' }}\n{% endif %}\n{% endfor %}", + "name": "chatml" + }, + { + "template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "name": "chatml" + }, + { + "template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "name": "chatml" + }, + { + "template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif 'system' not in messages[0]['role'] %}{% set loop_messages = messages %}{% set system_message = 'You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks \u2014 remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER\\'S QUERY.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% if system_message != false %}{{ '<|im_start|>system\n' + system_message | trim + '<|im_end|>\n'}}{% endif %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% else %}{{ '\n' + '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% endif %}{% if (add_generation_prompt == true and loop.last) %}{{ '\n' + '<|im_start|>' + 'assistant' + '\n' }}{% endif %}{% endfor %}", + "name": "chatml" + }, + { + "template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}", + "name": "alpaca" + }, + { + "template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", + "name": "chatqa" + }, + { + "template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", + "name": "gemma-instruct" + }, + { + "template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", + "name": "llama3-instruct" + }, + { + "template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'Question:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'system' %}\n{{ 'System:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Answer:\n' + message['content'] + '\n\n' }}{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Answer:\n' }}{% endif %}{% endfor %}", + "name": "granite-instruct" + }, + { + "template": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'@@ Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'@@ Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'@@ Response\n'}}", + "name": "magicoder" + }, + { + "template": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '' + message['content'].strip() + '' }}{% elif message['role'] == 'system' %}{{ '' + message['content'].strip() + '' }}{% elif message['role'] == 'assistant' %}{{ '' + message['content'] + '' }}{% else %}{{ raise_exception('Only system, user and assistant roles are supported.') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '' }}{% endif %}{% endfor %}", + "name": "alfred" + }, + { + "template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", + "name": "llama2-chat" + }, + { + "template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + "name": "phi-3" + }, + { + "template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + "name": "phi-3" + }, + { + "template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", + "name": "phi-3" + }, + { + "template": "{{ bos_token }}{%- if messages[0]['role'] == 'system' -%}{% set loop_messages = messages[1:] %}{%- else -%}{% set loop_messages = messages %}{% endif %}System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context.\n\n{% for message in loop_messages %}{%- if message['role'] == 'user' -%}User: {{ message['content'].strip() + '\n\n' }}{%- else -%}Assistant: {{ message['content'].strip() + '\n\n' }}{%- endif %}{% if loop.last and message['role'] == 'user' %}Assistant:{% endif %}{% endfor %}", + "name": "chatqa" + }, + { + "template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'User: \n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ 'System: ' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ 'Falcon:\n' + message['content']}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Falcon:' }}\n{% endif %}\n{% endfor %}", + "name": "falcon-instruct" + }, + { + "template": "{% for message in messages %}{% if not loop.first %}{{ '\n' }}{% endif %}{% if message['role'] == 'system' %}{{ 'System: ' }}{% elif message['role'] == 'user' %}{{ 'User: ' }}{% elif message['role'] == 'assistant' %}{{ 'Falcon: ' }}{% endif %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ '\n' + 'Falcon:' }}{% endif %}", + "name": "falcon-instruct" + }, + { + "template": "{% for message in messages %}{% if message['role'] == 'system' %}{% if message['content']%}{{'### System:\n' + message['content']+'\n\n'}}{% endif %}{% elif message['role'] == 'user' %}{{'### User:\n' + message['content']+'\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Assistant:\n' + message['content']}}{% endif %}{% if loop.last and add_generation_prompt %}{{ '### Assistant:\n' }}{% endif %}{% endfor %}", + "name": "solar-instruct" + } +] diff --git a/templates/llama2-chat.gotmpl b/templates/llama2-chat.gotmpl new file mode 100644 index 00000000..a739f690 --- /dev/null +++ b/templates/llama2-chat.gotmpl @@ -0,0 +1,3 @@ +[INST] <>{{ .System }}<> + +{{ .Prompt }} [/INST] {{ .Response }} \ No newline at end of file diff --git a/templates/llama3-instruct.gotmpl b/templates/llama3-instruct.gotmpl new file mode 100644 index 00000000..36d0218b --- /dev/null +++ b/templates/llama3-instruct.gotmpl @@ -0,0 +1,7 @@ +{{ if .System }}<|start_header_id|>system<|end_header_id|> + +{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|> + +{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|> + +{{ .Response }}<|eot_id|> \ No newline at end of file diff --git a/templates/magicoder.gotmpl b/templates/magicoder.gotmpl new file mode 100644 index 00000000..306972ec --- /dev/null +++ b/templates/magicoder.gotmpl @@ -0,0 +1,7 @@ +{{ if .System }}{{ .System }} + +{{ end }}{{ if .Prompt }}@@ Instruction +{{ .Prompt }} + +{{ end }}@@ Response +{{ .Response }} \ No newline at end of file diff --git a/templates/mistral-instruct.gotmpl b/templates/mistral-instruct.gotmpl new file mode 100644 index 00000000..dcf17285 --- /dev/null +++ b/templates/mistral-instruct.gotmpl @@ -0,0 +1,6 @@ +{{ if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ .Response }}<|im_end|> \ No newline at end of file diff --git a/templates/openchat.gotmpl b/templates/openchat.gotmpl new file mode 100644 index 00000000..d2ca3868 --- /dev/null +++ b/templates/openchat.gotmpl @@ -0,0 +1 @@ +{{ .System }}<|end_of_turn|>GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|> \ No newline at end of file diff --git a/templates/phi-3.gotmpl b/templates/phi-3.gotmpl new file mode 100644 index 00000000..bf26dcee --- /dev/null +++ b/templates/phi-3.gotmpl @@ -0,0 +1,6 @@ +{{ if .System }}<|system|> +{{ .System }}<|end|> +{{ end }}{{ if .Prompt }}<|user|> +{{ .Prompt }}<|end|> +{{ end }}<|assistant|> +{{ .Response }}<|end|> \ No newline at end of file diff --git a/templates/solar-instruct.gotmpl b/templates/solar-instruct.gotmpl new file mode 100644 index 00000000..c275a26a --- /dev/null +++ b/templates/solar-instruct.gotmpl @@ -0,0 +1,8 @@ +{{ if .System }}### System: +{{ .System }} + +{{ end }}{{ if .Prompt }}### User: +{{ .Prompt }} + +{{ end }}### Assistant: +{{ .Response }} \ No newline at end of file diff --git a/templates/starcoder2-instruct.gotmpl b/templates/starcoder2-instruct.gotmpl new file mode 100644 index 00000000..33357e54 --- /dev/null +++ b/templates/starcoder2-instruct.gotmpl @@ -0,0 +1,9 @@ +{{ if .System }}{{ .System }} + +{{ end }}{{ if .Prompt }}### Instruction +{{ .Prompt }} + + +{{ end }}### Response +{{ .Response }}<|endoftext|> + diff --git a/templates/template.go b/templates/template.go new file mode 100644 index 00000000..87962695 --- /dev/null +++ b/templates/template.go @@ -0,0 +1,69 @@ +package templates + +import ( + "bytes" + "embed" + "encoding/json" + "errors" + "io" + "math" + "sync" + + "github.com/agnivade/levenshtein" +) + +//go:embed index.json +var indexBytes []byte + +//go:embed *.gotmpl +var templatesFS embed.FS + +var templatesOnce = sync.OnceValues(func() ([]*Template, error) { + var templates []*Template + if err := json.Unmarshal(indexBytes, &templates); err != nil { + return nil, err + } + + for _, t := range templates { + bts, err := templatesFS.ReadFile(t.Name + ".gotmpl") + if err != nil { + return nil, err + } + + t.Bytes = bts + } + + return templates, nil +}) + +type Template struct { + Name string `json:"name"` + Template string `json:"template"` + Bytes []byte +} + +func (t Template) Reader() io.Reader { + return bytes.NewReader(t.Bytes) +} + +func NamedTemplate(s string) (*Template, error) { + templates, err := templatesOnce() + if err != nil { + return nil, err + } + + var template *Template + score := math.MaxInt + for _, t := range templates { + if s := levenshtein.ComputeDistance(s, t.Template); s < score { + score = s + template = t + } + } + + if score < 100 { + return template, nil + } + + return nil, errors.New("no matching template found") +} diff --git a/templates/template_test.go b/templates/template_test.go new file mode 100644 index 00000000..61bc7837 --- /dev/null +++ b/templates/template_test.go @@ -0,0 +1,59 @@ +package templates + +import ( + "bufio" + "bytes" + "encoding/json" + "io" + "os" + "path/filepath" + "testing" + "text/template" + + "github.com/ollama/ollama/llm" +) + +func TestKVChatTemplate(t *testing.T) { + f, err := os.Open(filepath.Join("testdata", "templates.jsonl")) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + var ss map[string]string + if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil { + t.Fatal(err) + } + + for k, v := range ss { + t.Run(k, func(t *testing.T) { + kv := llm.KV{"tokenizer.chat_template": v} + s := kv.ChatTemplate() + r, err := NamedTemplate(s) + if err != nil { + t.Fatal(err) + } + + if r.Name != k { + t.Errorf("expected %q, got %q", k, r.Name) + } + + var b bytes.Buffer + if _, err := io.Copy(&b, r.Reader()); err != nil { + t.Fatal(err) + } + + tmpl, err := template.New(s).Parse(b.String()) + if err != nil { + t.Fatal(err) + } + + if tmpl.Tree.Root.String() == "" { + t.Errorf("empty %s template", k) + } + }) + } + } +} diff --git a/templates/testdata/templates.jsonl b/templates/testdata/templates.jsonl new file mode 100644 index 00000000..41a8d0ef --- /dev/null +++ b/templates/testdata/templates.jsonl @@ -0,0 +1,35 @@ +{"chatml": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"} +{"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} +{"zephyr": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"} +{"chatml": "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}"} +{"openchat": "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}"} +{"chatml": "{{bos_token}}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} +{"chatml": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} +{"chatml": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} +{"chatml": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} +{"zephyr": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"} +{"mistral-instruct": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"} +{"starcoder2-instruct": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'### Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response\n'}}"} +{"llama2-chat": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}"} +{"codellama-70b-instruct": "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\n\n ' + message['content'] | trim %}{{ content + ' ' }}{% endfor %}{{'Source: assistant\nDestination: user\n\n '}}"} +{"mistral-instruct": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"} +{"chatml": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'system' %}\n{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|im_start|>assistant\n' + message['content'] + '<|im_end|>' }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|im_start|>assistant' }}\n{% endif %}\n{% endfor %}"} +{"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} +{"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} +{"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif 'system' not in messages[0]['role'] %}{% set loop_messages = messages %}{% set system_message = 'You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks \u2014 remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER\\'S QUERY.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% if system_message != false %}{{ '<|im_start|>system\n' + system_message | trim + '<|im_end|>\n'}}{% endif %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% else %}{{ '\n' + '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% endif %}{% if (add_generation_prompt == true and loop.last) %}{{ '\n' + '<|im_start|>' + 'assistant' + '\n' }}{% endif %}{% endfor %}"} +{"alpaca": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"} +{"chatqa": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"} +{"gemma-instruct": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}"} +{"llama3-instruct": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"} +{"granite-instruct": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'Question:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'system' %}\n{{ 'System:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Answer:\n' + message['content'] + '\n\n' }}{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Answer:\n' }}{% endif %}{% endfor %}"} +{"magicoder": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'@@ Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'@@ Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'@@ Response\n'}}"} +{"alfred": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '' + message['content'].strip() + '' }}{% elif message['role'] == 'system' %}{{ '' + message['content'].strip() + '' }}{% elif message['role'] == 'assistant' %}{{ '' + message['content'] + '' }}{% else %}{{ raise_exception('Only system, user and assistant roles are supported.') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '' }}{% endif %}{% endfor %}"} +{"llama2-chat": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"} +{"phi-3": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"} +{"phi-3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"} +{"phi-3": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}"} +{"chatqa": "{{ bos_token }}{%- if messages[0]['role'] == 'system' -%}{% set loop_messages = messages[1:] %}{%- else -%}{% set loop_messages = messages %}{% endif %}System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context.\n\n{% for message in loop_messages %}{%- if message['role'] == 'user' -%}User: {{ message['content'].strip() + '\n\n' }}{%- else -%}Assistant: {{ message['content'].strip() + '\n\n' }}{%- endif %}{% if loop.last and message['role'] == 'user' %}Assistant:{% endif %}{% endfor %}"} +{"falcon-instruct": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'User: \n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ 'System: ' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ 'Falcon:\n' + message['content']}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Falcon:' }}\n{% endif %}\n{% endfor %}"} +{"falcon-instruct": "{% for message in messages %}{% if not loop.first %}{{ '\n' }}{% endif %}{% if message['role'] == 'system' %}{{ 'System: ' }}{% elif message['role'] == 'user' %}{{ 'User: ' }}{% elif message['role'] == 'assistant' %}{{ 'Falcon: ' }}{% endif %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ '\n' + 'Falcon:' }}{% endif %}"} +{"solar-instruct": "{% for message in messages %}{% if message['role'] == 'system' %}{% if message['content']%}{{'### System:\n' + message['content']+'\n\n'}}{% endif %}{% elif message['role'] == 'user' %}{{'### User:\n' + message['content']+'\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Assistant:\n' + message['content']}}{% endif %}{% if loop.last and add_generation_prompt %}{{ '### Assistant:\n' }}{% endif %}{% endfor %}"} +{"chatml": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} diff --git a/templates/vicuna.gotmpl b/templates/vicuna.gotmpl new file mode 100644 index 00000000..174c1a35 --- /dev/null +++ b/templates/vicuna.gotmpl @@ -0,0 +1,3 @@ +{{ if .System }}{{ .System }} +{{ end }}{{ if .Prompt }}USER: {{ .Prompt }} +{{ end }}ASSISTANT: {{ .Response }} \ No newline at end of file diff --git a/templates/zephyr.gotmpl b/templates/zephyr.gotmpl new file mode 100644 index 00000000..aac0c7a1 --- /dev/null +++ b/templates/zephyr.gotmpl @@ -0,0 +1,6 @@ +{{ if .System }}<|system|> +{{ .System }} +{{ end }}{{ if .Prompt }}<|user|> +{{ .Prompt }} +{{ end }}<|assistant|> +{{ .Response }} \ No newline at end of file diff --git a/types/model/name_test.go b/types/model/name_test.go index 26d70ef3..66ce4c33 100644 --- a/types/model/name_test.go +++ b/types/model/name_test.go @@ -268,7 +268,6 @@ func TestNameIsValidPart(t *testing.T) { } }) } - } func TestFilepathAllocs(t *testing.T) { @@ -325,7 +324,7 @@ func TestParseNameFromFilepath(t *testing.T) { filepath.Join("host:port", "namespace", "model", "tag"): {Host: "host:port", Namespace: "namespace", Model: "model", Tag: "tag"}, filepath.Join("namespace", "model", "tag"): {}, filepath.Join("model", "tag"): {}, - filepath.Join("model"): {}, + "model": {}, filepath.Join("..", "..", "model", "tag"): {}, filepath.Join("", "namespace", ".", "tag"): {}, filepath.Join(".", ".", ".", "."): {}, @@ -382,14 +381,13 @@ func FuzzName(f *testing.F) { t.Errorf("String() = %q; want %q", n.String(), s) } } - }) } func TestIsValidNamespace(t *testing.T) { cases := []struct { - username string - expected bool + username string + expected bool }{ {"", false}, {"a", true},