diff --git a/api/client.go b/api/client.go index dc099e95..fccbc9ad 100644 --- a/api/client.go +++ b/api/client.go @@ -23,11 +23,9 @@ import ( "net" "net/http" "net/url" - "os" "runtime" - "strconv" - "strings" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/version" ) @@ -65,10 +63,7 @@ func checkError(resp *http.Response, body []byte) error { // If the variable is not specified, a default ollama host and port will be // used. func ClientFromEnvironment() (*Client, error) { - ollamaHost, err := GetOllamaHost() - if err != nil { - return nil, err - } + ollamaHost := envconfig.Host return &Client{ base: &url.URL{ @@ -79,52 +74,6 @@ func ClientFromEnvironment() (*Client, error) { }, nil } -type OllamaHost struct { - Scheme string - Host string - Port string -} - -func GetOllamaHost() (OllamaHost, error) { - defaultPort := "11434" - - hostVar := os.Getenv("OLLAMA_HOST") - hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'")) - - scheme, hostport, ok := strings.Cut(hostVar, "://") - switch { - case !ok: - scheme, hostport = "http", hostVar - case scheme == "http": - defaultPort = "80" - case scheme == "https": - defaultPort = "443" - } - - // trim trailing slashes - hostport = strings.TrimRight(hostport, "/") - - host, port, err := net.SplitHostPort(hostport) - if err != nil { - host, port = "127.0.0.1", defaultPort - if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil { - host = ip.String() - } else if hostport != "" { - host = hostport - } - } - - if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 { - return OllamaHost{}, ErrInvalidHostPort - } - - return OllamaHost{ - Scheme: scheme, - Host: host, - Port: port, - }, nil -} - func NewClient(base *url.URL, http *http.Client) *Client { return &Client{ base: base, diff --git a/api/client_test.go b/api/client_test.go index b2c51d00..fe9fd74f 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -1,11 +1,9 @@ package api import ( - "fmt" - "net" "testing" - "github.com/stretchr/testify/assert" + "github.com/ollama/ollama/envconfig" ) func TestClientFromEnvironment(t *testing.T) { @@ -35,6 +33,7 @@ func TestClientFromEnvironment(t *testing.T) { for k, v := range testCases { t.Run(k, func(t *testing.T) { t.Setenv("OLLAMA_HOST", v.value) + envconfig.LoadConfig() client, err := ClientFromEnvironment() if err != v.err { @@ -46,40 +45,4 @@ func TestClientFromEnvironment(t *testing.T) { } }) } - - hostTestCases := map[string]*testCase{ - "empty": {value: "", expect: "127.0.0.1:11434"}, - "only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"}, - "only port": {value: ":1234", expect: ":1234"}, - "address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"}, - "hostname": {value: "example.com", expect: "example.com:11434"}, - "hostname and port": {value: "example.com:1234", expect: "example.com:1234"}, - "zero port": {value: ":0", expect: ":0"}, - "too large port": {value: ":66000", err: ErrInvalidHostPort}, - "too small port": {value: ":-1", err: ErrInvalidHostPort}, - "ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"}, - "ipv6 world open": {value: "[::]", expect: "[::]:11434"}, - "ipv6 no brackets": {value: "::1", expect: "[::1]:11434"}, - "ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"}, - "extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"}, - "extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"}, - "extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"}, - "extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"}, - } - - for k, v := range hostTestCases { - t.Run(k, func(t *testing.T) { - t.Setenv("OLLAMA_HOST", v.value) - - oh, err := GetOllamaHost() - if err != v.err { - t.Fatalf("expected %s, got %s", v.err, err) - } - - if err == nil { - host := net.JoinHostPort(oh.Host, oh.Port) - assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host)) - } - }) - } } diff --git a/api/types.go b/api/types.go index caf2ad70..d99cf3bc 100644 --- a/api/types.go +++ b/api/types.go @@ -2,7 +2,6 @@ package api import ( "encoding/json" - "errors" "fmt" "log/slog" "math" @@ -377,8 +376,6 @@ func (m *Metrics) Summary() { } } -var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") - func (opts *Options) FromMap(m map[string]interface{}) error { valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct diff --git a/cmd/cmd.go b/cmd/cmd.go index b5747543..ae7c8da8 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -960,17 +960,11 @@ func generate(cmd *cobra.Command, opts runOptions) error { } func RunServer(cmd *cobra.Command, _ []string) error { - // retrieve the OLLAMA_HOST environment variable - ollamaHost, err := api.GetOllamaHost() - if err != nil { - return err - } - if err := initializeKeypair(); err != nil { return err } - ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port)) + ln, err := net.Listen("tcp", net.JoinHostPort(envconfig.Host.Host, envconfig.Host.Port)) if err != nil { return err } diff --git a/envconfig/config.go b/envconfig/config.go index ae4e9939..2c3b6f77 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -1,6 +1,7 @@ package envconfig import ( + "errors" "fmt" "log/slog" "net" @@ -11,6 +12,18 @@ import ( "strings" ) +type OllamaHost struct { + Scheme string + Host string + Port string +} + +func (o OllamaHost) String() string { + return fmt.Sprintf("%s://%s:%s", o.Scheme, o.Host, o.Port) +} + +var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") + var ( // Set via OLLAMA_ORIGINS in the environment AllowOrigins []string @@ -34,6 +47,8 @@ var ( NoPrune bool // Set via OLLAMA_NUM_PARALLEL in the environment NumParallel int + // Set via OLLAMA_HOST in the environment + Host *OllamaHost // Set via OLLAMA_RUNNERS_DIR in the environment RunnersDir string // Set via OLLAMA_TMPDIR in the environment @@ -50,7 +65,7 @@ func AsMap() map[string]EnvVar { return map[string]EnvVar{ "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug, "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"}, - "OLLAMA_HOST": {"OLLAMA_HOST", "", "IP Address for the ollama server (default 127.0.0.1:11434)"}, + "OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"}, @@ -216,4 +231,54 @@ func LoadConfig() { } KeepAlive = clean("OLLAMA_KEEP_ALIVE") + + var err error + Host, err = getOllamaHost() + if err != nil { + slog.Error("invalid setting", "OLLAMA_HOST", Host, "error", err, "using default port", Host.Port) + } +} + +func getOllamaHost() (*OllamaHost, error) { + defaultPort := "11434" + + hostVar := os.Getenv("OLLAMA_HOST") + hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'")) + + scheme, hostport, ok := strings.Cut(hostVar, "://") + switch { + case !ok: + scheme, hostport = "http", hostVar + case scheme == "http": + defaultPort = "80" + case scheme == "https": + defaultPort = "443" + } + + // trim trailing slashes + hostport = strings.TrimRight(hostport, "/") + + host, port, err := net.SplitHostPort(hostport) + if err != nil { + host, port = "127.0.0.1", defaultPort + if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil { + host = ip.String() + } else if hostport != "" { + host = hostport + } + } + + if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 { + return &OllamaHost{ + Scheme: scheme, + Host: host, + Port: defaultPort, + }, ErrInvalidHostPort + } + + return &OllamaHost{ + Scheme: scheme, + Host: host, + Port: port, + }, nil } diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 429434ae..7d923d62 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -1,8 +1,11 @@ package envconfig import ( + "fmt" + "net" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -21,3 +24,48 @@ func TestConfig(t *testing.T) { LoadConfig() require.True(t, FlashAttention) } + +func TestClientFromEnvironment(t *testing.T) { + type testCase struct { + value string + expect string + err error + } + + hostTestCases := map[string]*testCase{ + "empty": {value: "", expect: "127.0.0.1:11434"}, + "only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"}, + "only port": {value: ":1234", expect: ":1234"}, + "address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"}, + "hostname": {value: "example.com", expect: "example.com:11434"}, + "hostname and port": {value: "example.com:1234", expect: "example.com:1234"}, + "zero port": {value: ":0", expect: ":0"}, + "too large port": {value: ":66000", err: ErrInvalidHostPort}, + "too small port": {value: ":-1", err: ErrInvalidHostPort}, + "ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"}, + "ipv6 world open": {value: "[::]", expect: "[::]:11434"}, + "ipv6 no brackets": {value: "::1", expect: "[::1]:11434"}, + "ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"}, + "extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"}, + "extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"}, + "extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"}, + "extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"}, + } + + for k, v := range hostTestCases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_HOST", v.value) + LoadConfig() + + oh, err := getOllamaHost() + if err != v.err { + t.Fatalf("expected %s, got %s", v.err, err) + } + + if err == nil { + host := net.JoinHostPort(oh.Host, oh.Port) + assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host)) + } + }) + } +} diff --git a/server/images.go b/server/images.go index 683057b8..5fd762ae 100644 --- a/server/images.go +++ b/server/images.go @@ -28,7 +28,6 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/parser" - "github.com/ollama/ollama/templates" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -333,7 +332,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio switch c.Name { case "model", "adapter": - var baseLayers []*layerWithGGML + var baseLayers []*layerGGML if name := model.ParseName(c.Args); name.IsValid() { baseLayers, err = parseFromModel(ctx, name, fn) if err != nil { @@ -435,20 +434,6 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount())) config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String()) config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture()) - - if s := baseLayer.GGML.KV().ChatTemplate(); s != "" { - if t, err := templates.NamedTemplate(s); err != nil { - slog.Debug("template detection", "error", err) - } else { - layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template") - if err != nil { - return err - } - - layer.status = fmt.Sprintf("using autodetected template %s", t.Name) - layers = append(layers, layer) - } - } } layers = append(layers, baseLayer.Layer) diff --git a/server/model.go b/server/model.go index ee2ae080..b262ea38 100644 --- a/server/model.go +++ b/server/model.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/http" "os" "path/filepath" @@ -14,17 +15,18 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/convert" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/templates" "github.com/ollama/ollama/types/model" ) var intermediateBlobs map[string]string = make(map[string]string) -type layerWithGGML struct { +type layerGGML struct { *Layer *llm.GGML } -func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { +func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) { m, err := ParseNamedManifest(name) switch { case errors.Is(err, os.ErrNotExist): @@ -66,16 +68,16 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe return nil, err } - layers = append(layers, &layerWithGGML{layer, ggml}) + layers = append(layers, &layerGGML{layer, ggml}) default: - layers = append(layers, &layerWithGGML{layer, nil}) + layers = append(layers, &layerGGML{layer, nil}) } } return layers, nil } -func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { +func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) { stat, err := file.Stat() if err != nil { return nil, err @@ -179,13 +181,13 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a return nil, err } - layers = append(layers, &layerWithGGML{layer, ggml}) + layers = append(layers, &layerGGML{layer, ggml}) intermediateBlobs[digest] = layer.Digest - return layers, nil + return detectChatTemplate(layers) } -func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { +func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) { sr := io.NewSectionReader(file, 0, 512) contentType, err := detectContentType(sr) if err != nil { @@ -227,10 +229,30 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap return nil, err } - layers = append(layers, &layerWithGGML{layer, ggml}) + layers = append(layers, &layerGGML{layer, ggml}) offset = n } + return detectChatTemplate(layers) +} + +func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) { + for _, layer := range layers { + if s := layer.GGML.KV().ChatTemplate(); s != "" { + if t, err := templates.NamedTemplate(s); err != nil { + slog.Debug("template detection", "error", err) + } else { + tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template") + if err != nil { + return nil, err + } + + tmpl.status = fmt.Sprintf("using autodetected template %s", t.Name) + layers = append(layers, &layerGGML{tmpl, nil}) + } + } + } + return layers, nil } diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 0fc76b96..a61a618f 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -535,7 +535,7 @@ func TestCreateDetectTemplate(t *testing.T) { } checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ - filepath.Join(p, "blobs", "sha256-06cd2687a518d624073f125f1db1c5c727f77c75e84a138fe745186dbbbb4cd7"), + filepath.Join(p, "blobs", "sha256-2f8e594e6f34b1b4d36a246628eeb3365ce442303d656f1fcc69e821722acea0"), filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"), filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"), })