mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-24 07:28:27 +00:00
Merge branch 'ollama:main' into main
This commit is contained in:
4
llm/ext_server/server.cpp
vendored
4
llm/ext_server/server.cpp
vendored
@@ -41,6 +41,7 @@
|
||||
|
||||
#if defined(_WIN32)
|
||||
#include <windows.h>
|
||||
#include <errhandlingapi.h>
|
||||
#endif
|
||||
|
||||
#include <cstddef>
|
||||
@@ -2737,6 +2738,9 @@ int wmain(int argc, wchar_t **wargv) {
|
||||
for (int i = 0; i < argc; ++i) {
|
||||
argv[i] = wchar_to_char(wargv[i]);
|
||||
}
|
||||
|
||||
// Adjust error mode to avoid error dialog after we start.
|
||||
SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
#else
|
||||
int main(int argc, char **argv) {
|
||||
#endif
|
||||
|
||||
@@ -2,7 +2,10 @@ package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
//go:embed build/darwin/x86_64/*/bin/*
|
||||
var libEmbed embed.FS
|
||||
|
||||
var LlamaServerSysProcAttr = &syscall.SysProcAttr{}
|
||||
|
||||
@@ -2,7 +2,10 @@ package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
//go:embed build/darwin/arm64/*/bin/*
|
||||
var libEmbed embed.FS
|
||||
|
||||
var LlamaServerSysProcAttr = &syscall.SysProcAttr{}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package llm
|
||||
|
||||
import "embed"
|
||||
import (
|
||||
"embed"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
//go:embed build/linux/*/*/bin/*
|
||||
var libEmbed embed.FS
|
||||
|
||||
var LlamaServerSysProcAttr = &syscall.SysProcAttr{}
|
||||
|
||||
@@ -1,6 +1,20 @@
|
||||
package llm
|
||||
|
||||
import "embed"
|
||||
import (
|
||||
"embed"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// unused on windows
|
||||
var libEmbed embed.FS
|
||||
|
||||
const CREATE_DEFAULT_ERROR_MODE = 0x04000000
|
||||
|
||||
var LlamaServerSysProcAttr = &syscall.SysProcAttr{
|
||||
// Wire up the default error handling logic If for some reason a DLL is
|
||||
// missing in the path this will pop up a GUI Dialog explaining the fault so
|
||||
// the user can either fix their PATH, or report a bug. Without this
|
||||
// setting, the process exits immediately with a generic exit status but no
|
||||
// way to (easily) figure out what the actual missing DLL was.
|
||||
CreationFlags: CREATE_DEFAULT_ERROR_MODE,
|
||||
}
|
||||
|
||||
70
llm/patches/10-llama3-rope.diff
Normal file
70
llm/patches/10-llama3-rope.diff
Normal file
@@ -0,0 +1,70 @@
|
||||
From 2f872f294fb6f5c6e8f983b68c40ea656053dd92 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Yang <mxyng@pm.me>
|
||||
Date: Tue, 23 Jul 2024 14:33:29 -0700
|
||||
Subject: [PATCH] llama 3.1 rope scaling
|
||||
|
||||
---
|
||||
src/llama.cpp | 14 ++++++++++++--
|
||||
1 file changed, 12 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||
index 8fe51971..a9969df8 100644
|
||||
--- a/src/llama.cpp
|
||||
+++ b/src/llama.cpp
|
||||
@@ -2472,6 +2472,7 @@ struct llama_layer {
|
||||
// long rope factors
|
||||
struct ggml_tensor * rope_long = nullptr;
|
||||
struct ggml_tensor * rope_short = nullptr;
|
||||
+ struct ggml_tensor * rope_freqs = nullptr;
|
||||
|
||||
// bitnet scale
|
||||
struct ggml_tensor * wq_scale;
|
||||
@@ -6143,6 +6144,8 @@ static bool llm_load_tensors(
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
+ layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), { n_embd/n_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
+
|
||||
if (n_expert == 0) {
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||
@@ -8620,6 +8623,10 @@ struct llm_build_context {
|
||||
// choose long/short freq factors based on the context size
|
||||
const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||
|
||||
+ if (model.layers[il].rope_freqs != nullptr) {
|
||||
+ return model.layers[il].rope_freqs;
|
||||
+ }
|
||||
+
|
||||
if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
|
||||
return model.layers[il].rope_long;
|
||||
}
|
||||
@@ -8814,6 +8821,9 @@ struct llm_build_context {
|
||||
|
||||
// self-attention
|
||||
{
|
||||
+ // rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||
+ struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||
+
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
@@ -8837,14 +8847,14 @@ struct llm_build_context {
|
||||
}
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
- ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
- ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
--
|
||||
2.45.2
|
||||
@@ -346,6 +346,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
s.cmd.Env = os.Environ()
|
||||
s.cmd.Stdout = os.Stdout
|
||||
s.cmd.Stderr = s.status
|
||||
s.cmd.SysProcAttr = LlamaServerSysProcAttr
|
||||
|
||||
envWorkarounds := [][2]string{}
|
||||
for _, gpu := range gpus {
|
||||
|
||||
@@ -67,7 +67,7 @@ func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (st
|
||||
|
||||
headers.Add("Authorization", signature)
|
||||
|
||||
response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
|
||||
response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, ®istryOptions{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -44,17 +44,19 @@ type blobDownload struct {
|
||||
|
||||
context.CancelFunc
|
||||
|
||||
done bool
|
||||
done chan struct{}
|
||||
err error
|
||||
references atomic.Int32
|
||||
}
|
||||
|
||||
type blobDownloadPart struct {
|
||||
N int
|
||||
Offset int64
|
||||
Size int64
|
||||
Completed int64
|
||||
lastUpdated time.Time
|
||||
N int
|
||||
Offset int64
|
||||
Size int64
|
||||
Completed atomic.Int64
|
||||
|
||||
lastUpdatedMu sync.Mutex
|
||||
lastUpdated time.Time
|
||||
|
||||
*blobDownload `json:"-"`
|
||||
}
|
||||
@@ -72,7 +74,7 @@ func (p *blobDownloadPart) Name() string {
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) StartsAt() int64 {
|
||||
return p.Offset + p.Completed
|
||||
return p.Offset + p.Completed.Load()
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) StopsAt() int64 {
|
||||
@@ -82,7 +84,9 @@ func (p *blobDownloadPart) StopsAt() int64 {
|
||||
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
p.blobDownload.Completed.Add(int64(n))
|
||||
p.lastUpdatedMu.Lock()
|
||||
p.lastUpdated = time.Now()
|
||||
p.lastUpdatedMu.Unlock()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@@ -92,6 +96,8 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
|
||||
return err
|
||||
}
|
||||
|
||||
b.done = make(chan struct{})
|
||||
|
||||
for _, partFilePath := range partFilePaths {
|
||||
part, err := b.readPart(partFilePath)
|
||||
if err != nil {
|
||||
@@ -99,7 +105,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
|
||||
}
|
||||
|
||||
b.Total += part.Size
|
||||
b.Completed.Add(part.Completed)
|
||||
b.Completed.Add(part.Completed.Load())
|
||||
b.Parts = append(b.Parts, part)
|
||||
}
|
||||
|
||||
@@ -139,6 +145,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
|
||||
}
|
||||
|
||||
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
|
||||
defer close(b.done)
|
||||
b.err = b.run(ctx, requestURL, opts)
|
||||
}
|
||||
|
||||
@@ -230,7 +237,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
g.SetLimit(numDownloadParts)
|
||||
for i := range b.Parts {
|
||||
part := b.Parts[i]
|
||||
if part.Completed == part.Size {
|
||||
if part.Completed.Load() == part.Size {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -238,7 +245,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
var err error
|
||||
for try := 0; try < maxRetries; try++ {
|
||||
w := io.NewOffsetWriter(file, part.StartsAt())
|
||||
err = b.downloadChunk(inner, directURL, w, part, opts)
|
||||
err = b.downloadChunk(inner, directURL, w, part)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||
// return immediately if the context is canceled or the device is out of space
|
||||
@@ -279,29 +286,31 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
return err
|
||||
}
|
||||
|
||||
b.done = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error {
|
||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.Go(func() error {
|
||||
headers := make(http.Header)
|
||||
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed)
|
||||
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// rollback progress
|
||||
b.Completed.Add(-n)
|
||||
return err
|
||||
}
|
||||
|
||||
part.Completed += n
|
||||
part.Completed.Add(n)
|
||||
if err := b.writePart(part.Name(), part); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -315,15 +324,21 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if part.Completed >= part.Size {
|
||||
if part.Completed.Load() >= part.Size {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
|
||||
part.lastUpdatedMu.Lock()
|
||||
lastUpdated := part.lastUpdated
|
||||
part.lastUpdatedMu.Unlock()
|
||||
|
||||
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 5*time.Second {
|
||||
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
|
||||
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
|
||||
// reset last updated
|
||||
part.lastUpdatedMu.Lock()
|
||||
part.lastUpdated = time.Time{}
|
||||
part.lastUpdatedMu.Unlock()
|
||||
return errPartStalled
|
||||
}
|
||||
case <-ctx.Done():
|
||||
@@ -388,6 +403,8 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||
ticker := time.NewTicker(60 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-b.done:
|
||||
return b.err
|
||||
case <-ticker.C:
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
|
||||
@@ -395,10 +412,6 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||
Total: b.Total,
|
||||
Completed: b.Completed.Load(),
|
||||
})
|
||||
|
||||
if b.done || b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
@@ -263,13 +263,27 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
if t, err := template.Named(s); err != nil {
|
||||
slog.Debug("template detection", "error", err)
|
||||
} else {
|
||||
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tmpl.status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layers = append(layers, &layerGGML{tmpl, nil})
|
||||
layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layers = append(layers, &layerGGML{layer, nil})
|
||||
|
||||
if t.Parameters != nil {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(t.Parameters); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers = append(layers, &layerGGML{layer, nil})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -599,9 +599,10 @@ func TestCreateDetectTemplate(t *testing.T) {
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-0d79f567714c62c048378f2107fb332dabee0135d080c302d884317da9433cc5"),
|
||||
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
|
||||
filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"),
|
||||
filepath.Join(p, "blobs", "sha256-f836ee110db21567f826332e4cedd746c06d10664fd5a9ea3659e3683a944510"),
|
||||
filepath.Join(p, "blobs", "sha256-ea34c57ba5b78b740aafe2aeb74dc6507fc3ad14170b64c26a04fb9e36c88d75"),
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -254,7 +254,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||
|
||||
// retry uploading to the redirect URL
|
||||
for try := range maxRetries {
|
||||
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
|
||||
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, ®istryOptions{})
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
return err
|
||||
|
||||
8
template/alfred.json
Normal file
8
template/alfred.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"stop": [
|
||||
"<start_system>",
|
||||
"<end_message>",
|
||||
"<start_user>",
|
||||
"<start_assistant>"
|
||||
]
|
||||
}
|
||||
6
template/alpaca.json
Normal file
6
template/alpaca.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"stop": [
|
||||
"### Instruction:",
|
||||
"### Response"
|
||||
]
|
||||
}
|
||||
6
template/chatml.json
Normal file
6
template/chatml.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"stop": [
|
||||
"<|im_start|>",
|
||||
"<|im_end|>"
|
||||
]
|
||||
}
|
||||
8
template/chatqa.json
Normal file
8
template/chatqa.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"stop": [
|
||||
"System:",
|
||||
"User:",
|
||||
"Assistant:",
|
||||
"<|begin_of_text|>"
|
||||
]
|
||||
}
|
||||
7
template/codellama-70b-instruct.json
Normal file
7
template/codellama-70b-instruct.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"stop": [
|
||||
"Source:",
|
||||
"Destination:",
|
||||
"<step>"
|
||||
]
|
||||
}
|
||||
6
template/falcon-instruct.json
Normal file
6
template/falcon-instruct.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"stop": [
|
||||
"User:",
|
||||
"Assistant:"
|
||||
]
|
||||
}
|
||||
6
template/gemma-instruct.json
Normal file
6
template/gemma-instruct.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"stop": [
|
||||
"<start_of_turn>",
|
||||
"<end_of_turn>"
|
||||
]
|
||||
}
|
||||
7
template/granite-instruct.json
Normal file
7
template/granite-instruct.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"stop": [
|
||||
"System:",
|
||||
"Question:",
|
||||
"Answer:"
|
||||
]
|
||||
}
|
||||
8
template/llama2-chat.json
Normal file
8
template/llama2-chat.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"stop": [
|
||||
"[INST]",
|
||||
"[/INST]",
|
||||
"<<SYS>>",
|
||||
"<</SYS>>"
|
||||
]
|
||||
}
|
||||
7
template/llama3-instruct.json
Normal file
7
template/llama3-instruct.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"stop": [
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|eot_id|>"
|
||||
]
|
||||
}
|
||||
6
template/magicoder.json
Normal file
6
template/magicoder.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"stop": [
|
||||
"@@ Instruction",
|
||||
"@@ Response"
|
||||
]
|
||||
}
|
||||
6
template/mistral-instruct.json
Normal file
6
template/mistral-instruct.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"stop": [
|
||||
"<|im_start|>",
|
||||
"<|im_end|>"
|
||||
]
|
||||
}
|
||||
5
template/openchat.json
Normal file
5
template/openchat.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"stop": [
|
||||
"<|end_of_turn|>"
|
||||
]
|
||||
}
|
||||
8
template/phi-3.json
Normal file
8
template/phi-3.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"stop": [
|
||||
"<|end|>",
|
||||
"<|system|>",
|
||||
"<|user|>",
|
||||
"<|assistant|>"
|
||||
]
|
||||
}
|
||||
7
template/solar-instruct.json
Normal file
7
template/solar-instruct.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"stop": [
|
||||
"### System:",
|
||||
"### User:",
|
||||
"### Assistant"
|
||||
]
|
||||
}
|
||||
7
template/starcoder2-instruct.json
Normal file
7
template/starcoder2-instruct.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"stop": [
|
||||
"### Instruction",
|
||||
"### Response",
|
||||
"<|endoftext|>"
|
||||
]
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
var indexBytes []byte
|
||||
|
||||
//go:embed *.gotmpl
|
||||
//go:embed *.json
|
||||
var templatesFS embed.FS
|
||||
|
||||
var templatesOnce = sync.OnceValues(func() ([]*named, error) {
|
||||
@@ -39,6 +40,15 @@ var templatesOnce = sync.OnceValues(func() ([]*named, error) {
|
||||
|
||||
// normalize line endings
|
||||
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
|
||||
|
||||
params, err := templatesFS.ReadFile(t.Name + ".json")
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(params, &t.Parameters); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return templates, nil
|
||||
@@ -48,6 +58,10 @@ type named struct {
|
||||
Name string `json:"name"`
|
||||
Template string `json:"template"`
|
||||
Bytes []byte
|
||||
|
||||
Parameters *struct {
|
||||
Stop []string `json:"stop"`
|
||||
}
|
||||
}
|
||||
|
||||
func (t named) Reader() io.Reader {
|
||||
|
||||
6
template/vicuna.json
Normal file
6
template/vicuna.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"stop": [
|
||||
"USER:",
|
||||
"ASSISTANT:"
|
||||
]
|
||||
}
|
||||
8
template/zephyr.json
Normal file
8
template/zephyr.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"stop": [
|
||||
"<|system|>",
|
||||
"</s>",
|
||||
"<|user|>",
|
||||
"<|assistant|>"
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user