Explorar o código

add internal billing support

Alexander Belanger %!s(int64=4) %!d(string=hai) anos
pai
achega
648e8d7b91

+ 12 - 0
api/server/handlers/billing/billing_ce.go

@@ -34,3 +34,15 @@ func NewBillingWebhookHandler(
 ) http.Handler {
 	return handlers.NewUnavailable(config, "billing_webhook")
 }
+
+type BillingAddProjectHandler struct {
+	handlers.PorterHandlerReader
+	handlers.Unavailable
+}
+
+func NewBillingAddProjectHandler(
+	config *config.Config,
+	decoderValidator shared.RequestDecoderValidator,
+) http.Handler {
+	return handlers.NewUnavailable(config, "billing_add_project")
+}

+ 6 - 0
api/server/handlers/billing/billing_ee.go

@@ -22,7 +22,13 @@ var NewBillingWebhookHandler func(
 	decoderValidator shared.RequestDecoderValidator,
 ) http.Handler
 
+var NewBillingAddProjectHandler func(
+	config *config.Config,
+	decoderValidator shared.RequestDecoderValidator,
+) http.Handler
+
 func init() {
 	NewBillingGetTokenHandler = billing.NewBillingGetTokenHandler
 	NewBillingWebhookHandler = billing.NewBillingWebhookHandler
+	NewBillingAddProjectHandler = billing.NewBillingAddProjectHandler
 }

+ 25 - 0
api/server/router/base.go

@@ -2,6 +2,7 @@ package router
 
 import (
 	"github.com/go-chi/chi"
+	"github.com/porter-dev/porter/api/server/handlers/billing"
 	"github.com/porter-dev/porter/api/server/handlers/credentials"
 	"github.com/porter-dev/porter/api/server/handlers/gitinstallation"
 	"github.com/porter-dev/porter/api/server/handlers/healthcheck"
@@ -511,5 +512,29 @@ func GetBaseRoutes(
 		Router:   r,
 	})
 
+	// POST /api/internal/billing -> billing.NewBillingAddProjectHandler
+	addProjectBillingEndpoint := factory.NewAPIEndpoint(
+		&types.APIRequestMetadata{
+			Verb:   types.APIVerbCreate,
+			Method: types.HTTPVerbPost,
+			Path: &types.Path{
+				Parent:       basePath,
+				RelativePath: "/internal/billing",
+			},
+			Scopes: []types.PermissionScope{},
+		},
+	)
+
+	addProjectBillingHandler := billing.NewBillingAddProjectHandler(
+		config,
+		factory.GetDecoderValidator(),
+	)
+
+	routes = append(routes, &Route{
+		Endpoint: addProjectBillingEndpoint,
+		Handler:  addProjectBillingHandler,
+		Router:   r,
+	})
+
 	return routes
 }

+ 3 - 0
api/server/shared/config/env/envconfs.go

@@ -81,6 +81,9 @@ type ServerConf struct {
 	SelfKubeconfig     string `env:"SELF_KUBECONFIG"`
 
 	WelcomeFormWebhook string `env:"WELCOME_FORM_WEBHOOK"`
+
+	// Token for internal retool to authenticate to internal API endpoints
+	RetoolToken string `env:"RETOOL_TOKEN"`
 }
 
 // DBConf is the database configuration: if generated from environment variables,

+ 15 - 0
api/types/billing.go

@@ -0,0 +1,15 @@
+package types
+
+type AddProjectBillingRequest struct {
+	ProjectID uint `json:"project_id" form:"required"`
+
+	// Monthly price, in cents
+	Price uint `json:"price" form:"required"`
+
+	Users    uint `json:"users"`
+	Clusters uint `json:"clusters"`
+	CPU      uint `json:"cpu"`
+	Memory   uint `json:"memory"`
+
+	ExistingPlanName string `json:"existing_plan_name"`
+}

+ 157 - 0
ee/api/server/handlers/billing/add_project.go

