From 60323e08057d36b617f11d3c4958d342a44d0342 Mon Sep 17 00:00:00 2001 From: Shubham <25881429+shoebham@users.noreply.github.com> Date: Tue, 4 Jun 2024 10:50:48 +0530 Subject: [PATCH 01/26] add embed model command and fix question invoke (#4766) * add embed model command and fix question invoke * Update docs/tutorials/langchainpy.md Co-authored-by: Kim Hallberg * Update docs/tutorials/langchainpy.md --------- Co-authored-by: Kim Hallberg Co-authored-by: Jeffrey Morgan --- docs/tutorials/langchainpy.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/langchainpy.md b/docs/tutorials/langchainpy.md index 9a1bca0d..06543a07 100644 --- a/docs/tutorials/langchainpy.md +++ b/docs/tutorials/langchainpy.md @@ -45,7 +45,7 @@ all_splits = text_splitter.split_documents(data) ``` It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install chromadb` - +We also need to pull embedding model: `ollama pull nomic-embed-text` ```python from langchain.embeddings import OllamaEmbeddings from langchain.vectorstores import Chroma @@ -68,7 +68,8 @@ The next thing is to send the question and the relevant parts of the docs to the ```python from langchain.chains import RetrievalQA qachain=RetrievalQA.from_chain_type(ollama, retriever=vectorstore.as_retriever()) -qachain.invoke({"query": question}) +res = qachain.invoke({"query": question}) +print(res['result']) ``` The answer received from this chain was: From 04f3c12bb716ca763da61cb25884c8859ff81240 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 21 May 2024 21:30:52 -0700 Subject: [PATCH 02/26] replace x/exp/slices with slices --- cmd/cmd.go | 2 +- cmd/interactive.go | 2 +- llm/payload.go | 2 +- server/images.go | 3 +-- server/routes.go | 2 +- server/sched.go | 4 ++-- 6 files changed, 7 insertions(+), 8 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index b285f83c..5a451889 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -20,6 +20,7 @@ import ( "path/filepath" "regexp" "runtime" + "slices" "strings" "syscall" "time" @@ -29,7 +30,6 @@ import ( "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" "golang.org/x/crypto/ssh" - "golang.org/x/exp/slices" "golang.org/x/term" "github.com/ollama/ollama/api" diff --git a/cmd/interactive.go b/cmd/interactive.go index c055df0e..80a91547 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -8,11 +8,11 @@ import ( "os" "path/filepath" "regexp" + "slices" "sort" "strings" "github.com/spf13/cobra" - "golang.org/x/exp/slices" "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" diff --git a/llm/payload.go b/llm/payload.go index abe3d263..a025ee34 100644 --- a/llm/payload.go +++ b/llm/payload.go @@ -10,9 +10,9 @@ import ( "os" "path/filepath" "runtime" + "slices" "strings" - "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "github.com/ollama/ollama/gpu" diff --git a/server/images.go b/server/images.go index 9254671c..61740126 100644 --- a/server/images.go +++ b/server/images.go @@ -18,11 +18,10 @@ import ( "os" "path/filepath" "runtime" + "slices" "strconv" "strings" - "golang.org/x/exp/slices" - "github.com/ollama/ollama/api" "github.com/ollama/ollama/auth" "github.com/ollama/ollama/format" diff --git a/server/routes.go b/server/routes.go index 7a6dfd1f..7fcd2f2f 100644 --- a/server/routes.go +++ b/server/routes.go @@ -16,6 +16,7 @@ import ( "os" "os/signal" "path/filepath" + "slices" "strconv" "strings" "syscall" @@ -23,7 +24,6 @@ import ( "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" - "golang.org/x/exp/slices" "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" diff --git a/server/sched.go b/server/sched.go index 8c72177f..46fe2f60 100644 --- a/server/sched.go +++ b/server/sched.go @@ -7,17 +7,17 @@ import ( "log/slog" "reflect" "runtime" + "slices" "sort" "strings" "sync" "time" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" - "github.com/ollama/ollama/envconfig" - "golang.org/x/exp/slices" ) type LlmRequest struct { From 55f6eba049ec588baf895dc82afcd809b9acedd3 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 21 May 2024 21:32:43 -0700 Subject: [PATCH 03/26] gofmt --- convert/gemma.go | 1 - 1 file changed, 1 deletion(-) diff --git a/convert/gemma.go b/convert/gemma.go index 9dc406e0..d01ffedf 100644 --- a/convert/gemma.go +++ b/convert/gemma.go @@ -35,7 +35,6 @@ func addOnes(data []float32, vectorSize int) ([]float32, error) { f32s = append(f32s, t...) } - return f32s, nil } From 8ffb51749f7d5f37bb123e50f3c08b4cb50dc693 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 21 May 2024 21:52:20 -0700 Subject: [PATCH 04/26] nolintlint --- cmd/cmd.go | 1 - readline/readline.go | 2 +- server/download.go | 2 +- server/images.go | 2 +- server/upload.go | 2 +- 5 files changed, 4 insertions(+), 5 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 5a451889..6d395805 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1030,7 +1030,6 @@ func initializeKeypair() error { return nil } -//nolint:unused func waitForServer(ctx context.Context, client *api.Client) error { // wait for the server to start timeout := time.After(5 * time.Second) diff --git a/readline/readline.go b/readline/readline.go index ee461ae4..5215d617 100644 --- a/readline/readline.go +++ b/readline/readline.go @@ -81,7 +81,7 @@ func (i *Instance) Readline() (string, error) { defer func() { fd := int(syscall.Stdin) - // nolint: errcheck + //nolint:errcheck UnsetRawMode(fd, i.Terminal.termios) i.Terminal.rawmode = false }() diff --git a/server/download.go b/server/download.go index 5a735abf..937b6754 100644 --- a/server/download.go +++ b/server/download.go @@ -372,7 +372,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { return err } - // nolint: contextcheck + //nolint:contextcheck go download.Run(context.Background(), requestURL, opts.regOpts) } diff --git a/server/images.go b/server/images.go index 61740126..b8497eaa 100644 --- a/server/images.go +++ b/server/images.go @@ -661,7 +661,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}) // save (i.e. delete from the deleteMap) any files used in other manifests manifest, _, err := GetManifest(fmp) if err != nil { - // nolint: nilerr + //nolint:nilerr return nil } diff --git a/server/upload.go b/server/upload.go index 9b52238a..aa775518 100644 --- a/server/upload.go +++ b/server/upload.go @@ -391,7 +391,7 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryO return err } - // nolint: contextcheck + //nolint:contextcheck go upload.Run(context.Background(), opts) } From dad7a987ae93a41751a069394dd6c53e92fec138 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 21 May 2024 21:53:44 -0700 Subject: [PATCH 05/26] nosprintfhostport --- envconfig/config.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index d6699451..875c9039 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -3,6 +3,7 @@ package envconfig import ( "fmt" "log/slog" + "net" "os" "path/filepath" "runtime" @@ -184,8 +185,8 @@ func LoadConfig() { AllowOrigins = append(AllowOrigins, fmt.Sprintf("http://%s", allowOrigin), fmt.Sprintf("https://%s", allowOrigin), - fmt.Sprintf("http://%s:*", allowOrigin), - fmt.Sprintf("https://%s:*", allowOrigin), + fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")), + fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")), ) } From c895a7d13f74c66aee4c68aed75aaeddb7fbcf18 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 21 May 2024 22:07:57 -0700 Subject: [PATCH 06/26] some gocritic --- .golangci.yaml | 2 ++ api/types.go | 2 +- convert/llama.go | 7 ++++--- convert/safetensors.go | 2 +- convert/tokenizer.go | 7 ++----- convert/torch.go | 2 +- envconfig/config.go | 2 +- llm/server.go | 13 +++++++------ server/sched.go | 2 +- types/model/name_test.go | 2 +- 10 files changed, 21 insertions(+), 20 deletions(-) diff --git a/.golangci.yaml b/.golangci.yaml index 7dec49de..df966a16 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -14,4 +14,6 @@ linters: # - goimports - misspell - nilerr + - nolintlint + - nosprintfhostport - unused diff --git a/api/types.go b/api/types.go index 4195a7c5..230f58e8 100644 --- a/api/types.go +++ b/api/types.go @@ -306,7 +306,7 @@ type GenerateResponse struct { // Model is the model name that generated the response. Model string `json:"model"` - //CreatedAt is the timestamp of the response. + // CreatedAt is the timestamp of the response. CreatedAt time.Time `json:"created_at"` // Response is the textual response itself. diff --git a/convert/llama.go b/convert/llama.go index 7853c4cf..b4211b02 100644 --- a/convert/llama.go +++ b/convert/llama.go @@ -119,11 +119,12 @@ func llamaRepack(name string, params *Params, data []float32, shape []uint64) ([ } var heads int - if strings.HasSuffix(name, "attn_q.weight") { + switch { + case strings.HasSuffix(name, "attn_q.weight"): heads = params.AttentionHeads - } else if strings.HasSuffix(name, "attn_k.weight") { + case strings.HasSuffix(name, "attn_k.weight"): heads = cmp.Or(params.KeyValHeads, params.AttentionHeads) - } else { + default: return nil, fmt.Errorf("unknown tensor name: %s", name) } diff --git a/convert/safetensors.go b/convert/safetensors.go index 69270b87..f45687f1 100644 --- a/convert/safetensors.go +++ b/convert/safetensors.go @@ -120,7 +120,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params) Name: name, Kind: kind, Offset: offset, - Shape: shape[:], + Shape: shape, } t.WriterTo = safetensorWriterTo{ diff --git a/convert/tokenizer.go b/convert/tokenizer.go index efeb5491..fd6df5f5 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -85,11 +85,8 @@ func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, e sha256sum := sha256.New() for _, pt := range t.PreTokenizer.PreTokenizers { - switch pt.Type { - case "Split": - if pt.Pattern.Regex != "" { - sha256sum.Write([]byte(pt.Pattern.Regex)) - } + if pt.Type == "Split" && pt.Pattern.Regex != "" { + sha256sum.Write([]byte(pt.Pattern.Regex)) } } diff --git a/convert/torch.go b/convert/torch.go index b7ae0f76..eef41a48 100644 --- a/convert/torch.go +++ b/convert/torch.go @@ -88,7 +88,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, Name: ggufName, Kind: kind, Offset: offset, // calculate the offset - Shape: shape[:], + Shape: shape, } tensor.WriterTo = torchWriterTo{ diff --git a/envconfig/config.go b/envconfig/config.go index 875c9039..77e3e789 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -127,7 +127,7 @@ func LoadConfig() { var paths []string for _, root := range []string{filepath.Dir(appExe), cwd} { paths = append(paths, - filepath.Join(root), + root, filepath.Join(root, "windows-"+runtime.GOARCH), filepath.Join(root, "dist", "windows-"+runtime.GOARCH), ) diff --git a/llm/server.go b/llm/server.go index 3af8a329..f4027865 100644 --- a/llm/server.go +++ b/llm/server.go @@ -104,21 +104,22 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr var layers int layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts) - if gpus[0].Library == "metal" && estimatedVRAM > systemMemory { + switch { + case gpus[0].Library == "metal" && estimatedVRAM > systemMemory: // disable partial offloading when model is greater than total system memory as this // can lead to locking up the system opts.NumGPU = 0 - } else if gpus[0].Library != "metal" && layers == 0 { + case gpus[0].Library != "metal" && layers == 0: // Don't bother loading into the GPU if no layers can fit cpuRunner = serverForCpu() gpuCount = 0 - } else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" { + case opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu": opts.NumGPU = layers } } // Loop through potential servers - finalErr := fmt.Errorf("no suitable llama servers found") + finalErr := errors.New("no suitable llama servers found") if len(adapters) > 1 { return nil, errors.New("ollama supports only one lora adapter, but multiple were provided") @@ -284,7 +285,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr server := filepath.Join(dir, "ollama_llama_server") if runtime.GOOS == "windows" { - server = server + ".exe" + server += ".exe" } // Detect tmp cleaners wiping out the file @@ -459,7 +460,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { resp, err := http.DefaultClient.Do(req) if err != nil { if errors.Is(err, context.DeadlineExceeded) { - return ServerStatusNotResponding, fmt.Errorf("server not responding") + return ServerStatusNotResponding, errors.New("server not responding") } return ServerStatusError, fmt.Errorf("health resp: %w", err) } diff --git a/server/sched.go b/server/sched.go index 46fe2f60..3694b4d0 100644 --- a/server/sched.go +++ b/server/sched.go @@ -66,7 +66,7 @@ func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, opts.NumCtx = 4 } - opts.NumCtx = opts.NumCtx * envconfig.NumParallel + opts.NumCtx *= envconfig.NumParallel req := &LlmRequest{ ctx: c, diff --git a/types/model/name_test.go b/types/model/name_test.go index 26d70ef3..c88fffdb 100644 --- a/types/model/name_test.go +++ b/types/model/name_test.go @@ -325,7 +325,7 @@ func TestParseNameFromFilepath(t *testing.T) { filepath.Join("host:port", "namespace", "model", "tag"): {Host: "host:port", Namespace: "namespace", Model: "model", Tag: "tag"}, filepath.Join("namespace", "model", "tag"): {}, filepath.Join("model", "tag"): {}, - filepath.Join("model"): {}, + "model": {}, filepath.Join("..", "..", "model", "tag"): {}, filepath.Join("", "namespace", ".", "tag"): {}, filepath.Join(".", ".", ".", "."): {}, From e40145a39df0fc8bd6e98ca382806fb02daf8ae1 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 21 May 2024 22:21:04 -0700 Subject: [PATCH 07/26] lint --- .golangci.yaml | 6 ++++ api/types_test.go | 8 +++--- app/lifecycle/paths.go | 1 - app/lifecycle/server.go | 3 +- app/lifecycle/updater.go | 6 ++-- app/store/store.go | 1 - cmd/cmd.go | 2 -- cmd/interactive_test.go | 9 +++--- convert/convert.go | 2 +- convert/torch.go | 1 - format/format_test.go | 1 - gpu/assets.go | 2 +- gpu/gpu_test.go | 5 ++-- llm/gguf.go | 8 +++--- llm/memory.go | 2 +- llm/server.go | 5 ++-- openai/openai.go | 1 - parser/parser_test.go | 31 ++++++++++---------- progress/progress.go | 4 +-- readline/buffer.go | 8 +++--- readline/history.go | 4 +-- readline/readline.go | 12 ++++---- server/images.go | 2 +- server/model.go | 1 - server/modelpath_test.go | 5 ++-- server/routes.go | 4 +-- server/routes_test.go | 61 ++++++++++++++++++++-------------------- server/sched.go | 2 -- server/sched_test.go | 58 ++++++++++++++++++-------------------- server/upload.go | 6 ++-- types/model/name_test.go | 2 -- 31 files changed, 127 insertions(+), 136 deletions(-) diff --git a/.golangci.yaml b/.golangci.yaml index df966a16..9fe1cca8 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -12,8 +12,14 @@ linters: # FIXME: for some reason this errors on windows # - gofmt # - goimports + - intrange - misspell - nilerr - nolintlint - nosprintfhostport + - testifylint + - unconvert - unused + - usestdlibvars + - wastedassign + - whitespace diff --git a/api/types_test.go b/api/types_test.go index cfe1331f..211385c7 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -72,13 +72,13 @@ func TestDurationMarshalUnmarshal(t *testing.T) { }, { "positive duration", - time.Duration(42 * time.Second), - time.Duration(42 * time.Second), + 42 * time.Second, + 42 * time.Second, }, { "another positive duration", - time.Duration(42 * time.Minute), - time.Duration(42 * time.Minute), + 42 * time.Minute, + 42 * time.Minute, }, { "zero duration", diff --git a/app/lifecycle/paths.go b/app/lifecycle/paths.go index e4f2dbd9..fe07bce1 100644 --- a/app/lifecycle/paths.go +++ b/app/lifecycle/paths.go @@ -69,7 +69,6 @@ func init() { slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err)) } } - } else if runtime.GOOS == "darwin" { // TODO AppName += ".app" diff --git a/app/lifecycle/server.go b/app/lifecycle/server.go index 3c11edb8..0152ccd1 100644 --- a/app/lifecycle/server.go +++ b/app/lifecycle/server.go @@ -15,7 +15,7 @@ import ( ) func getCLIFullPath(command string) string { - cmdPath := "" + var cmdPath string appExe, err := os.Executable() if err == nil { cmdPath = filepath.Join(filepath.Dir(appExe), command) @@ -65,7 +65,6 @@ func start(ctx context.Context, command string) (*exec.Cmd, error) { if err != nil { if !errors.Is(err, os.ErrNotExist) { return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err) - } if err := os.MkdirAll(logDir, 0o755); err != nil { diff --git a/app/lifecycle/updater.go b/app/lifecycle/updater.go index 243bbf22..b6d95330 100644 --- a/app/lifecycle/updater.go +++ b/app/lifecycle/updater.go @@ -78,7 +78,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) { } defer resp.Body.Close() - if resp.StatusCode == 204 { + if resp.StatusCode == http.StatusNoContent { slog.Debug("check update response 204 (current version is up to date)") return false, updateResp } @@ -87,7 +87,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) { slog.Warn(fmt.Sprintf("failed to read body response: %s", err)) } - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { slog.Info(fmt.Sprintf("check update error %d - %.96s", resp.StatusCode, string(body))) return false, updateResp } @@ -114,7 +114,7 @@ func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error { if err != nil { return fmt.Errorf("error checking update: %w", err) } - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode) } resp.Body.Close() diff --git a/app/store/store.go b/app/store/store.go index 13a75a60..b743e8a8 100644 --- a/app/store/store.go +++ b/app/store/store.go @@ -29,7 +29,6 @@ func GetID() string { initStore() } return store.ID - } func GetFirstTimeRun() bool { diff --git a/cmd/cmd.go b/cmd/cmd.go index 6d395805..e4cd6d9b 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -746,7 +746,6 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState) if wordWrap && termWidth >= 10 { for _, ch := range content { if state.lineLength+1 > termWidth-5 { - if runewidth.StringWidth(state.wordBuffer) > termWidth-10 { fmt.Printf("%s%c", state.wordBuffer, ch) state.wordBuffer = "" @@ -1044,7 +1043,6 @@ func waitForServer(ctx context.Context, client *api.Client) error { } } } - } func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { diff --git a/cmd/interactive_test.go b/cmd/interactive_test.go index 8eedf729..d9af01eb 100644 --- a/cmd/interactive_test.go +++ b/cmd/interactive_test.go @@ -6,6 +6,7 @@ import ( "text/template" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ollama/ollama/api" ) @@ -85,11 +86,11 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark.""" ` tmpl, err := template.New("").Parse(expectedModelfile) - assert.Nil(t, err) + require.NoError(t, err) var buf bytes.Buffer err = tmpl.Execute(&buf, opts) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, buf.String(), mf) opts.ParentModel = "horseshark" @@ -107,10 +108,10 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark.""" ` tmpl, err = template.New("").Parse(expectedModelfile) - assert.Nil(t, err) + require.NoError(t, err) var parentBuf bytes.Buffer err = tmpl.Execute(&parentBuf, opts) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, parentBuf.String(), mf) } diff --git a/convert/convert.go b/convert/convert.go index e71a0ff3..103de457 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -189,7 +189,7 @@ func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) { if params.VocabSize > len(v.Tokens) { missingTokens := params.VocabSize - len(v.Tokens) slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens)) - for cnt := 0; cnt < missingTokens; cnt++ { + for cnt := range missingTokens { v.Tokens = append(v.Tokens, fmt.Sprintf("", cnt+1)) v.Scores = append(v.Scores, -1) v.Types = append(v.Types, tokenTypeUserDefined) diff --git a/convert/torch.go b/convert/torch.go index eef41a48..55414adc 100644 --- a/convert/torch.go +++ b/convert/torch.go @@ -104,7 +104,6 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, } return tensors, nil - } func getAltParams(dirpath string) (*Params, error) { diff --git a/format/format_test.go b/format/format_test.go index 1d73c80b..bff32780 100644 --- a/format/format_test.go +++ b/format/format_test.go @@ -5,7 +5,6 @@ import ( ) func TestHumanNumber(t *testing.T) { - type testCase struct { input uint64 expected string diff --git a/gpu/assets.go b/gpu/assets.go index e3fbe47c..f2adcf3e 100644 --- a/gpu/assets.go +++ b/gpu/assets.go @@ -80,7 +80,7 @@ func cleanupTmpDirs() { if err == nil { pid, err := strconv.Atoi(string(raw)) if err == nil { - if proc, err := os.FindProcess(int(pid)); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) { + if proc, err := os.FindProcess(pid); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) { // Another running ollama, ignore this tmpdir continue } diff --git a/gpu/gpu_test.go b/gpu/gpu_test.go index a28cbe8c..46d3201e 100644 --- a/gpu/gpu_test.go +++ b/gpu/gpu_test.go @@ -5,11 +5,12 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBasicGetGPUInfo(t *testing.T) { info := GetGPUInfo() - assert.Greater(t, len(info), 0) + assert.NotEmpty(t, len(info)) assert.Contains(t, "cuda rocm cpu metal", info[0].Library) if info[0].Library != "cpu" { assert.Greater(t, info[0].TotalMemory, uint64(0)) @@ -19,7 +20,7 @@ func TestBasicGetGPUInfo(t *testing.T) { func TestCPUMemInfo(t *testing.T) { info, err := GetCPUMem() - assert.NoError(t, err) + require.NoError(t, err) switch runtime.GOOS { case "darwin": t.Skip("CPU memory not populated on darwin") diff --git a/llm/gguf.go b/llm/gguf.go index 0ba48f76..ca7e340d 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -592,8 +592,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error { return err } - dims := 0 - for cnt := 0; cnt < len(tensor.Shape); cnt++ { + var dims int + for cnt := range len(tensor.Shape) { if tensor.Shape[cnt] > 0 { dims++ } @@ -603,8 +603,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error { return err } - for i := 0; i < dims; i++ { - if err := binary.Write(ws, llm.ByteOrder, uint64(tensor.Shape[dims-1-i])); err != nil { + for i := range dims { + if err := binary.Write(ws, llm.ByteOrder, tensor.Shape[dims-1-i]); err != nil { return err } } diff --git a/llm/memory.go b/llm/memory.go index ff64baf1..8b5d8541 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -103,7 +103,7 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts } var layerCount int - for i := 0; i < int(ggml.KV().BlockCount()); i++ { + for i := range int(ggml.KV().BlockCount()) { if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok { memoryLayer := blk.size() diff --git a/llm/server.go b/llm/server.go index f4027865..6cb01fa0 100644 --- a/llm/server.go +++ b/llm/server.go @@ -85,7 +85,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr var systemMemory uint64 gpuCount := len(gpus) if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 { - // TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner cpuRunner = serverForCpu() @@ -233,7 +232,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr params = append(params, "--parallel", fmt.Sprintf("%d", numParallel)) - for i := 0; i < len(servers); i++ { + for i := range len(servers) { dir := availableServers[servers[i]] if dir == "" { // Shouldn't happen @@ -316,7 +315,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr s.cmd.Stdout = os.Stdout s.cmd.Stderr = s.status - visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv() + visibleDevicesEnv, visibleDevicesEnvVal := gpus.GetVisibleDevicesEnv() pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) // Update or add the path and visible devices variable with our adjusted version diff --git a/openai/openai.go b/openai/openai.go index 7ce29e9f..310051a5 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -245,7 +245,6 @@ func (w *writer) writeResponse(data []byte) (int, error) { d, err := json.Marshal(toChunk(w.id, chatResponse)) if err != nil { return 0, err - } w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") diff --git a/parser/parser_test.go b/parser/parser_test.go index 21223cb1..55660590 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -10,6 +10,7 @@ import ( "unicode/utf16" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestParseFileFile(t *testing.T) { @@ -25,7 +26,7 @@ TEMPLATE template1 reader := strings.NewReader(input) modelfile, err := ParseFile(reader) - assert.NoError(t, err) + require.NoError(t, err) expectedCommands := []Command{ {Name: "model", Args: "model1"}, @@ -88,7 +89,7 @@ func TestParseFileFrom(t *testing.T) { for _, c := range cases { t.Run("", func(t *testing.T) { modelfile, err := ParseFile(strings.NewReader(c.input)) - assert.ErrorIs(t, err, c.err) + require.ErrorIs(t, err, c.err) if modelfile != nil { assert.Equal(t, c.expected, modelfile.Commands) } @@ -105,7 +106,7 @@ PARAMETER param1 reader := strings.NewReader(input) _, err := ParseFile(reader) - assert.ErrorIs(t, err, io.ErrUnexpectedEOF) + require.ErrorIs(t, err, io.ErrUnexpectedEOF) } func TestParseFileBadCommand(t *testing.T) { @@ -114,8 +115,7 @@ FROM foo BADCOMMAND param1 value1 ` _, err := ParseFile(strings.NewReader(input)) - assert.ErrorIs(t, err, errInvalidCommand) - + require.ErrorIs(t, err, errInvalidCommand) } func TestParseFileMessages(t *testing.T) { @@ -201,7 +201,7 @@ MESSAGE system`, for _, c := range cases { t.Run("", func(t *testing.T) { modelfile, err := ParseFile(strings.NewReader(c.input)) - assert.ErrorIs(t, err, c.err) + require.ErrorIs(t, err, c.err) if modelfile != nil { assert.Equal(t, c.expected, modelfile.Commands) } @@ -355,7 +355,7 @@ TEMPLATE """ for _, c := range cases { t.Run("", func(t *testing.T) { modelfile, err := ParseFile(strings.NewReader(c.multiline)) - assert.ErrorIs(t, err, c.err) + require.ErrorIs(t, err, c.err) if modelfile != nil { assert.Equal(t, c.expected, modelfile.Commands) } @@ -413,7 +413,7 @@ func TestParseFileParameters(t *testing.T) { fmt.Fprintln(&b, "FROM foo") fmt.Fprintln(&b, "PARAMETER", k) modelfile, err := ParseFile(&b) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, []Command{ {Name: "model", Args: "foo"}, @@ -442,7 +442,7 @@ FROM foo for _, c := range cases { t.Run("", func(t *testing.T) { modelfile, err := ParseFile(strings.NewReader(c.input)) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, c.expected, modelfile.Commands) }) } @@ -501,15 +501,14 @@ SYSTEM "" for _, c := range cases { t.Run("", func(t *testing.T) { modelfile, err := ParseFile(strings.NewReader(c)) - assert.NoError(t, err) + require.NoError(t, err) modelfile2, err := ParseFile(strings.NewReader(modelfile.String())) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, modelfile, modelfile2) }) } - } func TestParseFileUTF16ParseFile(t *testing.T) { @@ -522,10 +521,10 @@ SYSTEM You are a utf16 file. utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...)) buf := new(bytes.Buffer) err := binary.Write(buf, binary.LittleEndian, utf16File) - assert.NoError(t, err) + require.NoError(t, err) actual, err := ParseFile(buf) - assert.NoError(t, err) + require.NoError(t, err) expected := []Command{ {Name: "model", Args: "bob"}, @@ -539,9 +538,9 @@ SYSTEM You are a utf16 file. // simulate a utf16 be file buf = new(bytes.Buffer) err = binary.Write(buf, binary.BigEndian, utf16File) - assert.NoError(t, err) + require.NoError(t, err) actual, err = ParseFile(buf) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, expected, actual.Commands) } diff --git a/progress/progress.go b/progress/progress.go index 556ba00f..102830a8 100644 --- a/progress/progress.go +++ b/progress/progress.go @@ -59,7 +59,7 @@ func (p *Progress) StopAndClear() bool { stopped := p.stop() if stopped { // clear all progress lines - for i := 0; i < p.pos; i++ { + for i := range p.pos { if i > 0 { fmt.Fprint(p.w, "\033[A") } @@ -85,7 +85,7 @@ func (p *Progress) render() { defer fmt.Fprint(p.w, "\033[?25h") // clear already rendered progress lines - for i := 0; i < p.pos; i++ { + for i := range p.pos { if i > 0 { fmt.Fprint(p.w, "\033[A") } diff --git a/readline/buffer.go b/readline/buffer.go index 2c3bfec9..c5ac3f26 100644 --- a/readline/buffer.go +++ b/readline/buffer.go @@ -154,7 +154,7 @@ func (b *Buffer) MoveToStart() { if b.Pos > 0 { currLine := b.DisplayPos / b.LineWidth if currLine > 0 { - for cnt := 0; cnt < currLine; cnt++ { + for range currLine { fmt.Print(CursorUp) } } @@ -169,7 +169,7 @@ func (b *Buffer) MoveToEnd() { currLine := b.DisplayPos / b.LineWidth totalLines := b.DisplaySize() / b.LineWidth if currLine < totalLines { - for cnt := 0; cnt < totalLines-currLine; cnt++ { + for range totalLines - currLine { fmt.Print(CursorDown) } remainder := b.DisplaySize() % b.LineWidth @@ -451,7 +451,7 @@ func (b *Buffer) DeleteBefore() { func (b *Buffer) DeleteRemaining() { if b.DisplaySize() > 0 && b.Pos < b.DisplaySize() { charsToDel := b.Buf.Size() - b.Pos - for cnt := 0; cnt < charsToDel; cnt++ { + for range charsToDel { b.Delete() } } @@ -495,7 +495,7 @@ func (b *Buffer) ClearScreen() { if currPos > 0 { targetLine := currPos / b.LineWidth if targetLine > 0 { - for cnt := 0; cnt < targetLine; cnt++ { + for range targetLine { fmt.Print(CursorDown) } } diff --git a/readline/history.go b/readline/history.go index 670a0722..9c6b930b 100644 --- a/readline/history.go +++ b/readline/history.go @@ -91,7 +91,7 @@ func (h *History) Add(l []rune) { func (h *History) Compact() { s := h.Buf.Size() if s > h.Limit { - for cnt := 0; cnt < s-h.Limit; cnt++ { + for range s - h.Limit { h.Buf.Remove(0) } } @@ -139,7 +139,7 @@ func (h *History) Save() error { defer f.Close() buf := bufio.NewWriter(f) - for cnt := 0; cnt < h.Size(); cnt++ { + for cnt := range h.Size() { v, _ := h.Buf.Get(cnt) line, _ := v.([]rune) if _, err := buf.WriteString(string(line) + "\n"); err != nil { diff --git a/readline/readline.go b/readline/readline.go index 5215d617..d8c3676e 100644 --- a/readline/readline.go +++ b/readline/readline.go @@ -63,7 +63,7 @@ func New(prompt Prompt) (*Instance, error) { func (i *Instance) Readline() (string, error) { if !i.Terminal.rawmode { - fd := int(syscall.Stdin) + fd := syscall.Stdin termios, err := SetRawMode(fd) if err != nil { return "", err @@ -80,7 +80,7 @@ func (i *Instance) Readline() (string, error) { fmt.Print(prompt) defer func() { - fd := int(syscall.Stdin) + fd := syscall.Stdin //nolint:errcheck UnsetRawMode(fd, i.Terminal.termios) i.Terminal.rawmode = false @@ -136,7 +136,7 @@ func (i *Instance) Readline() (string, error) { buf.MoveRight() case CharBracketedPaste: var code string - for cnt := 0; cnt < 3; cnt++ { + for range 3 { r, err = i.Terminal.Read() if err != nil { return "", io.EOF @@ -198,7 +198,7 @@ func (i *Instance) Readline() (string, error) { buf.Remove() case CharTab: // todo: convert back to real tabs - for cnt := 0; cnt < 8; cnt++ { + for range 8 { buf.Add(' ') } case CharDelete: @@ -216,7 +216,7 @@ func (i *Instance) Readline() (string, error) { case CharCtrlW: buf.DeleteWord() case CharCtrlZ: - fd := int(syscall.Stdin) + fd := syscall.Stdin return handleCharCtrlZ(fd, i.Terminal.termios) case CharEnter, CharCtrlJ: output := buf.String() @@ -248,7 +248,7 @@ func (i *Instance) HistoryDisable() { } func NewTerminal() (*Terminal, error) { - fd := int(syscall.Stdin) + fd := syscall.Stdin termios, err := SetRawMode(fd) if err != nil { return nil, err diff --git a/server/images.go b/server/images.go index b8497eaa..6f23f0c1 100644 --- a/server/images.go +++ b/server/images.go @@ -987,7 +987,7 @@ func getTokenSubject(token string) string { func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) { anonymous := true // access will default to anonymous if no user is found associated with the public key - for i := 0; i < 2; i++ { + for range 2 { resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) if err != nil { if !errors.Is(err, context.Canceled) { diff --git a/server/model.go b/server/model.go index fcf406f6..4f76284d 100644 --- a/server/model.go +++ b/server/model.go @@ -72,7 +72,6 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe default: layers = append(layers, &layerWithGGML{layer, nil}) } - } return layers, nil diff --git a/server/modelpath_test.go b/server/modelpath_test.go index 30741d87..849e0fa7 100644 --- a/server/modelpath_test.go +++ b/server/modelpath_test.go @@ -6,12 +6,13 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGetBlobsPath(t *testing.T) { // GetBlobsPath expects an actual directory to exist dir, err := os.MkdirTemp("", "ollama-test") - assert.Nil(t, err) + require.NoError(t, err) defer os.RemoveAll(dir) tests := []struct { @@ -63,7 +64,7 @@ func TestGetBlobsPath(t *testing.T) { got, err := GetBlobsPath(tc.digest) - assert.ErrorIs(t, tc.err, err, tc.name) + require.ErrorIs(t, tc.err, err, tc.name) assert.Equal(t, tc.expected, got, tc.name) }) } diff --git a/server/routes.go b/server/routes.go index 7fcd2f2f..cc718156 100644 --- a/server/routes.go +++ b/server/routes.go @@ -77,7 +77,6 @@ func isSupportedImageType(image []byte) bool { } func (s *Server) GenerateHandler(c *gin.Context) { - checkpointStart := time.Now() var req api.GenerateRequest err := c.ShouldBindJSON(&req) @@ -942,7 +941,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc { } if allowedHost(host) { - if c.Request.Method == "OPTIONS" { + if c.Request.Method == http.MethodOptions { c.AbortWithStatus(http.StatusNoContent) return } @@ -1306,7 +1305,6 @@ func (s *Server) ChatHandler(c *gin.Context) { defer close(ch) fn := func(r llm.CompletionResponse) { - resp := api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), diff --git a/server/routes_test.go b/server/routes_test.go index 74933b1f..1fc258e0 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -15,6 +15,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ollama/ollama/api" "github.com/ollama/ollama/parser" @@ -25,20 +26,20 @@ func createTestFile(t *testing.T, name string) string { t.Helper() f, err := os.CreateTemp(t.TempDir(), name) - assert.Nil(t, err) + assert.NoError(t, err) defer f.Close() err = binary.Write(f, binary.LittleEndian, []byte("GGUF")) - assert.Nil(t, err) + assert.NoError(t, err) err = binary.Write(f, binary.LittleEndian, uint32(3)) - assert.Nil(t, err) + assert.NoError(t, err) err = binary.Write(f, binary.LittleEndian, uint64(0)) - assert.Nil(t, err) + assert.NoError(t, err) err = binary.Write(f, binary.LittleEndian, uint64(0)) - assert.Nil(t, err) + assert.NoError(t, err) return f.Name() } @@ -57,12 +58,12 @@ func Test_Routes(t *testing.T) { r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) modelfile, err := parser.ParseFile(r) - assert.Nil(t, err) + require.NoError(t, err) fn := func(resp api.ProgressResponse) { t.Logf("Status: %s", resp.Status) } err = CreateModel(context.TODO(), name, "", "", modelfile, fn) - assert.Nil(t, err) + require.NoError(t, err) } testCases := []testCase{ @@ -74,9 +75,9 @@ func Test_Routes(t *testing.T) { }, Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, contentType, "application/json; charset=utf-8") + assert.Equal(t, "application/json; charset=utf-8", contentType) body, err := io.ReadAll(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body)) }, }, @@ -86,17 +87,17 @@ func Test_Routes(t *testing.T) { Path: "/api/tags", Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, contentType, "application/json; charset=utf-8") + assert.Equal(t, "application/json; charset=utf-8", contentType) body, err := io.ReadAll(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) var modelList api.ListResponse err = json.Unmarshal(body, &modelList) - assert.Nil(t, err) + require.NoError(t, err) assert.NotNil(t, modelList.Models) - assert.Equal(t, 0, len(modelList.Models)) + assert.Empty(t, len(modelList.Models)) }, }, { @@ -108,16 +109,16 @@ func Test_Routes(t *testing.T) { }, Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, contentType, "application/json; charset=utf-8") + assert.Equal(t, "application/json; charset=utf-8", contentType) body, err := io.ReadAll(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) var modelList api.ListResponse err = json.Unmarshal(body, &modelList) - assert.Nil(t, err) + require.NoError(t, err) - assert.Equal(t, 1, len(modelList.Models)) - assert.Equal(t, modelList.Models[0].Name, "test-model:latest") + assert.Len(t, modelList.Models, 1) + assert.Equal(t, "test-model:latest", modelList.Models[0].Name) }, }, { @@ -134,7 +135,7 @@ func Test_Routes(t *testing.T) { Stream: &stream, } jsonData, err := json.Marshal(createReq) - assert.Nil(t, err) + require.NoError(t, err) req.Body = io.NopCloser(bytes.NewReader(jsonData)) }, @@ -142,11 +143,11 @@ func Test_Routes(t *testing.T) { contentType := resp.Header.Get("Content-Type") assert.Equal(t, "application/json", contentType) _, err := io.ReadAll(resp.Body) - assert.Nil(t, err) - assert.Equal(t, resp.StatusCode, 200) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) model, err := GetModel("t-bone") - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "t-bone:latest", model.ShortName) }, }, @@ -161,13 +162,13 @@ func Test_Routes(t *testing.T) { Destination: "beefsteak", } jsonData, err := json.Marshal(copyReq) - assert.Nil(t, err) + require.NoError(t, err) req.Body = io.NopCloser(bytes.NewReader(jsonData)) }, Expected: func(t *testing.T, resp *http.Response) { model, err := GetModel("beefsteak") - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "beefsteak:latest", model.ShortName) }, }, @@ -179,18 +180,18 @@ func Test_Routes(t *testing.T) { createTestModel(t, "show-model") showReq := api.ShowRequest{Model: "show-model"} jsonData, err := json.Marshal(showReq) - assert.Nil(t, err) + require.NoError(t, err) req.Body = io.NopCloser(bytes.NewReader(jsonData)) }, Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, contentType, "application/json; charset=utf-8") + assert.Equal(t, "application/json; charset=utf-8", contentType) body, err := io.ReadAll(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) var showResp api.ShowResponse err = json.Unmarshal(body, &showResp) - assert.Nil(t, err) + require.NoError(t, err) var params []string paramsSplit := strings.Split(showResp.Parameters, "\n") @@ -221,14 +222,14 @@ func Test_Routes(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { u := httpSrv.URL + tc.Path req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) - assert.Nil(t, err) + require.NoError(t, err) if tc.Setup != nil { tc.Setup(t, req) } resp, err := httpSrv.Client().Do(req) - assert.Nil(t, err) + require.NoError(t, err) defer resp.Body.Close() if tc.Expected != nil { diff --git a/server/sched.go b/server/sched.go index 3694b4d0..c36486f7 100644 --- a/server/sched.go +++ b/server/sched.go @@ -370,7 +370,6 @@ func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) { r.refMu.Lock() gpuIDs := make([]string, 0, len(r.gpus)) if r.llama != nil { - // TODO this should be broken down by GPU instead of assuming uniform spread estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus)) for _, gpu := range r.gpus { @@ -529,7 +528,6 @@ func (runner *runnerRef) waitForVRAMRecovery() chan interface{} { } }() return finished - } type ByDuration []*runnerRef diff --git a/server/sched_test.go b/server/sched_test.go index 3ee1b989..f7dce6d1 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -12,11 +12,10 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/app/lifecycle" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" - "github.com/ollama/ollama/envconfig" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -53,10 +52,10 @@ func TestLoad(t *testing.T) { } gpus := gpu.GpuInfoList{} s.load(req, ggml, gpus) - require.Len(t, req.successCh, 0) + require.Empty(t, req.successCh) require.Len(t, req.errCh, 1) s.loadedMu.Lock() - require.Len(t, s.loaded, 0) + require.Empty(t, s.loaded) s.loadedMu.Unlock() err := <-req.errCh require.Contains(t, err.Error(), "this model may be incompatible") @@ -113,7 +112,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV t.Helper() f, err := os.CreateTemp(t.TempDir(), modelName) - assert.Nil(t, err) + require.NoError(t, err) defer f.Close() gguf := llm.NewGGUFV3(binary.LittleEndian) @@ -131,7 +130,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV }, []llm.Tensor{ {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, }) - assert.Nil(t, err) + require.NoError(t, err) fname := f.Name() model := &Model{Name: modelName, ModelPath: fname} @@ -190,8 +189,8 @@ func TestRequests(t *testing.T) { select { case resp := <-scenario1a.req.successCh: require.Equal(t, resp.llama, scenario1a.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, scenario1a.req.errCh, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario1a.req.errCh) case <-ctx.Done(): t.Errorf("timeout") } @@ -203,8 +202,8 @@ func TestRequests(t *testing.T) { select { case resp := <-scenario1b.req.successCh: require.Equal(t, resp.llama, scenario1a.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, scenario1b.req.errCh, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario1b.req.errCh) case <-ctx.Done(): t.Errorf("timeout") } @@ -221,8 +220,8 @@ func TestRequests(t *testing.T) { select { case resp := <-scenario2a.req.successCh: require.Equal(t, resp.llama, scenario2a.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, scenario2a.req.errCh, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario2a.req.errCh) case <-ctx.Done(): t.Errorf("timeout") } @@ -237,8 +236,8 @@ func TestRequests(t *testing.T) { select { case resp := <-scenario3a.req.successCh: require.Equal(t, resp.llama, scenario3a.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, scenario3a.req.errCh, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario3a.req.errCh) case <-ctx.Done(): t.Errorf("timeout") } @@ -253,8 +252,8 @@ func TestRequests(t *testing.T) { select { case resp := <-scenario3b.req.successCh: require.Equal(t, resp.llama, scenario3b.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, scenario3b.req.errCh, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario3b.req.errCh) case <-ctx.Done(): t.Errorf("timeout") } @@ -269,8 +268,8 @@ func TestRequests(t *testing.T) { select { case resp := <-scenario3c.req.successCh: require.Equal(t, resp.llama, scenario3c.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, scenario3c.req.errCh, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario3c.req.errCh) case <-ctx.Done(): t.Errorf("timeout") } @@ -296,8 +295,8 @@ func TestRequests(t *testing.T) { select { case resp := <-scenario3d.req.successCh: require.Equal(t, resp.llama, scenario3d.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, scenario3d.req.errCh, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario3d.req.errCh) case <-ctx.Done(): t.Errorf("timeout") } @@ -332,7 +331,7 @@ func TestGetRunner(t *testing.T) { slog.Info("scenario1b") successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration) require.Len(t, s.pendingReqCh, 1) - require.Len(t, successCh1b, 0) + require.Empty(t, successCh1b) require.Len(t, errCh1b, 1) err := <-errCh1b require.Contains(t, err.Error(), "server busy") @@ -340,8 +339,8 @@ func TestGetRunner(t *testing.T) { select { case resp := <-successCh1a: require.Equal(t, resp.llama, scenario1a.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, errCh1a, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, errCh1a) case <-ctx.Done(): t.Errorf("timeout") } @@ -355,9 +354,9 @@ func TestGetRunner(t *testing.T) { successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration) // Starts in pending channel, then should be quickly processsed to return an error time.Sleep(5 * time.Millisecond) - require.Len(t, successCh1c, 0) + require.Empty(t, successCh1c) s.loadedMu.Lock() - require.Len(t, s.loaded, 0) + require.Empty(t, s.loaded) s.loadedMu.Unlock() require.Len(t, errCh1c, 1) err = <-errCh1c @@ -386,8 +385,8 @@ func TestPrematureExpired(t *testing.T) { select { case resp := <-successCh1a: require.Equal(t, resp.llama, scenario1a.srv) - require.Len(t, s.pendingReqCh, 0) - require.Len(t, errCh1a, 0) + require.Empty(t, s.pendingReqCh) + require.Empty(t, errCh1a) s.loadedMu.Lock() require.Len(t, s.loaded, 1) s.loadedMu.Unlock() @@ -401,9 +400,9 @@ func TestPrematureExpired(t *testing.T) { time.Sleep(20 * time.Millisecond) require.LessOrEqual(t, len(s.finishedReqCh), 1) time.Sleep(10 * time.Millisecond) - require.Len(t, s.finishedReqCh, 0) + require.Empty(t, s.finishedReqCh) s.loadedMu.Lock() - require.Len(t, s.loaded, 0) + require.Empty(t, s.loaded) s.loadedMu.Unlock() // also shouldn't happen in real life @@ -487,7 +486,6 @@ func TestFindRunnerToUnload(t *testing.T) { r2.refCount = 1 resp = s.findRunnerToUnload() require.Equal(t, r1, resp) - } func TestNeedsReload(t *testing.T) { diff --git a/server/upload.go b/server/upload.go index aa775518..73ce78ce 100644 --- a/server/upload.go +++ b/server/upload.go @@ -146,7 +146,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) { case requestURL := <-b.nextURL: g.Go(func() error { var err error - for try := 0; try < maxRetries; try++ { + for try := range maxRetries { err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts) switch { case errors.Is(err, context.Canceled): @@ -190,7 +190,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) { headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Length", "0") - for try := 0; try < maxRetries; try++ { + for try := range maxRetries { var resp *http.Response resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts) if errors.Is(err, context.Canceled) { @@ -253,7 +253,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL * } // retry uploading to the redirect URL - for try := 0; try < maxRetries; try++ { + for try := range maxRetries { err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil) switch { case errors.Is(err, context.Canceled): diff --git a/types/model/name_test.go b/types/model/name_test.go index c88fffdb..117d6333 100644 --- a/types/model/name_test.go +++ b/types/model/name_test.go @@ -268,7 +268,6 @@ func TestNameIsValidPart(t *testing.T) { } }) } - } func TestFilepathAllocs(t *testing.T) { @@ -382,7 +381,6 @@ func FuzzName(f *testing.F) { t.Errorf("String() = %q; want %q", n.String(), s) } } - }) } From 201d853fdf4d08bdbf0c7e58087ba5d574feace1 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 22 May 2024 08:52:00 -0700 Subject: [PATCH 08/26] nolintlint --- cmd/cmd.go | 16 ---------------- cmd/start.go | 27 +++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 16 deletions(-) create mode 100644 cmd/start.go diff --git a/cmd/cmd.go b/cmd/cmd.go index e4cd6d9b..b5747543 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1029,22 +1029,6 @@ func initializeKeypair() error { return nil } -func waitForServer(ctx context.Context, client *api.Client) error { - // wait for the server to start - timeout := time.After(5 * time.Second) - tick := time.Tick(500 * time.Millisecond) - for { - select { - case <-timeout: - return errors.New("timed out waiting for server to start") - case <-tick: - if err := client.Heartbeat(ctx); err == nil { - return nil // server has started - } - } - } -} - func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { client, err := api.ClientFromEnvironment() if err != nil { diff --git a/cmd/start.go b/cmd/start.go new file mode 100644 index 00000000..0c4eed08 --- /dev/null +++ b/cmd/start.go @@ -0,0 +1,27 @@ +//go:build darwin || windows + +package cmd + +import ( + "context" + "errors" + "time" + + "github.com/ollama/ollama/api" +) + +func waitForServer(ctx context.Context, client *api.Client) error { + // wait for the server to start + timeout := time.After(5 * time.Second) + tick := time.Tick(500 * time.Millisecond) + for { + select { + case <-timeout: + return errors.New("timed out waiting for server to start") + case <-tick: + if err := client.Heartbeat(ctx); err == nil { + return nil // server has started + } + } + } +} From f38353d6b91cf9c4f9c086eddfa6d337172d281f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 22 May 2024 09:00:38 -0700 Subject: [PATCH 09/26] stdin.fd --- readline/readline.go | 9 ++++----- readline/readline_unix.go | 2 +- readline/readline_windows.go | 2 +- readline/term.go | 6 +++--- readline/term_bsd.go | 8 ++++---- readline/term_linux.go | 8 ++++---- readline/term_windows.go | 6 +++--- 7 files changed, 20 insertions(+), 21 deletions(-) diff --git a/readline/readline.go b/readline/readline.go index d8c3676e..e90a5e01 100644 --- a/readline/readline.go +++ b/readline/readline.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "os" - "syscall" ) type Prompt struct { @@ -63,7 +62,7 @@ func New(prompt Prompt) (*Instance, error) { func (i *Instance) Readline() (string, error) { if !i.Terminal.rawmode { - fd := syscall.Stdin + fd := os.Stdin.Fd() termios, err := SetRawMode(fd) if err != nil { return "", err @@ -80,7 +79,7 @@ func (i *Instance) Readline() (string, error) { fmt.Print(prompt) defer func() { - fd := syscall.Stdin + fd := os.Stdin.Fd() //nolint:errcheck UnsetRawMode(fd, i.Terminal.termios) i.Terminal.rawmode = false @@ -216,7 +215,7 @@ func (i *Instance) Readline() (string, error) { case CharCtrlW: buf.DeleteWord() case CharCtrlZ: - fd := syscall.Stdin + fd := os.Stdin.Fd() return handleCharCtrlZ(fd, i.Terminal.termios) case CharEnter, CharCtrlJ: output := buf.String() @@ -248,7 +247,7 @@ func (i *Instance) HistoryDisable() { } func NewTerminal() (*Terminal, error) { - fd := syscall.Stdin + fd := os.Stdin.Fd() termios, err := SetRawMode(fd) if err != nil { return nil, err diff --git a/readline/readline_unix.go b/readline/readline_unix.go index 76cff8c8..d48b9176 100644 --- a/readline/readline_unix.go +++ b/readline/readline_unix.go @@ -6,7 +6,7 @@ import ( "syscall" ) -func handleCharCtrlZ(fd int, termios any) (string, error) { +func handleCharCtrlZ(fd uintptr, termios any) (string, error) { t := termios.(*Termios) if err := UnsetRawMode(fd, t); err != nil { return "", err diff --git a/readline/readline_windows.go b/readline/readline_windows.go index b4e96b25..a131d0ef 100644 --- a/readline/readline_windows.go +++ b/readline/readline_windows.go @@ -1,6 +1,6 @@ package readline -func handleCharCtrlZ(fd int, state any) (string, error) { +func handleCharCtrlZ(fd uintptr, state any) (string, error) { // not supported return "", nil } diff --git a/readline/term.go b/readline/term.go index 9d747162..5584cd25 100644 --- a/readline/term.go +++ b/readline/term.go @@ -8,7 +8,7 @@ import ( type Termios syscall.Termios -func SetRawMode(fd int) (*Termios, error) { +func SetRawMode(fd uintptr) (*Termios, error) { termios, err := getTermios(fd) if err != nil { return nil, err @@ -25,13 +25,13 @@ func SetRawMode(fd int) (*Termios, error) { return termios, setTermios(fd, &newTermios) } -func UnsetRawMode(fd int, termios any) error { +func UnsetRawMode(fd uintptr, termios any) error { t := termios.(*Termios) return setTermios(fd, t) } // IsTerminal returns true if the given file descriptor is a terminal. -func IsTerminal(fd int) bool { +func IsTerminal(fd uintptr) bool { _, err := getTermios(fd) return err == nil } diff --git a/readline/term_bsd.go b/readline/term_bsd.go index 04b912f0..80bee6b3 100644 --- a/readline/term_bsd.go +++ b/readline/term_bsd.go @@ -7,17 +7,17 @@ import ( "unsafe" ) -func getTermios(fd int) (*Termios, error) { +func getTermios(fd uintptr) (*Termios, error) { termios := new(Termios) - _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0) + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0) if err != 0 { return nil, err } return termios, nil } -func setTermios(fd int, termios *Termios) error { - _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0) +func setTermios(fd uintptr, termios *Termios) error { + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0) if err != 0 { return err } diff --git a/readline/term_linux.go b/readline/term_linux.go index 2d6211dd..e9ed0745 100644 --- a/readline/term_linux.go +++ b/readline/term_linux.go @@ -10,17 +10,17 @@ import ( const tcgets = 0x5401 const tcsets = 0x5402 -func getTermios(fd int) (*Termios, error) { +func getTermios(fd uintptr) (*Termios, error) { termios := new(Termios) - _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0) + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0) if err != 0 { return nil, err } return termios, nil } -func setTermios(fd int, termios *Termios) error { - _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0) +func setTermios(fd uintptr, termios *Termios) error { + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0) if err != 0 { return err } diff --git a/readline/term_windows.go b/readline/term_windows.go index cfdfd672..3b35149b 100644 --- a/readline/term_windows.go +++ b/readline/term_windows.go @@ -9,13 +9,13 @@ type State struct { } // IsTerminal checks if the given file descriptor is associated with a terminal -func IsTerminal(fd int) bool { +func IsTerminal(fd uintptr) bool { var st uint32 err := windows.GetConsoleMode(windows.Handle(fd), &st) return err == nil } -func SetRawMode(fd int) (*State, error) { +func SetRawMode(fd uintptr) (*State, error) { var st uint32 if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil { return nil, err @@ -32,7 +32,7 @@ func SetRawMode(fd int) (*State, error) { return &State{st}, nil } -func UnsetRawMode(fd int, state any) error { +func UnsetRawMode(fd uintptr, state any) error { s := state.(*State) return windows.SetConsoleMode(windows.Handle(fd), s.mode) } From bf7edb0d5d1a52bab51d83d1558762ad9eb3dc81 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 22 May 2024 09:08:01 -0700 Subject: [PATCH 10/26] lint linux --- gpu/cuda_common.go | 1 - gpu/gpu.go | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/gpu/cuda_common.go b/gpu/cuda_common.go index 03c1a25b..c90a644c 100644 --- a/gpu/cuda_common.go +++ b/gpu/cuda_common.go @@ -18,5 +18,4 @@ func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { ids = append(ids, info.ID) } return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",") - } diff --git a/gpu/gpu.go b/gpu/gpu.go index 03e16702..73ef1358 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -187,7 +187,7 @@ func GetGPUInfo() GpuInfoList { resp := []GpuInfo{} // NVIDIA first - for i := 0; i < gpuHandles.deviceCount; i++ { + for i := range gpuHandles.deviceCount { // TODO once we support CPU compilation variants of GPU libraries refine this... if cpuVariant == "" && runtime.GOARCH == "amd64" { continue @@ -221,8 +221,8 @@ func GetGPUInfo() GpuInfoList { gpuInfo.MinimumMemory = cudaMinimumMemory gpuInfo.DependencyPath = depPath gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - gpuInfo.DriverMajor = int(driverMajor) - gpuInfo.DriverMinor = int(driverMinor) + gpuInfo.DriverMajor = driverMajor + gpuInfo.DriverMinor = driverMinor // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... resp = append(resp, gpuInfo) From e919f6811f7933b120f783e5003727c91dae467f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 22 May 2024 09:26:45 -0700 Subject: [PATCH 11/26] lint windows --- app/lifecycle/server_windows.go | 6 ++++-- app/tray/wintray/eventloop.go | 5 ++--- app/tray/wintray/tray.go | 10 +++------- gpu/amd_windows.go | 2 +- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/app/lifecycle/server_windows.go b/app/lifecycle/server_windows.go index cd4244ff..5f9fe124 100644 --- a/app/lifecycle/server_windows.go +++ b/app/lifecycle/server_windows.go @@ -24,7 +24,8 @@ func terminate(cmd *exec.Cmd) error { if err != nil { return err } - defer dll.Release() // nolint: errcheck + //nolint:errcheck + defer dll.Release() pid := cmd.Process.Pid @@ -73,7 +74,8 @@ func isProcessExited(pid int) (bool, error) { if err != nil { return false, fmt.Errorf("failed to open process: %v", err) } - defer windows.CloseHandle(hProcess) // nolint: errcheck + //nolint:errcheck + defer windows.CloseHandle(hProcess) var exitCode uint32 err = windows.GetExitCodeProcess(hProcess, &exitCode) diff --git a/app/tray/wintray/eventloop.go b/app/tray/wintray/eventloop.go index a0af9787..0f944894 100644 --- a/app/tray/wintray/eventloop.go +++ b/app/tray/wintray/eventloop.go @@ -47,7 +47,6 @@ func nativeLoop() { default: pTranslateMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck pDispatchMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck - } } } @@ -160,8 +159,8 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui lResult, _, _ = pDefWindowProc.Call( uintptr(hWnd), uintptr(message), - uintptr(wParam), - uintptr(lParam), + wParam, + lParam, ) } return diff --git a/app/tray/wintray/tray.go b/app/tray/wintray/tray.go index 69d4487d..027ec5a5 100644 --- a/app/tray/wintray/tray.go +++ b/app/tray/wintray/tray.go @@ -186,7 +186,7 @@ func (t *winTray) initInstance() error { t.muNID.Lock() defer t.muNID.Unlock() t.nid = ¬ifyIconData{ - Wnd: windows.Handle(t.window), + Wnd: t.window, ID: 100, Flags: NIF_MESSAGE, CallbackMessage: t.wmSystrayMessage, @@ -197,7 +197,6 @@ func (t *winTray) initInstance() error { } func (t *winTray) createMenu() error { - menuHandle, _, err := pCreatePopupMenu.Call() if menuHandle == 0 { return err @@ -246,7 +245,7 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title mi := menuItemInfo{ Mask: MIIM_FTYPE | MIIM_STRING | MIIM_ID | MIIM_STATE, Type: MFT_STRING, - ID: uint32(menuItemId), + ID: menuItemId, TypeData: titlePtr, Cch: uint32(len(title)), } @@ -302,11 +301,10 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title } func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error { - mi := menuItemInfo{ Mask: MIIM_FTYPE | MIIM_ID | MIIM_STATE, Type: MFT_SEPARATOR, - ID: uint32(menuItemId), + ID: menuItemId, } mi.Size = uint32(unsafe.Sizeof(mi)) @@ -426,7 +424,6 @@ func iconBytesToFilePath(iconBytes []byte) (string, error) { // Loads an image from file and shows it in tray. // Shell_NotifyIcon: https://msdn.microsoft.com/en-us/library/windows/desktop/bb762159(v=vs.85).aspx func (t *winTray) setIcon(src string) error { - h, err := t.loadIconFrom(src) if err != nil { return err @@ -444,7 +441,6 @@ func (t *winTray) setIcon(src string) error { // Loads an image from file to be shown in tray or menu item. // LoadImage: https://msdn.microsoft.com/en-us/library/windows/desktop/ms648045(v=vs.85).aspx func (t *winTray) loadIconFrom(src string) (windows.Handle, error) { - // Save and reuse handles of loaded images t.muLoadedImages.RLock() h, ok := t.loadedImages[src] diff --git a/gpu/amd_windows.go b/gpu/amd_windows.go index aae6c5b7..987b543b 100644 --- a/gpu/amd_windows.go +++ b/gpu/amd_windows.go @@ -65,7 +65,7 @@ func AMDGetGPUInfo() []GpuInfo { slog.Debug("detected hip devices", "count", count) // TODO how to determine the underlying device ID when visible devices is causing this to subset? - for i := 0; i < count; i++ { + for i := range count { err = hl.HipSetDevice(i) if err != nil { slog.Warn("set device", "id", i, "error", err) From 42660466f8721a9f4bcbf906a1e6ed7f4c401c0a Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 23 May 2024 11:04:46 -0700 Subject: [PATCH 12/26] no usestdlibvars --- .golangci.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.golangci.yaml b/.golangci.yaml index 9fe1cca8..d8eda1c0 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -20,6 +20,7 @@ linters: - testifylint - unconvert - unused - - usestdlibvars + # TODO: bmizerany says no :( + # - usestdlibvars - wastedassign - whitespace From 8ce4032e727764f641260f7943c78b884569b719 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 29 May 2024 18:22:03 -0700 Subject: [PATCH 13/26] more lint --- readline/buffer.go | 16 +++------------- server/routes_test.go | 10 +++++----- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/readline/buffer.go b/readline/buffer.go index c5ac3f26..b7cf9b13 100644 --- a/readline/buffer.go +++ b/readline/buffer.go @@ -52,7 +52,6 @@ func (b *Buffer) GetLineSpacing(line int) bool { } return hasSpace.(bool) - } func (b *Buffer) MoveLeft() { @@ -117,15 +116,12 @@ func (b *Buffer) MoveRight() { if b.DisplayPos%b.LineWidth == 0 { fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt()))) - } else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace { fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())+rLength)) b.DisplayPos += 1 - } else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace { fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt()))) b.DisplayPos += 1 - } else { fmt.Print(cursorRightN(rLength)) } @@ -185,7 +181,7 @@ func (b *Buffer) MoveToEnd() { func (b *Buffer) DisplaySize() int { sum := 0 - for i := 0; i < b.Buf.Size(); i++ { + for i := range b.Buf.Size() { if e, ok := b.Buf.Get(i); ok { if r, ok := e.(rune); ok { sum += runewidth.RuneWidth(r) @@ -197,7 +193,6 @@ func (b *Buffer) DisplaySize() int { } func (b *Buffer) Add(r rune) { - if b.Pos == b.Buf.Size() { b.AddChar(r, false) } else { @@ -210,7 +205,6 @@ func (b *Buffer) AddChar(r rune, insert bool) { b.DisplayPos += rLength if b.Pos > 0 { - if b.DisplayPos%b.LineWidth == 0 { fmt.Printf("%c", r) fmt.Printf("\n%s", b.Prompt.AltPrompt) @@ -235,7 +229,6 @@ func (b *Buffer) AddChar(r rune, insert bool) { } else { b.LineHasSpace.Add(true) } - } else { fmt.Printf("%c", r) } @@ -356,7 +349,6 @@ func (b *Buffer) drawRemaining() { func (b *Buffer) Remove() { if b.Buf.Size() > 0 && b.Pos > 0 { - if e, ok := b.Buf.Get(b.Pos - 1); ok { if r, ok := e.(rune); ok { rLength := runewidth.RuneWidth(r) @@ -382,7 +374,6 @@ func (b *Buffer) Remove() { } else { fmt.Print(" " + CursorLeft) } - } else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace { fmt.Printf(CursorBOL + ClearToEOL) fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width)) @@ -391,10 +382,9 @@ func (b *Buffer) Remove() { b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1) } b.DisplayPos -= 1 - } else { fmt.Print(cursorLeftN(rLength)) - for i := 0; i < rLength; i++ { + for range rLength { fmt.Print(" ") } fmt.Print(cursorLeftN(rLength)) @@ -525,7 +515,7 @@ func (b *Buffer) Replace(r []rune) { fmt.Printf(CursorBOL + ClearToEOL) - for i := 0; i < lineNums; i++ { + for range lineNums { fmt.Print(CursorUp + CursorBOL + ClearToEOL) } diff --git a/server/routes_test.go b/server/routes_test.go index 1fc258e0..79354017 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -26,20 +26,20 @@ func createTestFile(t *testing.T, name string) string { t.Helper() f, err := os.CreateTemp(t.TempDir(), name) - assert.NoError(t, err) + require.NoError(t, err) defer f.Close() err = binary.Write(f, binary.LittleEndian, []byte("GGUF")) - assert.NoError(t, err) + require.NoError(t, err) err = binary.Write(f, binary.LittleEndian, uint32(3)) - assert.NoError(t, err) + require.NoError(t, err) err = binary.Write(f, binary.LittleEndian, uint64(0)) - assert.NoError(t, err) + require.NoError(t, err) err = binary.Write(f, binary.LittleEndian, uint64(0)) - assert.NoError(t, err) + require.NoError(t, err) return f.Name() } From ad40b92b6ad3ee5112489bc839ea489fa4a3d23e Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 4 Jun 2024 11:35:30 -0700 Subject: [PATCH 14/26] disable intrange --- .golangci.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.golangci.yaml b/.golangci.yaml index d8eda1c0..a4bd7f13 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -12,7 +12,8 @@ linters: # FIXME: for some reason this errors on windows # - gofmt # - goimports - - intrange + # TODO: disable for now + # - intrange - misspell - nilerr - nolintlint From ed56428dd7b94c9dcf9fd50cd724818e4ffdf950 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 4 Jun 2024 11:51:39 -0700 Subject: [PATCH 15/26] warn on intrange, usestdlibvars --- .golangci.yaml | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/.golangci.yaml b/.golangci.yaml index a4bd7f13..56a40df1 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -9,11 +9,9 @@ linters: - contextcheck - exportloopref - gocheckcompilerdirectives - # FIXME: for some reason this errors on windows - # - gofmt - # - goimports - # TODO: disable for now - # - intrange + - gofmt + - goimports + - intrange - misspell - nilerr - nolintlint @@ -21,7 +19,15 @@ linters: - testifylint - unconvert - unused - # TODO: bmizerany says no :( - # - usestdlibvars - wastedassign - whitespace + - usestdlibvars +severity: + default-severity: error + rules: + - linters: + - gofmt + - goimports + - intrange + - usestdlibvars + severity: info From 6297f8560692573da77fd34e198573d5f7f10823 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 4 Jun 2024 11:53:23 -0700 Subject: [PATCH 16/26] gofmt, goimports --- .github/workflows/test.yaml | 4 ++-- .golangci.yaml | 5 +++-- llm/memory.go | 2 +- server/images.go | 2 +- types/model/name_test.go | 4 ++-- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 049a97ed..dbb6c2fd 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -269,9 +269,9 @@ jobs: mkdir -p llm/build/darwin/$ARCH/stub/bin touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server if: ${{ startsWith(matrix.os, 'macos-') }} - - uses: golangci/golangci-lint-action@v4 + - uses: golangci/golangci-lint-action@v6 with: - args: --timeout 8m0s -v + args: --timeout 8m0s -v ${{ startsWith(matrix.os, 'windows-') && '' || '--disable gofmt --disable goimports' }} test: strategy: matrix: diff --git a/.golangci.yaml b/.golangci.yaml index 56a40df1..cfe06e07 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -9,8 +9,9 @@ linters: - contextcheck - exportloopref - gocheckcompilerdirectives - - gofmt - - goimports + # conditionally enable this on linux/macos + # - gofmt + # - goimports - intrange - misspell - nilerr diff --git a/llm/memory.go b/llm/memory.go index 8b5d8541..1c2e476b 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -5,9 +5,9 @@ import ( "log/slog" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" - "github.com/ollama/ollama/envconfig" ) // This algorithm looks for a complete fit to determine if we need to unload other models diff --git a/server/images.go b/server/images.go index 6f23f0c1..d1d49061 100644 --- a/server/images.go +++ b/server/images.go @@ -24,10 +24,10 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/auth" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/parser" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" diff --git a/types/model/name_test.go b/types/model/name_test.go index 117d6333..66ce4c33 100644 --- a/types/model/name_test.go +++ b/types/model/name_test.go @@ -386,8 +386,8 @@ func FuzzName(f *testing.F) { func TestIsValidNamespace(t *testing.T) { cases := []struct { - username string - expected bool + username string + expected bool }{ {"", false}, {"a", true}, From 4a048715b617403aa5b7b8cc2e94959500d1de8a Mon Sep 17 00:00:00 2001 From: Michael Date: Tue, 4 Jun 2024 13:25:25 -0700 Subject: [PATCH 17/26] local wording was confusing people local wording was confusing people -- Ollama runs on cloud providers --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 19c8adb1..2de6eca1 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [![Discord](https://dcbadge.vercel.app/api/server/ollama?style=flat&compact=true)](https://discord.gg/ollama) -Get up and running with large language models locally. +Get up and running with large language models. ### macOS From d61ef8b954027b6bc17840f200d2dc80958205d8 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 8 May 2024 14:36:08 -0700 Subject: [PATCH 18/26] update create handler to use model.Name --- server/images.go | 41 ++-- server/manifest.go | 54 ++--- server/model.go | 10 +- server/routes.go | 33 ++-- server/routes_create_test.go | 368 +++++++++++++++++++++++++++++++++++ server/routes_test.go | 5 +- 6 files changed, 441 insertions(+), 70 deletions(-) diff --git a/server/images.go b/server/images.go index d1d49061..c0fdec59 100644 --- a/server/images.go +++ b/server/images.go @@ -314,7 +314,7 @@ func realpath(rel, from string) string { return abspath } -func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) { +func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) { config := ConfigV2{ OS: "linux", Architecture: "amd64", @@ -439,19 +439,27 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m layers = append(layers, baseLayer.Layer) } case "license", "template", "system": + if c.Name != "license" { + // replace + layers = slices.DeleteFunc(layers, func(layer *Layer) bool { + if layer.MediaType != mediatype { + return false + } + + if err := layer.Remove(); err != nil { + return false + } + + return true + }) + } + blob := strings.NewReader(c.Args) layer, err := NewLayer(blob, mediatype) if err != nil { return err } - if c.Name != "license" { - // replace - layers = slices.DeleteFunc(layers, func(layer *Layer) bool { - return layer.MediaType == mediatype - }) - } - layers = append(layers, layer) case "message": role, content, ok := strings.Cut(c.Args, ": ") @@ -570,26 +578,15 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m } } - unref := make(map[string]struct{}) - if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil { - for _, layer := range manifest.Layers { - if !slices.Contains(digests, layer.Digest) { - unref[layer.Digest] = struct{}{} - } - } - - if manifest.Config.Digest != layer.Digest { - unref[manifest.Config.Digest] = struct{}{} - } - } + old, _ := ParseNamedManifest(name) fn(api.ProgressResponse{Status: "writing manifest"}) if err := WriteManifest(name, layer, layers); err != nil { return err } - if !envconfig.NoPrune { - if err := deleteUnusedLayers(nil, unref); err != nil { + if !envconfig.NoPrune && old != nil { + if err := old.RemoveLayers(); err != nil { return err } } diff --git a/server/manifest.go b/server/manifest.go index a5251298..d0675724 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -1,7 +1,6 @@ package server import ( - "bytes" "crypto/sha256" "encoding/json" "fmt" @@ -34,12 +33,6 @@ func (m *Manifest) Remove() error { return err } - for _, layer := range append(m.Layers, m.Config) { - if err := layer.Remove(); err != nil { - return err - } - } - manifests, err := GetManifestPath() if err != nil { return err @@ -48,6 +41,16 @@ func (m *Manifest) Remove() error { return PruneDirectory(manifests) } +func (m *Manifest) RemoveLayers() error { + for _, layer := range append(m.Layers, m.Config) { + if err := layer.Remove(); err != nil { + return err + } + } + + return nil +} + func ParseNamedManifest(n model.Name) (*Manifest, error) { if !n.IsFullyQualified() { return nil, model.Unqualified(n) @@ -85,30 +88,31 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { }, nil } -func WriteManifest(name string, config *Layer, layers []*Layer) error { - manifest := ManifestV2{ +func WriteManifest(name model.Name, config *Layer, layers []*Layer) error { + manifests, err := GetManifestPath() + if err != nil { + return err + } + + p := filepath.Join(manifests, name.Filepath()) + if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil { + return err + } + + f, err := os.Create(p) + if err != nil { + return err + } + defer f.Close() + + m := ManifestV2{ SchemaVersion: 2, MediaType: "application/vnd.docker.distribution.manifest.v2+json", Config: config, Layers: layers, } - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(manifest); err != nil { - return err - } - - modelpath := ParseModelPath(name) - manifestPath, err := modelpath.GetManifestPath() - if err != nil { - return err - } - - if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil { - return err - } - - return os.WriteFile(manifestPath, b.Bytes(), 0o644) + return json.NewEncoder(f).Encode(m) } func Manifests() (map[model.Name]*Manifest, error) { diff --git a/server/model.go b/server/model.go index 4f76284d..ee2ae080 100644 --- a/server/model.go +++ b/server/model.go @@ -25,16 +25,14 @@ type layerWithGGML struct { } func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { - modelpath := ParseModelPath(name.String()) - manifest, _, err := GetManifest(modelpath) + m, err := ParseNamedManifest(name) switch { case errors.Is(err, os.ErrNotExist): if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil { return nil, err } - modelpath = ParseModelPath(name.String()) - manifest, _, err = GetManifest(modelpath) + m, err = ParseNamedManifest(name) if err != nil { return nil, err } @@ -42,8 +40,8 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe return nil, err } - for _, layer := range manifest.Layers { - layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname()) + for _, layer := range m.Layers { + layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest()) if err != nil { return nil, err } diff --git a/server/routes.go b/server/routes.go index cc718156..56d4307e 100644 --- a/server/routes.go +++ b/server/routes.go @@ -523,8 +523,8 @@ func checkNameExists(name model.Name) error { } func (s *Server) CreateModelHandler(c *gin.Context) { - var req api.CreateRequest - if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { + var r api.CreateRequest + if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return } else if err != nil { @@ -532,7 +532,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) { return } - name := model.ParseName(cmp.Or(req.Model, req.Name)) + name := model.ParseName(cmp.Or(r.Model, r.Name)) if !name.IsValid() { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg}) return @@ -543,24 +543,24 @@ func (s *Server) CreateModelHandler(c *gin.Context) { return } - if req.Path == "" && req.Modelfile == "" { + if r.Path == "" && r.Modelfile == "" { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) return } - var r io.Reader = strings.NewReader(req.Modelfile) - if req.Path != "" && req.Modelfile == "" { - f, err := os.Open(req.Path) + var sr io.Reader = strings.NewReader(r.Modelfile) + if r.Path != "" && r.Modelfile == "" { + f, err := os.Open(r.Path) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)}) return } defer f.Close() - r = f + sr = f } - modelfile, err := parser.ParseFile(r) + f, err := parser.ParseFile(sr) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -576,17 +576,13 @@ func (s *Server) CreateModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - quantization := req.Quantization - if req.Quantize != "" { - quantization = req.Quantize - } - - if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(quantization), modelfile, fn); err != nil { + quantization := cmp.Or(r.Quantize, r.Quantization) + if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() - if req.Stream != nil && !*req.Stream { + if r.Stream != nil && !*r.Stream { waitForStream(c, ch) return } @@ -620,6 +616,11 @@ func (s *Server) DeleteModelHandler(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + + if err := m.RemoveLayers(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } } func (s *Server) ShowModelHandler(c *gin.Context) { diff --git a/server/routes_create_test.go b/server/routes_create_test.go index e5af1ded..19bf19ed 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -158,3 +158,371 @@ func TestCreateFromModel(t *testing.T) { filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"), }) } + +func TestCreateRemovesLayers(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-b507b9c2f6ca642bffcd06665ea7c91f235fd32daeefdf875a0f938db05fb315"), + filepath.Join(p, "blobs", "sha256-bc80b03733773e0728011b2f4adf34c458b400e1aad48cb28d61170f3a2ad2d6"), + }) + + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"), + }) +} + +func TestCreateUnsetsSystem(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nSYSTEM Say hi!", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-8585df945d1069bc78b79bd10bb73ba07fbc29b0f5479a31a601c0d12731416e"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-f29e82a8284dbdf5910b1555580ff60b04238b8da9d5e51159ada67a4d0d5851"), + }) + + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nSYSTEM \"\"", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-67d4b8d106af2a5b100a46e9bdc038c71eef2a35c9abac784092654212f97cf5"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"), + }) + + bts, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")) + if err != nil { + t.Fatal(err) + } + + if string(bts) != "" { + t.Fatalf("expected empty string, actual %s", string(bts)) + } +} + +func TestCreateMergeParameters(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nPARAMETER temperature 1\nPARAMETER top_k 10\nPARAMETER stop USER:\nPARAMETER stop ASSISTANT:", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"), + filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + }) + + // in order to merge parameters, the second model must be created FROM the first + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test2", + Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"), + filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"), + filepath.Join(p, "blobs", "sha256-4cd9d4ba6b734d9b4cbd1e5caa60374c00722e993fce5e1e2d15a33698f71187"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-e29a7b3c47287a2489c895d21fe413c20f859a85d20e749492f52a838e36e1ba"), + }) + + actual, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e29a7b3c47287a2489c895d21fe413c20f859a85d20e749492f52a838e36e1ba")) + if err != nil { + t.Fatal(err) + } + + expect, err := json.Marshal(map[string]any{"temperature": 0.6, "top_k": 10, "top_p": 0.7, "stop": []string{"USER:", "ASSISTANT:"}}) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(bytes.TrimSpace(expect), bytes.TrimSpace(actual)) { + t.Errorf("expected %s, actual %s", string(expect), string(actual)) + } + + // slices are replaced + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test2", + Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7\nPARAMETER stop <|endoftext|>", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"), + filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"), + filepath.Join(p, "blobs", "sha256-257aa726584f24970a4f240765e75a7169bfbe7f4966c1f04513d6b6c860583a"), + filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + }) + + actual, err = os.ReadFile(filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35")) + if err != nil { + t.Fatal(err) + } + + expect, err = json.Marshal(map[string]any{"temperature": 0.6, "top_k": 10, "top_p": 0.7, "stop": []string{"<|endoftext|>"}}) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(bytes.TrimSpace(expect), bytes.TrimSpace(actual)) { + t.Errorf("expected %s, actual %s", string(expect), string(actual)) + } +} + +func TestCreateReplacesMessages(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nMESSAGE assistant \"What is my purpose?\"\nMESSAGE user \"You run tests.\"\nMESSAGE assistant \"Oh, my god.\"", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"), + }) + + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test2", + Modelfile: "FROM test\nMESSAGE assistant \"You're a test, Harry.\"\nMESSAGE user \"I-I'm a what?\"\nMESSAGE assistant \"A test. And a thumping good one at that, I'd wager.\"", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"), + filepath.Join(p, "blobs", "sha256-4f48b25fe9969564c82f58eb1cedbdff6484cc0baf474bc6c2a9b37c8da3362a"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-a60ecc9da299ec7ede453f99236e5577fd125e143689b646d9f0ddc9971bf4db"), + filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"), + }) + + type message struct { + Role string `json:"role"` + Content string `json:"content"` + } + + f, err := os.Open(filepath.Join(p, "blobs", "sha256-a60ecc9da299ec7ede453f99236e5577fd125e143689b646d9f0ddc9971bf4db")) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + var actual []message + if err := json.NewDecoder(f).Decode(&actual); err != nil { + t.Fatal(err) + } + + expect := []message{ + {Role: "assistant", Content: "You're a test, Harry."}, + {Role: "user", Content: "I-I'm a what?"}, + {Role: "assistant", Content: "A test. And a thumping good one at that, I'd wager."}, + } + + if !slices.Equal(actual, expect) { + t.Errorf("expected %s, actual %s", expect, actual) + } +} + +func TestCreateTemplateSystem(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}\nSYSTEM Say hello!\nTEMPLATE {{ .System }} {{ .Prompt }}\nSYSTEM Say bye!", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-2b5e330885117c82f3fd75169ea323e141070a2947c11ddb9f79ee0b01c589c1"), + filepath.Join(p, "blobs", "sha256-4c5f51faac758fecaff8db42f0b7382891a4d0c0bb885f7b86be88c814a7cc86"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"), + }) + + template, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed")) + if err != nil { + t.Fatal(err) + } + + if string(template) != "{{ .System }} {{ .Prompt }}" { + t.Errorf("expected \"{{ .System }} {{ .Prompt }}\", actual %s", template) + } + + system, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-4c5f51faac758fecaff8db42f0b7382891a4d0c0bb885f7b86be88c814a7cc86")) + if err != nil { + t.Fatal(err) + } + + if string(system) != "Say bye!" { + t.Errorf("expected \"Say bye!\", actual %s", system) + } +} + +func TestCreateLicenses(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"), + filepath.Join(p, "blobs", "sha256-79a39c37536ddee29cbadd5d5e2dcba8ed7f03e431f626ff38432c1c866bb7e2"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-e5dcffe836b6ec8a58e492419b550e65fb8cbdc308503979e5dacb33ac7ea3b7"), + }) + + mit, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e5dcffe836b6ec8a58e492419b550e65fb8cbdc308503979e5dacb33ac7ea3b7")) + if err != nil { + t.Fatal(err) + } + + if string(mit) != "MIT" { + t.Errorf("expected MIT, actual %s", mit) + } + + apache, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608")) + if err != nil { + t.Fatal(err) + } + + if string(apache) != "Apache-2.0" { + t.Errorf("expected Apache-2.0, actual %s", apache) + } +} diff --git a/server/routes_test.go b/server/routes_test.go index 79354017..8dacaf0a 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -19,6 +19,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/parser" + "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -54,6 +55,8 @@ func Test_Routes(t *testing.T) { } createTestModel := func(t *testing.T, name string) { + t.Helper() + fname := createTestFile(t, "ollama-model") r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) @@ -62,7 +65,7 @@ func Test_Routes(t *testing.T) { fn := func(resp api.ProgressResponse) { t.Logf("Status: %s", resp.Status) } - err = CreateModel(context.TODO(), name, "", "", modelfile, fn) + err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn) require.NoError(t, err) } From 1d8616d30fd398f48e5a4ca06054b1322d98ae57 Mon Sep 17 00:00:00 2001 From: Kartikeya Mishra <108652656+kartikm7@users.noreply.github.com> Date: Wed, 5 Jun 2024 03:13:59 +0530 Subject: [PATCH 19/26] docs: update to add LLocal.in to web & desktop integrations (#4719) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 2de6eca1..3c5117ba 100644 --- a/README.md +++ b/README.md @@ -285,6 +285,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends) - [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama) - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS) +- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) ### Terminal From 28c7813ac4c5316fc65d83bece4f21aeaeb51d66 Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Wed, 5 Jun 2024 11:06:53 -0700 Subject: [PATCH 20/26] API PS Documentation (#4822) * API PS Documentation --- docs/api.md | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/docs/api.md b/docs/api.md index 0f11c388..64bfbed8 100644 --- a/docs/api.md +++ b/docs/api.md @@ -12,6 +12,7 @@ - [Pull a Model](#pull-a-model) - [Push a Model](#push-a-model) - [Generate Embeddings](#generate-embeddings) +- [List Running Models](#list-running-models) ## Conventions @@ -1035,3 +1036,48 @@ curl http://localhost:11434/api/embeddings -d '{ ] } ``` + +## List Running Models +```shell +GET /api/ps +``` + +List models that are currently loaded into memory. + +\* If a model is loaded completely into system memory, `size_vram` is omitted from the response. + +#### Examples + +### Request +```shell +curl http://localhost:11434/api/ps +``` + +#### Response + +A single JSON object will be returned. + +```json +{ + "models": [ + { + "name": "mistral:latest", + "model": "mistral:latest", + "size": 5137025024, + "digest": "2ae6f6dd7a3dd734790bbbf58b8909a606e0e7e97e94b7604e0aa7ae4490e6d8", + "details": { + "parent_model": "", + "format": "gguf", + "family": "llama", + "families": [ + "llama" + ], + "parameter_size": "7.2B", + "quantization_level": "Q4_0" + }, + "expires_at": "2024-06-04T14:38:31.83753-07:00", + "size_vram": 5137025024 + } + ] +} +``` \ No newline at end of file From 98e65929dcec0d9614fe8cded27ce95333ff347e Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 6 Jun 2024 09:13:39 +1200 Subject: [PATCH 21/26] docs(tools): add gollama (#4829) --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 3c5117ba..9d5994a0 100644 --- a/README.md +++ b/README.md @@ -308,6 +308,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [ShellOracle](https://github.com/djcopley/ShellOracle) - [tlm](https://github.com/yusufcanb/tlm) - [podman-ollama](https://github.com/ericcurtin/podman-ollama) +- [gollama](https://github.com/sammcj/gollama) ### Database @@ -347,6 +348,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama) - [PromptingTools.jl](https://github.com/svilupp/PromptingTools.jl) with an [example](https://svilupp.github.io/PromptingTools.jl/dev/examples/working_with_ollama) - [LlamaScript](https://github.com/Project-Llama/llamascript) + ### Mobile - [Enchanted](https://github.com/AugustDev/enchanted) @@ -380,6 +382,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation) - [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities. -### Supported backends +### Supported backends + - [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov. From de5beb06b314eb4950c5a0de8183dfadb325fc8b Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Fri, 24 May 2024 08:40:40 -0700 Subject: [PATCH 22/26] server: skip blob verification for already verified blobs --- server/download.go | 12 ++++++------ server/images.go | 20 ++++++++++++-------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/server/download.go b/server/download.go index 937b6754..d93cd3b4 100644 --- a/server/download.go +++ b/server/download.go @@ -340,17 +340,17 @@ type downloadOpts struct { } // downloadBlob downloads a blob from the registry and stores it in the blobs directory -func downloadBlob(ctx context.Context, opts downloadOpts) error { +func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) { fp, err := GetBlobsPath(opts.digest) if err != nil { - return err + return false, err } fi, err := os.Stat(fp) switch { case errors.Is(err, os.ErrNotExist): case err != nil: - return err + return false, err default: opts.fn(api.ProgressResponse{ Status: fmt.Sprintf("pulling %s", opts.digest[7:19]), @@ -359,7 +359,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { Completed: fi.Size(), }) - return nil + return true, nil } data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest}) @@ -369,12 +369,12 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest) if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil { blobDownloadManager.Delete(opts.digest) - return err + return false, err } //nolint:contextcheck go download.Run(context.Background(), requestURL, opts.regOpts) } - return download.Wait(ctx, opts.fn) + return false, download.Wait(ctx, opts.fn) } diff --git a/server/images.go b/server/images.go index c0fdec59..529f7b6a 100644 --- a/server/images.go +++ b/server/images.go @@ -853,23 +853,27 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu layers = append(layers, manifest.Layers...) layers = append(layers, manifest.Config) + skipVerify := make(map[string]bool) for _, layer := range layers { - if err := downloadBlob( - ctx, - downloadOpts{ - mp: mp, - digest: layer.Digest, - regOpts: regOpts, - fn: fn, - }); err != nil { + cacheHit, err := downloadBlob(ctx, downloadOpts{ + mp: mp, + digest: layer.Digest, + regOpts: regOpts, + fn: fn, + }) + if err != nil { return err } + skipVerify[layer.Digest] = cacheHit delete(deleteMap, layer.Digest) } delete(deleteMap, manifest.Config.Digest) fn(api.ProgressResponse{Status: "verifying sha256 digest"}) for _, layer := range layers { + if skipVerify[layer.Digest] { + continue + } if err := verifyBlob(layer.Digest); err != nil { if errors.Is(err, errDigestMismatch) { // something went wrong, delete the blob From 4bf1da49449536f23411e4f7768f8459541dfd94 Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Thu, 6 Jun 2024 10:11:45 -0700 Subject: [PATCH 23/26] Separate ListResponse and ModelResponse for api/tags vs api/ps (#4842) * Remove false time fields * Struct Separation for List and Process * Remove Marshaler --- api/client.go | 4 ++-- api/types.go | 26 ++++++++++++++++++++------ server/routes.go | 12 ++++++------ server/routes_test.go | 2 ++ 4 files changed, 30 insertions(+), 14 deletions(-) diff --git a/api/client.go b/api/client.go index d50b397d..dc099e95 100644 --- a/api/client.go +++ b/api/client.go @@ -355,8 +355,8 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) { } // List running models. -func (c *Client) ListRunning(ctx context.Context) (*ListResponse, error) { - var lr ListResponse +func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) { + var lr ProcessResponse if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil { return nil, err } diff --git a/api/types.go b/api/types.go index 230f58e8..caf2ad70 100644 --- a/api/types.go +++ b/api/types.go @@ -282,19 +282,33 @@ type PushRequest struct { // ListResponse is the response from [Client.List]. type ListResponse struct { - Models []ModelResponse `json:"models"` + Models []ListModelResponse `json:"models"` } -// ModelResponse is a single model description in [ListResponse]. -type ModelResponse struct { +// ProcessResponse is the response from [Client.Process]. +type ProcessResponse struct { + Models []ProcessModelResponse `json:"models"` +} + +// ListModelResponse is a single model description in [ListResponse]. +type ListModelResponse struct { Name string `json:"name"` Model string `json:"model"` - ModifiedAt time.Time `json:"modified_at,omitempty"` + ModifiedAt time.Time `json:"modified_at"` Size int64 `json:"size"` Digest string `json:"digest"` Details ModelDetails `json:"details,omitempty"` - ExpiresAt time.Time `json:"expires_at,omitempty"` - SizeVRAM int64 `json:"size_vram,omitempty"` +} + +// ProcessModelResponse is a single model description in [ProcessResponse]. +type ProcessModelResponse struct { + Name string `json:"name"` + Model string `json:"model"` + Size int64 `json:"size"` + Digest string `json:"digest"` + Details ModelDetails `json:"details,omitempty"` + ExpiresAt time.Time `json:"expires_at"` + SizeVRAM int64 `json:"size_vram"` } type TokenResponse struct { diff --git a/server/routes.go b/server/routes.go index 56d4307e..ecd60081 100644 --- a/server/routes.go +++ b/server/routes.go @@ -730,7 +730,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) { return } - models := []api.ModelResponse{} + models := []api.ListModelResponse{} for n, m := range ms { f, err := m.Config.Open() if err != nil { @@ -746,7 +746,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) { } // tag should never be masked - models = append(models, api.ModelResponse{ + models = append(models, api.ListModelResponse{ Model: n.DisplayShortest(), Name: n.DisplayShortest(), Size: m.Size(), @@ -762,7 +762,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) { }) } - slices.SortStableFunc(models, func(i, j api.ModelResponse) int { + slices.SortStableFunc(models, func(i, j api.ListModelResponse) int { // most recently modified first return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix()) }) @@ -1139,7 +1139,7 @@ func streamResponse(c *gin.Context, ch chan any) { } func (s *Server) ProcessHandler(c *gin.Context) { - models := []api.ModelResponse{} + models := []api.ProcessModelResponse{} for _, v := range s.sched.loaded { model := v.model @@ -1151,7 +1151,7 @@ func (s *Server) ProcessHandler(c *gin.Context) { QuantizationLevel: model.Config.FileType, } - mr := api.ModelResponse{ + mr := api.ProcessModelResponse{ Model: model.ShortName, Name: model.ShortName, Size: int64(v.estimatedTotal), @@ -1171,7 +1171,7 @@ func (s *Server) ProcessHandler(c *gin.Context) { models = append(models, mr) } - c.JSON(http.StatusOK, api.ListResponse{Models: models}) + c.JSON(http.StatusOK, api.ProcessResponse{Models: models}) } // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model diff --git a/server/routes_test.go b/server/routes_test.go index 8dacaf0a..91ef625b 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -116,6 +116,8 @@ func Test_Routes(t *testing.T) { body, err := io.ReadAll(resp.Body) require.NoError(t, err) + assert.NotContains(t, string(body), "expires_at") + var modelList api.ListResponse err = json.Unmarshal(body, &modelList) require.NoError(t, err) From 1a29e9a879433fc55cf1b256572d2a88a2e63166 Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Thu, 6 Jun 2024 15:19:03 -0700 Subject: [PATCH 24/26] API app/browser access (#4879) * API app/browser access * Add tauri (resolves #2291, #4791, #3799, #4388) --- envconfig/config.go | 6 ++++++ server/routes.go | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/envconfig/config.go b/envconfig/config.go index 77e3e789..ae4e9939 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -190,6 +190,12 @@ func LoadConfig() { ) } + AllowOrigins = append(AllowOrigins, + "app://*", + "file://*", + "tauri://*", + ) + maxRunners := clean("OLLAMA_MAX_LOADED_MODELS") if maxRunners != "" { m, err := strconv.Atoi(maxRunners) diff --git a/server/routes.go b/server/routes.go index ecd60081..188fe974 100644 --- a/server/routes.go +++ b/server/routes.go @@ -960,6 +960,10 @@ func (s *Server) GenerateRoutes() http.Handler { config.AllowWildcard = true config.AllowBrowserExtensions = true config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"} + openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"} + for _, prop := range openAIProperties { + config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop) + } config.AllowOrigins = envconfig.AllowOrigins r := gin.Default() From 9b6c2e6eb62c234f8a44556984bbb680d7065e01 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 3 Jun 2024 11:06:29 -0700 Subject: [PATCH 25/26] detect chat template from KV --- go.mod | 1 + go.sum | 6 ++ llm/ggml.go | 5 + server/images.go | 16 +++ templates/alfred.gotmpl | 1 + templates/alpaca.gotmpl | 7 ++ templates/chatml.gotmpl | 6 ++ templates/chatqa.gotmpl | 5 + templates/codellama-70b-instruct.gotmpl | 8 ++ templates/falcon-instruct.gotmpl | 3 + templates/gemma-instruct.gotmpl | 4 + templates/granite-instruct.gotmpl | 9 ++ templates/index.json | 138 ++++++++++++++++++++++++ templates/llama2-chat.gotmpl | 3 + templates/llama3-instruct.gotmpl | 7 ++ templates/magicoder.gotmpl | 7 ++ templates/mistral-instruct.gotmpl | 6 ++ templates/openchat.gotmpl | 1 + templates/phi-3.gotmpl | 6 ++ templates/solar-instruct.gotmpl | 8 ++ templates/starcoder2-instruct.gotmpl | 9 ++ templates/template.go | 69 ++++++++++++ templates/template_test.go | 59 ++++++++++ templates/testdata/templates.jsonl | 35 ++++++ templates/vicuna.gotmpl | 3 + templates/zephyr.gotmpl | 6 ++ 26 files changed, 428 insertions(+) create mode 100644 templates/alfred.gotmpl create mode 100644 templates/alpaca.gotmpl create mode 100644 templates/chatml.gotmpl create mode 100644 templates/chatqa.gotmpl create mode 100644 templates/codellama-70b-instruct.gotmpl create mode 100644 templates/falcon-instruct.gotmpl create mode 100644 templates/gemma-instruct.gotmpl create mode 100644 templates/granite-instruct.gotmpl create mode 100644 templates/index.json create mode 100644 templates/llama2-chat.gotmpl create mode 100644 templates/llama3-instruct.gotmpl create mode 100644 templates/magicoder.gotmpl create mode 100644 templates/mistral-instruct.gotmpl create mode 100644 templates/openchat.gotmpl create mode 100644 templates/phi-3.gotmpl create mode 100644 templates/solar-instruct.gotmpl create mode 100644 templates/starcoder2-instruct.gotmpl create mode 100644 templates/template.go create mode 100644 templates/template_test.go create mode 100644 templates/testdata/templates.jsonl create mode 100644 templates/vicuna.gotmpl create mode 100644 templates/zephyr.gotmpl diff --git a/go.mod b/go.mod index 2f3d4ca3..6807b9b4 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( ) require ( + github.com/agnivade/levenshtein v1.1.1 github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/mattn/go-runewidth v0.0.14 github.com/nlpodyssey/gopickle v0.3.0 diff --git a/go.sum b/go.sum index 9e1baebe..926ed26d 100644 --- a/go.sum +++ b/go.sum @@ -4,10 +4,14 @@ dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7 gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= +github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6ICHXqG5hm0ZW5IHyeEJXoIJSOZeBLmWPNeIQ= github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs= +github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= +github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= @@ -36,6 +40,8 @@ github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLc github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g= +github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= diff --git a/llm/ggml.go b/llm/ggml.go index 878800f3..645447d5 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -81,6 +81,11 @@ func (kv KV) ContextLength() uint64 { return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture())) } +func (kv KV) ChatTemplate() string { + s, _ := kv["tokenizer.chat_template"].(string) + return s +} + type Tensors []*Tensor func (ts Tensors) Layers() map[string]Layer { diff --git a/server/images.go b/server/images.go index 529f7b6a..32207f20 100644 --- a/server/images.go +++ b/server/images.go @@ -28,6 +28,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/parser" + "github.com/ollama/ollama/templates" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -434,6 +435,21 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount())) config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String()) config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture()) + + if s := baseLayer.GGML.KV().ChatTemplate(); s != "" { + t, err := templates.NamedTemplate(s) + if err != nil { + return err + } + + layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template") + if err != nil { + return err + } + + layer.status = fmt.Sprintf("using autodetected template %s", t.Name) + layers = append(layers, layer) + } } layers = append(layers, baseLayer.Layer) diff --git a/templates/alfred.gotmpl b/templates/alfred.gotmpl new file mode 100644 index 00000000..cecb9d2c --- /dev/null +++ b/templates/alfred.gotmpl @@ -0,0 +1 @@ +{{ if .System }}{{ .System }}{{ end }}{{ if .Prompt }}{{ .Prompt }}{{ end }}{{ .Response }} \ No newline at end of file diff --git a/templates/alpaca.gotmpl b/templates/alpaca.gotmpl new file mode 100644 index 00000000..440d0662 --- /dev/null +++ b/templates/alpaca.gotmpl @@ -0,0 +1,7 @@ +{{ if .System }}{{ .System }} + +{{ end }}{{ if .Prompt }}### Instruction: +{{ .Prompt }} + +{{ end }}### Response: +{{ .Response }} \ No newline at end of file diff --git a/templates/chatml.gotmpl b/templates/chatml.gotmpl new file mode 100644 index 00000000..dcf17285 --- /dev/null +++ b/templates/chatml.gotmpl @@ -0,0 +1,6 @@ +{{ if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ .Response }}<|im_end|> \ No newline at end of file diff --git a/templates/chatqa.gotmpl b/templates/chatqa.gotmpl new file mode 100644 index 00000000..1ede6227 --- /dev/null +++ b/templates/chatqa.gotmpl @@ -0,0 +1,5 @@ +{{ if .System }}System: {{ .System }} + +{{ end }}{{ if .Prompt }}User: {{ .Prompt }} + +{{ end }}Assistant: <|begin_of_text|>{{ .Response }} \ No newline at end of file diff --git a/templates/codellama-70b-instruct.gotmpl b/templates/codellama-70b-instruct.gotmpl new file mode 100644 index 00000000..3196bd6f --- /dev/null +++ b/templates/codellama-70b-instruct.gotmpl @@ -0,0 +1,8 @@ +{{ if .System }} Source: system + + {{ .System }} {{ end }} Source: user + + {{ .Prompt }} Source: assistant +Destination: user + + {{ .Response }} \ No newline at end of file diff --git a/templates/falcon-instruct.gotmpl b/templates/falcon-instruct.gotmpl new file mode 100644 index 00000000..2309a1c5 --- /dev/null +++ b/templates/falcon-instruct.gotmpl @@ -0,0 +1,3 @@ +{{ if .System }}{{ .System }} +{{ end }}{{ if .Prompt }}User: {{ .Prompt }} +{{ end }}Assistant: {{ .Response }} \ No newline at end of file diff --git a/templates/gemma-instruct.gotmpl b/templates/gemma-instruct.gotmpl new file mode 100644 index 00000000..91b9883a --- /dev/null +++ b/templates/gemma-instruct.gotmpl @@ -0,0 +1,4 @@ +user +{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} +model +{{ .Response }} \ No newline at end of file diff --git a/templates/granite-instruct.gotmpl b/templates/granite-instruct.gotmpl new file mode 100644 index 00000000..2ede647f --- /dev/null +++ b/templates/granite-instruct.gotmpl @@ -0,0 +1,9 @@ +{{ if .System }} +System: +{{ .System }} + +{{ end }}{{ if .Prompt }}Question: +{{ .Prompt }} + +{{ end }}Answer: +{{ .Response }} \ No newline at end of file diff --git a/templates/index.json b/templates/index.json new file mode 100644 index 00000000..e2d41893 --- /dev/null +++ b/templates/index.json @@ -0,0 +1,138 @@ +[ + { + "template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}", + "name": "chatml" + }, + { + "template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "name": "chatml" + }, + { + "template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", + "name": "zephyr" + }, + { + "template": "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}", + "name": "chatml" + }, + { + "template": "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", + "name": "openchat" + }, + { + "template": "{{bos_token}}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "name": "chatml" + }, + { + "template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "name": "chatml" + }, + { + "template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "name": "chatml" + }, + { + "template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "name": "chatml" + }, + { + "template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", + "name": "zephyr" + }, + { + "template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + "name": "mistral-instruct" + }, + { + "template": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'### Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response\n'}}", + "name": "starcoder2-instruct" + }, + { + "template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}", + "name": "llama2-chat" + }, + { + "template": "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\n\n ' + message['content'] | trim %}{{ content + ' ' }}{% endfor %}{{'Source: assistant\nDestination: user\n\n '}}", + "name": "codellama-70b-instruct" + }, + { + "template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + "name": "mistral-instruct" + }, + { + "template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'system' %}\n{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|im_start|>assistant\n' + message['content'] + '<|im_end|>' }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|im_start|>assistant' }}\n{% endif %}\n{% endfor %}", + "name": "chatml" + }, + { + "template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "name": "chatml" + }, + { + "template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "name": "chatml" + }, + { + "template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif 'system' not in messages[0]['role'] %}{% set loop_messages = messages %}{% set system_message = 'You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks \u2014 remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER\\'S QUERY.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% if system_message != false %}{{ '<|im_start|>system\n' + system_message | trim + '<|im_end|>\n'}}{% endif %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% else %}{{ '\n' + '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% endif %}{% if (add_generation_prompt == true and loop.last) %}{{ '\n' + '<|im_start|>' + 'assistant' + '\n' }}{% endif %}{% endfor %}", + "name": "chatml" + }, + { + "template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}", + "name": "alpaca" + }, + { + "template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", + "name": "chatqa" + }, + { + "template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", + "name": "gemma-instruct" + }, + { + "template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", + "name": "llama3-instruct" + }, + { + "template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'Question:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'system' %}\n{{ 'System:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Answer:\n' + message['content'] + '\n\n' }}{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Answer:\n' }}{% endif %}{% endfor %}", + "name": "granite-instruct" + }, + { + "template": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'@@ Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'@@ Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'@@ Response\n'}}", + "name": "magicoder" + }, + { + "template": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '' + message['content'].strip() + '' }}{% elif message['role'] == 'system' %}{{ '' + message['content'].strip() + '' }}{% elif message['role'] == 'assistant' %}{{ '' + message['content'] + '' }}{% else %}{{ raise_exception('Only system, user and assistant roles are supported.') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '' }}{% endif %}{% endfor %}", + "name": "alfred" + }, + { + "template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", + "name": "llama2-chat" + }, + { + "template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + "name": "phi-3" + }, + { + "template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + "name": "phi-3" + }, + { + "template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", + "name": "phi-3" + }, + { + "template": "{{ bos_token }}{%- if messages[0]['role'] == 'system' -%}{% set loop_messages = messages[1:] %}{%- else -%}{% set loop_messages = messages %}{% endif %}System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context.\n\n{% for message in loop_messages %}{%- if message['role'] == 'user' -%}User: {{ message['content'].strip() + '\n\n' }}{%- else -%}Assistant: {{ message['content'].strip() + '\n\n' }}{%- endif %}{% if loop.last and message['role'] == 'user' %}Assistant:{% endif %}{% endfor %}", + "name": "chatqa" + }, + { + "template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'User: \n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ 'System: ' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ 'Falcon:\n' + message['content']}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Falcon:' }}\n{% endif %}\n{% endfor %}", + "name": "falcon-instruct" + }, + { + "template": "{% for message in messages %}{% if not loop.first %}{{ '\n' }}{% endif %}{% if message['role'] == 'system' %}{{ 'System: ' }}{% elif message['role'] == 'user' %}{{ 'User: ' }}{% elif message['role'] == 'assistant' %}{{ 'Falcon: ' }}{% endif %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ '\n' + 'Falcon:' }}{% endif %}", + "name": "falcon-instruct" + }, + { + "template": "{% for message in messages %}{% if message['role'] == 'system' %}{% if message['content']%}{{'### System:\n' + message['content']+'\n\n'}}{% endif %}{% elif message['role'] == 'user' %}{{'### User:\n' + message['content']+'\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Assistant:\n' + message['content']}}{% endif %}{% if loop.last and add_generation_prompt %}{{ '### Assistant:\n' }}{% endif %}{% endfor %}", + "name": "solar-instruct" + } +] diff --git a/templates/llama2-chat.gotmpl b/templates/llama2-chat.gotmpl new file mode 100644 index 00000000..a739f690 --- /dev/null +++ b/templates/llama2-chat.gotmpl @@ -0,0 +1,3 @@ +[INST] <>{{ .System }}<> + +{{ .Prompt }} [/INST] {{ .Response }} \ No newline at end of file diff --git a/templates/llama3-instruct.gotmpl b/templates/llama3-instruct.gotmpl new file mode 100644 index 00000000..36d0218b --- /dev/null +++ b/templates/llama3-instruct.gotmpl @@ -0,0 +1,7 @@ +{{ if .System }}<|start_header_id|>system<|end_header_id|> + +{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|> + +{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|> + +{{ .Response }}<|eot_id|> \ No newline at end of file diff --git a/templates/magicoder.gotmpl b/templates/magicoder.gotmpl new file mode 100644 index 00000000..306972ec --- /dev/null +++ b/templates/magicoder.gotmpl @@ -0,0 +1,7 @@ +{{ if .System }}{{ .System }} + +{{ end }}{{ if .Prompt }}@@ Instruction +{{ .Prompt }} + +{{ end }}@@ Response +{{ .Response }} \ No newline at end of file diff --git a/templates/mistral-instruct.gotmpl b/templates/mistral-instruct.gotmpl new file mode 100644 index 00000000..dcf17285 --- /dev/null +++ b/templates/mistral-instruct.gotmpl @@ -0,0 +1,6 @@ +{{ if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ .Response }}<|im_end|> \ No newline at end of file diff --git a/templates/openchat.gotmpl b/templates/openchat.gotmpl new file mode 100644 index 00000000..d2ca3868 --- /dev/null +++ b/templates/openchat.gotmpl @@ -0,0 +1 @@ +{{ .System }}<|end_of_turn|>GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|> \ No newline at end of file diff --git a/templates/phi-3.gotmpl b/templates/phi-3.gotmpl new file mode 100644 index 00000000..bf26dcee --- /dev/null +++ b/templates/phi-3.gotmpl @@ -0,0 +1,6 @@ +{{ if .System }}<|system|> +{{ .System }}<|end|> +{{ end }}{{ if .Prompt }}<|user|> +{{ .Prompt }}<|end|> +{{ end }}<|assistant|> +{{ .Response }}<|end|> \ No newline at end of file diff --git a/templates/solar-instruct.gotmpl b/templates/solar-instruct.gotmpl new file mode 100644 index 00000000..c275a26a --- /dev/null +++ b/templates/solar-instruct.gotmpl @@ -0,0 +1,8 @@ +{{ if .System }}### System: +{{ .System }} + +{{ end }}{{ if .Prompt }}### User: +{{ .Prompt }} + +{{ end }}### Assistant: +{{ .Response }} \ No newline at end of file diff --git a/templates/starcoder2-instruct.gotmpl b/templates/starcoder2-instruct.gotmpl new file mode 100644 index 00000000..33357e54 --- /dev/null +++ b/templates/starcoder2-instruct.gotmpl @@ -0,0 +1,9 @@ +{{ if .System }}{{ .System }} + +{{ end }}{{ if .Prompt }}### Instruction +{{ .Prompt }} + + +{{ end }}### Response +{{ .Response }}<|endoftext|> + diff --git a/templates/template.go b/templates/template.go new file mode 100644 index 00000000..87962695 --- /dev/null +++ b/templates/template.go @@ -0,0 +1,69 @@ +package templates + +import ( + "bytes" + "embed" + "encoding/json" + "errors" + "io" + "math" + "sync" + + "github.com/agnivade/levenshtein" +) + +//go:embed index.json +var indexBytes []byte + +//go:embed *.gotmpl +var templatesFS embed.FS + +var templatesOnce = sync.OnceValues(func() ([]*Template, error) { + var templates []*Template + if err := json.Unmarshal(indexBytes, &templates); err != nil { + return nil, err + } + + for _, t := range templates { + bts, err := templatesFS.ReadFile(t.Name + ".gotmpl") + if err != nil { + return nil, err + } + + t.Bytes = bts + } + + return templates, nil +}) + +type Template struct { + Name string `json:"name"` + Template string `json:"template"` + Bytes []byte +} + +func (t Template) Reader() io.Reader { + return bytes.NewReader(t.Bytes) +} + +func NamedTemplate(s string) (*Template, error) { + templates, err := templatesOnce() + if err != nil { + return nil, err + } + + var template *Template + score := math.MaxInt + for _, t := range templates { + if s := levenshtein.ComputeDistance(s, t.Template); s < score { + score = s + template = t + } + } + + if score < 100 { + return template, nil + } + + return nil, errors.New("no matching template found") +} diff --git a/templates/template_test.go b/templates/template_test.go new file mode 100644 index 00000000..61bc7837 --- /dev/null +++ b/templates/template_test.go @@ -0,0 +1,59 @@ +package templates + +import ( + "bufio" + "bytes" + "encoding/json" + "io" + "os" + "path/filepath" + "testing" + "text/template" + + "github.com/ollama/ollama/llm" +) + +func TestKVChatTemplate(t *testing.T) { + f, err := os.Open(filepath.Join("testdata", "templates.jsonl")) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + var ss map[string]string + if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil { + t.Fatal(err) + } + + for k, v := range ss { + t.Run(k, func(t *testing.T) { + kv := llm.KV{"tokenizer.chat_template": v} + s := kv.ChatTemplate() + r, err := NamedTemplate(s) + if err != nil { + t.Fatal(err) + } + + if r.Name != k { + t.Errorf("expected %q, got %q", k, r.Name) + } + + var b bytes.Buffer + if _, err := io.Copy(&b, r.Reader()); err != nil { + t.Fatal(err) + } + + tmpl, err := template.New(s).Parse(b.String()) + if err != nil { + t.Fatal(err) + } + + if tmpl.Tree.Root.String() == "" { + t.Errorf("empty %s template", k) + } + }) + } + } +} diff --git a/templates/testdata/templates.jsonl b/templates/testdata/templates.jsonl new file mode 100644 index 00000000..41a8d0ef --- /dev/null +++ b/templates/testdata/templates.jsonl @@ -0,0 +1,35 @@ +{"chatml": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"} +{"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} +{"zephyr": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"} +{"chatml": "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}"} +{"openchat": "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}"} +{"chatml": "{{bos_token}}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} +{"chatml": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} +{"chatml": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} +{"chatml": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} +{"zephyr": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"} +{"mistral-instruct": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"} +{"starcoder2-instruct": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'### Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response\n'}}"} +{"llama2-chat": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}"} +{"codellama-70b-instruct": "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\n\n ' + message['content'] | trim %}{{ content + ' ' }}{% endfor %}{{'Source: assistant\nDestination: user\n\n '}}"} +{"mistral-instruct": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"} +{"chatml": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'system' %}\n{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|im_start|>assistant\n' + message['content'] + '<|im_end|>' }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|im_start|>assistant' }}\n{% endif %}\n{% endfor %}"} +{"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} +{"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} +{"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif 'system' not in messages[0]['role'] %}{% set loop_messages = messages %}{% set system_message = 'You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks \u2014 remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER\\'S QUERY.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% if system_message != false %}{{ '<|im_start|>system\n' + system_message | trim + '<|im_end|>\n'}}{% endif %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% else %}{{ '\n' + '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% endif %}{% if (add_generation_prompt == true and loop.last) %}{{ '\n' + '<|im_start|>' + 'assistant' + '\n' }}{% endif %}{% endfor %}"} +{"alpaca": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"} +{"chatqa": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"} +{"gemma-instruct": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}"} +{"llama3-instruct": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"} +{"granite-instruct": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'Question:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'system' %}\n{{ 'System:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Answer:\n' + message['content'] + '\n\n' }}{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Answer:\n' }}{% endif %}{% endfor %}"} +{"magicoder": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'@@ Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'@@ Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'@@ Response\n'}}"} +{"alfred": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '' + message['content'].strip() + '' }}{% elif message['role'] == 'system' %}{{ '' + message['content'].strip() + '' }}{% elif message['role'] == 'assistant' %}{{ '' + message['content'] + '' }}{% else %}{{ raise_exception('Only system, user and assistant roles are supported.') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '' }}{% endif %}{% endfor %}"} +{"llama2-chat": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"} +{"phi-3": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"} +{"phi-3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"} +{"phi-3": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}"} +{"chatqa": "{{ bos_token }}{%- if messages[0]['role'] == 'system' -%}{% set loop_messages = messages[1:] %}{%- else -%}{% set loop_messages = messages %}{% endif %}System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context.\n\n{% for message in loop_messages %}{%- if message['role'] == 'user' -%}User: {{ message['content'].strip() + '\n\n' }}{%- else -%}Assistant: {{ message['content'].strip() + '\n\n' }}{%- endif %}{% if loop.last and message['role'] == 'user' %}Assistant:{% endif %}{% endfor %}"} +{"falcon-instruct": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'User: \n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ 'System: ' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ 'Falcon:\n' + message['content']}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Falcon:' }}\n{% endif %}\n{% endfor %}"} +{"falcon-instruct": "{% for message in messages %}{% if not loop.first %}{{ '\n' }}{% endif %}{% if message['role'] == 'system' %}{{ 'System: ' }}{% elif message['role'] == 'user' %}{{ 'User: ' }}{% elif message['role'] == 'assistant' %}{{ 'Falcon: ' }}{% endif %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ '\n' + 'Falcon:' }}{% endif %}"} +{"solar-instruct": "{% for message in messages %}{% if message['role'] == 'system' %}{% if message['content']%}{{'### System:\n' + message['content']+'\n\n'}}{% endif %}{% elif message['role'] == 'user' %}{{'### User:\n' + message['content']+'\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Assistant:\n' + message['content']}}{% endif %}{% if loop.last and add_generation_prompt %}{{ '### Assistant:\n' }}{% endif %}{% endfor %}"} +{"chatml": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"} diff --git a/templates/vicuna.gotmpl b/templates/vicuna.gotmpl new file mode 100644 index 00000000..174c1a35 --- /dev/null +++ b/templates/vicuna.gotmpl @@ -0,0 +1,3 @@ +{{ if .System }}{{ .System }} +{{ end }}{{ if .Prompt }}USER: {{ .Prompt }} +{{ end }}ASSISTANT: {{ .Response }} \ No newline at end of file diff --git a/templates/zephyr.gotmpl b/templates/zephyr.gotmpl new file mode 100644 index 00000000..aac0c7a1 --- /dev/null +++ b/templates/zephyr.gotmpl @@ -0,0 +1,6 @@ +{{ if .System }}<|system|> +{{ .System }} +{{ end }}{{ if .Prompt }}<|user|> +{{ .Prompt }} +{{ end }}<|assistant|> +{{ .Response }} \ No newline at end of file From ce0dc33cb809405fda18a8077da4058d1f7a5374 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Thu, 6 Jun 2024 23:14:33 -0700 Subject: [PATCH 26/26] llm: patch to fix qwen 2 temporarily on nvidia (#4897) --- llm/patches/06-qwen2.diff | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 llm/patches/06-qwen2.diff diff --git a/llm/patches/06-qwen2.diff b/llm/patches/06-qwen2.diff new file mode 100644 index 00000000..d7b0c155 --- /dev/null +++ b/llm/patches/06-qwen2.diff @@ -0,0 +1,13 @@ +diff --git a/llama.cpp b/llama.cpp +index 40d2ec2c..f34eb79a 100644 +--- a/llama.cpp ++++ b/llama.cpp +@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv( + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + +- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { ++ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32);