diff --git a/README.md b/README.md index 54315277..4f980375 100644 --- a/README.md +++ b/README.md @@ -258,6 +258,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Open WebUI](https://github.com/open-webui/open-webui) - [Enchanted (macOS native)](https://github.com/AugustDev/enchanted) +- [Hollama](https://github.com/fmaclen/hollama) - [Lollms-Webui](https://github.com/ParisNeo/lollms-webui) - [LibreChat](https://github.com/danny-avila/LibreChat) - [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt) @@ -297,6 +298,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Lobe Chat](https://github.com/lobehub/lobe-chat) with [Integrating Doc](https://lobehub.com/docs/self-hosting/examples/ollama) - [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG) - [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation) +- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends) ### Terminal @@ -350,6 +352,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md) - [Elixir LangChain](https://github.com/brainlid/langchain) - [Ollama for R - rollama](https://github.com/JBGruber/rollama) +- [Ollama for R - ollama-r](https://github.com/hauselin/ollama-r) - [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex) - [Ollama Connector for SAP ABAP](https://github.com/b-tocs/abap_btocs_ollama) - [Testcontainers](https://testcontainers.com/modules/ollama/) diff --git a/api/client.go b/api/client.go index 074103cc..5b1fc796 100644 --- a/api/client.go +++ b/api/client.go @@ -1,9 +1,16 @@ // Package api implements the client-side API for code wishing to interact // with the ollama service. The methods of the [Client] type correspond to -// the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md -// +// the ollama REST API as described in [the API documentation]. // The ollama command-line client itself uses this package to interact with // the backend service. +// +// # Examples +// +// Several examples of using this package are available [in the GitHub +// repository]. +// +// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md +// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples package api import ( @@ -299,8 +306,14 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc }) } +// PushProgressFunc is a function that [Client.Push] invokes when progress is +// made. +// It's similar to other progress function types like [PullProgressFunc]. type PushProgressFunc func(ProgressResponse) error +// Push uploads a model to the model library; requires registering for ollama.ai +// and adding a public key first. fn is called each time progress is made on +// the request and can be used to display a progress bar, etc. func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error { return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error { var resp ProgressResponse @@ -312,8 +325,15 @@ func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc }) } +// CreateProgressFunc is a function that [Client.Create] invokes when progress +// is made. +// It's similar to other progress function types like [PullProgressFunc]. type CreateProgressFunc func(ProgressResponse) error +// Create creates a model from a [Modelfile]. fn is a progress function that +// behaves similarly to other methods (see [Client.Pull]). +// +// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error { return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error { var resp ProgressResponse @@ -325,6 +345,7 @@ func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgre }) } +// List lists models that are available locally. func (c *Client) List(ctx context.Context) (*ListResponse, error) { var lr ListResponse if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil { @@ -333,6 +354,8 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) { return &lr, nil } +// Copy copies a model - creating a model with another name from an existing +// model. func (c *Client) Copy(ctx context.Context, req *CopyRequest) error { if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil { return err @@ -340,6 +363,7 @@ func (c *Client) Copy(ctx context.Context, req *CopyRequest) error { return nil } +// Delete deletes a model and its data. func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error { if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil { return err @@ -347,6 +371,7 @@ func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error { return nil } +// Show obtains model information, including details, modelfile, license etc. func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) { var resp ShowResponse if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil { @@ -355,12 +380,16 @@ func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, err return &resp, nil } +// Hearbeat checks if the server has started and is responsive; if yes, it +// returns nil, otherwise an error. func (c *Client) Heartbeat(ctx context.Context) error { if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil { return err } return nil } + +// Embeddings generates embeddings from a model. func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) { var resp EmbeddingResponse if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil { @@ -369,10 +398,13 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd return &resp, nil } +// CreateBlob creates a blob from a file on the server. digest is the +// expected SHA256 digest of the file, and r represents the file. func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error { return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil) } +// Version returns the Ollama server version as a string. func (c *Client) Version(ctx context.Context) (string, error) { var version struct { Version string `json:"version"` diff --git a/api/types.go b/api/types.go index 70caee87..5d0212e5 100644 --- a/api/types.go +++ b/api/types.go @@ -12,6 +12,7 @@ import ( "time" ) +// StatusError is an error with and HTTP status code. type StatusError struct { StatusCode int Status string @@ -32,6 +33,7 @@ func (e StatusError) Error() string { } } +// ImageData represents the raw binary data of an image file. type ImageData []byte // GenerateRequest describes a request sent by [Client.Generate]. While you @@ -77,22 +79,39 @@ type GenerateRequest struct { Options map[string]interface{} `json:"options"` } +// ChatRequest describes a request sent by [Client.Chat]. type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Stream *bool `json:"stream,omitempty"` - Format string `json:"format"` + // Model is the model name, as in [GenerateRequest]. + Model string `json:"model"` + + // Messages is the messages of the chat - can be used to keep a chat memory. + Messages []Message `json:"messages"` + + // Stream enable streaming of returned response; true by default. + Stream *bool `json:"stream,omitempty"` + + // Format is the format to return the response in (e.g. "json"). + Format string `json:"format"` + + // KeepAlive controls how long the model will stay loaded into memory + // followin the request. KeepAlive *Duration `json:"keep_alive,omitempty"` + // Options lists model-specific options. Options map[string]interface{} `json:"options"` } +// Message is a single message in a chat sequence. The message contains the +// role ("system", "user", or "assistant"), the content and an optional list +// of images. type Message struct { - Role string `json:"role"` // one of ["system", "user", "assistant"] + Role string `json:"role"` Content string `json:"content"` Images []ImageData `json:"images,omitempty"` } +// ChatResponse is the response returned by [Client.Chat]. Its fields are +// similar to [GenerateResponse]. type ChatResponse struct { Model string `json:"model"` CreatedAt time.Time `json:"created_at"` @@ -112,7 +131,8 @@ type Metrics struct { EvalDuration time.Duration `json:"eval_duration,omitempty"` } -// Options specified in GenerateRequest, if you add a new option here add it to the API docs also +// Options specified in [GenerateRequest], if you add a new option here add it +// to the API docs also. type Options struct { Runner @@ -158,18 +178,28 @@ type Runner struct { RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"` } +// EmbeddingRequest is the request passed to [Client.Embeddings]. type EmbeddingRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` + // Model is the model name. + Model string `json:"model"` + + // Prompt is the textual prompt to embed. + Prompt string `json:"prompt"` + + // KeepAlive controls how long the model will stay loaded in memory following + // this request. KeepAlive *Duration `json:"keep_alive,omitempty"` + // Options lists model-specific options. Options map[string]interface{} `json:"options"` } +// EmbeddingResponse is the response from [Client.Embeddings]. type EmbeddingResponse struct { Embedding []float64 `json:"embedding"` } +// CreateRequest is the request passed to [Client.Create]. type CreateRequest struct { Model string `json:"model"` Path string `json:"path"` @@ -181,6 +211,7 @@ type CreateRequest struct { Name string `json:"name"` } +// DeleteRequest is the request passed to [Client.Delete]. type DeleteRequest struct { Model string `json:"model"` @@ -188,6 +219,7 @@ type DeleteRequest struct { Name string `json:"name"` } +// ShowRequest is the request passed to [Client.Show]. type ShowRequest struct { Model string `json:"model"` System string `json:"system"` @@ -199,6 +231,7 @@ type ShowRequest struct { Name string `json:"name"` } +// ShowResponse is the response returned from [Client.Show]. type ShowResponse struct { License string `json:"license,omitempty"` Modelfile string `json:"modelfile,omitempty"` @@ -209,11 +242,13 @@ type ShowResponse struct { Messages []Message `json:"messages,omitempty"` } +// CopyRequest is the request passed to [Client.Copy]. type CopyRequest struct { Source string `json:"source"` Destination string `json:"destination"` } +// PullRequest is the request passed to [Client.Pull]. type PullRequest struct { Model string `json:"model"` Insecure bool `json:"insecure,omitempty"` @@ -225,6 +260,8 @@ type PullRequest struct { Name string `json:"name"` } +// ProgressResponse is the response passed to progress functions like +// [PullProgressFunc] and [PushProgressFunc]. type ProgressResponse struct { Status string `json:"status"` Digest string `json:"digest,omitempty"` @@ -232,6 +269,7 @@ type ProgressResponse struct { Completed int64 `json:"completed,omitempty"` } +// PushRequest is the request passed to [Client.Push]. type PushRequest struct { Model string `json:"model"` Insecure bool `json:"insecure,omitempty"` @@ -243,10 +281,12 @@ type PushRequest struct { Name string `json:"name"` } +// ListResponse is the response from [Client.List]. type ListResponse struct { Models []ModelResponse `json:"models"` } +// ModelResponse is a single model description in [ListResponse]. type ModelResponse struct { Name string `json:"name"` Model string `json:"model"` @@ -260,17 +300,28 @@ type TokenResponse struct { Token string `json:"token"` } +// GenerateResponse is the response passed into [GenerateResponseFunc]. type GenerateResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Response string `json:"response"` + // Model is the model name that generated the response. + Model string `json:"model"` - Done bool `json:"done"` + //CreatedAt is the timestamp of the response. + CreatedAt time.Time `json:"created_at"` + + // Response is the textual response itself. + Response string `json:"response"` + + // Done specifies if the response is complete. + Done bool `json:"done"` + + // Context is an encoding of the conversation used in this response; this + // can be sent in the next request to keep a conversational memory. Context []int `json:"context,omitempty"` Metrics } +// ModelDetails provides details about a model. type ModelDetails struct { ParentModel string `json:"parent_model"` Format string `json:"format"` @@ -308,6 +359,7 @@ func (m *Metrics) Summary() { } } +// ErrInvalidOpts is returned when invalid options are passed to the client. var ErrInvalidOpts = errors.New("invalid options") var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") @@ -394,6 +446,8 @@ func (opts *Options) FromMap(m map[string]interface{}) error { return nil } +// DefaultOptions is the default set of options for [GenerateRequest]; these +// values are used unless the user specifies other values explicitly. func DefaultOptions() Options { return Options{ // options set on request to runner diff --git a/cmd/interactive.go b/cmd/interactive.go index 5673fda0..c294b7b5 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -162,7 +162,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty How strongly to penalize repetitions") fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n Set how far back to look for repetitions") fmt.Fprintln(os.Stderr, " /set parameter num_gpu The number of layers to send to the GPU") - fmt.Fprintln(os.Stderr, " /set parameter stop \"\", ... Set the stop parameters") + fmt.Fprintln(os.Stderr, " /set parameter stop ... Set the stop parameters") fmt.Fprintln(os.Stderr, "") } diff --git a/convert/convert.go b/convert/convert.go index 42de080c..f4210e50 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "io" "log/slog" "os" "path/filepath" @@ -47,7 +48,7 @@ type ByteOrder interface { type ModelArch interface { GetTensors() error LoadVocab() error - WriteGGUF() (string, error) + WriteGGUF(io.WriteSeeker) error } type ModelFormat interface { diff --git a/convert/gemma.go b/convert/gemma.go index 648a4ad9..88abe646 100644 --- a/convert/gemma.go +++ b/convert/gemma.go @@ -94,7 +94,7 @@ func (m *GemmaModel) LoadVocab() error { return nil } -func (m *GemmaModel) WriteGGUF() (string, error) { +func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error { kv := llm.KV{ "general.architecture": "gemma", "general.name": m.Name, @@ -122,16 +122,5 @@ func (m *GemmaModel) WriteGGUF() (string, error) { "tokenizer.ggml.add_eos_token": false, } - f, err := os.CreateTemp("", "ollama-gguf") - if err != nil { - return "", err - } - defer f.Close() - - mod := llm.NewGGUFV3(m.Params.ByteOrder) - if err := mod.Encode(f, kv, m.Tensors); err != nil { - return "", err - } - - return f.Name(), nil + return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) } diff --git a/convert/llama.go b/convert/llama.go index c7f7b290..fb576e2e 100644 --- a/convert/llama.go +++ b/convert/llama.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "log/slog" - "os" "regexp" "strings" @@ -132,7 +131,7 @@ func (m *LlamaModel) LoadVocab() error { return nil } -func (m *LlamaModel) WriteGGUF() (string, error) { +func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error { kv := llm.KV{ "general.architecture": "llama", "general.name": m.Name, @@ -159,18 +158,5 @@ func (m *LlamaModel) WriteGGUF() (string, error) { "tokenizer.ggml.add_eos_token": false, } - f, err := os.CreateTemp("", "ollama-gguf") - if err != nil { - return "", err - } - defer f.Close() - - mod := llm.NewGGUFV3(m.Params.ByteOrder) - if err := mod.Encode(f, kv, m.Tensors); err != nil { - return "", err - } - - slog.Debug(fmt.Sprintf("gguf file = %s", f.Name())) - - return f.Name(), nil + return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) } diff --git a/convert/mistral.go b/convert/mistral.go index 70c92edd..f88de12b 100644 --- a/convert/mistral.go +++ b/convert/mistral.go @@ -132,7 +132,7 @@ func (m *MistralModel) LoadVocab() error { return nil } -func (m *MistralModel) WriteGGUF() (string, error) { +func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error { kv := llm.KV{ "general.architecture": "llama", "general.name": m.Name, @@ -158,16 +158,5 @@ func (m *MistralModel) WriteGGUF() (string, error) { "tokenizer.ggml.unknown_token_id": uint32(0), } - f, err := os.CreateTemp("", "ollama-gguf") - if err != nil { - return "", err - } - defer f.Close() - - mod := llm.NewGGUFV3(m.Params.ByteOrder) - if err := mod.Encode(f, kv, m.Tensors); err != nil { - return "", err - } - - return f.Name(), nil + return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) } diff --git a/convert/mixtral.go b/convert/mixtral.go index e31e84af..940df55d 100644 --- a/convert/mixtral.go +++ b/convert/mixtral.go @@ -1,7 +1,7 @@ package convert import ( - "os" + "io" "regexp" "github.com/ollama/ollama/llm" @@ -47,7 +47,7 @@ func (m *MixtralModel) LoadVocab() error { return nil } -func (m *MixtralModel) WriteGGUF() (string, error) { +func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error { kv := llm.KV{ "general.architecture": "llama", "general.name": m.Name, @@ -81,16 +81,5 @@ func (m *MixtralModel) WriteGGUF() (string, error) { "tokenizer.ggml.add_eos_token": false, } - f, err := os.CreateTemp("", "ollama-gguf") - if err != nil { - return "", err - } - defer f.Close() - - mod := llm.NewGGUFV3(m.Params.ByteOrder) - if err := mod.Encode(f, kv, m.Tensors); err != nil { - return "", err - } - - return f.Name(), nil + return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) } diff --git a/docs/tutorials/langchainpy.md b/docs/tutorials/langchainpy.md index b0235679..9a1bca0d 100644 --- a/docs/tutorials/langchainpy.md +++ b/docs/tutorials/langchainpy.md @@ -12,7 +12,7 @@ So let's figure out how we can use **LangChain** with Ollama to ask our question Let's start by asking a simple question that we can get an answer to from the **Llama2** model using **Ollama**. First, we need to install the **LangChain** package: -`pip install langchain` +`pip install langchain_community` Then we can create a model and ask the question: diff --git a/examples/flyio/.gitignore b/examples/flyio/.gitignore new file mode 100644 index 00000000..0501d092 --- /dev/null +++ b/examples/flyio/.gitignore @@ -0,0 +1 @@ +fly.toml diff --git a/examples/flyio/README.md b/examples/flyio/README.md new file mode 100644 index 00000000..09b90aad --- /dev/null +++ b/examples/flyio/README.md @@ -0,0 +1,67 @@ +# Deploy Ollama to Fly.io + +> Note: this example exposes a public endpoint and does not configure authentication. Use with care. + +## Prerequisites + +- Ollama: https://ollama.com/download +- Fly.io account. Sign up for a free account: https://fly.io/app/sign-up + +## Steps + +1. Login to Fly.io + + ```bash + fly auth login + ``` + +1. Create a new Fly app + + ```bash + fly launch --name --image ollama/ollama --internal-port 11434 --vm-size shared-cpu-8x --now + ``` + +1. Pull and run `orca-mini:3b` + + ```bash + OLLAMA_HOST=https://.fly.dev ollama run orca-mini:3b + ``` + +`shared-cpu-8x` is a free-tier eligible machine type. For better performance, switch to a `performance` or `dedicated` machine type or attach a GPU for hardware acceleration (see below). + +## (Optional) Persistent Volume + +By default Fly Machines use ephemeral storage which is problematic if you want to use the same model across restarts without pulling it again. Create and attach a persistent volume to store the downloaded models: + +1. Create the Fly Volume + + ```bash + fly volume create ollama + ``` + +1. Update `fly.toml` and add `[mounts]` + + ```toml + [mounts] + source = "ollama" + destination = "/mnt/ollama/models" + ``` + +1. Update `fly.toml` and add `[env]` + + ```toml + [env] + OLLAMA_MODELS = "/mnt/ollama/models" + ``` + +1. Deploy your app + + ```bash + fly deploy + ``` + +## (Optional) Hardware Acceleration + +Fly.io GPU is currently in waitlist. Sign up for the waitlist: https://fly.io/gpu + +Once you've been accepted, create the app with the additional flags `--vm-gpu-kind a100-pcie-40gb` or `--vm-gpu-kind a100-pcie-80gb`. diff --git a/examples/kubernetes/README.md b/examples/kubernetes/README.md index c522ba76..2e2444c7 100644 --- a/examples/kubernetes/README.md +++ b/examples/kubernetes/README.md @@ -7,12 +7,24 @@ ## Steps -1. Create the Ollama namespace, daemon set, and service +1. Create the Ollama namespace, deployment, and service ```bash kubectl apply -f cpu.yaml ``` +## (Optional) Hardware Acceleration + +Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin) which is deployed in Kubernetes in form of daemonset. Follow the link for more details. + +Once configured, create a GPU enabled Ollama deployment. + +```bash +kubectl apply -f gpu.yaml +``` + +## Test + 1. Port forward the Ollama service to connect and use it locally ```bash @@ -23,14 +35,4 @@ ```bash ollama run orca-mini:3b - ``` - -## (Optional) Hardware Acceleration - -Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin). Follow the link for more details. - -Once configured, create a GPU enabled Ollama deployment. - -```bash -kubectl apply -f gpu.yaml -``` + ``` \ No newline at end of file diff --git a/format/format.go b/format/format.go index 8fd2defa..31059578 100644 --- a/format/format.go +++ b/format/format.go @@ -13,12 +13,20 @@ const ( func HumanNumber(b uint64) string { switch { - case b > Billion: - return fmt.Sprintf("%.0fB", math.Round(float64(b)/Billion)) - case b > Million: - return fmt.Sprintf("%.0fM", math.Round(float64(b)/Million)) - case b > Thousand: - return fmt.Sprintf("%.0fK", math.Round(float64(b)/Thousand)) + case b >= Billion: + number := float64(b) / Billion + if number == math.Floor(number) { + return fmt.Sprintf("%.0fB", number) // no decimals if whole number + } + return fmt.Sprintf("%.1fB", number) // one decimal if not a whole number + case b >= Million: + number := float64(b) / Million + if number == math.Floor(number) { + return fmt.Sprintf("%.0fM", number) // no decimals if whole number + } + return fmt.Sprintf("%.2fM", number) // two decimals if not a whole number + case b >= Thousand: + return fmt.Sprintf("%.0fK", float64(b)/Thousand) default: return fmt.Sprintf("%d", b) } diff --git a/format/format_test.go b/format/format_test.go new file mode 100644 index 00000000..1d73c80b --- /dev/null +++ b/format/format_test.go @@ -0,0 +1,34 @@ +package format + +import ( + "testing" +) + +func TestHumanNumber(t *testing.T) { + + type testCase struct { + input uint64 + expected string + } + + testCases := []testCase{ + {0, "0"}, + {1000000, "1M"}, + {125000000, "125M"}, + {500500000, "500.50M"}, + {500550000, "500.55M"}, + {1000000000, "1B"}, + {2800000000, "2.8B"}, + {2850000000, "2.9B"}, + {1000000000000, "1000B"}, + } + + for _, tc := range testCases { + t.Run(tc.expected, func(t *testing.T) { + result := HumanNumber(tc.input) + if result != tc.expected { + t.Errorf("Expected %s, got %s", tc.expected, result) + } + }) + } +} diff --git a/gpu/gpu.go b/gpu/gpu.go index 21666c8d..f8bae9b0 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -31,8 +31,8 @@ type handles struct { } const ( - cudaMinimumMemory = 457 * format.MebiByte - rocmMinimumMemory = 457 * format.MebiByte + cudaMinimumMemory = 256 * format.MebiByte + rocmMinimumMemory = 256 * format.MebiByte ) var gpuMutex sync.Mutex diff --git a/gpu/gpu_darwin.go b/gpu/gpu_darwin.go index f8cc1adb..0ba02e1b 100644 --- a/gpu/gpu_darwin.go +++ b/gpu/gpu_darwin.go @@ -15,7 +15,7 @@ import ( ) const ( - metalMinimumMemory = 512 * format.MebiByte + metalMinimumMemory = 384 * format.MebiByte ) func GetGPUInfo() GpuInfoList { diff --git a/integration/utils_test.go b/integration/utils_test.go index 3e91187a..e133e76d 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -107,7 +107,7 @@ func startServer(ctx context.Context, ollamaHost string) error { if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost { slog.Info("setting env", "OLLAMA_HOST", ollamaHost) - os.Setenv("OLLAMA_HOST", ollamaHost) + t.Setenv("OLLAMA_HOST", ollamaHost) } slog.Info("starting server", "url", ollamaHost) diff --git a/llm/filetype.go b/llm/filetype.go new file mode 100644 index 00000000..e5e9410d --- /dev/null +++ b/llm/filetype.go @@ -0,0 +1,140 @@ +package llm + +import "fmt" + +type fileType uint32 + +const ( + fileTypeF32 fileType = iota + fileTypeF16 + fileTypeQ4_0 + fileTypeQ4_1 + fileTypeQ4_1_F16 + fileTypeQ4_2 // unused + fileTypeQ4_3 // unused + fileTypeQ8_0 + fileTypeQ5_0 + fileTypeQ5_1 + fileTypeQ2_K + fileTypeQ3_K_S + fileTypeQ3_K_M + fileTypeQ3_K_L + fileTypeQ4_K_S + fileTypeQ4_K_M + fileTypeQ5_K_S + fileTypeQ5_K_M + fileTypeQ6_K + fileTypeIQ2_XXS + fileTypeIQ2_XS + fileTypeQ2_K_S + fileTypeQ3_K_XS + fileTypeIQ3_XXS + + fileTypeUnknown +) + +func ParseFileType(s string) (fileType, error) { + switch s { + case "F32": + return fileTypeF32, nil + case "F16": + return fileTypeF16, nil + case "Q4_0": + return fileTypeQ4_0, nil + case "Q4_1": + return fileTypeQ4_1, nil + case "Q4_1_F16": + return fileTypeQ4_1_F16, nil + case "Q8_0": + return fileTypeQ8_0, nil + case "Q5_0": + return fileTypeQ5_0, nil + case "Q5_1": + return fileTypeQ5_1, nil + case "Q2_K": + return fileTypeQ2_K, nil + case "Q3_K_S": + return fileTypeQ3_K_S, nil + case "Q3_K_M": + return fileTypeQ3_K_M, nil + case "Q3_K_L": + return fileTypeQ3_K_L, nil + case "Q4_K_S": + return fileTypeQ4_K_S, nil + case "Q4_K_M": + return fileTypeQ4_K_M, nil + case "Q5_K_S": + return fileTypeQ5_K_S, nil + case "Q5_K_M": + return fileTypeQ5_K_M, nil + case "Q6_K": + return fileTypeQ6_K, nil + case "IQ2_XXS": + return fileTypeIQ2_XXS, nil + case "IQ2_XS": + return fileTypeIQ2_XS, nil + case "Q2_K_S": + return fileTypeQ2_K_S, nil + case "Q3_K_XS": + return fileTypeQ3_K_XS, nil + case "IQ3_XXS": + return fileTypeIQ3_XXS, nil + default: + return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s) + } +} + +func (t fileType) String() string { + switch t { + case fileTypeF32: + return "F32" + case fileTypeF16: + return "F16" + case fileTypeQ4_0: + return "Q4_0" + case fileTypeQ4_1: + return "Q4_1" + case fileTypeQ4_1_F16: + return "Q4_1_F16" + case fileTypeQ8_0: + return "Q8_0" + case fileTypeQ5_0: + return "Q5_0" + case fileTypeQ5_1: + return "Q5_1" + case fileTypeQ2_K: + return "Q2_K" + case fileTypeQ3_K_S: + return "Q3_K_S" + case fileTypeQ3_K_M: + return "Q3_K_M" + case fileTypeQ3_K_L: + return "Q3_K_L" + case fileTypeQ4_K_S: + return "Q4_K_S" + case fileTypeQ4_K_M: + return "Q4_K_M" + case fileTypeQ5_K_S: + return "Q5_K_S" + case fileTypeQ5_K_M: + return "Q5_K_M" + case fileTypeQ6_K: + return "Q6_K" + case fileTypeIQ2_XXS: + return "IQ2_XXS" + case fileTypeIQ2_XS: + return "IQ2_XS" + case fileTypeQ2_K_S: + return "Q2_K_S" + case fileTypeQ3_K_XS: + return "Q3_K_XS" + case fileTypeIQ3_XXS: + return "IQ3_XXS" + default: + return "unknown" + } +} + +func (t fileType) Value() uint32 { + return uint32(t) +} diff --git a/llm/ggml.go b/llm/ggml.go index 1b094027..1c21bde0 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -13,82 +13,6 @@ type GGML struct { model } -const ( - fileTypeF32 uint32 = iota - fileTypeF16 - fileTypeQ4_0 - fileTypeQ4_1 - fileTypeQ4_1_F16 - fileTypeQ8_0 uint32 = iota + 2 - fileTypeQ5_0 - fileTypeQ5_1 - fileTypeQ2_K - fileTypeQ3_K_S - fileTypeQ3_K_M - fileTypeQ3_K_L - fileTypeQ4_K_S - fileTypeQ4_K_M - fileTypeQ5_K_S - fileTypeQ5_K_M - fileTypeQ6_K - fileTypeIQ2_XXS - fileTypeIQ2_XS - fileTypeQ2_K_S - fileTypeQ3_K_XS - fileTypeIQ3_XXS -) - -func fileType(fileType uint32) string { - switch fileType { - case fileTypeF32: - return "F32" - case fileTypeF16: - return "F16" - case fileTypeQ4_0: - return "Q4_0" - case fileTypeQ4_1: - return "Q4_1" - case fileTypeQ4_1_F16: - return "Q4_1_F16" - case fileTypeQ8_0: - return "Q8_0" - case fileTypeQ5_0: - return "Q5_0" - case fileTypeQ5_1: - return "Q5_1" - case fileTypeQ2_K: - return "Q2_K" - case fileTypeQ3_K_S: - return "Q3_K_S" - case fileTypeQ3_K_M: - return "Q3_K_M" - case fileTypeQ3_K_L: - return "Q3_K_L" - case fileTypeQ4_K_S: - return "Q4_K_S" - case fileTypeQ4_K_M: - return "Q4_K_M" - case fileTypeQ5_K_S: - return "Q5_K_S" - case fileTypeQ5_K_M: - return "Q5_K_M" - case fileTypeQ6_K: - return "Q6_K" - case fileTypeIQ2_XXS: - return "IQ2_XXS" - case fileTypeIQ2_XS: - return "IQ2_XS" - case fileTypeQ2_K_S: - return "Q2_K_S" - case fileTypeQ3_K_XS: - return "Q3_K_XS" - case fileTypeIQ3_XXS: - return "IQ3_XXS" - default: - return "unknown" - } -} - type model interface { KV() KV Tensors() Tensors @@ -121,12 +45,12 @@ func (kv KV) ParameterCount() uint64 { return kv.u64("general.parameter_count") } -func (kv KV) FileType() string { +func (kv KV) FileType() fileType { if u64 := kv.u64("general.file_type"); u64 > 0 { return fileType(uint32(u64)) } - return "unknown" + return fileTypeUnknown } func (kv KV) BlockCount() uint64 { @@ -286,6 +210,23 @@ const ( var ErrUnsupportedFormat = errors.New("unsupported model format") +func DetectGGMLType(b []byte) string { + switch binary.LittleEndian.Uint32(b[:4]) { + case FILE_MAGIC_GGML: + return "ggml" + case FILE_MAGIC_GGMF: + return "ggmf" + case FILE_MAGIC_GGJT: + return "ggjt" + case FILE_MAGIC_GGLA: + return "ggla" + case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE: + return "gguf" + default: + return "" + } +} + func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) { var magic uint32 if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { diff --git a/llm/llm.go b/llm/llm.go index c81e2edf..2a0c4b91 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -20,7 +20,7 @@ func SystemInfo() string { return C.GoString(C.llama_print_system_info()) } -func Quantize(infile, outfile, filetype string) error { +func Quantize(infile, outfile string, ftype fileType) error { cinfile := C.CString(infile) defer C.free(unsafe.Pointer(cinfile)) @@ -29,58 +29,10 @@ func Quantize(infile, outfile, filetype string) error { params := C.llama_model_quantize_default_params() params.nthread = -1 + params.ftype = ftype.Value() - switch filetype { - case "F32": - params.ftype = fileTypeF32 - case "F16": - params.ftype = fileTypeF16 - case "Q4_0": - params.ftype = fileTypeQ4_0 - case "Q4_1": - params.ftype = fileTypeQ4_1 - case "Q4_1_F16": - params.ftype = fileTypeQ4_1_F16 - case "Q8_0": - params.ftype = fileTypeQ8_0 - case "Q5_0": - params.ftype = fileTypeQ5_0 - case "Q5_1": - params.ftype = fileTypeQ5_1 - case "Q2_K": - params.ftype = fileTypeQ2_K - case "Q3_K_S": - params.ftype = fileTypeQ3_K_S - case "Q3_K_M": - params.ftype = fileTypeQ3_K_M - case "Q3_K_L": - params.ftype = fileTypeQ3_K_L - case "Q4_K_S": - params.ftype = fileTypeQ4_K_S - case "Q4_K_M": - params.ftype = fileTypeQ4_K_M - case "Q5_K_S": - params.ftype = fileTypeQ5_K_S - case "Q5_K_M": - params.ftype = fileTypeQ5_K_M - case "Q6_K": - params.ftype = fileTypeQ6_K - case "IQ2_XXS": - params.ftype = fileTypeIQ2_XXS - case "IQ2_XS": - params.ftype = fileTypeIQ2_XS - case "Q2_K_S": - params.ftype = fileTypeQ2_K_S - case "Q3_K_XS": - params.ftype = fileTypeQ3_K_XS - case "IQ3_XXS": - params.ftype = fileTypeIQ3_XXS - default: - return fmt.Errorf("unknown filetype: %s", filetype) - } - - if retval := C.llama_model_quantize(cinfile, coutfile, ¶ms); retval != 0 { - return fmt.Errorf("llama_model_quantize: %d", retval) + if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 { + return fmt.Errorf("llama_model_quantize: %d", rc) } return nil diff --git a/llm/memory.go b/llm/memory.go index 661a0c50..005a15aa 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -85,19 +85,19 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts graphPartialOffload = graphFullOffload } + layers := ggml.Tensors().Layers() + // memoryRequiredTotal represents the memory required for full GPU offloading (all layers) - memoryRequiredTotal := memoryMinimum + graphFullOffload + memoryRequiredTotal := memoryMinimum + graphFullOffload + layers["blk.0"].size() // memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers) - memoryRequiredPartial := memoryMinimum + graphPartialOffload + memoryRequiredPartial := memoryMinimum + graphPartialOffload + layers["blk.0"].size() if memoryRequiredPartial > memoryAvailable { slog.Debug("insufficient VRAM to load any model layers") return 0, 0 } - layers := ggml.Tensors().Layers() - var memoryLayerOutput uint64 if layer, ok := layers["output_norm"]; ok { memoryLayerOutput += layer.size() diff --git a/readline/readline.go b/readline/readline.go index 8ba7d89c..6fa45391 100644 --- a/readline/readline.go +++ b/readline/readline.go @@ -218,7 +218,7 @@ func (i *Instance) Readline() (string, error) { case CharCtrlZ: fd := int(syscall.Stdin) return handleCharCtrlZ(fd, i.Terminal.termios) - case CharEnter: + case CharEnter, CharCtrlJ: output := buf.String() if output != "" { i.History.Add([]rune(output)) @@ -232,7 +232,7 @@ func (i *Instance) Readline() (string, error) { metaDel = false continue } - if r >= CharSpace || r == CharEnter { + if r >= CharSpace || r == CharEnter || r == CharCtrlJ { buf.Add(r) } } diff --git a/server/images.go b/server/images.go index 76205392..a96db8d1 100644 --- a/server/images.go +++ b/server/images.go @@ -1,8 +1,8 @@ package server import ( - "archive/zip" "bytes" + "cmp" "context" "crypto/sha256" "encoding/base64" @@ -11,7 +11,6 @@ import ( "errors" "fmt" "io" - "io/fs" "log" "log/slog" "net/http" @@ -26,7 +25,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/auth" - "github.com/ollama/ollama/convert" "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/server/envconfig" @@ -54,7 +52,6 @@ type Model struct { System string License []string Digest string - Size int64 Options map[string]interface{} Messages []Message } @@ -158,50 +155,11 @@ type ConfigV2 struct { RootFS RootFS `json:"rootfs"` } -func (c *ConfigV2) SetModelFormat(format string) { - if c.ModelFormat == "" { - c.ModelFormat = format - } -} - -func (c *ConfigV2) SetModelFamily(families ...string) { - for _, family := range families { - if c.ModelFamily == "" { - c.ModelFamily = family - } - - if !slices.Contains(c.ModelFamilies, family) { - c.ModelFamilies = append(c.ModelFamilies, family) - } - } -} - -func (c *ConfigV2) SetModelType(modelType string) { - if c.ModelType == "" { - c.ModelType = modelType - } -} - -func (c *ConfigV2) SetFileType(fileType string) { - if c.FileType == "" { - c.FileType = fileType - } -} - type RootFS struct { Type string `json:"type"` DiffIDs []string `json:"diff_ids"` } -func (m *ManifestV2) GetTotalSize() (total int64) { - for _, layer := range m.Layers { - total += layer.Size - } - - total += m.Config.Size - return total -} - func GetManifest(mp ModelPath) (*ManifestV2, string, error) { fp, err := mp.GetManifestPath() if err != nil { @@ -242,7 +200,6 @@ func GetModel(name string) (*Model, error) { Digest: digest, Template: "{{ .Prompt }}", License: []string{}, - Size: manifest.GetTotalSize(), } filename, err := GetBlobsPath(manifest.Config.Digest) @@ -332,7 +289,7 @@ func GetModel(name string) (*Model, error) { return model, nil } -func realpath(mfDir, from string) string { +func realpath(rel, from string) string { abspath, err := filepath.Abs(from) if err != nil { return from @@ -349,22 +306,15 @@ func realpath(mfDir, from string) string { return filepath.Join(home, from[2:]) } - if _, err := os.Stat(filepath.Join(mfDir, from)); err == nil { + if _, err := os.Stat(filepath.Join(rel, from)); err == nil { // this is a file relative to the Modelfile - return filepath.Join(mfDir, from) + return filepath.Join(rel, from) } return abspath } -func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) error { - deleteMap := make(map[string]struct{}) - if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil { - for _, layer := range append(manifest.Layers, manifest.Config) { - deleteMap[layer.Digest] = struct{}{} - } - } - +func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) { config := ConfigV2{ OS: "linux", Architecture: "amd64", @@ -373,250 +323,181 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m }, } - var layers Layers - messages := []string{} - - params := make(map[string][]string) - fromParams := make(map[string]any) + var messages []*api.Message + parameters := make(map[string]any) + var layers []*Layer for _, c := range modelfile.Commands { mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) switch c.Name { - case "model": - if strings.HasPrefix(c.Args, "@") { - blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) + case "model", "adapter": + var baseLayers []*layerWithGGML + if name := model.ParseName(c.Args); name.IsValid() { + baseLayers, err = parseFromModel(ctx, name, fn) + if err != nil { + return err + } + } else if strings.HasPrefix(c.Args, "@") { + blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) if err != nil { return err } - c.Args = blobPath - } - - pathName := realpath(modelFileDir, c.Args) - - ggufName, err := convertModel(name, pathName, fn) - if err != nil { - var pathErr *fs.PathError - switch { - case errors.Is(err, zip.ErrFormat): - // it's not a safetensor archive - case errors.As(err, &pathErr): - // it's not a file on disk, could be a model reference - default: + blob, err := os.Open(blobpath) + if err != nil { return err } + defer blob.Close() + + baseLayers, err = parseFromFile(ctx, blob, fn) + if err != nil { + return err + } + } else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil { + defer file.Close() + + baseLayers, err = parseFromFile(ctx, file, fn) + if err != nil { + return err + } + } else { + return fmt.Errorf("invalid model reference: %s", c.Args) } - if ggufName != "" { - pathName = ggufName - defer os.RemoveAll(ggufName) - - if quantization != "" { - quantization = strings.ToUpper(quantization) - fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", "F16", quantization)}) - tempfile, err := os.CreateTemp(filepath.Dir(ggufName), quantization) + for _, baseLayer := range baseLayers { + if quantization != "" && + baseLayer.MediaType == "application/vnd.ollama.image.model" && + baseLayer.GGML != nil && + baseLayer.GGML.Name() == "gguf" { + want, err := llm.ParseFileType(quantization) if err != nil { return err } - defer os.RemoveAll(tempfile.Name()) - if err := llm.Quantize(ggufName, tempfile.Name(), quantization); err != nil { - return err - } + ft := baseLayer.GGML.KV().FileType() + if !slices.Contains([]string{"F16", "F32"}, ft.String()) { + return errors.New("quantization is only supported for F16 and F32 models") + } else if want != ft { + fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantization)}) - if err := tempfile.Close(); err != nil { - return err - } - - pathName = tempfile.Name() - } - } - - bin, err := os.Open(pathName) - if err != nil { - // not a file on disk so must be a model reference - modelpath := ParseModelPath(c.Args) - manifest, _, err := GetManifest(modelpath) - switch { - case errors.Is(err, os.ErrNotExist): - fn(api.ProgressResponse{Status: "pulling model"}) - if err := PullModel(ctx, c.Args, ®istryOptions{}, fn); err != nil { - return err - } - - manifest, _, err = GetManifest(modelpath) - if err != nil { - return err - } - case err != nil: - return err - } - - fn(api.ProgressResponse{Status: "reading model metadata"}) - fromConfigPath, err := GetBlobsPath(manifest.Config.Digest) - if err != nil { - return err - } - - fromConfigFile, err := os.Open(fromConfigPath) - if err != nil { - return err - } - defer fromConfigFile.Close() - - var fromConfig ConfigV2 - if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil { - return err - } - - // if the model is still not in gguf format, error out - if fromConfig.ModelFormat != "gguf" { - return fmt.Errorf("%s is not in gguf format, this base model is not compatible with this version of ollama", c.Args) - } - - config.SetModelFormat(fromConfig.ModelFormat) - config.SetModelFamily(append(fromConfig.ModelFamilies, fromConfig.ModelFamily)...) - config.SetModelType(fromConfig.ModelType) - config.SetFileType(fromConfig.FileType) - - for _, layer := range manifest.Layers { - deleteMap[layer.Digest] = struct{}{} - if layer.MediaType == "application/vnd.ollama.image.params" { - fromParamsPath, err := GetBlobsPath(layer.Digest) + blob, err := GetBlobsPath(baseLayer.Digest) if err != nil { return err } - fromParamsFile, err := os.Open(fromParamsPath) + temp, err := os.CreateTemp(filepath.Dir(blob), quantization) if err != nil { return err } - defer fromParamsFile.Close() + defer temp.Close() + defer os.Remove(temp.Name()) - if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil { + if err := llm.Quantize(blob, temp.Name(), want); err != nil { + return err + } + + baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType) + if err != nil { return err } } - - layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname()) - if err != nil { - return err - } - - layers.Add(layer) } - deleteMap[manifest.Config.Digest] = struct{}{} - continue + if baseLayer.GGML != nil { + config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name()) + config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture()) + config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount())) + config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String()) + config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture()) + } + + layers = append(layers, baseLayer.Layer) } - defer bin.Close() - - var offset int64 - for { - fn(api.ProgressResponse{Status: "creating model layer"}) - if _, err := bin.Seek(offset, io.SeekStart); err != nil { - return err - } - - ggml, size, err := llm.DecodeGGML(bin) - if errors.Is(err, io.EOF) { - break - } else if errors.Is(err, llm.ErrUnsupportedFormat) { - return fmt.Errorf("model binary specified in FROM field is not a valid gguf format model, %w", err) - } else if err != nil { - return err - } - - config.SetModelFormat(ggml.Name()) - config.SetModelFamily(ggml.KV().Architecture()) - config.SetModelType(format.HumanNumber(ggml.KV().ParameterCount())) - config.SetFileType(ggml.KV().FileType()) - - mediatype := mediatype - if ggml.KV().Architecture() == "clip" { - mediatype = "application/vnd.ollama.image.projector" - } - - sr := io.NewSectionReader(bin, offset, size) - layer, err := NewLayer(sr, mediatype) - if err != nil { - return err - } - - layers.Add(layer) - - offset += size - } - case "adapter": - if strings.HasPrefix(c.Args, "@") { - blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) - if err != nil { - return err - } - - c.Args = blobPath - } - - fn(api.ProgressResponse{Status: "creating adapter layer"}) - bin, err := os.Open(realpath(modelFileDir, c.Args)) - if err != nil { - return err - } - defer bin.Close() - - _, size, err := llm.DecodeGGML(bin) + case "license", "template", "system": + blob := strings.NewReader(c.Args) + layer, err := NewLayer(blob, mediatype) if err != nil { return err } - sr := io.NewSectionReader(bin, 0, size) - layer, err := NewLayer(sr, mediatype) - if err != nil { - return err + if c.Name != "license" { + // replace + layers = slices.DeleteFunc(layers, func(layer *Layer) bool { + return layer.MediaType == mediatype + }) } - layers.Add(layer) - case "license": - fn(api.ProgressResponse{Status: "creating license layer"}) - - bin := strings.NewReader(c.Args) - layer, err := NewLayer(bin, mediatype) - if err != nil { - return err - } - - layers.Add(layer) - case "template", "system": - fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)}) - - bin := strings.NewReader(c.Args) - layer, err := NewLayer(bin, mediatype) - if err != nil { - return err - } - - layers.Replace(layer) + layers = append(layers, layer) case "message": - messages = append(messages, c.Args) + role, content, ok := strings.Cut(c.Args, ": ") + if !ok { + return fmt.Errorf("invalid message: %s", c.Args) + } + + messages = append(messages, &api.Message{Role: role, Content: content}) default: - params[c.Name] = append(params[c.Name], c.Args) + ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}}) + if err != nil { + return err + } + + for k, v := range ps { + if ks, ok := parameters[k].([]string); ok { + parameters[k] = append(ks, v.([]string)...) + } else if vs, ok := v.([]string); ok { + parameters[k] = vs + } else { + parameters[k] = v + } + } } } - if len(messages) > 0 { - fn(api.ProgressResponse{Status: "creating parameters layer"}) + var err2 error + layers = slices.DeleteFunc(layers, func(layer *Layer) bool { + switch layer.MediaType { + case "application/vnd.ollama.image.message": + // if there are new messages, remove the inherited ones + if len(messages) > 0 { + return true + } - msgs := make([]api.Message, 0) + return false + case "application/vnd.ollama.image.params": + // merge inherited parameters with new ones + r, err := layer.Open() + if err != nil { + err2 = err + return false + } + defer r.Close() - for _, m := range messages { - // todo: handle images - msg := strings.SplitN(m, ": ", 2) - msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]}) + var ps map[string]any + if err := json.NewDecoder(r).Decode(&ps); err != nil { + err2 = err + return false + } + + for k, v := range ps { + if _, ok := parameters[k]; !ok { + parameters[k] = v + } + } + + return true + default: + return false } + }) + if err2 != nil { + return err2 + } + + if len(messages) > 0 { var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(msgs); err != nil { + if err := json.NewEncoder(&b).Encode(messages); err != nil { return err } @@ -625,39 +506,25 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m return err } - layers.Replace(layer) + layers = append(layers, layer) } - if len(params) > 0 { - fn(api.ProgressResponse{Status: "creating parameters layer"}) - - formattedParams, err := api.FormatParams(params) - if err != nil { - return err - } - - for k, v := range fromParams { - if _, ok := formattedParams[k]; !ok { - formattedParams[k] = v - } - } - + if len(parameters) > 0 { var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(formattedParams); err != nil { + if err := json.NewEncoder(&b).Encode(parameters); err != nil { return err } - fn(api.ProgressResponse{Status: "creating config layer"}) layer, err := NewLayer(&b, "application/vnd.ollama.image.params") if err != nil { return err } - layers.Replace(layer) + layers = append(layers, layer) } - digests := make([]string, len(layers.items)) - for i, layer := range layers.items { + digests := make([]string, len(layers)) + for i, layer := range layers { digests[i] = layer.Digest } @@ -668,36 +535,37 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m return err } - configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json") + layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json") if err != nil { return err } - delete(deleteMap, configLayer.Digest) + for _, layer := range append(layers, layer) { + if layer.status != "" { + fn(api.ProgressResponse{Status: layer.status}) + } + } - for _, layer := range append(layers.items, configLayer) { - committed, err := layer.Commit() - if err != nil { - return err + unref := make(map[string]struct{}) + if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil { + for _, layer := range manifest.Layers { + if !slices.Contains(digests, layer.Digest) { + unref[layer.Digest] = struct{}{} + } } - status := "writing layer" - if !committed { - status = "using already created layer" + if manifest.Config.Digest != layer.Digest { + unref[manifest.Config.Digest] = struct{}{} } - - fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)}) - - delete(deleteMap, layer.Digest) } fn(api.ProgressResponse{Status: "writing manifest"}) - if err := WriteManifest(name, configLayer, layers.items); err != nil { + if err := WriteManifest(name, layer, layers); err != nil { return err } if !envconfig.NoPrune { - if err := deleteUnusedLayers(nil, deleteMap, false); err != nil { + if err := deleteUnusedLayers(nil, unref, false); err != nil { return err } } @@ -706,74 +574,6 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m return nil } -func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string, error) { - r, err := zip.OpenReader(path) - if err != nil { - return "", err - } - defer r.Close() - - tempDir, err := os.MkdirTemp("", "ollama-convert") - if err != nil { - return "", err - } - defer os.RemoveAll(tempDir) - - fn(api.ProgressResponse{Status: "unpacking model metadata"}) - for _, f := range r.File { - fpath := filepath.Join(tempDir, f.Name) - outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) - if err != nil { - return "", err - } - - rc, err := f.Open() - if err != nil { - return "", err - } - - _, err = io.Copy(outFile, rc) - if err != nil { - return "", err - } - - outFile.Close() - rc.Close() - } - - mf, err := convert.GetModelFormat(tempDir) - if err != nil { - return "", err - } - - params, err := mf.GetParams(tempDir) - if err != nil { - return "", err - } - - mArch, err := mf.GetModelArch(name, tempDir, params) - if err != nil { - return "", err - } - - fn(api.ProgressResponse{Status: "processing tensors"}) - if err := mArch.GetTensors(); err != nil { - return "", err - } - - if err := mArch.LoadVocab(); err != nil { - return "", err - } - - fn(api.ProgressResponse{Status: "converting model"}) - path, err = mArch.WriteGGUF() - if err != nil { - return "", err - } - - return path, nil -} - func CopyModel(src, dst model.Name) error { if !dst.IsFullyQualified() { return model.Unqualified(dst) diff --git a/server/layers.go b/server/layer.go similarity index 53% rename from server/layers.go rename to server/layer.go index 07787406..dcca3854 100644 --- a/server/layers.go +++ b/server/layer.go @@ -5,39 +5,14 @@ import ( "fmt" "io" "os" - "strings" - - "golang.org/x/exp/slices" ) -type Layers struct { - items []*Layer -} - -func (ls *Layers) Add(layer *Layer) { - if layer.Size > 0 { - ls.items = append(ls.items, layer) - } -} - -func (ls *Layers) Replace(layer *Layer) { - if layer.Size > 0 { - mediatype := layer.MediaType - layers := slices.DeleteFunc(ls.items, func(l *Layer) bool { - return l.MediaType == mediatype - }) - - ls.items = append(layers, layer) - } -} - type Layer struct { MediaType string `json:"mediaType"` Digest string `json:"digest"` Size int64 `json:"size"` From string `json:"from,omitempty"` - - tempFileName string + status string } func NewLayer(r io.Reader, mediatype string) (*Layer, error) { @@ -46,14 +21,12 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) { return nil, err } - const delimiter = "-" - - pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter) - temp, err := os.CreateTemp(blobs, pattern) + temp, err := os.CreateTemp(blobs, "sha256-") if err != nil { return nil, err } defer temp.Close() + defer os.Remove(temp.Name()) sha256sum := sha256.New() n, err := io.Copy(io.MultiWriter(temp, sha256sum), r) @@ -61,11 +34,29 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) { return nil, err } + if err := temp.Close(); err != nil { + return nil, err + } + + digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)) + blob, err := GetBlobsPath(digest) + if err != nil { + return nil, err + } + + status := "using existing layer" + if _, err := os.Stat(blob); err != nil { + status = "creating new layer" + if err := os.Rename(temp.Name(), blob); err != nil { + return nil, err + } + } + return &Layer{ - MediaType: mediatype, - Digest: fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)), - Size: n, - tempFileName: temp.Name(), + MediaType: mediatype, + Digest: digest, + Size: n, + status: fmt.Sprintf("%s %s", status, digest), }, nil } @@ -85,21 +76,15 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) { Digest: digest, Size: fi.Size(), From: from, + status: fmt.Sprintf("using existing layer %s", digest), }, nil } -func (l *Layer) Commit() (bool, error) { - // always remove temp - defer os.Remove(l.tempFileName) - +func (l *Layer) Open() (io.ReadCloser, error) { blob, err := GetBlobsPath(l.Digest) if err != nil { - return false, err + return nil, err } - if _, err := os.Stat(blob); err != nil { - return true, os.Rename(l.tempFileName, blob) - } - - return false, nil + return os.Open(blob) } diff --git a/server/manifest.go b/server/manifest.go new file mode 100644 index 00000000..8a17700e --- /dev/null +++ b/server/manifest.go @@ -0,0 +1,79 @@ +package server + +import ( + "bytes" + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/ollama/ollama/types/model" +) + +type Manifest struct { + ManifestV2 + Digest string `json:"-"` +} + +func (m *Manifest) Size() (size int64) { + for _, layer := range append(m.Layers, m.Config) { + size += layer.Size + } + + return +} + +func ParseNamedManifest(name model.Name) (*Manifest, error) { + if !name.IsFullyQualified() { + return nil, model.Unqualified(name) + } + + manifests, err := GetManifestPath() + if err != nil { + return nil, err + } + + var manifest ManifestV2 + manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath())) + if err != nil { + return nil, err + } + + sha256sum := sha256.New() + if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil { + return nil, err + } + + return &Manifest{ + ManifestV2: manifest, + Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), + }, nil +} + +func WriteManifest(name string, config *Layer, layers []*Layer) error { + manifest := ManifestV2{ + SchemaVersion: 2, + MediaType: "application/vnd.docker.distribution.manifest.v2+json", + Config: config, + Layers: layers, + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(manifest); err != nil { + return err + } + + modelpath := ParseModelPath(name) + manifestPath, err := modelpath.GetManifestPath() + if err != nil { + return err + } + + if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil { + return err + } + + return os.WriteFile(manifestPath, b.Bytes(), 0o644) +} diff --git a/server/manifests.go b/server/manifests.go deleted file mode 100644 index 2b39db65..00000000 --- a/server/manifests.go +++ /dev/null @@ -1,34 +0,0 @@ -package server - -import ( - "bytes" - "encoding/json" - "os" - "path/filepath" -) - -func WriteManifest(name string, config *Layer, layers []*Layer) error { - manifest := ManifestV2{ - SchemaVersion: 2, - MediaType: "application/vnd.docker.distribution.manifest.v2+json", - Config: config, - Layers: layers, - } - - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(manifest); err != nil { - return err - } - - modelpath := ParseModelPath(name) - manifestPath, err := modelpath.GetManifestPath() - if err != nil { - return err - } - - if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil { - return err - } - - return os.WriteFile(manifestPath, b.Bytes(), 0o644) -} diff --git a/server/model.go b/server/model.go new file mode 100644 index 00000000..eea5d13a --- /dev/null +++ b/server/model.go @@ -0,0 +1,261 @@ +package server + +import ( + "archive/zip" + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/convert" + "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/types/model" +) + +type layerWithGGML struct { + *Layer + *llm.GGML +} + +func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { + modelpath := ParseModelPath(name.String()) + manifest, _, err := GetManifest(modelpath) + switch { + case errors.Is(err, os.ErrNotExist): + if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil { + return nil, err + } + + modelpath = ParseModelPath(name.String()) + manifest, _, err = GetManifest(modelpath) + if err != nil { + return nil, err + } + case err != nil: + return nil, err + } + + for _, layer := range manifest.Layers { + layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname()) + if err != nil { + return nil, err + } + + switch layer.MediaType { + case "application/vnd.ollama.image.model", + "application/vnd.ollama.image.projector", + "application/vnd.ollama.image.adapter": + blobpath, err := GetBlobsPath(layer.Digest) + if err != nil { + return nil, err + } + + blob, err := os.Open(blobpath) + if err != nil { + return nil, err + } + defer blob.Close() + + ggml, _, err := llm.DecodeGGML(blob) + if err != nil { + return nil, err + } + + layers = append(layers, &layerWithGGML{layer, ggml}) + default: + layers = append(layers, &layerWithGGML{layer, nil}) + } + + } + + return layers, nil +} + +func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { + stat, err := file.Stat() + if err != nil { + return nil, err + } + + r, err := zip.NewReader(file, stat.Size()) + if err != nil { + return nil, err + } + + tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "") + if err != nil { + return nil, err + } + defer os.RemoveAll(tempdir) + + fn(api.ProgressResponse{Status: "unpacking model metadata"}) + for _, f := range r.File { + // TODO(mxyng): this should not write out all files to disk + outfile, err := os.Create(filepath.Join(tempdir, f.Name)) + if err != nil { + return nil, err + } + defer outfile.Close() + + infile, err := f.Open() + if err != nil { + return nil, err + } + defer infile.Close() + + if _, err = io.Copy(outfile, infile); err != nil { + return nil, err + } + + if err := outfile.Close(); err != nil { + return nil, err + } + + if err := infile.Close(); err != nil { + return nil, err + } + } + + mf, err := convert.GetModelFormat(tempdir) + if err != nil { + return nil, err + } + + params, err := mf.GetParams(tempdir) + if err != nil { + return nil, err + } + + mArch, err := mf.GetModelArch("", tempdir, params) + if err != nil { + return nil, err + } + + fn(api.ProgressResponse{Status: "processing tensors"}) + if err := mArch.GetTensors(); err != nil { + return nil, err + } + + if err := mArch.LoadVocab(); err != nil { + return nil, err + } + + fn(api.ProgressResponse{Status: "converting model"}) + + // TODO(mxyng): this should write directly into a layer + // e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model") + temp, err := os.CreateTemp(tempdir, "fp16") + if err != nil { + return nil, err + } + defer temp.Close() + defer os.Remove(temp.Name()) + + if err = mArch.WriteGGUF(temp); err != nil { + return nil, err + } + + if _, err := temp.Seek(0, io.SeekStart); err != nil { + return nil, err + } + + layer, err := NewLayer(temp, "application/vnd.ollama.image.model") + if err != nil { + return nil, fmt.Errorf("aaa: %w", err) + } + + blobpath, err := GetBlobsPath(layer.Digest) + if err != nil { + return nil, err + } + + bin, err := os.Open(blobpath) + if err != nil { + return nil, err + } + defer bin.Close() + + ggml, _, err := llm.DecodeGGML(bin) + if err != nil { + return nil, err + } + + layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "") + if err != nil { + return nil, err + } + + layers = append(layers, &layerWithGGML{layer, ggml}) + return layers, nil +} + +func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { + sr := io.NewSectionReader(file, 0, 512) + contentType, err := detectContentType(sr) + if err != nil { + return nil, err + } + + switch contentType { + case "gguf", "ggla": + // noop + case "application/zip": + return parseFromZipFile(ctx, file, fn) + default: + return nil, fmt.Errorf("unsupported content type: %s", contentType) + } + + stat, err := file.Stat() + if err != nil { + return nil, err + } + + var offset int64 + for offset < stat.Size() { + ggml, n, err := llm.DecodeGGML(file) + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return nil, err + } + + mediatype := "application/vnd.ollama.image.model" + if ggml.Name() == "ggla" { + mediatype = "application/vnd.ollama.image.adapter" + } else if ggml.KV().Architecture() == "clip" { + mediatype = "application/vnd.ollama.image.projector" + } + + layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype) + if err != nil { + return nil, err + } + + layers = append(layers, &layerWithGGML{layer, ggml}) + offset = n + } + + return layers, nil +} + +func detectContentType(r io.Reader) (string, error) { + var b bytes.Buffer + if _, err := io.Copy(&b, r); err != nil { + return "", err + } + + if contentType := llm.DetectGGMLType(b.Bytes()); contentType != "" { + return contentType, nil + } + + if contentType := http.DetectContentType(b.Bytes()); contentType != "application/octet-stream" { + return contentType, nil + } + + return "unknown", nil +} diff --git a/server/routes.go b/server/routes.go index da51fbbe..e0459271 100644 --- a/server/routes.go +++ b/server/routes.go @@ -560,7 +560,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), req.Quantization, modelfile, fn); err != nil { + if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(req.Quantization), modelfile, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -719,62 +719,71 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } func (s *Server) ListModelsHandler(c *gin.Context) { - models := make([]api.ModelResponse, 0) - manifestsPath, err := GetManifestPath() + manifests, err := GetManifestPath() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - modelResponse := func(modelName string) (api.ModelResponse, error) { - model, err := GetModel(modelName) - if err != nil { - return api.ModelResponse{}, err - } - - modelDetails := api.ModelDetails{ - Format: model.Config.ModelFormat, - Family: model.Config.ModelFamily, - Families: model.Config.ModelFamilies, - ParameterSize: model.Config.ModelType, - QuantizationLevel: model.Config.FileType, - } - - return api.ModelResponse{ - Model: model.ShortName, - Name: model.ShortName, - Size: model.Size, - Digest: model.Digest, - Details: modelDetails, - }, nil - } - - walkFunc := func(path string, info os.FileInfo, _ error) error { + var models []api.ModelResponse + if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error { if !info.IsDir() { - path, tag := filepath.Split(path) - model := strings.Trim(strings.TrimPrefix(path, manifestsPath), string(os.PathSeparator)) - modelPath := strings.Join([]string{model, tag}, ":") - canonicalModelPath := strings.ReplaceAll(modelPath, string(os.PathSeparator), "/") - - resp, err := modelResponse(canonicalModelPath) + rel, err := filepath.Rel(manifests, path) if err != nil { - slog.Info(fmt.Sprintf("skipping file: %s", canonicalModelPath)) - // nolint: nilerr + return err + } + + if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil { + return err + } else if hidden { return nil } - resp.ModifiedAt = info.ModTime() - models = append(models, resp) + n := model.ParseNameFromFilepath(rel) + m, err := ParseNamedManifest(n) + if err != nil { + return err + } + + f, err := m.Config.Open() + if err != nil { + return err + } + defer f.Close() + + var c ConfigV2 + if err := json.NewDecoder(f).Decode(&c); err != nil { + return err + } + + // tag should never be masked + models = append(models, api.ModelResponse{ + Model: n.DisplayShortest(), + Name: n.DisplayShortest(), + Size: m.Size(), + Digest: m.Digest, + ModifiedAt: info.ModTime(), + Details: api.ModelDetails{ + Format: c.ModelFormat, + Family: c.ModelFamily, + Families: c.ModelFamilies, + ParameterSize: c.ModelType, + QuantizationLevel: c.FileType, + }, + }) } return nil - } - - if err := filepath.Walk(manifestsPath, walkFunc); err != nil { + }); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + slices.SortStableFunc(models, func(i, j api.ModelResponse) int { + // most recently modified first + return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix()) + }) + c.JSON(http.StatusOK, api.ListResponse{Models: models}) } @@ -796,7 +805,7 @@ func (s *Server) CopyModelHandler(c *gin.Context) { dst := model.ParseName(r.Destination) if !dst.IsValid() { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Source)}) + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)}) return } @@ -852,11 +861,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) { return } - if _, err := layer.Commit(); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - c.Status(http.StatusCreated) } diff --git a/server/routes_test.go b/server/routes_test.go index 27e53cbd..896dc27b 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -124,14 +124,12 @@ func Test_Routes(t *testing.T) { Method: http.MethodPost, Path: "/api/create", Setup: func(t *testing.T, req *http.Request) { - f, err := os.CreateTemp(t.TempDir(), "ollama-model") - assert.Nil(t, err) - defer f.Close() + fname := createTestFile(t, "ollama-model") stream := false createReq := api.CreateRequest{ Name: "t-bone", - Modelfile: fmt.Sprintf("FROM %s", f.Name()), + Modelfile: fmt.Sprintf("FROM %s", fname), Stream: &stream, } jsonData, err := json.Marshal(createReq) @@ -216,27 +214,25 @@ func Test_Routes(t *testing.T) { httpSrv := httptest.NewServer(router) t.Cleanup(httpSrv.Close) - workDir, err := os.MkdirTemp("", "ollama-test") - assert.Nil(t, err) - defer os.RemoveAll(workDir) - os.Setenv("OLLAMA_MODELS", workDir) + t.Setenv("OLLAMA_MODELS", t.TempDir()) for _, tc := range testCases { - t.Logf("Running Test: [%s]", tc.Name) - u := httpSrv.URL + tc.Path - req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) - assert.Nil(t, err) + t.Run(tc.Name, func(t *testing.T) { + u := httpSrv.URL + tc.Path + req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) + assert.Nil(t, err) - if tc.Setup != nil { - tc.Setup(t, req) - } + if tc.Setup != nil { + tc.Setup(t, req) + } - resp, err := httpSrv.Client().Do(req) - assert.Nil(t, err) - defer resp.Body.Close() + resp, err := httpSrv.Client().Do(req) + assert.Nil(t, err) + defer resp.Body.Close() - if tc.Expected != nil { - tc.Expected(t, resp) - } + if tc.Expected != nil { + tc.Expected(t, resp) + } + }) } } diff --git a/types/model/file.go b/types/model/file.go index c614fd32..ee398309 100644 --- a/types/model/file.go +++ b/types/model/file.go @@ -249,10 +249,6 @@ func quote(s string) string { } func unquote(s string) (string, bool) { - if len(s) == 0 { - return "", false - } - // TODO: single quotes if len(s) >= 3 && s[:3] == `"""` { if len(s) >= 6 && s[len(s)-3:] == `"""` { diff --git a/types/model/file_test.go b/types/model/file_test.go index d51c8d70..8e71760c 100644 --- a/types/model/file_test.go +++ b/types/model/file_test.go @@ -489,6 +489,10 @@ You are a store greeter. Always responsed with "Hello!". """ MESSAGE user Hey there! MESSAGE assistant Hello, I want to parse all the things! +`, + ` +FROM foo +SYSTEM "" `, } diff --git a/types/model/name.go b/types/model/name.go index fbb30fd4..6d2a187b 100644 --- a/types/model/name.go +++ b/types/model/name.go @@ -35,6 +35,12 @@ func Unqualified(n Name) error { // spot in logs. const MissingPart = "!MISSING!" +const ( + defaultHost = "registry.ollama.ai" + defaultNamespace = "library" + defaultTag = "latest" +) + // DefaultName returns a name with the default values for the host, namespace, // and tag parts. The model and digest parts are empty. // @@ -43,9 +49,9 @@ const MissingPart = "!MISSING!" // - The default tag is ("latest") func DefaultName() Name { return Name{ - Host: "registry.ollama.ai", - Namespace: "library", - Tag: "latest", + Host: defaultHost, + Namespace: defaultNamespace, + Tag: defaultTag, } } @@ -169,6 +175,27 @@ func ParseNameBare(s string) Name { return n } +// ParseNameFromFilepath parses a 4-part filepath as a Name. The parts are +// expected to be in the form: +// +// { host } "/" { namespace } "/" { model } "/" { tag } +func ParseNameFromFilepath(s string) (n Name) { + parts := strings.Split(s, string(filepath.Separator)) + if len(parts) != 4 { + return Name{} + } + + n.Host = parts[0] + n.Namespace = parts[1] + n.Model = parts[2] + n.Tag = parts[3] + if !n.IsFullyQualified() { + return Name{} + } + + return n +} + // Merge merges the host, namespace, and tag parts of the two names, // preferring the non-empty parts of a. func Merge(a, b Name) Name { @@ -203,6 +230,27 @@ func (n Name) String() string { return b.String() } +// DisplayShort returns a short string version of the name. +func (n Name) DisplayShortest() string { + var sb strings.Builder + + if n.Host != defaultHost { + sb.WriteString(n.Host) + sb.WriteByte('/') + sb.WriteString(n.Namespace) + sb.WriteByte('/') + } else if n.Namespace != defaultNamespace { + sb.WriteString(n.Namespace) + sb.WriteByte('/') + } + + // always include model and tag + sb.WriteString(n.Model) + sb.WriteString(":") + sb.WriteString(n.Tag) + return sb.String() +} + // IsValid reports whether all parts of the name are present and valid. The // digest is a special case, and is checked for validity only if present. func (n Name) IsValid() bool { diff --git a/types/model/name_test.go b/types/model/name_test.go index 47263c20..19bc2e2d 100644 --- a/types/model/name_test.go +++ b/types/model/name_test.go @@ -309,6 +309,49 @@ func TestParseDigest(t *testing.T) { } } +func TestParseNameFromFilepath(t *testing.T) { + cases := map[string]Name{ + filepath.Join("host", "namespace", "model", "tag"): {Host: "host", Namespace: "namespace", Model: "model", Tag: "tag"}, + filepath.Join("host:port", "namespace", "model", "tag"): {Host: "host:port", Namespace: "namespace", Model: "model", Tag: "tag"}, + filepath.Join("namespace", "model", "tag"): {}, + filepath.Join("model", "tag"): {}, + filepath.Join("model"): {}, + filepath.Join("..", "..", "model", "tag"): {}, + filepath.Join("", "namespace", ".", "tag"): {}, + filepath.Join(".", ".", ".", "."): {}, + filepath.Join("/", "path", "to", "random", "file"): {}, + } + + for in, want := range cases { + t.Run(in, func(t *testing.T) { + got := ParseNameFromFilepath(in) + + if !reflect.DeepEqual(got, want) { + t.Errorf("parseNameFromFilepath(%q) = %v; want %v", in, got, want) + } + }) + } +} + +func TestDisplayShortest(t *testing.T) { + cases := map[string]string{ + "registry.ollama.ai/library/model:latest": "model:latest", + "registry.ollama.ai/library/model:tag": "model:tag", + "registry.ollama.ai/namespace/model:tag": "namespace/model:tag", + "host/namespace/model:tag": "host/namespace/model:tag", + "host/library/model:tag": "host/library/model:tag", + } + + for in, want := range cases { + t.Run(in, func(t *testing.T) { + got := ParseNameBare(in).DisplayShortest() + if got != want { + t.Errorf("parseName(%q).DisplayShortest() = %q; want %q", in, got, want) + } + }) + } +} + func FuzzName(f *testing.F) { for s := range testCases { f.Add(s)