digitalocean.go 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. package oauth_callback
  2. import (
  3. "fmt"
  4. "net/http"
  5. "net/url"
  6. "github.com/porter-dev/porter/api/server/handlers"
  7. "github.com/porter-dev/porter/api/server/shared"
  8. "github.com/porter-dev/porter/api/server/shared/apierrors"
  9. "github.com/porter-dev/porter/api/server/shared/config"
  10. "github.com/porter-dev/porter/api/types"
  11. "github.com/porter-dev/porter/internal/models/integrations"
  12. "golang.org/x/oauth2"
  13. )
  14. type OAuthCallbackDOHandler struct {
  15. handlers.PorterHandlerReadWriter
  16. }
  17. func NewOAuthCallbackDOHandler(
  18. config *config.Config,
  19. decoderValidator shared.RequestDecoderValidator,
  20. writer shared.ResultWriter,
  21. ) *OAuthCallbackDOHandler {
  22. return &OAuthCallbackDOHandler{
  23. PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, decoderValidator, writer),
  24. }
  25. }
  26. func (p *OAuthCallbackDOHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  27. session, err := p.Config().Store.Get(r, p.Config().ServerConf.CookieName)
  28. if err != nil {
  29. p.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  30. return
  31. }
  32. if _, ok := session.Values["state"]; !ok {
  33. p.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  34. return
  35. }
  36. if r.URL.Query().Get("state") != session.Values["state"] {
  37. p.HandleAPIError(w, r, apierrors.NewErrForbidden(err))
  38. return
  39. }
  40. token, err := p.Config().DOConf.Exchange(oauth2.NoContext, r.URL.Query().Get("code"))
  41. if err != nil {
  42. p.HandleAPIError(w, r, apierrors.NewErrForbidden(err))
  43. return
  44. }
  45. if !token.Valid() {
  46. p.HandleAPIError(w, r, apierrors.NewErrForbidden(fmt.Errorf("invalid token")))
  47. return
  48. }
  49. userID, _ := session.Values["user_id"].(uint)
  50. projID, _ := session.Values["project_id"].(uint)
  51. oauthInt := &integrations.OAuthIntegration{
  52. SharedOAuthModel: integrations.SharedOAuthModel{
  53. AccessToken: []byte(token.AccessToken),
  54. RefreshToken: []byte(token.RefreshToken),
  55. },
  56. Client: types.OAuthDigitalOcean,
  57. UserID: userID,
  58. ProjectID: projID,
  59. }
  60. // create the oauth integration first
  61. oauthInt, err = p.Repo().OAuthIntegration().CreateOAuthIntegration(oauthInt)
  62. if err != nil {
  63. p.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  64. return
  65. }
  66. if redirectStr, ok := session.Values["redirect_uri"].(string); ok && redirectStr != "" {
  67. // attempt to parse the redirect uri, if it fails just redirect to dashboard
  68. redirectURI, err := url.Parse(redirectStr)
  69. if err != nil {
  70. http.Redirect(w, r, "/dashboard", 302)
  71. }
  72. http.Redirect(w, r, fmt.Sprintf("%s?%s", redirectURI.Path, redirectURI.RawQuery), 302)
  73. } else {
  74. http.Redirect(w, r, "/dashboard", 302)
  75. }
  76. }