accept.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. //go:build ee
  2. // +build ee
  3. package invite
  4. import (
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "net/url"
  9. "github.com/porter-dev/porter/api/server/handlers"
  10. "github.com/porter-dev/porter/api/server/shared/apierrors"
  11. "github.com/porter-dev/porter/api/server/shared/config"
  12. "github.com/porter-dev/porter/api/server/shared/requestutils"
  13. "github.com/porter-dev/porter/api/types"
  14. "github.com/porter-dev/porter/internal/models"
  15. "github.com/porter-dev/porter/internal/telemetry"
  16. "gorm.io/gorm"
  17. )
  18. type InviteAcceptHandler struct {
  19. handlers.PorterHandler
  20. }
  21. func NewInviteAcceptHandler(
  22. config *config.Config,
  23. ) http.Handler {
  24. return &InviteAcceptHandler{
  25. PorterHandler: handlers.NewDefaultPorterHandler(config, nil, nil),
  26. }
  27. }
  28. func (c *InviteAcceptHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  29. ctx, span := telemetry.NewSpan(r.Context(), "serve-invite-accept")
  30. defer span.End()
  31. user, _ := ctx.Value(types.UserScope).(*models.User)
  32. projectID, _ := requestutils.GetURLParamUint(r, types.URLParamProjectID)
  33. token, _ := requestutils.GetURLParamString(r, types.URLParamInviteToken)
  34. proj, err := c.Repo().Project().ReadProject(projectID)
  35. if err != nil {
  36. vals := url.Values{}
  37. if errors.Is(err, gorm.ErrRecordNotFound) {
  38. vals.Add("error", "Invalid invite token")
  39. } else {
  40. vals.Add("error", "Unknown error")
  41. }
  42. http.Redirect(w, r, fmt.Sprintf("/dashboard?%s", vals.Encode()), 302)
  43. return
  44. }
  45. invite, err := c.Repo().Invite().ReadInviteByToken(token)
  46. if err != nil || invite.ProjectID != proj.ID {
  47. vals := url.Values{}
  48. vals.Add("error", "Invalid invite token")
  49. http.Redirect(w, r, fmt.Sprintf("/dashboard?%s", vals.Encode()), 302)
  50. return
  51. }
  52. // check that the invite has not expired and has not been accepted
  53. if invite.IsExpired() || invite.IsAccepted() {
  54. vals := url.Values{}
  55. vals.Add("error", "Invite has expired")
  56. http.Redirect(w, r, fmt.Sprintf("/dashboard?%s", vals.Encode()), 302)
  57. return
  58. }
  59. // check that the invite email matches the user's email
  60. if user.Email != invite.Email {
  61. vals := url.Values{}
  62. vals.Add("error", "Wrong email for invite")
  63. http.Redirect(w, r, fmt.Sprintf("/dashboard?%s", vals.Encode()), 302)
  64. return
  65. }
  66. kind := invite.Kind
  67. if kind == "" {
  68. kind = models.RoleDeveloper
  69. }
  70. role := &models.Role{
  71. Role: types.Role{
  72. UserID: user.ID,
  73. ProjectID: proj.ID,
  74. Kind: types.RoleKind(kind),
  75. },
  76. }
  77. if role, err = c.Repo().Project().CreateProjectRole(proj, role); err != nil {
  78. c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  79. return
  80. }
  81. // update the invite
  82. invite.UserID = user.ID
  83. if _, err = c.Repo().Invite().UpdateInvite(invite); err != nil {
  84. c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  85. return
  86. }
  87. http.Redirect(w, r, "/dashboard", 302)
  88. }