2
0

digitalocean.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package oauth_callback
  2. import (
  3. "fmt"
  4. "net/http"
  5. "net/url"
  6. "github.com/porter-dev/porter/api/server/handlers"
  7. "github.com/porter-dev/porter/api/server/shared"
  8. "github.com/porter-dev/porter/api/server/shared/apierrors"
  9. "github.com/porter-dev/porter/api/server/shared/config"
  10. "github.com/porter-dev/porter/api/types"
  11. "github.com/porter-dev/porter/internal/models/integrations"
  12. "golang.org/x/oauth2"
  13. )
  14. type OAuthCallbackDOHandler struct {
  15. handlers.PorterHandlerReadWriter
  16. }
  17. func NewOAuthCallbackDOHandler(
  18. config *config.Config,
  19. decoderValidator shared.RequestDecoderValidator,
  20. writer shared.ResultWriter,
  21. ) *OAuthCallbackDOHandler {
  22. return &OAuthCallbackDOHandler{
  23. PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, decoderValidator, writer),
  24. }
  25. }
  26. func (p *OAuthCallbackDOHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  27. session, err := p.Config().Store.Get(r, p.Config().ServerConf.CookieName)
  28. if err != nil {
  29. p.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  30. return
  31. }
  32. if _, ok := session.Values["state"]; !ok {
  33. p.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  34. return
  35. }
  36. if r.URL.Query().Get("state") != session.Values["state"] {
  37. p.HandleAPIError(w, r, apierrors.NewErrForbidden(err))
  38. return
  39. }
  40. token, err := p.Config().DOConf.Exchange(oauth2.NoContext, r.URL.Query().Get("code"))
  41. if err != nil {
  42. p.HandleAPIError(w, r, apierrors.NewErrForbidden(err))
  43. return
  44. }
  45. if !token.Valid() {
  46. p.HandleAPIError(w, r, apierrors.NewErrForbidden(fmt.Errorf("invalid token")))
  47. return
  48. }
  49. userID, _ := session.Values["user_id"].(uint)
  50. projID, _ := session.Values["project_id"].(uint)
  51. oauthInt := &integrations.OAuthIntegration{
  52. SharedOAuthModel: integrations.SharedOAuthModel{
  53. AccessToken: []byte(token.AccessToken),
  54. RefreshToken: []byte(token.RefreshToken),
  55. },
  56. Client: types.OAuthDigitalOcean,
  57. UserID: userID,
  58. ProjectID: projID,
  59. }
  60. oauthInt.PopulateTargetMetadata()
  61. // create the oauth integration first
  62. oauthInt, err = p.Repo().OAuthIntegration().CreateOAuthIntegration(oauthInt)
  63. if err != nil {
  64. p.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  65. return
  66. }
  67. if redirectStr, ok := session.Values["redirect_uri"].(string); ok && redirectStr != "" {
  68. // attempt to parse the redirect uri, if it fails just redirect to dashboard
  69. redirectURI, err := url.Parse(redirectStr)
  70. if err != nil {
  71. http.Redirect(w, r, "/dashboard", 302)
  72. }
  73. http.Redirect(w, r, fmt.Sprintf("%s?%s", redirectURI.Path, redirectURI.RawQuery), 302)
  74. } else {
  75. http.Redirect(w, r, "/dashboard", 302)
  76. }
  77. }