mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
cli: add device signin flow when doing ollama push (#12405)
This commit is contained in:
20
cmd/cmd.go
20
cmd/cmd.go
@@ -540,6 +540,25 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
n := model.ParseName(args[0])
|
||||
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
|
||||
_, err := client.Whoami(cmd.Context())
|
||||
if err != nil {
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Println("You need to be signed in to push models to ollama.com.")
|
||||
fmt.Println()
|
||||
|
||||
if aErr.SigninURL != "" {
|
||||
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
@@ -576,7 +595,6 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||
|
||||
n := model.ParseName(args[0])
|
||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
|
||||
@@ -491,9 +491,35 @@ func TestPushHandler(t *testing.T) {
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
},
|
||||
"/api/me": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST request, got %s", r.Method)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
|
||||
},
|
||||
{
|
||||
name: "not signed in push",
|
||||
modelName: "notsignedin-model",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/me": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST request, got %s", r.Method)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://somethingsomething",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedOutput: "You need to be signed in to push",
|
||||
},
|
||||
{
|
||||
name: "unauthorized push",
|
||||
modelName: "unauthorized-model",
|
||||
@@ -508,6 +534,11 @@ func TestPushHandler(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
"/api/me": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST request, got %s", r.Method)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
|
||||
},
|
||||
@@ -564,7 +595,7 @@ func TestPushHandler(t *testing.T) {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if tt.expectedOutput != "" {
|
||||
if got := string(stdout); got != tt.expectedOutput {
|
||||
if got := string(stdout); !strings.Contains(got, tt.expectedOutput) {
|
||||
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user