소스 검색

Always use admin user for billing (#4557)

Mauricio Araujo 2 년 전
부모
커밋
1904f348c8

+ 2 - 2
api/server/handlers/billing/create.go

@@ -37,7 +37,7 @@ func NewCreateBillingHandler(
 }
 
 func (c *CreateBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ctx, span := telemetry.NewSpan(r.Context(), "create-billing-endpoint")
+	ctx, span := telemetry.NewSpan(r.Context(), "serve-create-billing-method")
 	defer span.End()
 
 	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
@@ -68,7 +68,7 @@ func NewSetDefaultBillingHandler(
 }
 
 func (c *SetDefaultBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ctx, span := telemetry.NewSpan(r.Context(), "set-default-billing-endpoint")
+	ctx, span := telemetry.NewSpan(r.Context(), "serve-set-default-billing-method")
 	defer span.End()
 
 	user, _ := r.Context().Value(types.UserScope).(*models.User)

+ 1 - 1
api/server/handlers/billing/delete.go

@@ -31,7 +31,7 @@ func NewDeleteBillingHandler(
 }
 
 func (c *DeleteBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ctx, span := telemetry.NewSpan(r.Context(), "delete-billing-endpoint")
+	ctx, span := telemetry.NewSpan(r.Context(), "serve-delete-billing-method")
 	defer span.End()
 
 	user, _ := r.Context().Value(types.UserScope).(*models.User)

+ 1 - 1
api/server/handlers/billing/key.go

@@ -28,7 +28,7 @@ func NewGetPublishableKeyHandler(
 }
 
 func (c *GetPublishableKeyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ctx, span := telemetry.NewSpan(r.Context(), "get-publishable-key-endpoint")
+	ctx, span := telemetry.NewSpan(r.Context(), "serve-get-publishable-key")
 	defer span.End()
 
 	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)

+ 31 - 5
api/server/handlers/billing/list.go

@@ -35,7 +35,7 @@ func NewListBillingHandler(
 }
 
 func (c *ListBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ctx, span := telemetry.NewSpan(r.Context(), "list-payment-endpoint")
+	ctx, span := telemetry.NewSpan(r.Context(), "serve-list-payment-methods")
 	defer span.End()
 
 	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
@@ -61,16 +61,42 @@ func NewCheckPaymentEnabledHandler(
 }
 
 func (c *CheckPaymentEnabledHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ctx, span := telemetry.NewSpan(r.Context(), "check-payment-endpoint")
+	ctx, span := telemetry.NewSpan(r.Context(), "serve-check-payment-enabled")
 	defer span.End()
 
 	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
-	user, _ := ctx.Value(types.UserScope).(*models.User)
+
+	// Get project roles
+	roles, err := c.Repo().Project().ListProjectRolesOrdered(proj.ID)
+	if err != nil {
+		err = telemetry.Error(ctx, span, err, "error listing project roles")
+		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+		return
+	}
+
+	// Get the project admin user
+	var adminUser *models.User
+	for _, role := range roles {
+		if role.Kind == types.RoleAdmin {
+			adminUser, err = c.Repo().User().ReadUser(role.UserID)
+			if err != nil {
+				err = telemetry.Error(ctx, span, err, "error reading user")
+				c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+				return
+			}
+			break
+		}
+	}
+
+	telemetry.WithAttributes(span,
+		telemetry.AttributeKV{Key: "admin-user-id", Value: adminUser.ID},
+		telemetry.AttributeKV{Key: "admin-user-email", Value: adminUser.Email},
+	)
 
 	// Create billing customer for project and set the billing ID if it doesn't exist
 	var shouldUpdate bool
 	if proj.BillingID == "" {
-		billingID, err := c.Config().BillingManager.StripeClient.CreateCustomer(ctx, user.Email, proj.ID, proj.Name)
+		billingID, err := c.Config().BillingManager.StripeClient.CreateCustomer(ctx, adminUser.Email, proj.ID, proj.Name)
 		if err != nil {
 			err = telemetry.Error(ctx, span, err, "error creating billing customer")
 			c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
@@ -85,7 +111,7 @@ func (c *CheckPaymentEnabledHandler) ServeHTTP(w http.ResponseWriter, r *http.Re
 	}
 
 	if c.Config().BillingManager.MetronomeEnabled && proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient) && proj.UsageID == uuid.Nil {
-		customerID, customerPlanID, err := c.Config().BillingManager.MetronomeClient.CreateCustomerWithPlan(ctx, user.Email, proj.Name, proj.ID, proj.BillingID, proj.EnableSandbox)
+		customerID, customerPlanID, err := c.Config().BillingManager.MetronomeClient.CreateCustomerWithPlan(ctx, adminUser.Email, proj.Name, proj.ID, proj.BillingID, proj.EnableSandbox)
 		if err != nil {
 			err = telemetry.Error(ctx, span, err, "error creating Metronome customer")
 			c.HandleAPIError(w, r, apierrors.NewErrInternal(err))

+ 12 - 1
internal/repository/gorm/project.go

@@ -104,7 +104,18 @@ func (repo *ProjectRepository) ListProjectsByUserID(userID uint) ([]*models.Proj
 	return projects, nil
 }
 
-// ReadProject gets a projects specified by a unique id
+// ListProjectRolesOrdered returns a list of roles for a project ordered by creation date
+func (repo *ProjectRepository) ListProjectRolesOrdered(projID uint) ([]models.Role, error) {
+	project := &models.Project{}
+
+	if err := repo.db.Preload("Roles").Where("id = ?", projID).Order("created_at").First(&project).Error; err != nil {
+		return nil, err
+	}
+
+	return project.Roles, nil
+}
+
+// ListProjectRoles returns a list of roles for the project
 func (repo *ProjectRepository) ListProjectRoles(projID uint) ([]models.Role, error) {
 	project := &models.Project{}
 

+ 1 - 0
internal/repository/project.go

@@ -16,6 +16,7 @@ type ProjectRepository interface {
 	ReadProject(id uint) (*models.Project, error)
 	ReadProjectRole(projID, userID uint) (*models.Role, error)
 	ListProjectRoles(projID uint) ([]models.Role, error)
+	ListProjectRolesOrdered(projID uint) ([]models.Role, error)
 	ListProjectsByUserID(userID uint) ([]*models.Project, error)
 	DeleteProject(project *models.Project) (*models.Project, error)
 	DeleteProjectRole(projID, userID uint) (*models.Role, error)

+ 16 - 0
internal/repository/test/project.go

@@ -171,6 +171,22 @@ func (repo *ProjectRepository) ListProjectRoles(projID uint) ([]models.Role, err
 	return repo.projects[index].Roles, nil
 }
 
+// ListProjectRoles returns a list of roles for the project
+func (repo *ProjectRepository) ListProjectRolesOrdered(projID uint) ([]models.Role, error) {
+	if !repo.canQuery {
+		return nil, errors.New("Cannot read from database")
+	}
+
+	if int(projID-1) >= len(repo.projects) || repo.projects[projID-1] == nil {
+		return nil, gorm.ErrRecordNotFound
+	}
+
+	index := int(projID - 1)
+	repo.projects[index] = nil
+
+	return repo.projects[index].Roles, nil
+}
+
 // DeleteProject removes a project
 func (repo *ProjectRepository) DeleteProject(project *models.Project) (*models.Project, error) {
 	if !repo.canQuery {