@@ -0,0 +1,157 @@
+package billing
+
+import (
+	"errors"
+	"fmt"
+	"net/http"
+	"strings"
+
+	"github.com/porter-dev/porter/api/server/authz"
+	"github.com/porter-dev/porter/api/server/handlers"
+	"github.com/porter-dev/porter/api/server/shared"
+	"github.com/porter-dev/porter/api/server/shared/apierrors"
+	"github.com/porter-dev/porter/api/server/shared/config"
+	"github.com/porter-dev/porter/api/types"
+	"github.com/porter-dev/porter/internal/models"
+	"gorm.io/gorm"
+)
+
+type BillingAddProjectHandler struct {
+	handlers.PorterHandlerReadWriter
+	authz.KubernetesAgentGetter
+}
+
+func NewBillingAddProjectHandler(
+	config *config.Config,
+	decoderValidator shared.RequestDecoderValidator,
+) http.Handler {
+	return &BillingAddProjectHandler{
+		PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, decoderValidator, nil),
+	}
+}
+
+// Adds a project to a billing team in IronPlans. Takes the following steps:
+// 1. Looks for project billing data for the given project.
+// 2. Checks for project billing data. If the project already has billing data, move to step 3b, otherwise 3a.
+// 3a. Creates a new team in IronPlans, and creates a custom plan in IronPlans. Subscribes the team to the plan.
+// 3b. Finds the relevant team in IronPlans, creates a custom plan, and updates the subscription for the team.
+// 4. If team was created, creates ProjectBilling object.
+// 5. If team was created, finds all roles in the team. Adds all roles as a team member to the project billing. Updates UserBilling models.
+func (c *BillingAddProjectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	// validation for internal token
+	reqToken := r.Header.Get("Authorization")
+	splitToken := strings.Split(reqToken, "Bearer")
+
+	if len(splitToken) != 2 {
+		c.HandleAPIError(w, r, apierrors.NewErrForbidden(fmt.Errorf("no token found")))
+		return
+	}
+
+	reqToken = strings.TrimSpace(splitToken[1])
+
+	if reqToken != c.Config().ServerConf.RetoolToken {
+		c.HandleAPIError(w, r, apierrors.NewErrForbidden(fmt.Errorf("passed retool token does not match env")))
+		return
+	}
+
+	request := &types.AddProjectBillingRequest{}
+
+	if ok := c.DecodeAndValidate(w, r, request); !ok {
+		return
+	}
+
+	// make sure the project exists; if it does not exist, throw forbidden error
+	proj, err := c.Repo().Project().ReadProject(request.ProjectID)
+
+	if err != nil {
+		if errors.Is(err, gorm.ErrRecordNotFound) {
+			c.HandleAPIError(w, r, apierrors.NewErrForbidden(err))
+			return
+		}
+
+		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+		return
+	}
+
+	// look for project billing data for the given project
+	teamID, err := c.Config().BillingManager.GetTeamID(proj)
+	isNotFound := err != nil && errors.Is(err, gorm.ErrRecordNotFound)
+
+	// if the error is not nil and is not "ErrRecordNotFound", throw error
+	if err != nil && !isNotFound {
+		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+		return
+	}
+
+	// if the team is not found, create a new team
+	if isNotFound {
+		teamID, err = c.Config().BillingManager.CreateTeam(proj)
+
+		if err != nil {
+			c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+			return
+		}
+	}
+
+	// determine whether to place the team on a custom plan or an existing plan
+	if request.ExistingPlanName != "" {
+		err = addToExistingPlan(c.Config(), request.ExistingPlanName, teamID)
+	} else {
+		err = addToCustomPlan(c.Config(), teamID, proj, request)
+	}
+
+	if err != nil {
+		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+		return
+	}
+
+	// add users in project to the plan
+	projRoles, err := c.Repo().Project().ListProjectRoles(proj.ID)
+
+	if err != nil {
+		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+		return
+	}
+
+	for _, role := range projRoles {
+		user, err := c.Repo().User().ReadUser(role.UserID)
+
+		if err != nil {
+			c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+			return
+		}
+
+		err = c.Config().BillingManager.AddUserToTeam(teamID, user, &role)
+
+		if err != nil {
+			c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+			return
+		}
+	}
+
+	w.WriteHeader(http.StatusOK)
+}
+
+func addToCustomPlan(c *config.Config, teamID string, proj *models.Project, req *types.AddProjectBillingRequest) error {
+	// create a new plan in IronPlans
+	planID, err := c.BillingManager.CreatePlan(teamID, proj, req)
+
+	if err != nil {
+		return err
+	}
+
+	// create a new subscription to this plan in IronPlans
+	return c.BillingManager.CreateOrUpdateSubscription(teamID, planID)
+}
+
+func addToExistingPlan(c *config.Config, existingPlanName, teamID string) error {
+	// look for existing plans in IronPlans
+	planID, err := c.BillingManager.GetExistingPublicPlan(existingPlanName)
+
+	if err != nil {
+		return err
+	}
+
+	// create a new subscription to this plan in IronPlans
+	return c.BillingManager.CreateOrUpdateSubscription(teamID, planID)
+}

+ 211 - 18
ee/billing/ironplans.go

