Просмотр исходного кода

Change billing clients to non-pointers, move creation logic to check payment endpoint

Mauricio Araujo 2 лет назад
Родитель
Сommit
7d04a0ea99

+ 0 - 20
api/server/handlers/billing/key.go

@@ -5,7 +5,6 @@ import (
 
 	"github.com/porter-dev/porter/api/server/handlers"
 	"github.com/porter-dev/porter/api/server/shared"
-	"github.com/porter-dev/porter/api/server/shared/apierrors"
 	"github.com/porter-dev/porter/api/server/shared/config"
 	"github.com/porter-dev/porter/api/types"
 	"github.com/porter-dev/porter/internal/models"
@@ -33,25 +32,6 @@ func (c *GetPublishableKeyHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ
 	defer span.End()
 
 	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
-	user, _ := ctx.Value(types.UserScope).(*models.User)
-
-	// Create billing customer for project and set the billing ID if it doesn't exist
-	if proj.BillingID == "" {
-		billingID, err := c.Config().BillingManager.StripeClient.CreateCustomer(ctx, user.Email, proj.ID, proj.Name)
-		if err != nil {
-			err = telemetry.Error(ctx, span, err, "error creating billing customer")
-			c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-			return
-		}
-		proj.BillingID = billingID
-
-		_, err = c.Repo().Project().UpdateProject(proj)
-		if err != nil {
-			err := telemetry.Error(ctx, span, err, "error updating project")
-			c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-			return
-		}
-	}
 
 	// There is no easy way to pass environment variables to the frontend,
 	// so for now pass via the backend. This is acceptable because the key is

+ 44 - 19
api/server/handlers/billing/list.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"net/http"
 
+	"github.com/google/uuid"
 	"github.com/porter-dev/porter/api/server/handlers"
 	"github.com/porter-dev/porter/api/server/shared"
 	"github.com/porter-dev/porter/api/server/shared/apierrors"
@@ -38,25 +39,6 @@ func (c *ListBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	defer span.End()
 
 	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
-	user, _ := ctx.Value(types.UserScope).(*models.User)
-
-	// Create billing customer for project and set the billing ID if it doesn't exist
-	if proj.BillingID == "" {
-		billingID, err := c.Config().BillingManager.StripeClient.CreateCustomer(ctx, user.Email, proj.ID, proj.Name)
-		if err != nil {
-			err = telemetry.Error(ctx, span, err, "error creating billing customer")
-			c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-			return
-		}
-		proj.BillingID = billingID
-
-		_, err = c.Repo().Project().UpdateProject(proj)
-		if err != nil {
-			err := telemetry.Error(ctx, span, err, "error updating project")
-			c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-			return
-		}
-	}
 
 	paymentMethods, err := c.Config().BillingManager.StripeClient.ListPaymentMethod(ctx, proj.BillingID)
 	if err != nil {
@@ -83,6 +65,49 @@ func (c *CheckPaymentEnabledHandler) ServeHTTP(w http.ResponseWriter, r *http.Re
 	defer span.End()
 
 	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
+	user, _ := ctx.Value(types.UserScope).(*models.User)
+
+	// Create billing customer for project and set the billing ID if it doesn't exist
+	var shouldUpdate bool
+	if proj.BillingID == "" {
+		billingID, err := c.Config().BillingManager.StripeClient.CreateCustomer(ctx, user.Email, proj.ID, proj.Name)
+		if err != nil {
+			err = telemetry.Error(ctx, span, err, "error creating billing customer")
+			c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+			return
+		}
+		proj.BillingID = billingID
+		shouldUpdate = true
+
+		telemetry.WithAttributes(span,
+			telemetry.AttributeKV{Key: "billing-id", Value: proj.BillingID},
+		)
+	}
+
+	if proj.UsageID == uuid.Nil {
+		customerID, customerPlanID, err := c.Config().BillingManager.MetronomeClient.CreateCustomerWithPlan(user.CompanyName, proj.Name, proj.ID, proj.BillingID, c.Config().ServerConf.PorterCloudPlanID)
+		if err != nil {
+			err = telemetry.Error(ctx, span, err, "error creating Metronome customer")
+			c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+		}
+		proj.UsageID = customerID
+		proj.UsagePlanID = customerPlanID
+		shouldUpdate = true
+
+		telemetry.WithAttributes(span,
+			telemetry.AttributeKV{Key: "usage-id", Value: proj.UsageID},
+			telemetry.AttributeKV{Key: "usage-plan-id", Value: proj.UsagePlanID},
+		)
+	}
+
+	if shouldUpdate {
+		_, err := c.Repo().Project().UpdateProject(proj)
+		if err != nil {
+			err := telemetry.Error(ctx, span, err, "error updating project")
+			c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+			return
+		}
+	}
 
 	paymentEnabled, err := c.Config().BillingManager.StripeClient.CheckPaymentEnabled(ctx, proj.BillingID)
 	if err != nil {

+ 2 - 1
api/server/handlers/project/create.go

@@ -99,11 +99,12 @@ func (p *ProjectCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
 
 	// Create Metronome customer and add to starter plan
 	if p.Config().ServerConf.MetronomeAPIKey != "" && p.Config().ServerConf.PorterCloudPlanID != "" && proj.GetFeatureFlag(models.MetronomeEnabled, p.Config().LaunchDarklyClient) {
-		customerPlanID, err := p.Config().BillingManager.MetronomeClient.CreateCustomerWithPlan(user.CompanyName, proj.Name, proj.ID, proj.BillingID, p.Config().ServerConf.PorterCloudPlanID)
+		customerID, customerPlanID, err := p.Config().BillingManager.MetronomeClient.CreateCustomerWithPlan(user.CompanyName, proj.Name, proj.ID, proj.BillingID, p.Config().ServerConf.PorterCloudPlanID)
 		if err != nil {
 			err = telemetry.Error(ctx, span, err, "error creating Metronome customer")
 			p.HandleAPIError(w, r, apierrors.NewErrInternal(err))
 		}
+		proj.UsageID = customerID
 		proj.UsagePlanID = customerPlanID
 		telemetry.WithAttributes(span,
 			telemetry.AttributeKV{Key: "usage-id", Value: proj.UsageID},

+ 19 - 16
api/server/shared/config/loader/loader.go

@@ -40,10 +40,8 @@ import (
 )
 
 var (
-	// InstanceBillingManager manages Stripe and Metronome clients
-	InstanceBillingManager billing.Manager
-	InstanceEnvConf        *envloader.EnvConf
-	InstanceDB             *pgorm.DB
+	InstanceEnvConf *envloader.EnvConf
+	InstanceDB      *pgorm.DB
 )
 
 type EnvConfigLoader struct {
@@ -92,7 +90,6 @@ func (e *EnvConfigLoader) LoadConfig() (res *config.Config, err error) {
 		ServerConf:        sc,
 		DBConf:            envConf.DBConf,
 		RedisConf:         envConf.RedisConf,
-		BillingManager:    InstanceBillingManager,
 		CredentialBackend: instanceCredentialBackend,
 	}
 	res.Logger.Info().Msg("Loading MetadataFromConf")
@@ -248,14 +245,6 @@ func (e *EnvConfigLoader) LoadConfig() (res *config.Config, err error) {
 	}
 	res.LaunchDarklyClient = launchDarklyClient
 
-	if sc.StripeSecretKey == "" {
-		res.Logger.Info().Msg("STRIPE_SECRET_KEY not set, all Stripe functionality will be disabled")
-	}
-
-	if sc.MetronomeAPIKey == "" {
-		res.Logger.Info().Msg("METRONOME_API_KEY not set, all Metronome functionality will be disabled")
-	}
-
 	if sc.SlackClientID != "" && sc.SlackClientSecret != "" {
 		res.Logger.Info().Msg("Creating Slack client")
 		res.SlackConf = oauth.NewSlackClient(&oauth.Config{
@@ -342,10 +331,24 @@ func (e *EnvConfigLoader) LoadConfig() (res *config.Config, err error) {
 		CollectorURL: sc.TelemetryCollectorURL,
 	}
 
+	var (
+		stripeClient    billing.StripeClient
+		metronomeClient billing.MetronomeClient
+	)
+	if sc.StripeSecretKey != "" {
+		stripeClient = billing.NewStripeClient(InstanceEnvConf.ServerConf.StripeSecretKey, InstanceEnvConf.ServerConf.StripePublishableKey)
+	} else {
+		res.Logger.Info().Msg("STRIPE_SECRET_KEY not set, all Stripe functionality will be disabled")
+	}
+
+	if sc.MetronomeAPIKey != "" {
+		metronomeClient = billing.NewMetronomeClient(InstanceEnvConf.ServerConf.MetronomeAPIKey)
+	} else {
+		res.Logger.Info().Msg("METRONOME_API_KEY not set, all Metronome functionality will be disabled")
+	}
+
 	res.Logger.Info().Msg("Creating billing manager")
-	stripeClient := billing.NewStripeClient(InstanceEnvConf.ServerConf.StripeSecretKey, InstanceEnvConf.ServerConf.StripePublishableKey)
-	metronomeClient := billing.NewMetronomeClient(InstanceEnvConf.ServerConf.MetronomeAPIKey)
-	InstanceBillingManager = billing.Manager{
+	res.BillingManager = billing.Manager{
 		StripeClient:    stripeClient,
 		MetronomeClient: metronomeClient,
 	}

+ 11 - 9
internal/billing/metronome.go

@@ -34,7 +34,7 @@ func NewMetronomeClient(metronomeApiKey string) MetronomeClient {
 }
 
 // createCustomer will create the customer in Metronome
-func (m *MetronomeClient) createCustomer(orgName string, projectName string, projectID uint, billingID string) (customerID uuid.UUID, err error) {
+func (m MetronomeClient) createCustomer(orgName string, projectName string, projectID uint, billingID string) (customerID uuid.UUID, err error) {
 	path := "customers"
 	projIDStr := strconv.FormatUint(uint64(projectID), 10)
 
@@ -62,7 +62,7 @@ func (m *MetronomeClient) createCustomer(orgName string, projectName string, pro
 }
 
 // addCustomerPlan will start the customer on the given plan
-func (m *MetronomeClient) addCustomerPlan(customerID uuid.UUID, planID uuid.UUID) (customerPlanID uuid.UUID, err error) {
+func (m MetronomeClient) addCustomerPlan(customerID uuid.UUID, planID uuid.UUID) (customerPlanID uuid.UUID, err error) {
 	if customerID == uuid.Nil || planID == uuid.Nil {
 		return customerPlanID, fmt.Errorf("customer or plan id empty")
 	}
@@ -94,22 +94,24 @@ func (m *MetronomeClient) addCustomerPlan(customerID uuid.UUID, planID uuid.UUID
 }
 
 // CreateCustomerWithPlan will create the customer in Metronome and immediately add it to the plan
-func (m *MetronomeClient) CreateCustomerWithPlan(orgName string, projectName string, projectID uint, billingID string, planID string) (customerPlanID uuid.UUID, err error) {
+func (m MetronomeClient) CreateCustomerWithPlan(orgName string, projectName string, projectID uint, billingID string, planID string) (customerID uuid.UUID, customerPlanID uuid.UUID, err error) {
 	porterCloudPlanID, err := uuid.Parse(planID)
 	if err != nil {
-		return customerPlanID, fmt.Errorf("error parsing starter plan id: %w", err)
+		return customerID, customerPlanID, fmt.Errorf("error parsing starter plan id: %w", err)
 	}
 
-	customerID, err := m.createCustomer(orgName, projectName, projectID, billingID)
+	customerID, err = m.createCustomer(orgName, projectName, projectID, billingID)
 	if err != nil {
-		return customerPlanID, err
+		return customerID, customerPlanID, err
 	}
 
-	return m.addCustomerPlan(customerID, porterCloudPlanID)
+	customerPlanID, err = m.addCustomerPlan(customerID, porterCloudPlanID)
+
+	return customerID, customerPlanID, err
 }
 
 // EndCustomerPlan will immediately end the plan for the given customer
-func (m *MetronomeClient) EndCustomerPlan(customerID uuid.UUID, customerPlanID uuid.UUID) (err error) {
+func (m MetronomeClient) EndCustomerPlan(customerID uuid.UUID, customerPlanID uuid.UUID) (err error) {
 	if customerID == uuid.Nil || customerPlanID == uuid.Nil {
 		return fmt.Errorf("customer or customer plan id empty")
 	}
@@ -134,7 +136,7 @@ func (m *MetronomeClient) EndCustomerPlan(customerID uuid.UUID, customerPlanID u
 }
 
 // GetCustomerCredits will return the first credit grant for the customer
-func (m *MetronomeClient) GetCustomerCredits(customerID uuid.UUID) (credits int64, err error) {
+func (m MetronomeClient) GetCustomerCredits(customerID uuid.UUID) (credits int64, err error) {
 	if customerID == uuid.Nil {
 		return credits, fmt.Errorf("customer id empty")
 	}

+ 9 - 9
internal/billing/stripe.go

@@ -29,7 +29,7 @@ func NewStripeClient(secretKey string, publishableKey string) StripeClient {
 }
 
 // CreateCustomer will create a customer in Stripe only if the project doesn't have a BillingID
-func (s *StripeClient) CreateCustomer(ctx context.Context, userEmail string, projectID uint, projectName string) (customerID string, err error) {
+func (s StripeClient) CreateCustomer(ctx context.Context, userEmail string, projectID uint, projectName string) (customerID string, err error) {
 	ctx, span := telemetry.NewSpan(ctx, "create-stripe-customer")
 	defer span.End()
 
@@ -68,7 +68,7 @@ func (s *StripeClient) CreateCustomer(ctx context.Context, userEmail string, pro
 }
 
 // DeleteCustomer will delete the customer from the billing provider
-func (s *StripeClient) DeleteCustomer(ctx context.Context, customerID string) (err error) {
+func (s StripeClient) DeleteCustomer(ctx context.Context, customerID string) (err error) {
 	ctx, span := telemetry.NewSpan(ctx, "delete-stripe-customer")
 	defer span.End()
 
@@ -92,7 +92,7 @@ func (s *StripeClient) DeleteCustomer(ctx context.Context, customerID string) (e
 }
 
 // CheckPaymentEnabled will return true if the project has a payment method added, false otherwise
-func (s *StripeClient) CheckPaymentEnabled(ctx context.Context, customerID string) (paymentEnabled bool, err error) {
+func (s StripeClient) CheckPaymentEnabled(ctx context.Context, customerID string) (paymentEnabled bool, err error) {
 	_, span := telemetry.NewSpan(ctx, "check-stripe-payment-enabled")
 	defer span.End()
 
@@ -112,7 +112,7 @@ func (s *StripeClient) CheckPaymentEnabled(ctx context.Context, customerID strin
 }
 
 // ListPaymentMethod will return all payment methods for the project
-func (s *StripeClient) ListPaymentMethod(ctx context.Context, customerID string) (paymentMethods []types.PaymentMethod, err error) {
+func (s StripeClient) ListPaymentMethod(ctx context.Context, customerID string) (paymentMethods []types.PaymentMethod, err error) {
 	ctx, span := telemetry.NewSpan(ctx, "list-stripe-payment-method")
 	defer span.End()
 
@@ -165,7 +165,7 @@ func (s *StripeClient) ListPaymentMethod(ctx context.Context, customerID string)
 }
 
 // CreatePaymentMethod will add a new payment method to the project in Stripe
-func (s *StripeClient) CreatePaymentMethod(ctx context.Context, customerID string) (clientSecret string, err error) {
+func (s StripeClient) CreatePaymentMethod(ctx context.Context, customerID string) (clientSecret string, err error) {
 	ctx, span := telemetry.NewSpan(ctx, "create-stripe-payment-method")
 	defer span.End()
 
@@ -192,7 +192,7 @@ func (s *StripeClient) CreatePaymentMethod(ctx context.Context, customerID strin
 }
 
 // SetDefaultPaymentMethod will add a new payment method to the project in Stripe
-func (s *StripeClient) SetDefaultPaymentMethod(ctx context.Context, paymentMethodID string, customerID string) (err error) {
+func (s StripeClient) SetDefaultPaymentMethod(ctx context.Context, paymentMethodID string, customerID string) (err error) {
 	ctx, span := telemetry.NewSpan(ctx, "set-default-stripe-payment-method")
 	defer span.End()
 
@@ -217,7 +217,7 @@ func (s *StripeClient) SetDefaultPaymentMethod(ctx context.Context, paymentMetho
 }
 
 // DeletePaymentMethod will remove a payment method for the project in Stripe
-func (s *StripeClient) DeletePaymentMethod(ctx context.Context, paymentMethodID string) (err error) {
+func (s StripeClient) DeletePaymentMethod(ctx context.Context, paymentMethodID string) (err error) {
 	ctx, span := telemetry.NewSpan(ctx, "delete-stripe-payment-method")
 	defer span.End()
 
@@ -236,14 +236,14 @@ func (s *StripeClient) DeletePaymentMethod(ctx context.Context, paymentMethodID
 }
 
 // GetPublishableKey returns the Stripe publishable key
-func (s *StripeClient) GetPublishableKey(ctx context.Context) (key string) {
+func (s StripeClient) GetPublishableKey(ctx context.Context) (key string) {
 	_, span := telemetry.NewSpan(ctx, "get-stripe-publishable-key")
 	defer span.End()
 
 	return s.PublishableKey
 }
 
-func (s *StripeClient) checkDefaultPaymentMethod(customerID string) (defaultPaymentExists bool, defaultPaymentID string, err error) {
+func (s StripeClient) checkDefaultPaymentMethod(customerID string) (defaultPaymentExists bool, defaultPaymentID string, err error) {
 	// Get customer to check default payment method
 	customer, err := customer.Get(customerID, nil)
 	if err != nil {