diff --git a/.gitattributes b/.gitattributes index f7192096..f1c8bcb4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,3 @@ llm/ext_server/* linguist-vendored -* text eol=lf +* text=auto +*.go text eol=lf diff --git a/README.md b/README.md index 3e4455a7..68e56eff 100644 --- a/README.md +++ b/README.md @@ -343,6 +343,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [tlm](https://github.com/yusufcanb/tlm) - [podman-ollama](https://github.com/ericcurtin/podman-ollama) - [gollama](https://github.com/sammcj/gollama) +- [Ollama eBook Summary](https://github.com/cognitivetech/ollama-ebook-summary/) ### Database diff --git a/cmd/cmd.go b/cmd/cmd.go index d47db65b..2356110e 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -22,6 +22,7 @@ import ( "runtime" "slices" "strings" + "sync/atomic" "syscall" "time" @@ -78,6 +79,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { status := "transferring model data" spinner := progress.NewSpinner(status) p.Add(status, spinner) + defer p.Stop() for i := range modelfile.Commands { switch modelfile.Commands[i].Name { @@ -112,7 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { path = tempfile } - digest, err := createBlob(cmd, client, path) + digest, err := createBlob(cmd, client, path, spinner) if err != nil { return err } @@ -263,13 +265,20 @@ func tempZipFiles(path string) (string, error) { return tempfile.Name(), nil } -func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { +func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) { bin, err := os.Open(path) if err != nil { return "", err } defer bin.Close() + // Get file info to retrieve the size + fileInfo, err := bin.Stat() + if err != nil { + return "", err + } + fileSize := fileInfo.Size() + hash := sha256.New() if _, err := io.Copy(hash, bin); err != nil { return "", err @@ -279,13 +288,43 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er return "", err } + var pw progressWriter + status := "transferring model data 0%" + spinner.SetMessage(status) + + done := make(chan struct{}) + defer close(done) + + go func() { + ticker := time.NewTicker(60 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n.Load()/fileSize))) + case <-done: + spinner.SetMessage("transferring model data 100%") + return + } + } + }() + digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) - if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil { + if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { return "", err } return digest, nil } +type progressWriter struct { + n atomic.Int64 +} + +func (w *progressWriter) Write(p []byte) (n int, err error) { + w.n.Add(int64(len(p))) + return len(p), nil +} + func RunHandler(cmd *cobra.Command, args []string) error { interactive := true diff --git a/convert/convert.go b/convert/convert.go index b9461e4f..24c19aa4 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -27,6 +27,10 @@ func (Parameters) KV(t *Tokenizer) llm.KV { "tokenizer.ggml.token_type": t.Vocabulary.Types, } + if len(t.Merges) > 0 { + kv["tokenizer.ggml.merges"] = t.Merges + } + if t.Template != "" { kv["tokenizer.chat_template"] = t.Template } @@ -89,6 +93,8 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error { conv = &mixtral{} case "GemmaForCausalLM": conv = &gemma{} + case "Phi3ForCausalLM": + conv = &phi3{} default: return errors.New("unsupported architecture") } diff --git a/convert/convert_llama.go b/convert/convert_llama.go index 0383a85e..178b13f3 100644 --- a/convert/convert_llama.go +++ b/convert/convert_llama.go @@ -90,10 +90,6 @@ func (p *llama) KV(t *Tokenizer) llm.KV { kv["llama.attention.value_length"] = p.HeadDim } - if len(t.Merges) > 0 { - kv["tokenizer.ggml.merges"] = t.Merges - } - return kv } diff --git a/convert/convert_phi3.go b/convert/convert_phi3.go new file mode 100644 index 00000000..0f645217 --- /dev/null +++ b/convert/convert_phi3.go @@ -0,0 +1,125 @@ +package convert + +import ( + "cmp" + "encoding/binary" + "io" + "math" + "strings" + "sync" + + "github.com/ollama/ollama/llm" +) + +type phi3 struct { + Parameters + NumHiddenLayers uint32 `json:"num_hidden_layers"` + NLayers uint32 `json:"n_layers"` + HiddenSize uint32 `json:"hidden_size"` + NEmbd uint32 `json:"n_embd"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NHead uint32 `json:"n_head"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + NHeadKV uint32 `json:"n_head_kv"` + RopeTheta float32 `json:"rope_theta"` + RopeScaling struct { + Type string `json:"type"` + LongFactor ropeFactor `json:"long_factor"` + ShortFactor ropeFactor `json:"short_factor"` + } `json:"rope_scaling"` + RMSNormEPS float32 `json:"rms_norm_eps"` + NPositions uint32 `json:"n_positions"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` + SlidingWindow uint32 `json:"sliding_window"` +} + +var _ Converter = (*phi3)(nil) + +func (p *phi3) KV(t *Tokenizer) llm.KV { + kv := p.Parameters.KV(t) + kv["general.architecture"] = "phi3" + kv["general.name"] = "phi3" + kv["phi3.context_length"] = p.MaxPositionEmbeddings + kv["phi3.embedding_length"] = cmp.Or(p.HiddenSize, p.NEmbd) + kv["phi3.feed_forward_length"] = p.IntermediateSize + kv["phi3.block_count"] = cmp.Or(p.NumHiddenLayers, p.NLayers) + kv["phi3.attention.head_count"] = cmp.Or(p.NumAttentionHeads, p.NHead) + kv["phi3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NHeadKV) + kv["phi3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS + kv["phi3.rope.dimension_count"] = p.HiddenSize / cmp.Or(p.NumAttentionHeads, p.NHead) + kv["phi3.rope.freq_base"] = p.RopeTheta + kv["phi3.rope.scaling.original_context_length"] = p.OriginalMaxPositionEmbeddings + kv["phi3.attention.sliding_window"] = p.SlidingWindow + + scale := float64(p.MaxPositionEmbeddings) / float64(p.OriginalMaxPositionEmbeddings) + + switch p.RopeScaling.Type { + case "": + // no scaling + case "su", "longrope": + kv["phi3.rope.scaling.attn_factor"] = float32(max(math.Sqrt(1+math.Log(scale)/math.Log(float64(p.OriginalMaxPositionEmbeddings))), 1.0)) + case "yarn": + kv["phi3.rope.scaling.attn_factor"] = float32(max(0.1*math.Log(scale)+1.0, 1.0)) + default: + panic("unknown rope scaling type") + } + + return kv +} + +func (p *phi3) Tensors(ts []Tensor) []llm.Tensor { + var addRopeFactors sync.Once + + out := make([]llm.Tensor, 0, len(ts)+2) + for _, t := range ts { + name := p.tensorName(t.Name()) + if strings.HasPrefix(name, "blk.0.") { + addRopeFactors.Do(func() { + out = append(out, llm.Tensor{ + Name: "rope_factors_long.weight", + Kind: 0, + Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))}, + WriterTo: p.RopeScaling.LongFactor, + }, llm.Tensor{ + Name: "rope_factors_short.weight", + Kind: 0, + Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))}, + WriterTo: p.RopeScaling.ShortFactor, + }) + }) + } + + out = append(out, llm.Tensor{ + Name: name, + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + + return out +} + +func (p *phi3) tensorName(n string) string { + return strings.NewReplacer( + "lm_head", "output", + "model.embed_tokens", "token_embd", + "model.norm", "output_norm", + "model.layers", "blk", + "input_layernorm", "attn_norm", + "self_attn.qkv_proj", "attn_qkv", + "self_attn.o_proj", "attn_output", + "mlp.down_proj", "ffn_down", + "mlp.gate_up_proj", "ffn_up", + "post_attention_layernorm", "ffn_norm", + ).Replace(n) +} + +type ropeFactor []float32 + +func (r ropeFactor) WriteTo(w io.Writer) (int64, error) { + err := binary.Write(w, binary.LittleEndian, r) + return 0, err +} diff --git a/convert/convert_test.go b/convert/convert_test.go index 88f38494..cb2c585e 100644 --- a/convert/convert_test.go +++ b/convert/convert_test.go @@ -65,6 +65,8 @@ func TestConvertFull(t *testing.T) { "Mistral-7B-Instruct-v0.2", "Mixtral-8x7B-Instruct-v0.1", "gemma-2b-it", + // microsoft/Phi-3-mini-128-instruct@d548c233192db00165d842bf8edff054bb3212f8 + "Phi-3-mini-128k-instruct", } for i := range cases { diff --git a/convert/testdata/Phi-3-mini-128k-instruct.json b/convert/testdata/Phi-3-mini-128k-instruct.json new file mode 100644 index 00000000..19296f5a --- /dev/null +++ b/convert/testdata/Phi-3-mini-128k-instruct.json @@ -0,0 +1,225 @@ +{ + "general.architecture": "phi3", + "general.file_type": "1", + "general.quantization_version": "2", + "phi3.block_count": "32", + "phi3.context_length": "131072", + "phi3.embedding_length": "3072", + "phi3.feed_forward_length": "8192", + "phi3.rope.scaling.original_context_length": "4096", + "phi3.rope.dimension_count": "96", + "phi3.rope.freq_base": "10000", + "phi3.rope.scaling.attn_factor": "1.1902381", + "phi3.attention.head_count": "32", + "phi3.attention.head_count_kv": "32", + "phi3.attention.layer_norm_rms_epsilon": "1e-05", + "phi3.attention.sliding_window": "262144", + "tokenizer.ggml.model": "llama", + "tokenizer.ggml.pre": "default", + "tokenizer.ggml.add_bos_token": "false", + "tokenizer.ggml.add_eos_token": "false", + "tokenizer.ggml.bos_token_id": "1", + "tokenizer.ggml.eos_token_id": "32000", + "tokenizer.ggml.unknown_token_id": "0", + "tokenizer.ggml.padding_token_id": "32000", + "tokenizer.ggml.scores": "6e37bcde2adc7e350e87c496eddd7a2124329c1dc66c5bf3ad3997253e4f7a62", + "tokenizer.ggml.token_type": "b6ecf55ec64ee67d87750bdb8d757a2c58bf78377e9f4219f5689a6c4dea57ce", + "tokenizer.ggml.tokens": "d168da3ddd3eee820916945fcb9baf24dd3cde42f606cffa2d19e7c8a8743918", + "blk.0.attn_norm.weight": "216aeb2c9e0c271f899e1ef2a63cceeb8f41e97642e84fada54b1d3c1c11cf25", + "blk.0.attn_output.weight": "b597d56f7188ffc1fafc273fadc59d41738cffd677ae98c61a62c3285b3a3099", + "blk.0.attn_qkv.weight": "d28a6b44e13f59be5483e4be2bedb544e346168d720aca27f47d1a5a722be91e", + "blk.0.ffn_down.weight": "4a691370e5a61fcbbf540fbcbf4c0f1d15dec0364528c0e916d0744f6262b63b", + "blk.0.ffn_norm.weight": "0c00af2b4a3128bec64a0cbb1084b042fdbe13d9ad0d03bd577f9449dfead338", + "blk.0.ffn_up.weight": "b32b52f790c1c083bfb8a3126dc1111cfeeb28dc8c584a930a1e5334cb176bf4", + "blk.1.attn_norm.weight": "68748011503c6c029e8e69a84a8e5a89338f378769627b6dbf7f93d715c292e1", + "blk.1.attn_output.weight": "2267344add13b048ca59e4377c86dc512be8046a57156901fa32a20fa74e4ee0", + "blk.1.attn_qkv.weight": "9109d2e3d7a2eacfda5226587b8be124a3bf44b972da7ebb17aa15795897eacc", + "blk.1.ffn_down.weight": "d675df4df4dd039c0c339ad6445d39eddd2004db6bf35bed6314c7497245a633", + "blk.1.ffn_norm.weight": "3b5767ae977bc8baaa06b06efdbea193b6b3ba605ce76d77a76ce317e935500c", + "blk.1.ffn_up.weight": "80dfd6d9d234b00334c89b8e0a02f81899c2efd377321c34ba5ba51a5f61b5ff", + "blk.2.attn_norm.weight": "6a6743b057e5088f145bc179e92c9bfb41163e7295d7b81c62e23dd89d2b59c4", + "blk.2.attn_output.weight": "bc5491ea54e0db81462d7d9b7d25cbdda380c2db8de041bd1c4ab7b76a1d19c3", + "blk.2.attn_qkv.weight": "a61287a9852e2f5aca9c100b471d98398b2913a3497c743de3c70ec9ddd7087f", + "blk.2.ffn_down.weight": "4fddcc382c8dceeab027fe43d8d44e67edb5e8ce4b9a1b7f773c87770380ade1", + "blk.2.ffn_norm.weight": "07e05f82b3f63f711db3b684ca79aed25c0657917e66f88af47348a82065c227", + "blk.2.ffn_up.weight": "4835a682ef1826c12df01ae7663fc45f9c82bc8e64b665f13fb7da8e201ec0fb", + "blk.3.attn_norm.weight": "f22aba7c03999ba7136f39cda747a39715e498699dc1716cd97fc5dfc58d1b1c", + "blk.3.attn_output.weight": "53b579855366fd786c5126b2b30aac4d583ca7bda56833c4865f5cadb5c18c6d", + "blk.3.attn_qkv.weight": "bb56aba78158123140fcea59c69ac562ca208f6d3086819417cdad8c50f333ad", + "blk.3.ffn_down.weight": "97280897a7cd86db2830c004bccc5bc094f50e293baded0189159a2019145a6e", + "blk.3.ffn_norm.weight": "10a8c99f8b57a960e8e0a1133c4a26f9148403d1b9bff2eff114917de996f3b5", + "blk.3.ffn_up.weight": "7324046c915e75d621b2043597a245a428d8eea31869135e6257a861491d8dcc", + "blk.4.attn_norm.weight": "507d8e164de94646edbfe33def8e8fbf7c9a6ee3fbaedb5000f72d9f51ec5e36", + "blk.4.attn_output.weight": "bbb3429e6efa98c150e0fdbf48c16180cbf0d0cbc1b3c253c6c319d78f4593a2", + "blk.4.attn_qkv.weight": "b95ee5be0786d3901273d806c339fe6c20e6bfffd2a20672a9f56af80921e8ab", + "blk.4.ffn_down.weight": "806bbf91df92a5a22bd5aa1ffb7fc2869f7293ffc7704771c290ecc583b27975", + "blk.4.ffn_norm.weight": "cfc2930a81df7aee3a5e7f726a15c1182233e868bf0d9d37f6b6ae6d8c15c234", + "blk.4.ffn_up.weight": "c3390c69533de2c8424e8069323ccc5d0c4543111535da04cf2c7d26745576aa", + "blk.5.attn_norm.weight": "0d71c4fbcefabbd021569442853d2fe90668b19409ae2805a718a829ca60beab", + "blk.5.attn_output.weight": "10ebd93629112bf2df5c30dd0953a4a5e9020306768283181ed426934d47e14f", + "blk.5.attn_qkv.weight": "5cb05633369f12d4b00e0ff787736bd846856682115720ebc6cce05270c334f6", + "blk.5.ffn_down.weight": "e28bcc5094212eafc7476dbc5b7a520d25b79578cbf4229d698e2655956a80ad", + "blk.5.ffn_norm.weight": "b6f2c4cf9f34bb4d59989f96165c14a67dc1e266ad0a6d0fcc49f1add929e6ff", + "blk.5.ffn_up.weight": "0f9ef99423cc07ebedc0e9cfa95809f2d7108d910bb4ef97ebc0b0309c440750", + "blk.6.attn_norm.weight": "b3edcc47a42218234f7564d7470611b49401a41ae8cd42123f86557c69f5d7f2", + "blk.6.attn_output.weight": "eb9b7d257b388bb5b8fe0515e5c6873317239cb94cda236e4b6ada2a6c57c65c", + "blk.6.attn_qkv.weight": "eb968081f478c52f07bd9c2761741e982dba33cc4eeadeea3557d391b9ac2106", + "blk.6.ffn_down.weight": "1b8588bb7463206290322695577dcfced300895d6e6f4b26966c53a9ae2f0f84", + "blk.6.ffn_norm.weight": "1219c04b7770983c77814200eefe743f46d15328ea2b12711e44f8103eab08d3", + "blk.6.ffn_up.weight": "197ef287239fec47c55677f0fbb66eaf0644f775bc382de843971730721394f6", + "blk.7.attn_norm.weight": "b630ad08c80d564ed1c024384818e9fd3f22a36cd7a14aa96e7e2759a8285099", + "blk.7.attn_output.weight": "970255aa750828a47d6b9d399f9612b5bf25aefe7dadbcba41fc416d0d4067c1", + "blk.7.attn_qkv.weight": "ebb157c880293e6de8d629f263ba8853ed1dbdc02c311d43432bb8cfbb310739", + "blk.7.ffn_down.weight": "24bcd4db4cba844c89f878b81843c373dbbc0675e889d32c5b12e63384a7b670", + "blk.7.ffn_norm.weight": "b9c6f71001808ee873ce7db8056e4b53fb4cccec8b7f0f312899b575fae39d39", + "blk.7.ffn_up.weight": "979f1828d227455c26015a2a11afe9dd05f2bb97a8ba6b38c8dab3f50e627401", + "blk.8.attn_norm.weight": "4e8e347e3775010b7112ee630f2f4f2383be7ff64e6ca6154b9b22566552eaa6", + "blk.8.attn_output.weight": "65a44babf44a435a1829945211b3168f9ec78ac3cb7a049a733e93d11f0d6659", + "blk.8.attn_qkv.weight": "343ed07671da400b040812a4058482fa38284b5d9af9becfed07417fe26ce747", + "blk.8.ffn_down.weight": "7fb7e073e3c2c503c4e9d60efa0988fed7398d900cc003695fe3fffd3e188b82", + "blk.8.ffn_norm.weight": "b07c1f655d8593e3892a2cf73f8a0c19ce8e5cb613fafbe7cbd430da8ce4c57d", + "blk.8.ffn_up.weight": "8b26e14de54b3fdc2e2d3ea41720f9d9c236a93688c3b7fd7bf43f5fbb327c9b", + "blk.9.attn_norm.weight": "46394d408a8e316916177e6aa261de32e137a82d729c0b1800b072f0c38c39b6", + "blk.9.attn_output.weight": "d57f3d46107947a7073373a0b35d6ecf7759b5df15406f4a3590a60666af6b16", + "blk.9.attn_qkv.weight": "14bb8ace8c5453148f4b536e9f4279c813f31136716947256f5cca333448639c", + "blk.9.ffn_down.weight": "2b8d98e2b5ed68338f6e4de43bf7de0c4858cc69103cd5177725f7444eec7694", + "blk.9.ffn_norm.weight": "41a499dfd418cc4c6b8c12313f673f7e2cd4a3f9c4065eb6c4feb5eed02fb542", + "blk.9.ffn_up.weight": "143aab7533a64b17fbe201490a6f674bc7f0bd370c094500b2e100419073d1c2", + "blk.10.attn_norm.weight": "ebb670aafd36816a794347287269d8f1a5b19c1e3c0a1e38023bc19fdba9b073", + "blk.10.attn_output.weight": "b5d65bbc0ed5e49fdd9d754bc18163cd042a285024d0cf6f954c503bc8c877cb", + "blk.10.attn_qkv.weight": "f06b15bac88da798fa34a62b03eaac0dbe8b846020516603c387541f2d8dd672", + "blk.10.ffn_down.weight": "fb091fcd1b4de25d1bea94d1755e255cb02914a030d23e3a234e57b8d46bde6e", + "blk.10.ffn_norm.weight": "eb347bdf9c40414af87e13a8e72e40b31f004b50f7cb366f1a219ced60a61355", + "blk.10.ffn_up.weight": "ed2d52fc881a173f404fe8a1067862c9856d6c3e0d2e90a330a7aa394e3f84d1", + "blk.11.attn_norm.weight": "64e252603cf010a0e502ca39fdf8d0a196a79aec67c0d2bb9213fc0cb80c47d4", + "blk.11.attn_output.weight": "228e33e21c69f52efc74fdfc831bc9af271e44b2a29a3dced1d64e667ce36eb5", + "blk.11.attn_qkv.weight": "ab9ce6d4ef9e42ee0da3f20a7708a3bbc5e79e967b05fa86ba946a05e2eb63eb", + "blk.11.ffn_down.weight": "0ca133b7835c98dc77c25d64e4eb7873778bdb5e4d22d8b80f920f46865b43bd", + "blk.11.ffn_norm.weight": "02455741a0dfd161c79aa1ecc381901721f229fdcda5615622a629631fb61cfd", + "blk.11.ffn_up.weight": "9fecdcc099fbb8e23c6b1ea9294702a027f4a58d265543ec5e7be79b8f63b354", + "blk.12.attn_norm.weight": "783bb459911b1b3609a9b2bdfe272f1670add73b5471da738e07ac47e2e07dfd", + "blk.12.attn_output.weight": "1e1a914c9e48b857206ac5a1f7cead994bc1ea91d5d4fff8c834d73f2e38ef5d", + "blk.12.attn_qkv.weight": "5953e7185ccb87fb4dae8f9426ec86315d4c7794326e8ab59b3a95d4af2189f0", + "blk.12.ffn_down.weight": "a3eecf0f394f86e2cfb48a5940a5c50ca86d71883b2f79fcc642a935fabce0d4", + "blk.12.ffn_norm.weight": "0a4272e41373c23bd72f10d2d82930aa3a1480aac75832bfbf01cebf0b86b6a4", + "blk.12.ffn_up.weight": "06f42776de3a7ceac3025f26a7a8bd20e062233cce2bdaa2183470dc4b30b87d", + "blk.13.attn_norm.weight": "5915da60fb03e201fa649faba780e5fdf1c761c262b206e5415cf83181f65780", + "blk.13.attn_output.weight": "4dbf6eab074fa3835fd32bd631a8208e511037d5056d2fd3015735cca7674ef7", + "blk.13.attn_qkv.weight": "d3d8339a1c4782d9e73d77fdebe154d3c5b83ac40c9175b3e91a4977d08f876b", + "blk.13.ffn_down.weight": "de6772b46a55e1fd42b007637dfbf68b6598e5d5b61622da0935002e1e192d3a", + "blk.13.ffn_norm.weight": "5a640ea3b8c7be49c95a58a2327e10d8e8d9d142504bde5c8091613e5b961d7a", + "blk.13.ffn_up.weight": "f35e3545e4bd3531b2e843b5efd31dee0c13c807ee6386e65473ba67bbec30d0", + "blk.14.attn_norm.weight": "9b34986450b7c98b4927e81e61a816f9e84b1addc7c14926402100037aad6678", + "blk.14.attn_output.weight": "155d52efb23d366016d861a251d4d1f4a0c13699188c50d50dba016a0d8bfcd9", + "blk.14.attn_qkv.weight": "8e1415084e1f33c73a777f19e752489f4dd312cca047733e5ea643cd4a955e04", + "blk.14.ffn_down.weight": "a2a142226b94baa01ccb65bdea2b7418e49085c1d9c3c63e544e3112c58a25da", + "blk.14.ffn_norm.weight": "8aecfd9b0ae6affaea31a80c5c9a4a14b31deaa0db7bd8f6da2a64d23447921c", + "blk.14.ffn_up.weight": "0c1407237b8c1bd02f193346b5681926fe698a5055eac6a7450451b0f991707c", + "blk.15.attn_norm.weight": "e037bd19880bfa83d983200fb0c7866f8ad16c3ff5cc4b4f3a37ca7373870ff6", + "blk.15.attn_output.weight": "045fe4fc95cc129a1b92771b179c11b12845c4c088786c607f17bd98857e68e1", + "blk.15.attn_qkv.weight": "7621b7559705cab1d4dea1c69f76dbf9dc1c8837a203b656f484703b9c1b70ce", + "blk.15.ffn_down.weight": "7e5ac20e290bc60761e1cd972354fde225b7fa861048d44d9a0dd9b046d55f58", + "blk.15.ffn_norm.weight": "b6d830d88f1db1825687973c8c2b1a24c6fa84f07af8d0e3ef9c86009baca0b2", + "blk.15.ffn_up.weight": "dcda0957cd04fc45476774dba2bbf9aa89d6b05d5ca7b10ae6f73ad2c49b1cd3", + "blk.16.attn_norm.weight": "4ee9b70ba15cb2a08240f93990e90f5068c48fceb481f8e2186bec8b7214eb3f", + "blk.16.attn_output.weight": "315cfe5536658d2498192b2980eade15b2c9a4ff220e4011911457b1727fa103", + "blk.16.attn_qkv.weight": "3c8122e3ad637583b9dcde8ff3a323267d3014bb1f0f9771e5322260ca9ecc8d", + "blk.16.ffn_down.weight": "3b5fbebd5ee2b86cad96fb8a9b45a8770d08f82c1c8b74d7061e866f7020a18d", + "blk.16.ffn_norm.weight": "ffab69f20bda372de6e5878f0539163e2fc6ba113621ded95705fc3b1465c9f0", + "blk.16.ffn_up.weight": "0935ea3d258da42d6258406365f39f58ddaabfe97ea5977580db3635188f24a1", + "blk.17.attn_norm.weight": "f030441733f3d147b4a06a1eb4aeb8465c7c24d9c53bf4c48fe7e134d3629803", + "blk.17.attn_output.weight": "07a955ef09e8dc766ac0df647d0b2c69f23c4c69a7137654b4aad80303ed0eda", + "blk.17.attn_qkv.weight": "1c10688061e21e2fe12ad0cb54bf03895c1f83c3b0df743a42f548b52cbca1b2", + "blk.17.ffn_down.weight": "ebb9cc9836f41d88fdae2aa9a4355514e4edaec8d1577ffeb947a35204e77f52", + "blk.17.ffn_norm.weight": "50aff44f6528b13db5389f2ddcdb7676244947610bd7ffbff3f881c968c2a0d4", + "blk.17.ffn_up.weight": "d716537949582be33bde6b02e38f5a70081c9642a9fb05a61312126718b8d148", + "blk.18.attn_norm.weight": "0ea695c4e53d637902f46663a6ee42adc493c36794476acc7dbddaa05b13840d", + "blk.18.attn_output.weight": "5fd35b500221a612eb4f4bddf0e9b6b7db4d7733032a75f8802fb2d884647c2e", + "blk.18.attn_qkv.weight": "b0da37fd030fe69581f990bf23bfd35467a1bbe558af6de7c0924f6b72e92317", + "blk.18.ffn_down.weight": "b355c33f44b328f4bb977567de8f7544db4b005d7a8fbded658518ecf3c5a153", + "blk.18.ffn_norm.weight": "58b3fe9094079989a86e0387143259e1cc35952d24dc3df290c4ba6df44f5c51", + "blk.18.ffn_up.weight": "2ce530954c342c30ed2ead5353f931960bfae1d278868504c0efb973560fabbe", + "blk.19.attn_norm.weight": "533e9aed66feea8f0392aa81f9e293240e1f009a5334253915fb60c2749b615d", + "blk.19.attn_output.weight": "84f2d00f98a4113a779d3b5d1c3e7c914eb47784d3ab13b290367c124c2994aa", + "blk.19.attn_qkv.weight": "fbe6b9f53b07fa7537d3b3d452d20a9bc666f9fd41ec2091dd28bc2f70fc668f", + "blk.19.ffn_down.weight": "b30199e098c8bb3f890183d8b18471e80b62b604729b277ad62488dd71e1206b", + "blk.19.ffn_norm.weight": "c81373e41cd340b7badb19f9517c77c4250b4eb9a02dc758b8b49b652487d7ff", + "blk.19.ffn_up.weight": "5a5cb083ca7725720e3a890f7fa46354760e8007a8188849a092e305694a75e3", + "blk.20.attn_norm.weight": "4953091b4477e354357a8e743ba0a1900633e52f1599ee082a0c9b0b2b5cd978", + "blk.20.attn_output.weight": "62d54f7749cd6856097b2632066a322b0296df915fe66f382c5b5981be0d4f23", + "blk.20.attn_qkv.weight": "406de9e35b0729ebe902d7a47905cc7fb29a921431ed35dbef0c03e5690a1329", + "blk.20.ffn_down.weight": "62fb678b0d1261e19a4903a2b347d67afcc8acff01feb33a687a35a2d1e6f9a5", + "blk.20.ffn_norm.weight": "cd9d36b7e71e55c8925b97bb09c28219f182626bcff094878ae39c3db887a14b", + "blk.20.ffn_up.weight": "b9276771d79d3e932e73ccc520c3f8476342b9ef312ed2ee1e0da822e6e3ad18", + "blk.21.attn_norm.weight": "66d8c8a35e13ce9c2a0e75b670150e2c31484a55c2316df46075312196178ed3", + "blk.21.attn_output.weight": "12ab46c9382648f9b3350fdd92a6be6352743d62d6b520d7e2024e0c838588f5", + "blk.21.attn_qkv.weight": "a7909676ee1675ca23cd29a5fdd226df8dd9d68f94c6c9bbb51dd9fd38504008", + "blk.21.ffn_down.weight": "6fb317279c6542e82f97d5a12a60fac1bd0fa0405154f9fbe265e2fe39bd49cc", + "blk.21.ffn_norm.weight": "c0f703eb3ff161b5ba4490d87d8684b8a6c47a8f433e12f418333b9db439010a", + "blk.21.ffn_up.weight": "6dbdb80ef0c35e364bbce12d40d5e74c7963c7b55d58d9579567a07ffce7b863", + "blk.22.attn_norm.weight": "f94237433bf03d675cb2f655b81ca91a1ce2447bc6b00b13d6b0ccfe2d411eff", + "blk.22.attn_output.weight": "e821f95995ce497c01e63ca64f737713b1b65f11df1903e51d444aa516f33f71", + "blk.22.attn_qkv.weight": "1b0f717c73afb5eb4c82a1708c4e85c969e8a2a8770d9ddb78b1870a2d8a781e", + "blk.22.ffn_down.weight": "0f33f7a3cdc685484be99aa0c03642b0b20850a27d1fddbe054b13a9382f3ccb", + "blk.22.ffn_norm.weight": "9df285cf211ddd7df2b36a50489af574755c7d4d98b29a05cd04566ae613c8dc", + "blk.22.ffn_up.weight": "63ac300e1efb34041dd0136cf43ea622fac6f0caccce1cd9262f5e08d2cf179c", + "blk.23.attn_norm.weight": "5f72d9e88689b4027b28f5f8f26cd3abb03635ceea7ec98a4c91a9fc691f6707", + "blk.23.attn_output.weight": "6ecf04ff61125c5fc768f8656497152149373daf321ee9c957e8f7245a1184d1", + "blk.23.attn_qkv.weight": "a9d9978806724c2959f2cf386c233831f08e1e933dbf2b32665e788d9d512ea4", + "blk.23.ffn_down.weight": "72c7d17886a3da17fa0daa456aa5e877b2ef5b8b403182b870d9ca5ca9c70347", + "blk.23.ffn_norm.weight": "971e4b712e3025a13419b5b57d674b5e4ab7f18f74b57b9afc4671623da90c4b", + "blk.23.ffn_up.weight": "df2b5c7dbd5834545b815073af0c7355b065124e6d6f0fee78d8fa5b2076dc3e", + "blk.24.attn_norm.weight": "c41957c4a79ad3b16f6e11daec1c7f530b9f3f4b618e1e4367c3b67787ac4ab6", + "blk.24.attn_output.weight": "ef7d61f5fc88ac6f31bf60cb5f4d2d6b8df42d38825807112361a7224b0dee3b", + "blk.24.attn_qkv.weight": "3e6a58fe7d49c90bb6971efbad3371c32256881173ea5aee4b0c296cb206490f", + "blk.24.ffn_down.weight": "f43619144047de42fed81dfa495f1815d3cb771330e574043e2b67620819292c", + "blk.24.ffn_norm.weight": "5501d4a2a98c8ca6b42e77b53b221dbc08f530f6a067256d787534ec6fe028bd", + "blk.24.ffn_up.weight": "d64c8b0e509e2b1118f6000176f8956cacecdbb200c7e95ed93fb78b6e26c84a", + "blk.25.attn_norm.weight": "502fa3c302d371f61c5791f4615b73018ffb1daa09b6499b227116581244c5d4", + "blk.25.attn_output.weight": "ad8391d4e9c980856f2547aa945b2b6a407a6382158dc1ddd4f08d94ecc24be6", + "blk.25.attn_qkv.weight": "42e8983780d4a01a02c54ad23d4df21eea437f119a10af5a9c12a76a42d308c1", + "blk.25.ffn_down.weight": "302dd010d4e0ab4eeaee89090409ea0dddeeeed3236415eb8f97c942497eea91", + "blk.25.ffn_norm.weight": "fb34c1ee5bca96986c08834df0a0c047ba041c1123ac1f563e9d64312bf82d6a", + "blk.25.ffn_up.weight": "10739a8de156816d93c92b935386540bfa976bdbef204f0312960f6fc657582f", + "blk.26.attn_norm.weight": "7036c711609128c4e55968ff3681d3043338879a5737efd6c2ac9e1a2a61f1a0", + "blk.26.attn_output.weight": "db5db45dead5cb911fa01da59832f121b7c18b2d167bf53741c40819f24d346c", + "blk.26.attn_qkv.weight": "cae34c6b7f82ed14348d5ed30a79919c383737c1694a9cb9c0de609d3b0c1d0a", + "blk.26.ffn_down.weight": "491ec3a4da9b4f49f8ebc6be658ce397a9b801ae9fb35e82177e47808c65e5d0", + "blk.26.ffn_norm.weight": "fd7059d75d7f0e5288511ddeeb0f772eb3cae3ccfe4226b877015834edc3c386", + "blk.26.ffn_up.weight": "ea1ee1274c56458ce056d2205e5bb6e5422ce4cb0ad58006b8141749b97a0c39", + "blk.27.attn_norm.weight": "cc362c9a937609265052cd38544af17a1a7448cea086d4c801139e1fc865832d", + "blk.27.attn_output.weight": "ba757a81dabde9cb1b069d1bb616fe79649a1724f756567ec61caed1304fe6cf", + "blk.27.attn_qkv.weight": "1ab8d7d02d87756c12c2275636823aa5ede3d683178225c4cac4bd892c319bd4", + "blk.27.ffn_down.weight": "deb1c711c8a66acf4dcd2d088e1548f8e08f296f755e4067d6557fa55afde88c", + "blk.27.ffn_norm.weight": "fc6242d8cb8a4a37a8ddb7e41e7e60a63d4a89edf36acb35df052f10b9c91ece", + "blk.27.ffn_up.weight": "8df39b09c4801f343aca78f2918a1f6db78c8c55e591eda4c69eadb74c26e180", + "blk.28.attn_norm.weight": "75b539308f77e3cefdc6d98484d8b5cbf0538f0c2869a77b7373a145a18bc850", + "blk.28.attn_output.weight": "ae128940eb60a6d2e121762ef4b3e9dcf9eb3e105b249507fa7f12de0e19822c", + "blk.28.attn_qkv.weight": "bdda781c288e9326c240e33905f8e621b6a2ad902e620739d34f93fcd6f933de", + "blk.28.ffn_down.weight": "f1d6e6d1c286b1138bfd7e53fe477f399ae93bc2c04e35416f84218ed7247965", + "blk.28.ffn_norm.weight": "3f837ce82c8b9bde0d61d08b6f5fe5574886ea5328dbdc53f2929f18da8b4087", + "blk.28.ffn_up.weight": "2af027002e31d1b6cfedbdb30a2b9d7213f3aa691167c353913adfd48fda31e4", + "blk.29.attn_norm.weight": "61e8003b5329462ffe0fe172f2b160260de006aed858332d49d75504b6b6aa7a", + "blk.29.attn_output.weight": "ca44542a72a37476dc73dbdcc01f5b7497cb3ebc4ea230a55c9634ccd8e56ad4", + "blk.29.attn_qkv.weight": "abb3d9d6abe57872ae3daa51935d43264093ded5ce63b49d1e280ee5758be0e4", + "blk.29.ffn_down.weight": "6764b895fce881df097489c263446f0106de36217997660c15984b3ee22a5a06", + "blk.29.ffn_norm.weight": "89e03e9a33fc0e6e31ba9f0c2bd7c5734a118c5602bb90148793e08a80e8d0ae", + "blk.29.ffn_up.weight": "fa7ad57a84954f4121653152efed1a871d8adb20a1ea9086e3e849ce359d7d2e", + "blk.30.attn_norm.weight": "91a697aca1e42af54f806a20211031c3369e8d0bd58df1b0147fe24954e1f5a4", + "blk.30.attn_output.weight": "36063fcf766c89ac75be56f688cc63cefe5f2c733fbf4378ea9956ad386fa148", + "blk.30.attn_qkv.weight": "2cacd1161f1121a2c0b979930134f4666f73fb8d7237b3b0659ae091b15955a6", + "blk.30.ffn_down.weight": "9f3fcb6217100595850c05dc98f9ab2a263afdb6ab28df2fcb08aeff512057d7", + "blk.30.ffn_norm.weight": "6c600bc1fc7de39d4f8917b81fc7d1d5ed2a9b56492234c13a4bd6028c30d880", + "blk.30.ffn_up.weight": "73cabd1bb011956b2689ea3338bb76642ef3a57c197377d666d2ab5f56317668", + "blk.31.attn_norm.weight": "72d3e1cc771380645fa75a899858c95f39857a4f3f1ed60fe1578df383b8bc53", + "blk.31.attn_output.weight": "40089cdd29994dc19a1d89fa15902a89cfeca3540f12dc9bf4d00ef82506e456", + "blk.31.attn_qkv.weight": "1d0bb40e9258071ae14290a53c619a8e331dda07354d2a02ef45766c029ae5e4", + "blk.31.ffn_down.weight": "8defa0e06335b793fa8be03883f0a322d6c5b33f52c69c943c35c60d16e42c0a", + "blk.31.ffn_norm.weight": "33c55d9d0c496ccfb130361fe131649346e098abaaac39c0519507e5d846721d", + "blk.31.ffn_up.weight": "599f6503f61c692c1f82001973d35119f9688db5e6be9d9c298411491c93f09b", + "output.weight": "14b8dc662bfa3308ebb2e102c562d8e52c15670e538f20f3216a9c310ca9dd41", + "output_norm.weight": "7f2294ba94ce65681df6c7ddd8698799199b9d77dc83c10bdad5c3999f0fdb82", + "rope_factors_long.weight": "e34d378664e354652c38f47d10dafb0498ccc2fb042d39ff7fef768146fff22b", + "rope_factors_short.weight": "9379146a4988f373d362fe47b06c75e7fe7c54aa4dc9558758df79b7a87471fd", + "token_embd.weight": "19a03c1fb5ac0baee93b0a7d8b0f26e9a9b011e229b694afc50ebfc13d84f8bf" +} diff --git a/docs/import.md b/docs/import.md index f34f09ac..82ea9ba5 100644 --- a/docs/import.md +++ b/docs/import.md @@ -16,7 +16,9 @@ If the model being imported is one of these architectures, it can be imported di - LlamaForCausalLM - MistralForCausalLM + - MixtralForCausalLM - GemmaForCausalLM + - Phi3ForCausalLM ```dockerfile FROM /path/to/safetensors/directory diff --git a/gpu/assets.go b/gpu/assets.go index a35b6630..6d62d0dc 100644 --- a/gpu/assets.go +++ b/gpu/assets.go @@ -49,13 +49,9 @@ func PayloadsDir() (string, error) { } // Track our pid so we can clean up orphaned tmpdirs - pidFilePath := filepath.Join(tmpDir, "ollama.pid") - pidFile, err := os.OpenFile(pidFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.ModePerm) - if err != nil { - return "", err - } - if _, err := pidFile.Write([]byte(strconv.Itoa(os.Getpid()))); err != nil { - return "", err + n := filepath.Join(tmpDir, "ollama.pid") + if err := os.WriteFile(n, []byte(strconv.Itoa(os.Getpid())), 0o644); err != nil { + return "", fmt.Errorf("failed to write pid file %s: %w", n, err) } // We create a distinct subdirectory for payloads within the tmpdir @@ -67,37 +63,44 @@ func PayloadsDir() (string, error) { // Best effort to clean up prior tmpdirs func cleanupTmpDirs() { - dirs, err := filepath.Glob(filepath.Join(os.TempDir(), "ollama*")) + matches, err := filepath.Glob(filepath.Join(os.TempDir(), "ollama*", "ollama.pid")) if err != nil { return } - for _, d := range dirs { - info, err := os.Stat(d) - if err != nil || !info.IsDir() { + + for _, match := range matches { + raw, err := os.ReadFile(match) + if errors.Is(err, os.ErrNotExist) { + slog.Debug("not a ollama runtime directory, skipping", "path", match) continue - } - raw, err := os.ReadFile(filepath.Join(d, "ollama.pid")) - if err != nil { - slog.Warn("failed to read ollama.pid", "path", d, "error", err) - // No pid, ignore this tmpdir + } else if err != nil { + slog.Warn("could not read ollama.pid, skipping", "path", match, "error", err) continue } pid, err := strconv.Atoi(string(raw)) if err != nil { - slog.Warn("failed to parse pid", "path", d, "error", err) + slog.Warn("invalid pid, skipping", "path", match, "error", err) continue } - proc, err := os.FindProcess(pid) - if err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) { - slog.Warn("found running ollama", "pid", pid, "path", d) - // Another running ollama, ignore this tmpdir + p, err := os.FindProcess(pid) + if err == nil && !errors.Is(p.Signal(syscall.Signal(0)), os.ErrProcessDone) { + slog.Warn("process still running, skipping", "pid", pid, "path", match) continue } - if err := os.Remove(d); err != nil { - slog.Warn("unable to cleanup stale tmpdir", "path", d, "error", err) + if err := os.Remove(match); err != nil { + slog.Warn("could not cleanup stale pidfile", "path", match, "error", err) + } + + runners := filepath.Join(filepath.Dir(match), "runners") + if err := os.RemoveAll(runners); err != nil { + slog.Warn("could not cleanup stale runners", "path", runners, "error", err) + } + + if err := os.Remove(filepath.Dir(match)); err != nil { + slog.Warn("could not cleanup stale tmpdir", "path", filepath.Dir(match), "error", err) } } } diff --git a/gpu/gpu.go b/gpu/gpu.go index 7ae8fbec..dc124a3e 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -305,38 +305,41 @@ func GetGPUInfo() GpuInfoList { // Intel if envconfig.IntelGPU() { oHandles = initOneAPIHandles() - // On windows we bundle the oneapi library one level above the runner dir - depPath = "" - if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" { - depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "oneapi") - } + if oHandles != nil && oHandles.oneapi != nil { - for d := range oHandles.oneapi.num_drivers { - if oHandles.oneapi == nil { - // shouldn't happen - slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers)) - continue + // On windows we bundle the oneapi library one level above the runner dir + depPath = "" + if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" { + depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "oneapi") } - devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d)) - for i := range devCount { - gpuInfo := OneapiGPUInfo{ - GpuInfo: GpuInfo{ - Library: "oneapi", - }, - driverIndex: int(d), - gpuIndex: int(i), + + for d := range oHandles.oneapi.num_drivers { + if oHandles.oneapi == nil { + // shouldn't happen + slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers)) + continue + } + devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d)) + for i := range devCount { + gpuInfo := OneapiGPUInfo{ + GpuInfo: GpuInfo{ + Library: "oneapi", + }, + driverIndex: int(d), + gpuIndex: int(i), + } + // TODO - split bootstrapping from updating free memory + C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo) + // TODO - convert this to MinimumMemory based on testing... + var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. + memInfo.free = C.uint64_t(totalFreeMem) + gpuInfo.TotalMemory = uint64(memInfo.total) + gpuInfo.FreeMemory = uint64(memInfo.free) + gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) + gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) + gpuInfo.DependencyPath = depPath + oneapiGPUs = append(oneapiGPUs, gpuInfo) } - // TODO - split bootstrapping from updating free memory - C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo) - // TODO - convert this to MinimumMemory based on testing... - var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. - memInfo.free = C.uint64_t(totalFreeMem) - gpuInfo.TotalMemory = uint64(memInfo.total) - gpuInfo.FreeMemory = uint64(memInfo.free) - gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) - gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - gpuInfo.DependencyPath = depPath - oneapiGPUs = append(oneapiGPUs, gpuInfo) } } } diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index c65901c7..5717c17a 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -1223,9 +1223,7 @@ struct llama_server_context res.result_json = json { - {"id", res.id}, {"embedding", std::vector(embd, embd + n_embd)}, - {"timings", slot.get_formated_timings()}, }; } } @@ -3194,41 +3192,17 @@ int main(int argc, char **argv) { prompt = ""; } - if (prompt.size() == 1) { - prompt = prompt[0]; - } - // create and queue the task - json responses; - { - const int id_task = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(id_task); - llama.request_completion(id_task, {{"prompt", prompt}}, true, -1); + const int task_id = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(task_id); + llama.request_completion(task_id, {{"prompt", prompt}}, true, -1); - // get the result - task_result result = llama.queue_results.recv(id_task); - llama.queue_results.remove_waiting_task_id(id_task); - if (result.error) { - return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); - } + // get the result + task_result result = llama.queue_results.recv(task_id); + llama.queue_results.remove_waiting_task_id(task_id); - responses = result.result_json.value("results", std::vector{result.result_json}); - std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) { - return a["id"] < b["id"]; - }); - - json embeddings = json::array(); - - int prompt_n = 0; - for (auto & elem : responses) { - embeddings.push_back(elem.at("embedding")); - prompt_n += elem.at("timings").at("prompt_n").get(); - } - - // send the result - json embedding_res = json{{"embedding", embeddings}, {"prompt_n", prompt_n}}; - return res.set_content(embedding_res.dump(), "application/json; charset=utf-8"); - } + // send the result + return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); }); // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? diff --git a/llm/ggml.go b/llm/ggml.go index d7f2eef7..4c68adf9 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -157,6 +157,14 @@ type Tensor struct { io.WriterTo `json:"-"` } +func (t Tensor) block() (n int) { + if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil { + return -1 + } + + return +} + func (t Tensor) blockSize() uint64 { switch t.Kind { case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16 diff --git a/llm/gguf.go b/llm/gguf.go index 98158313..2e6bc542 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -532,15 +532,14 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error { } } - slices.SortFunc(ts, func(a, b Tensor) int { - var i, j int - if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 { - return cmp.Compare(a.Name, b.Name) - } else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 { - return cmp.Compare(a.Name, b.Name) + slices.SortStableFunc(ts, func(a, b Tensor) int { + if i, j := a.block(), b.block(); i < 0 && j > 0 { + return 1 + } else if i > 0 && j < 0 { + return -1 + } else { + return cmp.Compare(i, j) } - - return cmp.Compare(i, j) }) var s uint64 diff --git a/llm/server.go b/llm/server.go index 41736068..d2b8db9b 100644 --- a/llm/server.go +++ b/llm/server.go @@ -33,7 +33,7 @@ type LlamaServer interface { Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error - Embed(ctx context.Context, input []string) (*EmbedResponse, error) + Embedding(ctx context.Context, input string) ([]float32, error) Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -125,8 +125,9 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr } } - // On linux, over-allocating CPU memory will almost always result in an error - if runtime.GOOS == "linux" { + // On linux and windows, over-allocating CPU memory will almost always result in an error + // Darwin has fully dynamic swap so has no direct concept of free swap space + if runtime.GOOS != "darwin" { systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize available := systemFreeMemory + systemSwapFreeMemory if systemMemoryRequired > available { @@ -882,24 +883,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return nil } -type EmbedRequest struct { - Content []string `json:"content"` +type EmbeddingRequest struct { + Content string `json:"content"` } -type EmbedResponse struct { - Embedding [][]float32 `json:"embedding"` - PromptEvalCount int `json:"prompt_n"` +type EmbeddingResponse struct { + Embedding []float32 `json:"embedding"` } -func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) { - // each input will use a slot, so we need to acquire the semaphore for - // the number of inputs up to numParallel - slots := int64(min(len(input), s.numParallel)) - if err := s.sem.Acquire(ctx, slots); err != nil { +func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) { + if err := s.sem.Acquire(ctx, 1); err != nil { slog.Error("Failed to acquire semaphore", "error", err) return nil, err } - defer s.sem.Release(slots) + defer s.sem.Release(1) // Make sure the server is ready status, err := s.getServerStatusRetry(ctx) @@ -909,18 +906,18 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } - data, err := json.Marshal(EmbedRequest{Content: input}) + data, err := json.Marshal(EmbeddingRequest{Content: input}) if err != nil { return nil, fmt.Errorf("error marshaling embed data: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) + r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) if err != nil { return nil, fmt.Errorf("error creating embed request: %w", err) } - req.Header.Set("Content-Type", "application/json") + r.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) + resp, err := http.DefaultClient.Do(r) if err != nil { return nil, fmt.Errorf("do embedding request: %w", err) } @@ -936,12 +933,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, return nil, fmt.Errorf("%s", body) } - var e EmbedResponse + var e EmbeddingResponse if err := json.Unmarshal(body, &e); err != nil { return nil, fmt.Errorf("unmarshal tokenize response: %w", err) } - return &e, nil + return e.Embedding, nil } type TokenizeRequest struct { diff --git a/llm/status.go b/llm/status.go index d9f36115..604fe9e0 100644 --- a/llm/status.go +++ b/llm/status.go @@ -26,6 +26,7 @@ var errorPrefixes = []string{ "cudaMalloc failed", "\"ERR\"", "error loading model", + "GGML_ASSERT", } func (w *StatusWriter) Write(b []byte) (int, error) { diff --git a/openai/openai_test.go b/openai/openai_test.go index e08a96c9..c7e9f384 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -7,27 +7,22 @@ import ( "io" "net/http" "net/http/httptest" + "reflect" "strings" "testing" "time" "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" "github.com/ollama/ollama/api" ) const ( - prefix = `data:image/jpeg;base64,` - image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` - imageURL = prefix + image + prefix = `data:image/jpeg;base64,` + image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` ) -func prepareRequest(req *http.Request, body any) { - bodyBytes, _ := json.Marshal(body) - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") -} +var False = false func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { return func(c *gin.Context) { @@ -43,134 +38,136 @@ func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { func TestChatMiddleware(t *testing.T) { type testCase struct { - Name string - Setup func(t *testing.T, req *http.Request) - Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) + name string + body string + req api.ChatRequest + err ErrorResponse } var capturedRequest *api.ChatRequest testCases := []testCase{ { - Name: "chat handler", - Setup: func(t *testing.T, req *http.Request) { - body := ChatCompletionRequest{ - Model: "test-model", - Messages: []Message{{Role: "user", Content: "Hello"}}, - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { - if resp.Code != http.StatusOK { - t.Fatalf("expected 200, got %d", resp.Code) - } - - if req.Messages[0].Role != "user" { - t.Fatalf("expected 'user', got %s", req.Messages[0].Role) - } - - if req.Messages[0].Content != "Hello" { - t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content) - } + name: "chat handler", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, }, }, { - Name: "chat handler with image content", - Setup: func(t *testing.T, req *http.Request) { - body := ChatCompletionRequest{ - Model: "test-model", - Messages: []Message{ - { - Role: "user", Content: []map[string]any{ - {"type": "text", "text": "Hello"}, - {"type": "image_url", "image_url": map[string]string{"url": imageURL}}, + name: "chat handler with image content", + body: `{ + "model": "test-model", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Hello" + }, + { + "type": "image_url", + "image_url": { + "url": "` + prefix + image + `" + } + } + ] + } + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + { + Role: "user", + Images: []api.ImageData{ + func() []byte { + img, _ := base64.StdEncoding.DecodeString(image) + return img + }(), + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with tools", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris Today?"}, + {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris Today?", + }, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]interface{}{ + "location": "Paris, France", + "format": "celsius", + }, + }, }, }, }, - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { - if resp.Code != http.StatusOK { - t.Fatalf("expected 200, got %d", resp.Code) - } - - if req.Messages[0].Role != "user" { - t.Fatalf("expected 'user', got %s", req.Messages[0].Role) - } - - if req.Messages[0].Content != "Hello" { - t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content) - } - - img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):]) - - if req.Messages[1].Role != "user" { - t.Fatalf("expected 'user', got %s", req.Messages[1].Role) - } - - if !bytes.Equal(req.Messages[1].Images[0], img) { - t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0]) - } + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, }, }, + { - Name: "chat handler with tools", - Setup: func(t *testing.T, req *http.Request) { - body := ChatCompletionRequest{ - Model: "test-model", - Messages: []Message{ - {Role: "user", Content: "What's the weather like in Paris Today?"}, - {Role: "assistant", ToolCalls: []ToolCall{{ - ID: "id", - Type: "function", - Function: struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - }{ - Name: "get_current_weather", - Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}", - }, - }}}, - }, - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { - if resp.Code != 200 { - t.Fatalf("expected 200, got %d", resp.Code) - } - - if req.Messages[0].Content != "What's the weather like in Paris Today?" { - t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content) - } - - if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" { - t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"]) - } - - if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" { - t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"]) - } - }, - }, - { - Name: "chat handler error forwarding", - Setup: func(t *testing.T, req *http.Request) { - body := ChatCompletionRequest{ - Model: "test-model", - Messages: []Message{{Role: "user", Content: 2}}, - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { - if resp.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.Code) - } - - if !strings.Contains(resp.Body.String(), "invalid message content type") { - t.Fatalf("error was not forwarded") - } + name: "chat handler error forwarding", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": 2} + ] + }`, + err: ErrorResponse{ + Error: Error{ + Message: "invalid message content type: float64", + Type: "invalid_request_error", + }, }, }, } @@ -185,16 +182,26 @@ func TestChatMiddleware(t *testing.T) { router.Handle(http.MethodPost, "/api/chat", endpoint) for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil) - - tc.Setup(t, req) + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") resp := httptest.NewRecorder() router.ServeHTTP(resp, req) - tc.Expected(t, capturedRequest, resp) + var errResp ErrorResponse + if resp.Code != http.StatusOK { + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatal(err) + } + } + if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { + t.Fatal("requests did not match") + } + if !reflect.DeepEqual(tc.err, errResp) { + t.Fatal("errors did not match") + } capturedRequest = nil }) } @@ -202,71 +209,52 @@ func TestChatMiddleware(t *testing.T) { func TestCompletionsMiddleware(t *testing.T) { type testCase struct { - Name string - Setup func(t *testing.T, req *http.Request) - Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) + name string + body string + req api.GenerateRequest + err ErrorResponse } var capturedRequest *api.GenerateRequest testCases := []testCase{ { - Name: "completions handler", - Setup: func(t *testing.T, req *http.Request) { - temp := float32(0.8) - body := CompletionRequest{ - Model: "test-model", - Prompt: "Hello", - Temperature: &temp, - Stop: []string{"\n", "stop"}, - Suffix: "suffix", - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { - if req.Prompt != "Hello" { - t.Fatalf("expected 'Hello', got %s", req.Prompt) - } - - if req.Options["temperature"] != 1.6 { - t.Fatalf("expected 1.6, got %f", req.Options["temperature"]) - } - - stopTokens, ok := req.Options["stop"].([]any) - - if !ok { - t.Fatalf("expected stop tokens to be a list") - } - - if stopTokens[0] != "\n" || stopTokens[1] != "stop" { - t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens) - } - - if req.Suffix != "suffix" { - t.Fatalf("expected 'suffix', got %s", req.Suffix) - } + name: "completions handler", + body: `{ + "model": "test-model", + "prompt": "Hello", + "temperature": 0.8, + "stop": ["\n", "stop"], + "suffix": "suffix" + }`, + req: api.GenerateRequest{ + Model: "test-model", + Prompt: "Hello", + Options: map[string]any{ + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "temperature": 1.6, + "top_p": 1.0, + "stop": []any{"\n", "stop"}, + }, + Suffix: "suffix", + Stream: &False, }, }, { - Name: "completions handler error forwarding", - Setup: func(t *testing.T, req *http.Request) { - body := CompletionRequest{ - Model: "test-model", - Prompt: "Hello", - Temperature: nil, - Stop: []int{1, 2}, - Suffix: "suffix", - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { - if resp.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.Code) - } - - if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") { - t.Fatalf("error was not forwarded") - } + name: "completions handler error forwarding", + body: `{ + "model": "test-model", + "prompt": "Hello", + "temperature": null, + "stop": [1, 2], + "suffix": "suffix" + }`, + err: ErrorResponse{ + Error: Error{ + Message: "invalid type for 'stop' field: float64", + Type: "invalid_request_error", + }, }, }, } @@ -281,15 +269,27 @@ func TestCompletionsMiddleware(t *testing.T) { router.Handle(http.MethodPost, "/api/generate", endpoint) for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil) - - tc.Setup(t, req) + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") resp := httptest.NewRecorder() router.ServeHTTP(resp, req) - tc.Expected(t, capturedRequest, resp) + var errResp ErrorResponse + if resp.Code != http.StatusOK { + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatal(err) + } + } + + if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { + t.Fatal("requests did not match") + } + + if !reflect.DeepEqual(tc.err, errResp) { + t.Fatal("errors did not match") + } capturedRequest = nil }) @@ -298,78 +298,47 @@ func TestCompletionsMiddleware(t *testing.T) { func TestEmbeddingsMiddleware(t *testing.T) { type testCase struct { - Name string - Setup func(t *testing.T, req *http.Request) - Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) + name string + body string + req api.EmbedRequest + err ErrorResponse } var capturedRequest *api.EmbedRequest testCases := []testCase{ { - Name: "embed handler single input", - Setup: func(t *testing.T, req *http.Request) { - body := EmbedRequest{ - Input: "Hello", - Model: "test-model", - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) { - if req.Input != "Hello" { - t.Fatalf("expected 'Hello', got %s", req.Input) - } - - if req.Model != "test-model" { - t.Fatalf("expected 'test-model', got %s", req.Model) - } + name: "embed handler single input", + body: `{ + "input": "Hello", + "model": "test-model" + }`, + req: api.EmbedRequest{ + Input: "Hello", + Model: "test-model", }, }, { - Name: "embed handler batch input", - Setup: func(t *testing.T, req *http.Request) { - body := EmbedRequest{ - Input: []string{"Hello", "World"}, - Model: "test-model", - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) { - input, ok := req.Input.([]any) - - if !ok { - t.Fatalf("expected input to be a list") - } - - if input[0].(string) != "Hello" { - t.Fatalf("expected 'Hello', got %s", input[0]) - } - - if input[1].(string) != "World" { - t.Fatalf("expected 'World', got %s", input[1]) - } - - if req.Model != "test-model" { - t.Fatalf("expected 'test-model', got %s", req.Model) - } + name: "embed handler batch input", + body: `{ + "input": ["Hello", "World"], + "model": "test-model" + }`, + req: api.EmbedRequest{ + Input: []any{"Hello", "World"}, + Model: "test-model", }, }, { - Name: "embed handler error forwarding", - Setup: func(t *testing.T, req *http.Request) { - body := EmbedRequest{ - Model: "test-model", - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) { - if resp.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.Code) - } - - if !strings.Contains(resp.Body.String(), "invalid input") { - t.Fatalf("error was not forwarded") - } + name: "embed handler error forwarding", + body: `{ + "model": "test-model" + }`, + err: ErrorResponse{ + Error: Error{ + Message: "invalid input", + Type: "invalid_request_error", + }, }, }, } @@ -384,116 +353,167 @@ func TestEmbeddingsMiddleware(t *testing.T) { router.Handle(http.MethodPost, "/api/embed", endpoint) for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil) - - tc.Setup(t, req) + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") resp := httptest.NewRecorder() router.ServeHTTP(resp, req) - tc.Expected(t, capturedRequest, resp) + var errResp ErrorResponse + if resp.Code != http.StatusOK { + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatal(err) + } + } + + if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { + t.Fatal("requests did not match") + } + + if !reflect.DeepEqual(tc.err, errResp) { + t.Fatal("errors did not match") + } capturedRequest = nil }) } } -func TestMiddlewareResponses(t *testing.T) { +func TestListMiddleware(t *testing.T) { type testCase struct { - Name string - Method string - Path string - TestPath string - Handler func() gin.HandlerFunc - Endpoint func(c *gin.Context) - Setup func(t *testing.T, req *http.Request) - Expected func(t *testing.T, resp *httptest.ResponseRecorder) + name string + endpoint func(c *gin.Context) + resp string } testCases := []testCase{ { - Name: "list handler", - Method: http.MethodGet, - Path: "/api/tags", - TestPath: "/api/tags", - Handler: ListMiddleware, - Endpoint: func(c *gin.Context) { + name: "list handler", + endpoint: func(c *gin.Context) { c.JSON(http.StatusOK, api.ListResponse{ Models: []api.ListModelResponse{ { - Name: "Test Model", + Name: "test-model", + ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), }, }, }) }, - Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { - var listResp ListCompletion - if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { - t.Fatal(err) - } - - if listResp.Object != "list" { - t.Fatalf("expected list, got %s", listResp.Object) - } - - if len(listResp.Data) != 1 { - t.Fatalf("expected 1, got %d", len(listResp.Data)) - } - - if listResp.Data[0].Id != "Test Model" { - t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id) - } - }, + resp: `{ + "object": "list", + "data": [ + { + "id": "test-model", + "object": "model", + "created": 1686935002, + "owned_by": "library" + } + ] + }`, }, { - Name: "retrieve model", - Method: http.MethodGet, - Path: "/api/show/:model", - TestPath: "/api/show/test-model", - Handler: RetrieveMiddleware, - Endpoint: func(c *gin.Context) { - c.JSON(http.StatusOK, api.ShowResponse{ - ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC), - }) - }, - Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { - var retrieveResp Model - if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil { - t.Fatal(err) - } - - if retrieveResp.Object != "model" { - t.Fatalf("Expected object to be model, got %s", retrieveResp.Object) - } - - if retrieveResp.Id != "test-model" { - t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id) - } + name: "list handler empty output", + endpoint: func(c *gin.Context) { + c.JSON(http.StatusOK, api.ListResponse{}) }, + resp: `{ + "object": "list", + "data": null + }`, }, } gin.SetMode(gin.TestMode) - router := gin.New() for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - router = gin.New() - router.Use(tc.Handler()) - router.Handle(tc.Method, tc.Path, tc.Endpoint) - req, _ := http.NewRequest(tc.Method, tc.TestPath, nil) + router := gin.New() + router.Use(ListMiddleware()) + router.Handle(http.MethodGet, "/api/tags", tc.endpoint) + req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil) - if tc.Setup != nil { - tc.Setup(t, req) - } + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + var expected, actual map[string]any + err := json.Unmarshal([]byte(tc.resp), &expected) + if err != nil { + t.Fatalf("failed to unmarshal expected response: %v", err) + } - assert.Equal(t, http.StatusOK, resp.Code) + err = json.Unmarshal(resp.Body.Bytes(), &actual) + if err != nil { + t.Fatalf("failed to unmarshal actual response: %v", err) + } - tc.Expected(t, resp) - }) + if !reflect.DeepEqual(expected, actual) { + t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) + } + } +} + +func TestRetrieveMiddleware(t *testing.T) { + type testCase struct { + name string + endpoint func(c *gin.Context) + resp string + } + + testCases := []testCase{ + { + name: "retrieve handler", + endpoint: func(c *gin.Context) { + c.JSON(http.StatusOK, api.ShowResponse{ + ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), + }) + }, + resp: `{ + "id":"test-model", + "object":"model", + "created":1686935002, + "owned_by":"library"} + `, + }, + { + name: "retrieve handler error forwarding", + endpoint: func(c *gin.Context) { + c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"}) + }, + resp: `{ + "error": { + "code": null, + "message": "model not found", + "param": null, + "type": "api_error" + } + }`, + }, + } + + gin.SetMode(gin.TestMode) + + for _, tc := range testCases { + router := gin.New() + router.Use(RetrieveMiddleware()) + router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint) + req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil) + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + var expected, actual map[string]any + err := json.Unmarshal([]byte(tc.resp), &expected) + if err != nil { + t.Fatalf("failed to unmarshal expected response: %v", err) + } + + err = json.Unmarshal(resp.Body.Bytes(), &actual) + if err != nil { + t.Fatalf("failed to unmarshal actual response: %v", err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) + } } } diff --git a/progress/spinner.go b/progress/spinner.go index 02f3f9fb..e39a45ee 100644 --- a/progress/spinner.go +++ b/progress/spinner.go @@ -3,11 +3,12 @@ package progress import ( "fmt" "strings" + "sync/atomic" "time" ) type Spinner struct { - message string + message atomic.Value messageWidth int parts []string @@ -21,20 +22,25 @@ type Spinner struct { func NewSpinner(message string) *Spinner { s := &Spinner{ - message: message, parts: []string{ "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏", }, started: time.Now(), } + s.SetMessage(message) go s.start() return s } +func (s *Spinner) SetMessage(message string) { + s.message.Store(message) +} + func (s *Spinner) String() string { var sb strings.Builder - if len(s.message) > 0 { - message := strings.TrimSpace(s.message) + + if message, ok := s.message.Load().(string); ok && len(message) > 0 { + message := strings.TrimSpace(message) if s.messageWidth > 0 && len(message) > s.messageWidth { message = message[:s.messageWidth] } diff --git a/scripts/install.sh b/scripts/install.sh index aa8b3e5e..03af5a69 100644 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -209,15 +209,15 @@ install_cuda_driver_yum() { case $PACKAGE_MANAGER in yum) $SUDO $PACKAGE_MANAGER -y install yum-utils - if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then - $SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo + if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo" >/dev/null ; then + $SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo else error $CUDA_REPO_ERR_MSG fi ;; dnf) - if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then - $SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo + if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo" >/dev/null ; then + $SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo else error $CUDA_REPO_ERR_MSG fi @@ -245,8 +245,8 @@ install_cuda_driver_yum() { # ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian install_cuda_driver_apt() { status 'Installing NVIDIA repository...' - if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb" >/dev/null ; then - curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb + if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-keyring_1.1-1_all.deb" >/dev/null ; then + curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-keyring_1.1-1_all.deb else error $CUDA_REPO_ERR_MSG fi diff --git a/server/download.go b/server/download.go index 38d24a6b..1bca86bf 100644 --- a/server/download.go +++ b/server/download.go @@ -216,9 +216,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return err } defer file.Close() - if err := setSparse(file); err != nil { - return err - } + setSparse(file) _ = file.Truncate(b.Total) @@ -235,7 +233,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error { if len(via) > 10 { - return errors.New("maxium redirects exceeded (10) for directURL") + return errors.New("maximum redirects exceeded (10) for directURL") } // if the hostname is the same, allow the redirect diff --git a/server/images.go b/server/images.go index 7ed35995..0e753f56 100644 --- a/server/images.go +++ b/server/images.go @@ -373,7 +373,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio var messages []*api.Message parameters := make(map[string]any) - var layers []*Layer + var layers []Layer for _, c := range modelfile.Commands { mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) @@ -499,7 +499,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio if c.Name != "license" { // replace - layers = slices.DeleteFunc(layers, func(layer *Layer) bool { + layers = slices.DeleteFunc(layers, func(layer Layer) bool { if layer.MediaType != mediatype { return false } @@ -545,7 +545,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio } var err2 error - layers = slices.DeleteFunc(layers, func(layer *Layer) bool { + layers = slices.DeleteFunc(layers, func(layer Layer) bool { switch layer.MediaType { case "application/vnd.ollama.image.message": // if there are new messages, remove the inherited ones @@ -625,12 +625,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio return err } - layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json") + configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json") if err != nil { return err } - for _, layer := range append(layers, layer) { + for _, layer := range append(layers, configLayer) { if layer.status != "" { fn(api.ProgressResponse{Status: layer.status}) } @@ -639,7 +639,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio old, _ := ParseNamedManifest(name) fn(api.ProgressResponse{Status: "writing manifest"}) - if err := WriteManifest(name, layer, layers); err != nil { + if err := WriteManifest(name, configLayer, layers); err != nil { return err } @@ -839,10 +839,10 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu return err } - var layers []*Layer + var layers []Layer layers = append(layers, manifest.Layers...) if manifest.Config.Digest != "" { - layers = append(layers, &manifest.Config) + layers = append(layers, manifest.Config) } for _, layer := range layers { @@ -911,10 +911,10 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu return fmt.Errorf("pull model manifest: %s", err) } - var layers []*Layer + var layers []Layer layers = append(layers, manifest.Layers...) if manifest.Config.Digest != "" { - layers = append(layers, &manifest.Config) + layers = append(layers, manifest.Config) } skipVerify := make(map[string]bool) diff --git a/server/layer.go b/server/layer.go index a2b66782..c666bd10 100644 --- a/server/layer.go +++ b/server/layer.go @@ -16,15 +16,15 @@ type Layer struct { status string } -func NewLayer(r io.Reader, mediatype string) (*Layer, error) { +func NewLayer(r io.Reader, mediatype string) (Layer, error) { blobs, err := GetBlobsPath("") if err != nil { - return nil, err + return Layer{}, err } temp, err := os.CreateTemp(blobs, "sha256-") if err != nil { - return nil, err + return Layer{}, err } defer temp.Close() defer os.Remove(temp.Name()) @@ -32,28 +32,28 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) { sha256sum := sha256.New() n, err := io.Copy(io.MultiWriter(temp, sha256sum), r) if err != nil { - return nil, err + return Layer{}, err } if err := temp.Close(); err != nil { - return nil, err + return Layer{}, err } digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)) blob, err := GetBlobsPath(digest) if err != nil { - return nil, err + return Layer{}, err } status := "using existing layer" if _, err := os.Stat(blob); err != nil { status = "creating new layer" if err := os.Rename(temp.Name(), blob); err != nil { - return nil, err + return Layer{}, err } } - return &Layer{ + return Layer{ MediaType: mediatype, Digest: digest, Size: n, @@ -61,22 +61,22 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) { }, nil } -func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) { +func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) { if digest == "" { - return nil, errors.New("creating new layer from layer with empty digest") + return Layer{}, errors.New("creating new layer from layer with empty digest") } blob, err := GetBlobsPath(digest) if err != nil { - return nil, err + return Layer{}, err } fi, err := os.Stat(blob) if err != nil { - return nil, err + return Layer{}, err } - return &Layer{ + return Layer{ MediaType: mediatype, Digest: digest, Size: fi.Size(), @@ -109,7 +109,7 @@ func (l *Layer) Remove() error { } for _, m := range ms { - for _, layer := range append(m.Layers, &m.Config) { + for _, layer := range append(m.Layers, m.Config) { if layer.Digest == l.Digest { // something is using this layer return nil diff --git a/server/manifest.go b/server/manifest.go index b966ddbe..6a5d7b88 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -14,10 +14,10 @@ import ( ) type Manifest struct { - SchemaVersion int `json:"schemaVersion"` - MediaType string `json:"mediaType"` - Config Layer `json:"config"` - Layers []*Layer `json:"layers"` + SchemaVersion int `json:"schemaVersion"` + MediaType string `json:"mediaType"` + Config Layer `json:"config"` + Layers []Layer `json:"layers"` filepath string fi os.FileInfo @@ -25,7 +25,7 @@ type Manifest struct { } func (m *Manifest) Size() (size int64) { - for _, layer := range append(m.Layers, &m.Config) { + for _, layer := range append(m.Layers, m.Config) { size += layer.Size } @@ -46,7 +46,7 @@ func (m *Manifest) Remove() error { } func (m *Manifest) RemoveLayers() error { - for _, layer := range append(m.Layers, &m.Config) { + for _, layer := range append(m.Layers, m.Config) { if layer.Digest != "" { if err := layer.Remove(); errors.Is(err, os.ErrNotExist) { slog.Debug("layer does not exist", "digest", layer.Digest) @@ -95,7 +95,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { return &m, nil } -func WriteManifest(name model.Name, config *Layer, layers []*Layer) error { +func WriteManifest(name model.Name, config Layer, layers []Layer) error { manifests, err := GetManifestPath() if err != nil { return err @@ -115,7 +115,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error { m := Manifest{ SchemaVersion: 2, MediaType: "application/vnd.docker.distribution.manifest.v2+json", - Config: *config, + Config: config, Layers: layers, } diff --git a/server/model.go b/server/model.go index f2946a0b..b17bf0e3 100644 --- a/server/model.go +++ b/server/model.go @@ -26,7 +26,7 @@ import ( var intermediateBlobs map[string]string = make(map[string]string) type layerGGML struct { - *Layer + Layer *llm.GGML } @@ -176,9 +176,20 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap mediatype = "application/vnd.ollama.image.projector" } - layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype) - if err != nil { - return nil, err + var layer Layer + if digest != "" && n == stat.Size() && offset == 0 { + layer, err = NewLayerFromLayer(digest, mediatype, file.Name()) + if err != nil { + slog.Debug("could not create new layer from layer", "error", err) + } + } + + // Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size()) + if layer.Digest == "" { + layer, err = NewLayer(io.NewSectionReader(file, offset, n), mediatype) + if err != nil { + return nil, err + } } layers = append(layers, &layerGGML{layer, ggml}) diff --git a/server/model_test.go b/server/model_test.go index aa214d3d..63fc408d 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -2,8 +2,10 @@ package server import ( "bytes" + "context" "encoding/json" "fmt" + "io" "os" "path/filepath" "testing" @@ -11,6 +13,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/template" ) @@ -133,3 +136,82 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, }) } } + +func TestParseFromFileFromLayer(t *testing.T) { + tempModels := t.TempDir() + + file, err := os.CreateTemp(tempModels, "") + if err != nil { + t.Fatalf("failed to open file: %v", err) + } + defer file.Close() + if err := llm.WriteGGUF(file, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil { + t.Fatalf("failed to write gguf: %v", err) + } + + if _, err := file.Seek(0, io.SeekStart); err != nil { + t.Fatalf("failed to seek to start: %v", err) + } + + layers, err := parseFromFile(context.Background(), file, "", func(api.ProgressResponse) {}) + if err != nil { + t.Fatalf("failed to parse from file: %v", err) + } + + if len(layers) != 1 { + t.Fatalf("got %d != want 1", len(layers)) + } + + if _, err := file.Seek(0, io.SeekStart); err != nil { + t.Fatalf("failed to seek to start: %v", err) + } + + layers2, err := parseFromFile(context.Background(), file, layers[0].Digest, func(api.ProgressResponse) {}) + if err != nil { + t.Fatalf("failed to parse from file: %v", err) + } + if len(layers2) != 1 { + t.Fatalf("got %d != want 1", len(layers2)) + } + + if layers[0].Digest != layers2[0].Digest { + t.Fatalf("got %s != want %s", layers[0].Digest, layers2[0].Digest) + } + + if layers[0].Size != layers2[0].Size { + t.Fatalf("got %d != want %d", layers[0].Size, layers2[0].Size) + } + + if layers[0].MediaType != layers2[0].MediaType { + t.Fatalf("got %v != want %v", layers[0].MediaType, layers2[0].MediaType) + } +} + +func TestParseLayerFromCopy(t *testing.T) { + tempModels := t.TempDir() + + file2, err := os.CreateTemp(tempModels, "") + if err != nil { + t.Fatalf("failed to open file: %v", err) + } + defer file2.Close() + + for range 5 { + if err := llm.WriteGGUF(file2, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil { + t.Fatalf("failed to write gguf: %v", err) + } + } + + if _, err := file2.Seek(0, io.SeekStart); err != nil { + t.Fatalf("failed to seek to start: %v", err) + } + + layers, err := parseFromFile(context.Background(), file2, "", func(api.ProgressResponse) {}) + if err != nil { + t.Fatalf("failed to parse from file: %v", err) + } + + if len(layers) != 5 { + t.Fatalf("got %d != want 5", len(layers)) + } +} diff --git a/server/routes.go b/server/routes.go index e55eaa9d..e5a31002 100644 --- a/server/routes.go +++ b/server/routes.go @@ -23,6 +23,7 @@ import ( "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" + "golang.org/x/sync/errgroup" "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" @@ -346,6 +347,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } + var count int for i, s := range input { tokens, err := r.Tokenize(c.Request.Context(), s) if err != nil { @@ -368,25 +370,36 @@ func (s *Server) EmbedHandler(c *gin.Context) { } } + count += len(tokens) + input[i] = s } - embeddings, err := r.Embed(c.Request.Context(), input) - if err != nil { - slog.Error("embedding generation failed", "error", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) - return + + var g errgroup.Group + embeddings := make([][]float32, len(input)) + for i, text := range input { + g.Go(func() error { + embedding, err := r.Embedding(c.Request.Context(), text) + if err != nil { + return err + } + embeddings[i] = normalize(embedding) + return nil + }) } - for i, e := range embeddings.Embedding { - embeddings.Embedding[i] = normalize(e) + if err := g.Wait(); err != nil { + slog.Error("embedding generation failed", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)}) + return } resp := api.EmbedResponse{ Model: req.Model, - Embeddings: embeddings.Embedding, + Embeddings: embeddings, TotalDuration: time.Since(checkpointStart), LoadDuration: checkpointLoaded.Sub(checkpointStart), - PromptEvalCount: embeddings.PromptEvalCount, + PromptEvalCount: count, } c.JSON(http.StatusOK, resp) } @@ -430,21 +443,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt}) + embedding, err := r.Embedding(c.Request.Context(), req.Prompt) if err != nil { slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) return } - embedding := make([]float64, len(embeddings.Embedding[0])) - - for i, v := range embeddings.Embedding[0] { - embedding[i] = float64(v) + var e []float64 + for _, v := range embedding { + e = append(e, float64(v)) } resp := api.EmbeddingResponse{ - Embedding: embedding, + Embedding: e, } c.JSON(http.StatusOK, resp) } diff --git a/server/routes_delete_test.go b/server/routes_delete_test.go index 1c950d66..82fac9f5 100644 --- a/server/routes_delete_test.go +++ b/server/routes_delete_test.go @@ -98,7 +98,7 @@ func TestDeleteDuplicateLayers(t *testing.T) { } // create a manifest with duplicate layers - if err := WriteManifest(n, config, []*Layer{config}); err != nil { + if err := WriteManifest(n, config, []Layer{config}); err != nil { t.Fatal(err) } diff --git a/server/sched_test.go b/server/sched_test.go index c8717430..713b9259 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -708,8 +708,8 @@ type mockLlm struct { pingResp error waitResp error completionResp error - embedResp *llm.EmbedResponse - embedRespErr error + embeddingResp []float32 + embeddingRespErr error tokenizeResp []int tokenizeRespErr error detokenizeResp string @@ -727,8 +727,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn return s.completionResp } -func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) { - return s.embedResp, s.embedRespErr +func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) { + return s.embeddingResp, s.embeddingRespErr } func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { diff --git a/server/sparse_common.go b/server/sparse_common.go index f25627fc..c88b2da0 100644 --- a/server/sparse_common.go +++ b/server/sparse_common.go @@ -4,6 +4,5 @@ package server import "os" -func setSparse(file *os.File) error { - return nil +func setSparse(*os.File) { } diff --git a/server/sparse_windows.go b/server/sparse_windows.go index cdad379e..f21cbbda 100644 --- a/server/sparse_windows.go +++ b/server/sparse_windows.go @@ -6,8 +6,9 @@ import ( "golang.org/x/sys/windows" ) -func setSparse(file *os.File) error { - return windows.DeviceIoControl( +func setSparse(file *os.File) { + // exFat (and other FS types) don't support sparse files, so ignore errors + windows.DeviceIoControl( //nolint:errcheck windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE, nil, 0, nil, 0, diff --git a/server/upload.go b/server/upload.go index b5a244ea..2f115436 100644 --- a/server/upload.go +++ b/server/upload.go @@ -26,7 +26,7 @@ import ( var blobUploadManager sync.Map type blobUpload struct { - *Layer + Layer Total int64 Completed atomic.Int64 @@ -362,7 +362,7 @@ func (p *progressWriter) Rollback() { p.written = 0 } -func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error { +func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error { requestURL := mp.BaseURL() requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)