Browse Source

relocate cookie query and consolidate types

Alexander Belanger 3 years ago
parent
commit
c3cd4586ba

+ 2 - 83
api/server/handlers/billing/redirect_billing.go

@@ -1,13 +1,9 @@
 package billing
 
 import (
-	"encoding/json"
-	"fmt"
 	"net/http"
 	"net/url"
-	"strings"
 
-	"github.com/gorilla/schema"
 	"github.com/porter-dev/porter/api/server/handlers"
 	"github.com/porter-dev/porter/api/server/shared"
 	"github.com/porter-dev/porter/api/server/shared/apierrors"
@@ -29,23 +25,6 @@ func NewRedirectBillingHandler(
 	}
 }
 
-type CreateBillingCookieRequest struct {
-	Email       string `json:"email" form:"required"`
-	UserID      uint   `json:"user_id" form:"required"`
-	ProjectID   uint   `json:"project_id" form:"required"`
-	ProjectName string `json:"project_name" form:"required"`
-}
-
-type CreateBillingCookieResponse struct {
-	Token   string `json:"token"`
-	TokenID string `json:"token_id"`
-}
-
-type VerifyUserRequest struct {
-	TokenID string `schema:"token_id" form:"required"`
-	Token   string `schema:"token" form:"required"`
-}
-
 func (c *RedirectBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	user, _ := r.Context().Value(types.UserScope).(*models.User)
 	proj, _ := r.Context().Value(types.ProjectScope).(*models.Project)
@@ -75,72 +54,12 @@ func (c *RedirectBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Reques
 		return
 	}
 
-	// get an internal cookie
-	data := &CreateBillingCookieRequest{
-		ProjectName: proj.Name,
-		ProjectID:   proj.ID,
-		UserID:      user.ID,
-		Email:       user.Email,
-	}
-
-	var strData []byte
-	var err error
-
-	if data != nil {
-		strData, err = json.Marshal(data)
-
-		if err != nil {
-			c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-			return
-		}
-	}
-
-	req, err := http.NewRequest(
-		"POST",
-		fmt.Sprintf("%s/api/v1/private/cookie", c.Config().ServerConf.BillingPrivateServerURL),
-		strings.NewReader(string(strData)),
-	)
+	redirectURI, err := c.Config().BillingManager.GetRedirectURI(user, proj)
 
 	if err != nil {
 		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
 		return
 	}
 
-	req.Header.Set("Content-Type", "application/json; charset=utf-8")
-	req.Header.Set("Accept", "application/json; charset=utf-8")
-	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.Config().ServerConf.BillingPrivateKey))
-
-	httpClient := http.Client{}
-
-	res, err := httpClient.Do(req)
-
-	if err != nil {
-		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-		return
-	}
-
-	defer res.Body.Close()
-
-	dst := &CreateBillingCookieResponse{}
-
-	err = json.NewDecoder(res.Body).Decode(dst)
-
-	if err != nil {
-		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-		return
-	}
-
-	redirectData := &VerifyUserRequest{
-		TokenID: dst.TokenID,
-		Token:   dst.Token,
-	}
-
-	vals := make(map[string][]string)
-	err = schema.NewEncoder().Encode(redirectData, vals)
-
-	urlVals := url.Values(vals)
-	encodedURLVals := urlVals.Encode()
-
-	reqURL := fmt.Sprintf("%s/api/v1/verify?%s", c.Config().ServerConf.BillingPublicServerURL, encodedURLVals)
-	http.Redirect(w, r, reqURL, 302)
+	http.Redirect(w, r, redirectURI, 302)
 }

+ 3 - 2
api/server/shared/config/loader/init_ee.go

@@ -24,12 +24,13 @@ func init() {
 		key[i] = b
 	}
 
