Просмотр исходного кода

Stripe integaration fixes (#4444)

Co-authored-by: jusrhee <justin@porter.run>
Mauricio Araujo 2 лет назад
Родитель
Сommit
254ae7d401

+ 39 - 0
api/server/handlers/billing/create.go

@@ -8,6 +8,7 @@ import (
 	"github.com/porter-dev/porter/api/server/shared"
 	"github.com/porter-dev/porter/api/server/shared/apierrors"
 	"github.com/porter-dev/porter/api/server/shared/config"
+	"github.com/porter-dev/porter/api/server/shared/requestutils"
 	"github.com/porter-dev/porter/api/types"
 	"github.com/porter-dev/porter/internal/models"
 	"github.com/porter-dev/porter/internal/telemetry"
@@ -18,6 +19,11 @@ type CreateBillingHandler struct {
 	handlers.PorterHandlerWriter
 }
 
+// SetDefaultBillingHandler is a handler for setting default payment method
+type SetDefaultBillingHandler struct {
+	handlers.PorterHandlerWriter
+}
+
 // NewCreateBillingHandler will create a new CreateBillingHandler
 func NewCreateBillingHandler(
 	config *config.Config,
@@ -44,3 +50,36 @@ func (c *CreateBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
 
 	c.WriteResult(w, r, clientSecret)
 }
+
+// NewSetDefaultBillingHandler will create a new CreateBillingHandler
+func NewSetDefaultBillingHandler(
+	config *config.Config,
+	writer shared.ResultWriter,
+) *SetDefaultBillingHandler {
+	return &SetDefaultBillingHandler{
+		PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
+	}
+}
+
+func (c *SetDefaultBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	ctx, span := telemetry.NewSpan(r.Context(), "set-default-billing-endpoint")
+	defer span.End()
+
+	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
+
+	paymentMethodID, reqErr := requestutils.GetURLParamString(r, types.URLParamPaymentMethodID)
+	if reqErr != nil {
+		err := telemetry.Error(ctx, span, reqErr, "error setting default payment method")
+		c.HandleAPIError(w, r, apierrors.NewErrInternal(fmt.Errorf("error setting default payment method: %w", err)))
+		return
+	}
+
+	err := c.Config().BillingManager.SetDefaultPaymentMethod(paymentMethodID, proj)
+	if err != nil {
+		err := telemetry.Error(ctx, span, err, "error setting default payment method")
+		c.HandleAPIError(w, r, apierrors.NewErrInternal(fmt.Errorf("error setting default payment method: %w", err)))
+		return
+	}
+
+	c.WriteResult(w, r, "")
+}

+ 27 - 0
api/server/router/project.go

@@ -367,6 +367,33 @@ func getProjectRoutes(
 		Router:   r,
 	})
 
+	// PUT /api/projects/{project_id}/billing/payment_method/{payment_method_id}/default -> project.NewSetDefaultBillingHandler
+	setDefaultBillingEndpoint := factory.NewAPIEndpoint(
+		&types.APIRequestMetadata{
+			Verb:   types.APIVerbUpdate,
+			Method: types.HTTPVerbPut,
+			Path: &types.Path{
+				Parent:       basePath,
+				RelativePath: fmt.Sprintf("%s/billing/payment_method/{%s}/default", relPath, types.URLParamPaymentMethodID),
+			},
+			Scopes: []types.PermissionScope{
+				types.UserScope,
+				types.ProjectScope,
+			},
+		},
+	)
+
+	setDefaultBillingHandler := billing.NewSetDefaultBillingHandler(
+		config,
+		factory.GetResultWriter(),
+	)
+
+	routes = append(routes, &router.Route{
+		Endpoint: setDefaultBillingEndpoint,
+		Handler:  setDefaultBillingHandler,
+		Router:   r,
+	})
+
 	// DELETE /api/projects/{project_id}/billing/payment_method/{payment_method_id} -> project.NewDeleteBillingHandler
 	deleteBillingEndpoint := factory.NewAPIEndpoint(
 		&types.APIRequestMetadata{

+ 2 - 1
api/types/billing.go

@@ -6,11 +6,12 @@ type CreateBillingCustomerRequest struct {
 }
 
 // PaymentMethod is a subset of the Stripe PaymentMethod type,
-// with only the fields used on the dashboard
+// with only the fields used in the dashboard
 type PaymentMethod = struct {
 	ID           string `json:"id"`
 	DisplayBrand string `json:"display_brand"`
 	Last4        string `json:"last4"`
 	ExpMonth     int64  `json:"exp_month"`
 	ExpYear      int64  `json:"exp_year"`
+	Default      bool   `json:"is_default"`
 }

+ 1 - 0
dashboard/src/lib/billing/types.tsx

@@ -10,6 +10,7 @@ export const PaymentMethodValidator = z.object({
   last4: z.string(),
   exp_month: z.number(),
   exp_year: z.number(),
+  is_default: z.boolean(),
 });
 
 export const ClientSecretResponse = z.string();

+ 41 - 6
dashboard/src/lib/hooks/useStripe.tsx

@@ -22,6 +22,10 @@ type TCreatePaymentMethod = {
   createPaymentMethod: () => Promise<string>;
 };
 
+type TSetDefaultPaymentMethod = {
+  setDefaultPaymentMethod: (paymentMethodId: string) => Promise<void>;
+};
+
 type TCheckHasPaymentEnabled = {
   hasPaymentEnabled: boolean;
   refetchPaymentEnabled: any;
@@ -32,7 +36,7 @@ type TGetPublishableKey = {
 };
 
 export const usePaymentMethods = (): TUsePaymentMethod => {
-  const { user, currentProject } = useContext(Context);
+  const { currentProject } = useContext(Context);
 
   // State has be shared so that payment methods can be removed
   // from the Billing page once they are deleted
@@ -48,11 +52,6 @@ export const usePaymentMethods = (): TUsePaymentMethod => {
       if (!currentProject?.id || currentProject.id === -1) {
         return;
       }
-      await api.checkBillingCustomerExists(
-        "<token>",
-        { user_email: user?.email },
-        { project_id: currentProject?.id }
-      );
       const listResponse = await api.listPaymentMethod(
         "<token>",
         {},
@@ -169,3 +168,39 @@ export const usePublishableKey = (): TGetPublishableKey => {
     publishableKey: keyReq.data,
   };
 };
+
+export const checkBillingCustomerExists = async (): Promise<void> => {
+  const { user, currentProject } = useContext(Context);
+  const res = await api.checkBillingCustomerExists(
+    "<token>",
+    { user_email: user?.email },
+    {
+      project_id: currentProject?.id,
+    }
+  );
+
+  if (res.status !== 200) {
+    throw Error("failed to check if billing customer exists");
+  }
+};
+
+export const useSetDefaultPaymentMethod = (): TSetDefaultPaymentMethod => {
+  const { currentProject } = useContext(Context);
+
+  const setDefaultPaymentMethod = async (paymentMethodId: string) => {
+    // Set payment method as default
+    const res = await api.setDefaultPaymentMethod(
+      "<token>",
+      {},
+      { project_id: currentProject?.id, payment_method_id: paymentMethodId }
+    );
+
+    if (res.status !== 200) {
+      throw Error("failed to set payment method as default");
+    }
+  };
+
+  return {
+    setDefaultPaymentMethod,
+  };
+};

+ 12 - 39
dashboard/src/main/home/modals/BillingModal.tsx

@@ -1,22 +1,28 @@
 import React from "react";
 import { Elements } from "@stripe/react-stripe-js";
 import { loadStripe } from "@stripe/stripe-js";
-import styled from "styled-components";
 
-import Heading from "components/form-components/Heading";
 import Link from "components/porter/Link";
 import Modal from "components/porter/Modal";
 import Spacer from "components/porter/Spacer";
 import Text from "components/porter/Text";
-import { usePublishableKey } from "lib/hooks/useStripe";
-
-import backArrow from "assets/back_arrow.png";
+import {
+  checkBillingCustomerExists,
+  usePublishableKey,
+} from "lib/hooks/useStripe";
 
 import PaymentSetupForm from "./PaymentSetupForm";
 
-const BillingModal = ({ back, onCreate }) => {
+const BillingModal = ({
+  back,
+  onCreate,
+}: {
+  back: (value: React.SetStateAction<boolean>) => void;
+  onCreate: () => Promise<void>;
+}) => {
   const { publishableKey } = usePublishableKey();
   const stripePromise = loadStripe(publishableKey);
+  checkBillingCustomerExists();
 
   const appearance = {
     variables: {
@@ -65,36 +71,3 @@ const BillingModal = ({ back, onCreate }) => {
 };
 
 export default BillingModal;
-
-const ControlRow = styled.div`
-  width: 100%;
-  display: flex;
-  margin-left: auto;
-  justify-content: space-between;
-  align-items: center;
-  margin-bottom: 35px;
-`;
-
-const BackButton = styled.div`
-  display: flex;
-  width: 36px;
-  cursor: pointer;
-  height: 36px;
-  align-items: center;
-  justify-content: center;
-  border: 1px solid #ffffff55;
-  border-radius: 100px;
-  background: #ffffff11;
-
-  :hover {
-    background: #ffffff22;
-    > img {
-      opacity: 1;
-    }
-  }
-`;
-
-const BackButtonImg = styled.img`
-  width: 16px;
-  opacity: 0.75;
-`;

+ 12 - 3
dashboard/src/main/home/modals/PaymentSetupForm.tsx

@@ -10,15 +10,19 @@ import Button from "components/porter/Button";
 import Error from "components/porter/Error";
 import Spacer from "components/porter/Spacer";
 import SaveButton from "components/SaveButton";
-import { useCreatePaymentMethod } from "lib/hooks/useStripe";
+import {
+  useCreatePaymentMethod,
+  useSetDefaultPaymentMethod,
+} from "lib/hooks/useStripe";
 
-const PaymentSetupForm = ({ onCreate }: { onCreate: () => void }) => {
+const PaymentSetupForm = ({ onCreate }: { onCreate: () => Promise<void> }) => {
   const stripe = useStripe();
   const elements = useElements();
 
   const [errorMessage, setErrorMessage] = useState(null);
   const [loading, setLoading] = useState(false);
   const { createPaymentMethod } = useCreatePaymentMethod();
+  const { setDefaultPaymentMethod } = useSetDefaultPaymentMethod();
 
   const handleSubmit = async () => {
     if (!stripe || !elements) {
@@ -38,7 +42,7 @@ const PaymentSetupForm = ({ onCreate }: { onCreate: () => void }) => {
     const clientSecret = await createPaymentMethod();
 
     // Finally, confirm with Stripe so the payment method is saved
-    const { error } = await stripe.confirmSetup({
+    const { error, setupIntent } = await stripe.confirmSetup({
       elements,
       clientSecret,
       redirect: "if_required",
@@ -48,6 +52,11 @@ const PaymentSetupForm = ({ onCreate }: { onCreate: () => void }) => {
       setErrorMessage(error.message);
     }
 
+    // Confirm the setup and set as default
+    if (setupIntent?.payment_method !== null) {
+      await setDefaultPaymentMethod(setupIntent?.payment_method as string);
+    }
+
     onCreate();
   };
 

+ 35 - 22
dashboard/src/main/home/project-settings/BillingPage.tsx

@@ -9,8 +9,10 @@ import Icon from "components/porter/Icon";
 import Spacer from "components/porter/Spacer";
 import Text from "components/porter/Text";
 import {
+  checkBillingCustomerExists,
   checkIfProjectHasPayment,
   usePaymentMethods,
+  useSetDefaultPaymentMethod,
 } from "lib/hooks/useStripe";
 
 import { Context } from "shared/Context";
@@ -22,23 +24,21 @@ import BillingModal from "../modals/BillingModal";
 function BillingPage(): JSX.Element {
   const { setCurrentOverlay } = useContext(Context);
   const [shouldCreate, setShouldCreate] = useState(false);
+  checkBillingCustomerExists();
+
   const {
     paymentMethodList,
     refetchPaymentMethods,
     deletePaymentMethod,
     isDeleting,
   } = usePaymentMethods();
+  const { setDefaultPaymentMethod } = useSetDefaultPaymentMethod();
 
   const { refetchPaymentEnabled } = checkIfProjectHasPayment();
 
   const onCreate = async () => {
+    await refetchPaymentMethods();
     setShouldCreate(false);
-    refetchPaymentMethods();
-    refetchPaymentEnabled();
-  };
-
-  const onDelete = async (paymentMethodId: string) => {
-    deletePaymentMethod(paymentMethodId);
     refetchPaymentEnabled();
   };
 
@@ -76,23 +76,36 @@ function BillingPage(): JSX.Element {
                 <DeleteButtonContainer>
                   {isDeleting ? (
                     <Loading />
+                  ) : !paymentMethod.is_default ? (
+                    <Container row={true}>
+                      <DeleteButton
+                        onClick={() => {
+                          setCurrentOverlay({
+                            message: `Are you sure you want to remove this payment method?`,
+                            onYes: () => {
+                              deletePaymentMethod(paymentMethod.id);
+                              setCurrentOverlay(null);
+                            },
+                            onNo: () => {
+                              setCurrentOverlay(null);
+                            },
+                          });
+                        }}
+                      >
+                        <Icon src={trashIcon} height={"18px"} />
+                      </DeleteButton>
+                      <Spacer inline x={1} />
+                      <Button
+                        onClick={() => {
+                          setDefaultPaymentMethod(paymentMethod.id);
+                          refetchPaymentMethods();
+                        }}
+                      >
+                        Set as default
+                      </Button>
+                    </Container>
                   ) : (
-                    <DeleteButton
-                      onClick={() => {
-                        setCurrentOverlay({
-                          message: `Are you sure you want to remove this payment method?`,
-                          onYes: () => {
-                            deletePaymentMethod(paymentMethod.id);
-                            setCurrentOverlay(null);
-                          },
-                          onNo: () => {
-                            setCurrentOverlay(null);
-                          },
-                        });
-                      }}
-                    >
-                      <Icon src={trashIcon} height={"18px"} />
-                    </DeleteButton>
+                    <Text>Default</Text>
                   )}
                 </DeleteButtonContainer>
               </Container>

+ 4 - 4
dashboard/src/shared/api.tsx

@@ -3461,16 +3461,16 @@ const addPaymentMethod = baseApi<
   ({ project_id }) => `/api/projects/${project_id}/billing/payment_method`
 );
 
-const updatePaymentMethod = baseApi<
+const setDefaultPaymentMethod = baseApi<
   {},
   {
     project_id?: number;
     payment_method_id: string;
   }
 >(
-  "GET",
+  "PUT",
   ({ project_id, payment_method_id }) =>
-    `/api/projects/${project_id}/billing/payment_method/${payment_method_id}`
+    `/api/projects/${project_id}/billing/payment_method/${payment_method_id}/default`
 );
 
 const deletePaymentMethod = baseApi<
@@ -3833,7 +3833,7 @@ export default {
   checkBillingCustomerExists,
   listPaymentMethod,
   addPaymentMethod,
-  updatePaymentMethod,
+  setDefaultPaymentMethod,
   deletePaymentMethod,
 
   // STATUS

+ 8 - 0
internal/billing/billing.go

@@ -23,6 +23,9 @@ type BillingManager interface {
 	// CreatePaymentMethod will add a new payment method to the project in Stripe
 	CreatePaymentMethod(proj *models.Project) (clientSecret string, err error)
 
+	// SetDefaultPaymentMethod will set the payment method as default in the customer invoice settings
+	SetDefaultPaymentMethod(paymentMethodID string, proj *models.Project) (err error)
+
 	// DeletePaymentMethod will remove a payment method for the project in Stripe
 	DeletePaymentMethod(paymentMethodID string) (err error)
 
@@ -58,6 +61,11 @@ func (s *NoopBillingManager) CreatePaymentMethod(proj *models.Project) (clientSe
 	return "", nil
 }
 
+// SetDefaultPaymentMethod is a no-op
+func (s *NoopBillingManager) SetDefaultPaymentMethod(paymentMethodID string, proj *models.Project) (err error) {
+	return nil
+}
+
 // DeletePaymentMethod is a no-op
 func (s *NoopBillingManager) DeletePaymentMethod(paymentMethodID string) (err error) {
 	return nil

+ 54 - 0
internal/billing/stripe.go

@@ -74,24 +74,45 @@ func (s *StripeBillingManager) CheckPaymentEnabled(proj *models.Project) (paymen
 func (s *StripeBillingManager) ListPaymentMethod(proj *models.Project) (paymentMethods []types.PaymentMethod, err error) {
 	stripe.Key = s.StripeSecretKey
 
+	// Get configured payment methods
 	params := &stripe.PaymentMethodListParams{
 		Customer: stripe.String(proj.BillingID),
 		Type:     stripe.String(string(stripe.PaymentMethodTypeCard)),
 	}
 	result := paymentmethod.List(params)
 
+	defaultPaymentExists, defaultPaymentID, err := s.checkDefaultPaymentMethod(proj.BillingID)
+	if err != nil {
+		return paymentMethods, err
+	}
+
 	for result.Next() {
 		stripePaymentMethod := result.PaymentMethod()
 
+		var isDefaultPaymentMethod bool
+		if stripePaymentMethod.ID == defaultPaymentID {
+			isDefaultPaymentMethod = true
+		}
+
 		paymentMethods = append(paymentMethods, types.PaymentMethod{
 			ID:           stripePaymentMethod.ID,
 			DisplayBrand: stripePaymentMethod.Card.DisplayBrand,
 			Last4:        stripePaymentMethod.Card.Last4,
 			ExpMonth:     stripePaymentMethod.Card.ExpMonth,
 			ExpYear:      stripePaymentMethod.Card.ExpYear,
+			Default:      isDefaultPaymentMethod,
 		})
 	}
 
+	// Set default payment method when project has payment methods enabled but
+	// no default setup
+	if len(paymentMethods) > 0 && !defaultPaymentExists {
+		err = s.SetDefaultPaymentMethod(paymentMethods[len(paymentMethods)-1].ID, proj)
+		if err != nil {
+			return paymentMethods, err
+		}
+	}
+
 	return paymentMethods, nil
 }
 
@@ -115,6 +136,24 @@ func (s *StripeBillingManager) CreatePaymentMethod(proj *models.Project) (client
 	return intent.ClientSecret, nil
 }
 
+// SetDefaultPaymentMethod will add a new payment method to the project in Stripe
+func (s *StripeBillingManager) SetDefaultPaymentMethod(paymentMethodID string, proj *models.Project) (err error) {
+	stripe.Key = s.StripeSecretKey
+
+	params := &stripe.CustomerParams{
+		InvoiceSettings: &stripe.CustomerInvoiceSettingsParams{
+			DefaultPaymentMethod: stripe.String(paymentMethodID),
+		},
+	}
+
+	_, err = customer.Update(proj.BillingID, params)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
 // DeletePaymentMethod will remove a payment method for the project in Stripe
 func (s *StripeBillingManager) DeletePaymentMethod(paymentMethodID string) (err error) {
 	stripe.Key = s.StripeSecretKey
@@ -131,3 +170,18 @@ func (s *StripeBillingManager) DeletePaymentMethod(paymentMethodID string) (err
 func (s *StripeBillingManager) GetPublishableKey() (key string) {
 	return s.StripePublishableKey
 }
+
+func (s *StripeBillingManager) checkDefaultPaymentMethod(customerID string) (defaultPaymentExists bool, defaultPaymentID string, err error) {
+	// Get customer to check default payment method
+	customer, err := customer.Get(customerID, nil)
+	if err != nil {
+		return defaultPaymentExists, defaultPaymentID, err
+	}
+
+	if customer.InvoiceSettings != nil && customer.InvoiceSettings.DefaultPaymentMethod != nil {
+		defaultPaymentExists = true
+		defaultPaymentID = customer.InvoiceSettings.DefaultPaymentMethod.ID
+	}
+
+	return defaultPaymentExists, defaultPaymentID, err
+}