Jelajahi Sumber

digitalocean oauth completed

Alexander Belanger 5 tahun lalu
induk
melakukan
35894a3220

+ 1 - 1
server/api/api.go

@@ -126,7 +126,7 @@ func New(conf *AppConfig) (*App, error) {
 	}
 
 	if sc := conf.ServerConf; sc.DOClientID != "" && sc.DOClientSecret != "" {
-		app.DOConf = oauth.NewGithubClient(&oauth.Config{
+		app.DOConf = oauth.NewDigitalOceanClient(&oauth.Config{
 			ClientID:     sc.DOClientID,
 			ClientSecret: sc.DOClientSecret,
 			Scopes:       []string{"read", "write"},

+ 99 - 0
server/api/oauth_do_handler.go

@@ -0,0 +1,99 @@
+package api
+
+import (
+	"fmt"
+	"net/http"
+
+	"github.com/porter-dev/porter/internal/oauth"
+	"golang.org/x/oauth2"
+
+	"github.com/porter-dev/porter/internal/models/integrations"
+)
+
+// HandleDOOAuthStartProject starts the oauth2 flow for a project digitalocean request.
+// In this handler, the project id gets written to the session (along with the oauth
+// state param), so that the correct project id can be identified in the callback.
+func (app *App) HandleDOOAuthStartProject(w http.ResponseWriter, r *http.Request) {
+	state := oauth.CreateRandomState()
+
+	err := app.populateOAuthSession(w, r, state, true)
+
+	if err != nil {
+		app.handleErrorDataRead(err, w)
+		return
+	}
+
+	// specify access type offline to get a refresh token
+	url := app.DOConf.AuthCodeURL(state, oauth2.AccessTypeOffline)
+
+	http.Redirect(w, r, url, 302)
+}
+
+// HandleDOOAuthCallback verifies the callback request by checking that the
+// state parameter has not been modified, and validates the token.
+func (app *App) HandleDOOAuthCallback(w http.ResponseWriter, r *http.Request) {
+	session, err := app.Store.Get(r, app.ServerConf.CookieName)
+
+	if err != nil {
+		app.handleErrorDataRead(err, w)
+		return
+	}
+
+	if _, ok := session.Values["state"]; !ok {
+		app.sendExternalError(
+			err,
+			http.StatusForbidden,
+			HTTPError{
+				Code: http.StatusForbidden,
+				Errors: []string{
+					"Could not read cookie: are cookies enabled?",
+				},
+			},
+			w,
+		)
+
+		return
+	}
+
+	if r.URL.Query().Get("state") != session.Values["state"] {
+		http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
+		return
+	}
+
+	token, err := app.DOConf.Exchange(oauth2.NoContext, r.URL.Query().Get("code"))
+
+	if err != nil {
+		http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
+		return
+	}
+
+	if !token.Valid() {
+		http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
+		return
+	}
+
+	userID, _ := session.Values["user_id"].(uint)
+	projID, _ := session.Values["project_id"].(uint)
+
+	oauthInt := &integrations.OAuthIntegration{
+		Client:       integrations.OAuthDigitalOcean,
+		UserID:       userID,
+		ProjectID:    projID,
+		AccessToken:  []byte(token.AccessToken),
+		RefreshToken: []byte(token.RefreshToken),
+	}
+
+	// create the oauth integration first
+	oauthInt, err = app.Repo.OAuthIntegration.CreateOAuthIntegration(oauthInt)
+
+	if err != nil {
+		app.handleErrorDataWrite(err, w)
+		return
+	}
+
+	if session.Values["query_params"] != "" {
+		http.Redirect(w, r, fmt.Sprintf("/dashboard?%s", session.Values["query_params"]), 302)
+	} else {
+		http.Redirect(w, r, "/dashboard", 302)
+	}
+}

+ 1 - 1
server/api/oauth_github_handler.go

@@ -139,7 +139,7 @@ func (app *App) populateOAuthSession(w http.ResponseWriter, r *http.Request, sta
 			return fmt.Errorf("could not read project id")
 		}
 
-		session.Values["project_id"] = projID
+		session.Values["project_id"] = uint(projID)
 		session.Values["query_params"] = r.URL.RawQuery
 	}
 

+ 16 - 0
server/router/router.go

@@ -148,6 +148,22 @@ func New(a *api.App) *chi.Mux {
 			requestlog.NewHandler(a.HandleGithubOAuthCallback, l),
 		)
 
+		r.Method(
+			"GET",
+			"/oauth/projects/{project_id}/digitalocean",
+			auth.DoesUserHaveProjectAccess(
+				requestlog.NewHandler(a.HandleDOOAuthStartProject, l),
+				mw.URLParam,
+				mw.WriteAccess,
+			),
+		)
+
+		r.Method(
+			"GET",
+			"/oauth/digitalocean/callback",
+			requestlog.NewHandler(a.HandleDOOAuthCallback, l),
+		)
+
 		// /api/projects routes
 		r.Method(
 			"GET",