-	if InstanceEnvConf.ServerConf.BillingPrivateServerURL != "" && InstanceEnvConf.ServerConf.BillingPrivateKey != "" {
+	if InstanceEnvConf.ServerConf.BillingPrivateServerURL != "" && InstanceEnvConf.ServerConf.BillingPrivateKey != "" && InstanceEnvConf.ServerConf.BillingPublicServerURL != "" {
 		serverURL := InstanceEnvConf.ServerConf.BillingPrivateServerURL
+		publicServerURL := InstanceEnvConf.ServerConf.BillingPublicServerURL
 		apiKey := InstanceEnvConf.ServerConf.BillingPrivateKey
 		var err error
 
-		InstanceBillingManager, err = eeBilling.NewClient(serverURL, apiKey)
+		InstanceBillingManager, err = eeBilling.NewClient(serverURL, publicServerURL, apiKey)
 
 		if err != nil {
 			panic(err)

+ 52 - 5
ee/billing/client.go

@@ -15,24 +15,26 @@ import (
 	"strings"
 	"time"
 
+	"github.com/gorilla/schema"
 	"github.com/porter-dev/porter/api/types"
 	cemodels "github.com/porter-dev/porter/internal/models"
 )
 
 // Client contains an API client for the internal billing engine
 type Client struct {
-	apiKey     string
-	serverURL  string
-	httpClient *http.Client
+	apiKey          string
+	serverURL       string
+	publicServerURL string
+	httpClient      *http.Client
 }
 
 // NewClient creates a new billing API client
-func NewClient(serverURL, apiKey string) (*Client, error) {
+func NewClient(serverURL, publicServerURL, apiKey string) (*Client, error) {
 	httpClient := &http.Client{
 		Timeout: time.Minute,
 	}
 
-	client := &Client{apiKey, serverURL, httpClient}
+	client := &Client{apiKey, serverURL, publicServerURL, httpClient}
 
 	return client, nil
 }
@@ -65,6 +67,51 @@ func (c *Client) DeleteTeam(user *cemodels.User, proj *cemodels.Project) error {
 	return c.deleteRequest("/api/v1/private/customer", reqData, nil)
 }
 
+func (c *Client) GetRedirectURI(user *cemodels.User, proj *cemodels.Project) (string, error) {
+	// get an internal cookie
+	reqData := &CreateBillingCookieRequest{
+		ProjectName: proj.Name,
+		ProjectID:   proj.ID,
+		UserID:      user.ID,
+		Email:       user.Email,
+	}
+
+	createCookieVals := make(map[string][]string)
+	err := schema.NewEncoder().Encode(reqData, createCookieVals)
+
+	if err != nil {
+		return "", err
+	}
+
+	urlVals := url.Values(createCookieVals)
+	encodedURLVals := urlVals.Encode()
+
+	dst := &CreateBillingCookieResponse{}
+
+	err = c.postRequest("/api/v1/private/cookie", reqData, dst)
+
+	if err != nil {
+		return "", err
+	}
+
+	redirectData := &VerifyUserRequest{
+		TokenID: dst.TokenID,
+		Token:   dst.Token,
+	}
+
+	vals := make(map[string][]string)
+	err = schema.NewEncoder().Encode(redirectData, vals)
+
+	if err != nil {
+		return "", err
+	}
+
+	urlVals = url.Values(vals)
+	encodedURLVals = urlVals.Encode()
+
+	return fmt.Sprintf("%s/api/v1/verify?%s", c.publicServerURL, encodedURLVals), nil
+}
+
 // VerifySignature verifies a webhook signature based on hmac protocol
 func (c *Client) VerifySignature(signature string, body []byte) bool {
 	if len(signature) != 71 || !strings.HasPrefix(signature, "sha256=") {

+ 17 - 0
ee/billing/types.go

@@ -28,3 +28,20 @@ type APIWebhookRequest struct {
 	StacksEnabled              string `json:"stacks_enabled,omitempty"`
 	ManagedDatabasesEnabled    string `json:"managed_databases_enabled,omitempty"`
 }
+
+type CreateBillingCookieRequest struct {
+	Email       string `json:"email" form:"required"`
+	UserID      uint   `json:"user_id" form:"required"`
+	ProjectID   uint   `json:"project_id" form:"required"`
+	ProjectName string `json:"project_name" form:"required"`
+}
+
+type CreateBillingCookieResponse struct {
+	Token   string `json:"token"`
+	TokenID string `json:"token_id"`
+}
+
+type VerifyUserRequest struct {
+	TokenID string `schema:"token_id" form:"required"`
+	Token   string `schema:"token" form:"required"`
+}

+ 7 - 0
internal/billing/billing.go

@@ -17,6 +17,9 @@ type BillingManager interface {
 	// DeleteTeam deletes a billing team.
 	DeleteTeam(user *models.User, proj *models.Project) (err error)
 
+	// GetRedirectURI gets the redirect URI to send the user to the billing portal
+	GetRedirectURI(user *models.User, proj *models.Project) (url string, err error)
+
 	// ParseProjectUsageFromWebhook parses the project usage from a webhook payload sent
 	// from a billing agent
 	ParseProjectUsageFromWebhook(payload []byte) (*models.ProjectUsage, *types.FeatureFlags, error)
@@ -36,6 +39,10 @@ func (n *NoopBillingManager) DeleteTeam(user *models.User, proj *models.Project)
 	return nil
 }
 
+func (n *NoopBillingManager) GetRedirectURI(user *models.User, proj *models.Project) (url string, err error) {
+	return "", nil
+}
+
 func (n *NoopBillingManager) ParseProjectUsageFromWebhook(payload []byte) (*models.ProjectUsage, *types.FeatureFlags, error) {
 	return nil, nil, nil
 }