From ebc529cbb3f0b27f6c154fa90e724db8243a7614 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 5 Jul 2024 17:31:23 -0700 Subject: [PATCH 1/5] autodetect stop parameters from template --- server/model.go | 21 ++++++++++++++++++--- server/routes_create_test.go | 3 ++- template/alfred.json | 8 ++++++++ template/alpaca.json | 6 ++++++ template/chatml.json | 6 ++++++ template/chatqa.json | 8 ++++++++ template/codellama-70b-instruct.json | 7 +++++++ template/falcon-instruct.json | 6 ++++++ template/gemma-instruct.json | 6 ++++++ template/granite-instruct.json | 7 +++++++ template/llama2-chat.json | 8 ++++++++ template/llama3-instruct.json | 7 +++++++ template/magicoder.json | 6 ++++++ template/mistral-instruct.json | 6 ++++++ template/openchat.json | 5 +++++ template/phi-3.json | 8 ++++++++ template/solar-instruct.json | 7 +++++++ template/starcoder2-instruct.json | 7 +++++++ template/template.go | 14 ++++++++++++++ template/vicuna.json | 6 ++++++ template/zephyr.json | 8 ++++++++ 21 files changed, 156 insertions(+), 4 deletions(-) create mode 100644 template/alfred.json create mode 100644 template/alpaca.json create mode 100644 template/chatml.json create mode 100644 template/chatqa.json create mode 100644 template/codellama-70b-instruct.json create mode 100644 template/falcon-instruct.json create mode 100644 template/gemma-instruct.json create mode 100644 template/granite-instruct.json create mode 100644 template/llama2-chat.json create mode 100644 template/llama3-instruct.json create mode 100644 template/magicoder.json create mode 100644 template/mistral-instruct.json create mode 100644 template/openchat.json create mode 100644 template/phi-3.json create mode 100644 template/solar-instruct.json create mode 100644 template/starcoder2-instruct.json create mode 100644 template/vicuna.json create mode 100644 template/zephyr.json diff --git a/server/model.go b/server/model.go index a79f549a..d33ffaec 100644 --- a/server/model.go +++ b/server/model.go @@ -4,6 +4,7 @@ import ( "archive/zip" "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -259,13 +260,27 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) { if t, err := template.Named(s); err != nil { slog.Debug("template detection", "error", err) } else { - tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template") + layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template") if err != nil { return nil, err } - tmpl.status = fmt.Sprintf("using autodetected template %s", t.Name) - layers = append(layers, &layerGGML{tmpl, nil}) + layer.status = fmt.Sprintf("using autodetected template %s", t.Name) + layers = append(layers, &layerGGML{layer, nil}) + + if t.Parameters != nil { + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(t.Parameters); err != nil { + return nil, err + } + + layer, err := NewLayer(&b, "application/vnd.ollama.image.params") + if err != nil { + return nil, err + } + + layers = append(layers, &layerGGML{layer, nil}) + } } } } diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 04174b92..84672087 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -545,9 +545,10 @@ func TestCreateDetectTemplate(t *testing.T) { } checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-0d79f567714c62c048378f2107fb332dabee0135d080c302d884317da9433cc5"), filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"), filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"), - filepath.Join(p, "blobs", "sha256-f836ee110db21567f826332e4cedd746c06d10664fd5a9ea3659e3683a944510"), + filepath.Join(p, "blobs", "sha256-ea34c57ba5b78b740aafe2aeb74dc6507fc3ad14170b64c26a04fb9e36c88d75"), }) }) diff --git a/template/alfred.json b/template/alfred.json new file mode 100644 index 00000000..edac21af --- /dev/null +++ b/template/alfred.json @@ -0,0 +1,8 @@ +{ + "stop": [ + "", + "", + "", + "" + ] +} diff --git a/template/alpaca.json b/template/alpaca.json new file mode 100644 index 00000000..eafe2b8a --- /dev/null +++ b/template/alpaca.json @@ -0,0 +1,6 @@ +{ + "stop": [ + "### Instruction:", + "### Response" + ] +} diff --git a/template/chatml.json b/template/chatml.json new file mode 100644 index 00000000..7afeb3de --- /dev/null +++ b/template/chatml.json @@ -0,0 +1,6 @@ +{ + "stop": [ + "<|im_start|>", + "<|im_end|>" + ] +} diff --git a/template/chatqa.json b/template/chatqa.json new file mode 100644 index 00000000..64dd0f33 --- /dev/null +++ b/template/chatqa.json @@ -0,0 +1,8 @@ +{ + "stop": [ + "System:", + "User:", + "Assistant:", + "<|begin_of_text|>" + ] +} diff --git a/template/codellama-70b-instruct.json b/template/codellama-70b-instruct.json new file mode 100644 index 00000000..a56a63f1 --- /dev/null +++ b/template/codellama-70b-instruct.json @@ -0,0 +1,7 @@ +{ + "stop": [ + "Source:", + "Destination:", + "" + ] +} diff --git a/template/falcon-instruct.json b/template/falcon-instruct.json new file mode 100644 index 00000000..a0da0e81 --- /dev/null +++ b/template/falcon-instruct.json @@ -0,0 +1,6 @@ +{ + "stop": [ + "User:", + "Assistant:" + ] +} diff --git a/template/gemma-instruct.json b/template/gemma-instruct.json new file mode 100644 index 00000000..f4ad415c --- /dev/null +++ b/template/gemma-instruct.json @@ -0,0 +1,6 @@ +{ + "stop": [ + "", + "" + ] +} diff --git a/template/granite-instruct.json b/template/granite-instruct.json new file mode 100644 index 00000000..0933e4b5 --- /dev/null +++ b/template/granite-instruct.json @@ -0,0 +1,7 @@ +{ + "stop": [ + "System:", + "Question:", + "Answer:" + ] +} diff --git a/template/llama2-chat.json b/template/llama2-chat.json new file mode 100644 index 00000000..17590ab4 --- /dev/null +++ b/template/llama2-chat.json @@ -0,0 +1,8 @@ +{ + "stop": [ + "[INST]", + "[/INST]", + "<>", + "<>" + ] +} diff --git a/template/llama3-instruct.json b/template/llama3-instruct.json new file mode 100644 index 00000000..c4e9d448 --- /dev/null +++ b/template/llama3-instruct.json @@ -0,0 +1,7 @@ +{ + "stop": [ + "<|start_header_id|>", + "<|end_header_id|>", + "<|eot_id|>" + ] +} diff --git a/template/magicoder.json b/template/magicoder.json new file mode 100644 index 00000000..6f67cab0 --- /dev/null +++ b/template/magicoder.json @@ -0,0 +1,6 @@ +{ + "stop": [ + "@@ Instruction", + "@@ Response" + ] +} diff --git a/template/mistral-instruct.json b/template/mistral-instruct.json new file mode 100644 index 00000000..7afeb3de --- /dev/null +++ b/template/mistral-instruct.json @@ -0,0 +1,6 @@ +{ + "stop": [ + "<|im_start|>", + "<|im_end|>" + ] +} diff --git a/template/openchat.json b/template/openchat.json new file mode 100644 index 00000000..0edc341f --- /dev/null +++ b/template/openchat.json @@ -0,0 +1,5 @@ +{ + "stop": [ + "<|end_of_turn|>" + ] +} diff --git a/template/phi-3.json b/template/phi-3.json new file mode 100644 index 00000000..27bf7664 --- /dev/null +++ b/template/phi-3.json @@ -0,0 +1,8 @@ +{ + "stop": [ + "<|end|>", + "<|system|>", + "<|user|>", + "<|assistant|>" + ] +} diff --git a/template/solar-instruct.json b/template/solar-instruct.json new file mode 100644 index 00000000..7b7a9050 --- /dev/null +++ b/template/solar-instruct.json @@ -0,0 +1,7 @@ +{ + "stop": [ + "### System:", + "### User:", + "### Assistant" + ] +} diff --git a/template/starcoder2-instruct.json b/template/starcoder2-instruct.json new file mode 100644 index 00000000..31348908 --- /dev/null +++ b/template/starcoder2-instruct.json @@ -0,0 +1,7 @@ +{ + "stop": [ + "### Instruction", + "### Response", + "<|endoftext|>" + ] +} diff --git a/template/template.go b/template/template.go index 9b351666..9bb6a399 100644 --- a/template/template.go +++ b/template/template.go @@ -23,6 +23,7 @@ import ( var indexBytes []byte //go:embed *.gotmpl +//go:embed *.json var templatesFS embed.FS var templatesOnce = sync.OnceValues(func() ([]*named, error) { @@ -39,6 +40,15 @@ var templatesOnce = sync.OnceValues(func() ([]*named, error) { // normalize line endings t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n")) + + params, err := templatesFS.ReadFile(t.Name + ".json") + if err != nil { + continue + } + + if err := json.Unmarshal(params, &t.Parameters); err != nil { + return nil, err + } } return templates, nil @@ -48,6 +58,10 @@ type named struct { Name string `json:"name"` Template string `json:"template"` Bytes []byte + + Parameters *struct { + Stop []string `json:"stop"` + } } func (t named) Reader() io.Reader { diff --git a/template/vicuna.json b/template/vicuna.json new file mode 100644 index 00000000..ed7bfb0f --- /dev/null +++ b/template/vicuna.json @@ -0,0 +1,6 @@ +{ + "stop": [ + "USER:", + "ASSISTANT:" + ] +} diff --git a/template/zephyr.json b/template/zephyr.json new file mode 100644 index 00000000..f9c0115c --- /dev/null +++ b/template/zephyr.json @@ -0,0 +1,8 @@ +{ + "stop": [ + "<|system|>", + "", + "<|user|>", + "<|assistant|>" + ] +} From e12fff8810e37bfabe4416f7f41902387ff3aae1 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 15 Jul 2024 09:25:56 -0700 Subject: [PATCH 2/5] Enable windows error dialog for subprocess startup Make sure if something goes wrong spawning the process, the user gets enough info to be able to try to self correct, or at least file a bug with details so we can fix it. Once the process starts, we immediately change back to the recommended setting to prevent the blocking dialog. This ensures if the model fails to load (OOM, unsupported model type, etc.) the process will exit quickly and we can scan the stdout/stderr of the subprocess for the reason to report via API. --- llm/ext_server/server.cpp | 4 ++++ llm/llm_darwin_amd64.go | 3 +++ llm/llm_darwin_arm64.go | 3 +++ llm/llm_linux.go | 7 ++++++- llm/llm_windows.go | 16 +++++++++++++++- llm/server.go | 1 + 6 files changed, 32 insertions(+), 2 deletions(-) diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index e8a076c4..14d921c0 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -41,6 +41,7 @@ #if defined(_WIN32) #include +#include #endif #include @@ -2737,6 +2738,9 @@ int wmain(int argc, wchar_t **wargv) { for (int i = 0; i < argc; ++i) { argv[i] = wchar_to_char(wargv[i]); } + + // Adjust error mode to avoid error dialog after we start. + SetErrorMode(SEM_FAILCRITICALERRORS); #else int main(int argc, char **argv) { #endif diff --git a/llm/llm_darwin_amd64.go b/llm/llm_darwin_amd64.go index 3093e1ad..60eed719 100644 --- a/llm/llm_darwin_amd64.go +++ b/llm/llm_darwin_amd64.go @@ -2,7 +2,10 @@ package llm import ( "embed" + "syscall" ) //go:embed build/darwin/x86_64/*/bin/* var libEmbed embed.FS + +var LlamaServerSysProcAttr = &syscall.SysProcAttr{} diff --git a/llm/llm_darwin_arm64.go b/llm/llm_darwin_arm64.go index 928f0b82..20ce8552 100644 --- a/llm/llm_darwin_arm64.go +++ b/llm/llm_darwin_arm64.go @@ -2,7 +2,10 @@ package llm import ( "embed" + "syscall" ) //go:embed build/darwin/arm64/*/bin/* var libEmbed embed.FS + +var LlamaServerSysProcAttr = &syscall.SysProcAttr{} diff --git a/llm/llm_linux.go b/llm/llm_linux.go index c2c5c4cb..928b4e79 100644 --- a/llm/llm_linux.go +++ b/llm/llm_linux.go @@ -1,6 +1,11 @@ package llm -import "embed" +import ( + "embed" + "syscall" +) //go:embed build/linux/*/*/bin/* var libEmbed embed.FS + +var LlamaServerSysProcAttr = &syscall.SysProcAttr{} diff --git a/llm/llm_windows.go b/llm/llm_windows.go index e44f4b95..763cccf9 100644 --- a/llm/llm_windows.go +++ b/llm/llm_windows.go @@ -1,6 +1,20 @@ package llm -import "embed" +import ( + "embed" + "syscall" +) // unused on windows var libEmbed embed.FS + +const CREATE_DEFAULT_ERROR_MODE = 0x04000000 + +var LlamaServerSysProcAttr = &syscall.SysProcAttr{ + // Wire up the default error handling logic If for some reason a DLL is + // missing in the path this will pop up a GUI Dialog explaining the fault so + // the user can either fix their PATH, or report a bug. Without this + // setting, the process exits immediately with a generic exit status but no + // way to (easily) figure out what the actual missing DLL was. + CreationFlags: CREATE_DEFAULT_ERROR_MODE, +} diff --git a/llm/server.go b/llm/server.go index 08463ef0..55732773 100644 --- a/llm/server.go +++ b/llm/server.go @@ -346,6 +346,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr s.cmd.Env = os.Environ() s.cmd.Stdout = os.Stdout s.cmd.Stderr = s.status + s.cmd.SysProcAttr = LlamaServerSysProcAttr envWorkarounds := [][2]string{} for _, gpu := range gpus { From a622c47bd32e4c7d8d6cd12ba8c7556fcc492524 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 26 Jul 2024 14:10:18 -0700 Subject: [PATCH 3/5] fix nil deref in auth.go --- server/auth.go | 2 +- server/upload.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/auth.go b/server/auth.go index e92a5b65..dcef5bf9 100644 --- a/server/auth.go +++ b/server/auth.go @@ -67,7 +67,7 @@ func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (st headers.Add("Authorization", signature) - response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil) + response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, ®istryOptions{}) if err != nil { return "", err } diff --git a/server/upload.go b/server/upload.go index 73ce78ce..c4078c22 100644 --- a/server/upload.go +++ b/server/upload.go @@ -254,7 +254,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL * // retry uploading to the redirect URL for try := range maxRetries { - err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil) + err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, ®istryOptions{}) switch { case errors.Is(err, context.Canceled): return err From 750c1c55f7ea65219e4e24d6107a4a3ad519b53f Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Fri, 26 Jul 2024 14:24:24 -0700 Subject: [PATCH 4/5] server: fix race conditions during download (#5994) This fixes various data races scattered throughout the download/pull client where the client was accessing the download state concurrently. This commit is mostly a hot-fix and will be replaced by a new client one day soon. Also, remove the unnecessary opts argument from downloadChunk. --- server/download.go | 59 ++++++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/server/download.go b/server/download.go index 8b5b577f..45483ba6 100644 --- a/server/download.go +++ b/server/download.go @@ -44,17 +44,19 @@ type blobDownload struct { context.CancelFunc - done bool + done chan struct{} err error references atomic.Int32 } type blobDownloadPart struct { - N int - Offset int64 - Size int64 - Completed int64 - lastUpdated time.Time + N int + Offset int64 + Size int64 + Completed atomic.Int64 + + lastUpdatedMu sync.Mutex + lastUpdated time.Time *blobDownload `json:"-"` } @@ -72,7 +74,7 @@ func (p *blobDownloadPart) Name() string { } func (p *blobDownloadPart) StartsAt() int64 { - return p.Offset + p.Completed + return p.Offset + p.Completed.Load() } func (p *blobDownloadPart) StopsAt() int64 { @@ -82,7 +84,9 @@ func (p *blobDownloadPart) StopsAt() int64 { func (p *blobDownloadPart) Write(b []byte) (n int, err error) { n = len(b) p.blobDownload.Completed.Add(int64(n)) + p.lastUpdatedMu.Lock() p.lastUpdated = time.Now() + p.lastUpdatedMu.Unlock() return n, nil } @@ -92,6 +96,8 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r return err } + b.done = make(chan struct{}) + for _, partFilePath := range partFilePaths { part, err := b.readPart(partFilePath) if err != nil { @@ -99,7 +105,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r } b.Total += part.Size - b.Completed.Add(part.Completed) + b.Completed.Add(part.Completed.Load()) b.Parts = append(b.Parts, part) } @@ -139,6 +145,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r } func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) { + defer close(b.done) b.err = b.run(ctx, requestURL, opts) } @@ -230,7 +237,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis g.SetLimit(numDownloadParts) for i := range b.Parts { part := b.Parts[i] - if part.Completed == part.Size { + if part.Completed.Load() == part.Size { continue } @@ -238,7 +245,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis var err error for try := 0; try < maxRetries; try++ { w := io.NewOffsetWriter(file, part.StartsAt()) - err = b.downloadChunk(inner, directURL, w, part, opts) + err = b.downloadChunk(inner, directURL, w, part) switch { case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): // return immediately if the context is canceled or the device is out of space @@ -279,29 +286,31 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return err } - b.done = true return nil } -func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error { +func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error { g, ctx := errgroup.WithContext(ctx) g.Go(func() error { - headers := make(http.Header) - headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) - resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil) + if err != nil { + return err + } + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) + resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() - n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed) + n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load()) if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { // rollback progress b.Completed.Add(-n) return err } - part.Completed += n + part.Completed.Add(n) if err := b.writePart(part.Name(), part); err != nil { return err } @@ -315,15 +324,21 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w for { select { case <-ticker.C: - if part.Completed >= part.Size { + if part.Completed.Load() >= part.Size { return nil } - if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second { + part.lastUpdatedMu.Lock() + lastUpdated := part.lastUpdated + part.lastUpdatedMu.Unlock() + + if !lastUpdated.IsZero() && time.Since(lastUpdated) > 5*time.Second { const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection." slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N)) // reset last updated + part.lastUpdatedMu.Lock() part.lastUpdated = time.Time{} + part.lastUpdatedMu.Unlock() return errPartStalled } case <-ctx.Done(): @@ -388,6 +403,8 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) ticker := time.NewTicker(60 * time.Millisecond) for { select { + case <-b.done: + return b.err case <-ticker.C: fn(api.ProgressResponse{ Status: fmt.Sprintf("pulling %s", b.Digest[7:19]), @@ -395,10 +412,6 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) Total: b.Total, Completed: b.Completed.Load(), }) - - if b.done || b.err != nil { - return b.err - } case <-ctx.Done(): return ctx.Err() } From f2a96c7d778249a7f911471b6a1532339e42fcf5 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Fri, 26 Jul 2024 18:20:52 -0400 Subject: [PATCH 5/5] llm: keep patch for llama 3 rope factors (#5987) --- llm/patches/10-llama3-rope.diff | 70 +++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 llm/patches/10-llama3-rope.diff diff --git a/llm/patches/10-llama3-rope.diff b/llm/patches/10-llama3-rope.diff new file mode 100644 index 00000000..39f38fea --- /dev/null +++ b/llm/patches/10-llama3-rope.diff @@ -0,0 +1,70 @@ +From 2f872f294fb6f5c6e8f983b68c40ea656053dd92 Mon Sep 17 00:00:00 2001 +From: Michael Yang +Date: Tue, 23 Jul 2024 14:33:29 -0700 +Subject: [PATCH] llama 3.1 rope scaling + +--- + src/llama.cpp | 14 ++++++++++++-- + 1 file changed, 12 insertions(+), 2 deletions(-) + +diff --git a/src/llama.cpp b/src/llama.cpp +index 8fe51971..a9969df8 100644 +--- a/src/llama.cpp ++++ b/src/llama.cpp +@@ -2472,6 +2472,7 @@ struct llama_layer { + // long rope factors + struct ggml_tensor * rope_long = nullptr; + struct ggml_tensor * rope_short = nullptr; ++ struct ggml_tensor * rope_freqs = nullptr; + + // bitnet scale + struct ggml_tensor * wq_scale; +@@ -6143,6 +6144,8 @@ static bool llm_load_tensors( + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + ++ layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), { n_embd/n_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); ++ + if (n_expert == 0) { + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); +@@ -8620,6 +8623,10 @@ struct llm_build_context { + // choose long/short freq factors based on the context size + const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max; + ++ if (model.layers[il].rope_freqs != nullptr) { ++ return model.layers[il].rope_freqs; ++ } ++ + if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) { + return model.layers[il].rope_long; + } +@@ -8814,6 +8821,9 @@ struct llm_build_context { + + // self-attention + { ++ // rope freq factors for llama3; may return nullptr for llama2 and other models ++ struct ggml_tensor * rope_factors = build_rope_factors(il); ++ + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); +@@ -8837,14 +8847,14 @@ struct llm_build_context { + } + + Qcur = ggml_rope_ext( +- ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, ++ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( +- ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, ++ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); +-- +2.45.2