@@ -32,7 +32,8 @@ type Client struct {
 
 	httpClient *http.Client
 
-	defaultPlan *Plan
+	defaultPlanID string
+	customPlanID  string
 }
 
 // NewClient creates a new billing API client
@@ -41,23 +42,24 @@ func NewClient(serverURL, apiKey string, repo repository.EERepository) (*Client,
 		Timeout: time.Minute,
 	}
 
-	client := &Client{apiKey, serverURL, repo, httpClient, nil}
+	client := &Client{apiKey, serverURL, repo, httpClient, "", ""}
 
 	// get the default plans from the IronPlans API server
-	listResp := &ListPlansResponse{}
-	err := client.getRequest("/plans/v1", listResp)
+	defPlanID, err := client.GetExistingPublicPlan("Free")
 
 	if err != nil {
 		return nil, err
 	}
 
-	for _, plan := range listResp.Results {
-		if plan.Name == "Free" {
-			copyPlan := plan
-			client.defaultPlan = &copyPlan
-		}
+	customPlanID, err := client.GetExistingPublicPlan("Enterprise")
+
+	if err != nil {
+		return nil, err
 	}
 
+	client.defaultPlanID = defPlanID
+	client.customPlanID = customPlanID
+
 	return client, nil
 }
 
@@ -72,13 +74,8 @@ func (c *Client) CreateTeam(proj *cemodels.Project) (string, error) {
 	}
 
 	// put the user on the free plan, as the default behavior, if there is a default plan
