diff --git a/cmd/cmd.go b/cmd/cmd.go index e8cfa134..3b41c71e 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -540,6 +540,25 @@ func PushHandler(cmd *cobra.Command, args []string) error { return err } + n := model.ParseName(args[0]) + if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") { + _, err := client.Whoami(cmd.Context()) + if err != nil { + var aErr api.AuthorizationError + if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized { + fmt.Println("You need to be signed in to push models to ollama.com.") + fmt.Println() + + if aErr.SigninURL != "" { + fmt.Printf(ConnectInstructions, aErr.SigninURL) + } + return nil + } + + return err + } + } + p := progress.NewProgress(os.Stderr) defer p.Stop() @@ -576,7 +595,6 @@ func PushHandler(cmd *cobra.Command, args []string) error { request := api.PushRequest{Name: args[0], Insecure: insecure} - n := model.ParseName(args[0]) if err := client.Push(cmd.Context(), &request, fn); err != nil { if spinner != nil { spinner.Stop() diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 24d28705..fb3b039e 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -491,9 +491,35 @@ func TestPushHandler(t *testing.T) { w.(http.Flusher).Flush() } }, + "/api/me": func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + }, }, expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n", }, + { + name: "not signed in push", + modelName: "notsignedin-model", + serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ + "/api/me": func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + err := json.NewEncoder(w).Encode(map[string]string{ + "error": "unauthorized", + "signin_url": "https://somethingsomething", + }) + if err != nil { + t.Fatal(err) + } + }, + }, + expectedOutput: "You need to be signed in to push", + }, { name: "unauthorized push", modelName: "unauthorized-model", @@ -508,6 +534,11 @@ func TestPushHandler(t *testing.T) { t.Fatal(err) } }, + "/api/me": func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + }, }, expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own", }, @@ -564,7 +595,7 @@ func TestPushHandler(t *testing.T) { t.Errorf("expected no error, got %v", err) } if tt.expectedOutput != "" { - if got := string(stdout); got != tt.expectedOutput { + if got := string(stdout); !strings.Contains(got, tt.expectedOutput) { t.Errorf("expected output %q, got %q", tt.expectedOutput, got) } }