-	if c.defaultPlan != nil {
-		err := c.postRequest("/subscriptions/v1", &CreateSubscriptionRequest{
-			PlanID:     c.defaultPlan.ID,
-			NextPlanID: c.defaultPlan.ID,
-			TeamID:     resp.ID,
-			IsPaused:   false,
-		}, nil)
+	if c.defaultPlanID != "" {
+		err = c.CreateOrUpdateSubscription(resp.ID, c.defaultPlanID)
 
 		if err != nil {
 			return "", fmt.Errorf("subscription creation failed: %s", err)
@@ -117,7 +114,194 @@ func (c *Client) GetTeamID(proj *cemodels.Project) (teamID string, err error) {
 	return projBilling.BillingTeamID, nil
 }
 
+func (c *Client) CreatePlan(teamID string, proj *cemodels.Project, planSpec *types.AddProjectBillingRequest) (string, error) {
+	// construct basic plan object
+	planFeatures := make([]*CreatePlanFeature, 0)
+
+	userDisplay := fmt.Sprintf("Up to %d users", planSpec.Users)
+
+	if planSpec.Users == 0 {
+		userDisplay = fmt.Sprintf("Unlimited users")
+	}
+
+	clusterDisplay := fmt.Sprintf("Up to %d clusters", planSpec.Clusters)
+
+	if planSpec.Clusters == 0 {
+		clusterDisplay = fmt.Sprintf("Unlimited clusters")
+	}
+
+	cpuDisplay := fmt.Sprintf("Up to %d CPUs", planSpec.CPU)
+
+	if planSpec.CPU == 0 {
+		cpuDisplay = fmt.Sprintf("Unlimited CPU")
+	}
+
+	ramDisplay := fmt.Sprintf("Up to %d GB RAM", planSpec.Memory)
+
+	if planSpec.Memory == 0 {
+		ramDisplay = fmt.Sprintf("Unlimited RAM")
+	}
+
+	planFeatures = append(planFeatures, &CreatePlanFeature{
+		Display: userDisplay,
+	})
+	planFeatures = append(planFeatures, &CreatePlanFeature{
+		Display: clusterDisplay,
+	})
+	planFeatures = append(planFeatures, &CreatePlanFeature{
+		Display: cpuDisplay,
+	})
+	planFeatures = append(planFeatures, &CreatePlanFeature{
+		Display: ramDisplay,
+	})
+
+	var customPlanID *string
+
+	if c.customPlanID != "" {
+		customPlanID = &c.customPlanID
+	}
+
+	createPlanReq := &CreatePlanRequest{
+		Name:               proj.Name,
+		IsActive:           true,
+		IsPublic:           false,
+		IsTrialAllowed:     true,
+		ReplacePlanID:      customPlanID,
+		PerMonthPriceCents: planSpec.Price,
+		PerYearPriceCents:  12 * planSpec.Price,
+		Features:           planFeatures,
+		TeamsAccess: []*CreatePlanTeamsAccess{
+			{
+				TeamID: teamID,
+				Revoke: false,
+			},
+		},
+	}
+
+	// find all relevant feature IDs
+	listResp := &ListFeaturesResponse{}
+	err := c.getRequest("/features/v1", listResp)
+
+	if err != nil {
+		return "", err
+	}
+
+	// create a feature spec per feature ID, and add to features array for plan
+	for _, feature := range listResp.Results {
+		featureSpec := &CreateFeatureSpecRequest{
+			Name:         "unnamed",
+			RecordPeriod: "monthly",
+			Aggregation:  "sum",
+			UnitPrice:    0,
+		}
+
+		switch feature.Slug {
+		case FeatureSlugUsers:
+			featureSpec.MaxLimit = planSpec.Users
+			featureSpec.UnitsIncluded = planSpec.Users
+		case FeatureSlugClusters:
+			featureSpec.MaxLimit = planSpec.Clusters
+			featureSpec.UnitsIncluded = planSpec.Clusters
+		case FeatureSlugCPU:
+			featureSpec.MaxLimit = planSpec.CPU
+			featureSpec.UnitsIncluded = planSpec.CPU
+		case FeatureSlugMemory:
+			featureSpec.MaxLimit = planSpec.Memory
+			featureSpec.UnitsIncluded = planSpec.Memory
+		// continue on default behavior so that feature spec is not created for
+		// features that don't match a slug
+		default:
+			continue
+		}
+
+		// create the feature spec
+		resp := &CreateFeaturespecResponse{}
+		err = c.postRequest("/featurespecs/v1/", featureSpec, resp)
+
+		if err != nil {
+			return "", err
+		}
+
+		var index int
+		switch feature.Slug {
+		case FeatureSlugUsers:
+			index = 0
+		case FeatureSlugClusters:
+			index = 1
+		case FeatureSlugCPU:
+			index = 2
+		case FeatureSlugMemory:
+			index = 3
+		}
+
+		createPlanReq.Features[index].FeatureID = feature.ID
+		createPlanReq.Features[index].SpecID = resp.ID
+	}
+
+	// create the plan and return the plan ID
+	planResp := &Plan{}
+
+	err = c.postRequest("/plans/v1/", createPlanReq, planResp)
+
+	if err != nil {
+		return "", err
+	}
+
+	return planResp.ID, nil
+}
+
+func (c *Client) CreateOrUpdateSubscription(teamID, planID string) error {
+	// determine if subscription already exists by reading the team ID and seeing if the subscription
+	// field has an ID attached
+	teamResp := &Team{}
+	err := c.getRequest(fmt.Sprintf("/teams/v1/%s", teamID), teamResp)
+
+	if err != nil {
+		return err
+	}
+
+	subReq := &CreateSubscriptionRequest{
+		PlanID:     planID,
+		NextPlanID: c.defaultPlanID,
+		TeamID:     teamID,
+		IsPaused:   false,
+	}
+
+	// if subscription ID is not empty, perform a PUT request to update the subscription
+	if teamResp.Subscription.ID != "" {
+		err = c.putRequest(fmt.Sprintf("/subscriptions/v1/%s", teamResp.Subscription.ID), subReq, nil)
+	} else {
+		err = c.postRequest("/subscriptions/v1", subReq, nil)
+	}
+
+	return err
+}
+
+func (c *Client) GetExistingPublicPlan(planName string) (string, error) {
+	listResp := &ListPlansResponse{}
+	err := c.getRequest("/plans/v1/", listResp, map[string]string{"is_public": "true"})
+
+	if err != nil {
+		return "", err
+	}
+
+	for _, plan := range listResp.Results {
+		if plan.Name == planName {
+			return plan.ID, nil
+		}
+	}
+
+	return "", fmt.Errorf("plan not found")
+}
+
 func (c *Client) AddUserToTeam(teamID string, user *cemodels.User, role *cemodels.Role) error {
+	// determine if user is already in team/has user billing
+	userBilling, err := c.repo.UserBilling().ReadUserBilling(role.ProjectID, user.ID)
+
+	if userBilling != nil {
+		return nil
+	}
+
 	roleEnum := RoleEnumMember
 
 	// if user's role is admin, add them to the team as an owner
@@ -134,7 +318,7 @@ func (c *Client) AddUserToTeam(teamID string, user *cemodels.User, role *cemodel
 
 	resp := &Teammate{}
 
-	err := c.postRequest("/team_memberships/v1", req, resp)
+	err = c.postRequest("/team_memberships/v1", req, resp)
 
 	if err != nil {
 		return err
@@ -292,7 +476,7 @@ func (c *Client) deleteRequest(path string, data interface{}, dst interface{}) e
 	return c.writeRequest("DELETE", path, data, dst)
 }
 
-func (c *Client) getRequest(path string, dst interface{}) error {
+func (c *Client) getRequest(path string, dst interface{}, query ...map[string]string) error {
 	reqURL, err := url.Parse(c.serverURL)
 
 	if err != nil {
@@ -301,6 +485,15 @@ func (c *Client) getRequest(path string, dst interface{}) error {
 
 	reqURL.Path = path
 
+	q := reqURL.Query()
+	for _, queryGroup := range query {
+		for key, val := range queryGroup {
+			q.Add(key, val)
+		}
+	}
+
+	reqURL.RawQuery = q.Encode()
+
 	req, err := http.NewRequest(
 		"GET",
 		reqURL.String(),

+ 44 - 0
ee/billing/types.go

@@ -38,6 +38,49 @@ type Plan struct {
 	Features   []PlanFeature `json:"features"`
 }
 
+type CreatePlanRequest struct {
+	Name               string                   `json:"name"`
+	IsActive           bool                     `json:"is_active"`
+	IsPublic           bool                     `json:"is_public"`
+	IsTrialAllowed     bool                     `json:"is_trial_allowed"`
+	PerMonthPriceCents uint                     `json:"per_month_price_cents"`
+	PerYearPriceCents  uint                     `json:"per_year_price_cents"`
+	ReplacePlanID      *string                  `json:"replace_plan_id"`
+	Features           []*CreatePlanFeature     `json:"features"`
+	TeamsAccess        []*CreatePlanTeamsAccess `json:"teams_access"`
+}
+
+type CreatePlanFeature struct {
+	FeatureID string `json:"feature_id"`
+	SpecID    string `json:"spec_id"`
+	Display   string `json:"display"`
+	Sort      uint   `json:"sort"`
+	IsActive  bool   `json:"is_active"`
+}
+
+type CreatePlanTeamsAccess struct {
+	TeamID string `json:"team_id"`
+	Revoke bool   `json:"revoke"`
+}
+
+type CreateFeatureSpecRequest struct {
+	Name          string `json:"name"`
+	RecordPeriod  string `json:"record_period"`
+	Aggregation   string `json:"aggregation"`
+	MaxLimit      uint   `json:"max_limit"`
+	UnitPrice     uint   `json:"unit_price"`
+	UnitsIncluded uint   `json:"units_included"`
+}
+
+type CreateFeaturespecResponse struct {
+	*CreateFeatureSpecRequest
+	ID string `json:"id"`
+}
+
+type ListFeaturesResponse struct {
+	Results []Feature `json:"results"`
+}
+
 type ListPlansResponse struct {
 	Results []Plan `json:"results"`
 }
@@ -50,6 +93,7 @@ type PlanFeature struct {
 }
 
 type Feature struct {
+	ID   string `json:"id"`
 	Slug string `json:"slug"`
 }
 

+ 22 - 0
internal/billing/billing.go

@@ -3,6 +3,7 @@ package billing
 import (
 	"fmt"
 
+	"github.com/porter-dev/porter/api/types"
 	"github.com/porter-dev/porter/internal/models"
 )
 
@@ -19,6 +20,15 @@ type BillingManager interface {
 	// GetTeamID gets the billing team id for a project
 	GetTeamID(proj *models.Project) (teamID string, err error)
 
+	// CreatePlan creates a new plan based on the requested limits
+	CreatePlan(teamID string, proj *models.Project, planSpec *types.AddProjectBillingRequest) (string, error)
+
+	// CreateOrUpdateSubscription creates or updates a new subscription to a plan, based on a team and plan ID
+	CreateOrUpdateSubscription(teamID, planID string) error
+
+	// GetExistingPublicPlan returns an existing public plan based on a name
+	GetExistingPublicPlan(planName string) (string, error)
+
 	// AddUserToTeam adds a user to a team, and cases on whether the user can view
 	// billing based on the role.
 	AddUserToTeam(teamID string, user *models.User, role *models.Role) error
@@ -57,6 +67,18 @@ func (n *NoopBillingManager) GetTeamID(proj *models.Project) (teamID string, err
 	return fmt.Sprintf("%d", proj.ID), nil
 }
 
+func (n *NoopBillingManager) CreatePlan(teamID string, proj *models.Project, planSpec *types.AddProjectBillingRequest) (string, error) {
+	return "", nil
+}
+
+func (n *NoopBillingManager) CreateOrUpdateSubscription(teamID, planID string) error {
+	return nil
+}
+
+func (n *NoopBillingManager) GetExistingPublicPlan(planName string) (string, error) {
+	return "", nil
+}
+
 func (n *NoopBillingManager) AddUserToTeam(teamID string, user *models.User, role *models.Role) error {
 	return nil
 }