Ver código fonte

Add check for payment enabled

Mauricio Araujo 2 anos atrás
pai
commit
7182d2857d

+ 1 - 1
api/server/handlers/billing/create.go

@@ -30,7 +30,7 @@ func NewCreateBillingHandler(
 }
 
 func (c *CreateBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ctx, span := telemetry.NewSpan(r.Context(), "auth-endpoint-api-token")
+	ctx, span := telemetry.NewSpan(r.Context(), "create-billing-endpoint")
 	defer span.End()
 
 	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)

+ 1 - 1
api/server/handlers/billing/customer.go

@@ -31,7 +31,7 @@ func NewCreateBillingCustomerIfNotExists(
 }
 
 func (c *CreateBillingCustomerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ctx, span := telemetry.NewSpan(r.Context(), "auth-endpoint-api-token")
+	ctx, span := telemetry.NewSpan(r.Context(), "create-billing-customer-endpoint")
 	defer span.End()
 
 	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)

+ 1 - 1
api/server/handlers/billing/delete.go

@@ -29,7 +29,7 @@ func NewDeleteBillingHandler(
 }
 
 func (c *DeleteBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ctx, span := telemetry.NewSpan(r.Context(), "auth-endpoint-api-token")
+	ctx, span := telemetry.NewSpan(r.Context(), "delete-billing-endpoint")
 	defer span.End()
 
 	paymentMethodID, reqErr := requestutils.GetURLParamString(r, types.URLParamPaymentMethodID)

+ 31 - 1
api/server/handlers/billing/list.go

@@ -18,6 +18,10 @@ type ListBillingHandler struct {
 	handlers.PorterHandlerWriter
 }
 
+type CheckPaymentEnabledHandler struct {
+	handlers.PorterHandlerWriter
+}
+
 // NewListBillingHandler will create a new ListBillingHandler
 func NewListBillingHandler(
 	config *config.Config,
@@ -29,7 +33,7 @@ func NewListBillingHandler(
 }
 
 func (c *ListBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ctx, span := telemetry.NewSpan(r.Context(), "auth-endpoint-api-token")
+	ctx, span := telemetry.NewSpan(r.Context(), "list-payment-endpoint")
 	defer span.End()
 
 	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
@@ -43,3 +47,29 @@ func (c *ListBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 
 	c.WriteResult(w, r, paymentMethods)
 }
+
+// NewCheckPaymentEnabledHandler will create a new CheckPaymentEnabledHandler
+func NewCheckPaymentEnabledHandler(
+	config *config.Config,
+	writer shared.ResultWriter,
+) *CheckPaymentEnabledHandler {
+	return &CheckPaymentEnabledHandler{
+		PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
+	}
+}
+
+func (c *CheckPaymentEnabledHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	ctx, span := telemetry.NewSpan(r.Context(), "check-payment-endpoint")
+	defer span.End()
+
+	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
+
+	paymentEnabled, err := c.Config().BillingManager.CheckPaymentEnabled(proj)
+	if err != nil {
+		err := telemetry.Error(ctx, span, err, "error checking if payment enabled")
+		c.HandleAPIError(w, r, apierrors.NewErrInternal(fmt.Errorf("error checking if payment enabled: %w", err)))
+		return
+	}
+
+	c.WriteResult(w, r, paymentEnabled)
+}

+ 27 - 0
api/server/router/project.go

@@ -285,6 +285,33 @@ func getProjectRoutes(
 		Router:   r,
 	})
 
+	// GET /api/projects/{project_id}/billing -> project.NewCheckPaymentEnabledHandler
+	checkPaymentEndpoint := factory.NewAPIEndpoint(
+		&types.APIRequestMetadata{
+			Verb:   types.APIVerbGet,
+			Method: types.HTTPVerbGet,
+			Path: &types.Path{
+				Parent:       basePath,
+				RelativePath: relPath + "/billing",
+			},
+			Scopes: []types.PermissionScope{
+				types.UserScope,
+				types.ProjectScope,
+			},
+		},
+	)
+
+	checkPaymentHandler := billing.NewCheckPaymentEnabledHandler(
+		config,
+		factory.GetResultWriter(),
+	)
+
+	routes = append(routes, &router.Route{
+		Endpoint: checkPaymentEndpoint,
+		Handler:  checkPaymentHandler,
+		Router:   r,
+	})
+
 	// GET /api/projects/{project_id}/billing/payment_method -> project.NewListBillingHandler
 	listBillingEndpoint := factory.NewAPIEndpoint(
 		&types.APIRequestMetadata{

+ 1 - 1
api/types/project.go

@@ -66,7 +66,7 @@ type Project struct {
 // retrieve feature flags from the `GET /projects/{project_id}` response instead
 type FeatureFlags struct {
 	AzureEnabled                    bool   `json:"azure_enabled,omitempty"`
-	BillingEnabled 									bool   `json:"billing_enabled,omitempty"`
+	BillingEnabled                  bool   `json:"billing_enabled,omitempty"`
 	CapiProvisionerEnabled          string `json:"capi_provisioner_enabled,omitempty"`
 	EnableReprovision               bool   `json:"enable_reprovision,omitempty"`
 	FullAddOns                      bool   `json:"full_add_ons,omitempty"`

+ 33 - 0
dashboard/src/lib/hooks/useStripe.tsx

@@ -22,6 +22,11 @@ type TCreatePaymentMethod = {
   createPaymentMethod: () => Promise<string>;
 };
 
+type TCheckHasPaymentEnabled = {
+  hasPaymentEnabled: boolean;
+  refetchPaymentEnabled: any;
+};
+
 export const usePaymentMethods = (): TUsePaymentMethod => {
   const { user, currentProject } = useContext(Context);
 
@@ -111,3 +116,31 @@ export const useCreatePaymentMethod = (): TCreatePaymentMethod => {
     createPaymentMethod,
   };
 };
+
+export const checkIfProjectHasPayment = (): TCheckHasPaymentEnabled => {
+  const { currentProject } = useContext(Context);
+
+  if (!currentProject?.id) {
+    throw new Error("Project ID is missing");
+  }
+
+  // Fetch list of payment methods
+  const paymentEnabledReq = useQuery(
+    ["checkPaymentEnabled", currentProject?.id],
+    async () => {
+      const res = await api.getHasBilling(
+        "<token>",
+        {},
+        { project_id: currentProject.id }
+      );
+
+      const data = z.boolean().parse(res.data);
+      return data;
+    }
+  );
+
+  return {
+    hasPaymentEnabled: paymentEnabledReq.data ?? false,
+    refetchPaymentEnabled: paymentEnabledReq.refetch,
+  };
+};

+ 1 - 38
dashboard/src/main/home/Home.tsx

@@ -1,4 +1,5 @@
 import React, { useContext, useEffect, useRef, useState } from "react";
+import { useStripe } from "@stripe/react-stripe-js";
 import { createPortal } from "react-dom";
 import {
   Route,
@@ -108,7 +109,6 @@ const Home: React.FC<Props> = (props) => {
     setHasFinishedOnboarding,
     setCurrentError,
     setCurrentModal,
-    setHasBillingEnabled,
     setUsage,
     setShouldRefreshClusters,
   } = useContext(Context);
@@ -262,46 +262,9 @@ const Home: React.FC<Props> = (props) => {
     }
   }, [shouldRefreshClusters]);
 
-  const checkIfProjectHasBilling = async (projectId: number) => {
-    if (!projectId) {
-      return false;
-    }
-    try {
-      const res = await api.getHasBilling(
-        "<token>",
-        {},
-        { project_id: projectId }
-      );
-      setHasBillingEnabled(res.data?.has_billing);
-      return res?.data?.has_billing;
-    } catch (error) {
-      console.log(error);
-    }
-  };
-
   useEffect(() => {
     getMetadata();
     checkOnboarding();
-    if (!process.env.DISABLE_BILLING) {
-      checkIfProjectHasBilling(currentProject?.id)
-        .then((isBillingEnabled) => {
-          if (isBillingEnabled) {
-            api
-              .getUsage("<token>", {}, { project_id: currentProject?.id })
-              .then((res) => {
-                const usage = res.data;
-                setUsage(usage);
-                /*
-                if (usage.exceeded) {
-                  setCurrentModal("UsageWarningModal", { usage });
-                }
-                */
-              })
-              .catch(console.log);
-          }
-        })
-        .catch(console.log);
-    }
   }, [props.currentProject?.id]);
 
   useEffect(() => {

+ 3 - 2
dashboard/src/main/home/app-dashboard/apps/Apps.tsx

@@ -21,6 +21,7 @@ import DeleteEnvModal from "main/home/cluster-dashboard/preview-environments/v2/
 import BillingModal from "main/home/modals/BillingModal";
 import { clientAddonFromProto, type ClientAddon } from "lib/addons";
 import { useAppAnalytics } from "lib/hooks/useAppAnalytics";
+import { checkIfProjectHasPayment } from "lib/hooks/useStripe";
 
 import api from "shared/api";
 import { Context } from "shared/Context";
@@ -45,6 +46,7 @@ const Apps: React.FC = () => {
   const { currentProject, currentCluster } = useContext(Context);
   const { updateAppStep } = useAppAnalytics();
   const { currentDeploymentTarget } = useDeploymentTarget();
+  const { hasPaymentEnabled } = checkIfProjectHasPayment();
   const history = useHistory();
 
   const [searchValue, setSearchValue] = useState("");
@@ -215,7 +217,7 @@ const Apps: React.FC = () => {
           <Spacer y={0.5} />
           <Text color={"helper"}>Get started by creating an application.</Text>
           <Spacer y={1} />
-          {currentProject?.billing_enabled ? (
+          {currentProject?.billing_enabled && !hasPaymentEnabled ? (
             <Button
               alt
               onClick={() => {
@@ -252,7 +254,6 @@ const Apps: React.FC = () => {
               onCreate={() => {
                 history.push("/apps/new/app");
               }}
-              project_id={currentProject?.id ?? -1}
             />
           )}
         </DashboardPlaceholder>

+ 14 - 9
dashboard/src/main/home/project-settings/BillingPage.tsx

@@ -8,7 +8,10 @@ import Fieldset from "components/porter/Fieldset";
 import Icon from "components/porter/Icon";
 import Spacer from "components/porter/Spacer";
 import Text from "components/porter/Text";
-import { usePaymentMethods } from "lib/hooks/useStripe";
+import {
+  checkIfProjectHasPayment,
+  usePaymentMethods,
+} from "lib/hooks/useStripe";
 
 import { Context } from "shared/Context";
 import cardIcon from "assets/credit-card.svg";
@@ -17,7 +20,7 @@ import trashIcon from "assets/trash.png";
 import BillingModal from "../modals/BillingModal";
 
 function BillingPage(): JSX.Element {
-  const { currentProject, setCurrentOverlay } = useContext(Context);
+  const { setCurrentOverlay } = useContext(Context);
   const [shouldCreate, setShouldCreate] = useState(false);
   const {
     paymentMethodList,
@@ -26,20 +29,22 @@ function BillingPage(): JSX.Element {
     isDeleting,
   } = usePaymentMethods();
 
+  const { refetchPaymentEnabled } = checkIfProjectHasPayment();
+
   const onCreate = async () => {
     setShouldCreate(false);
     refetchPaymentMethods();
+    refetchPaymentEnabled();
+  };
+
+  const onDelete = async (paymentMethodId: string) => {
+    deletePaymentMethod(paymentMethodId);
+    refetchPaymentEnabled();
   };
 
   if (shouldCreate) {
     return (
-      <BillingModal
-        onCreate={onCreate}
-        back={() => {
-          setShouldCreate(false);
-        }}
-        project_id={currentProject?.id}
-      />
+      <BillingModal onCreate={onCreate} back={() => setShouldCreate(false)} />
     );
   }
 

+ 6 - 6
dashboard/src/main/home/project-settings/ProjectSettings.tsx

@@ -85,12 +85,12 @@ function ProjectSettings(props: any) {
         });
       }
 
-      if (currentProject?.billing_enabled) {
-        tabOpts.push({
-          value: "billing",
-          label: "Billing",
-        });
-      }
+      // if (currentProject?.billing_enabled) {
+      tabOpts.push({
+        value: "billing",
+        label: "Billing",
+      });
+      // }
 
       tabOpts.push({
         value: "additional-settings",

+ 62 - 58
dashboard/src/shared/api.tsx

@@ -371,8 +371,9 @@ const getFeedEvents = baseApi<
   }
 >("GET", (pathParams) => {
   const { project_id, cluster_id, stack_name, page } = pathParams;
-  return `/api/projects/${project_id}/clusters/${cluster_id}/applications/${stack_name}/events?page=${page || 1
-    }`;
+  return `/api/projects/${project_id}/clusters/${cluster_id}/applications/${stack_name}/events?page=${
+    page || 1
+  }`;
 });
 
 const createEnvironment = baseApi<
@@ -860,9 +861,11 @@ const detectBuildpack = baseApi<
     branch: string;
   }
 >("GET", (pathParams) => {
-  return `/api/projects/${pathParams.project_id}/gitrepos/${pathParams.git_repo_id
-    }/repos/${pathParams.kind}/${pathParams.owner}/${pathParams.name
-    }/${encodeURIComponent(pathParams.branch)}/buildpack/detect`;
+  return `/api/projects/${pathParams.project_id}/gitrepos/${
+    pathParams.git_repo_id
+  }/repos/${pathParams.kind}/${pathParams.owner}/${
+    pathParams.name
+  }/${encodeURIComponent(pathParams.branch)}/buildpack/detect`;
 });
 
 const detectGitlabBuildpack = baseApi<
@@ -893,9 +896,11 @@ const getBranchContents = baseApi<
     branch: string;
   }
 >("GET", (pathParams) => {
-  return `/api/projects/${pathParams.project_id}/gitrepos/${pathParams.git_repo_id
-    }/repos/${pathParams.kind}/${pathParams.owner}/${pathParams.name
-    }/${encodeURIComponent(pathParams.branch)}/contents`;
+  return `/api/projects/${pathParams.project_id}/gitrepos/${
+    pathParams.git_repo_id
+  }/repos/${pathParams.kind}/${pathParams.owner}/${
+    pathParams.name
+  }/${encodeURIComponent(pathParams.branch)}/contents`;
 });
 
 const getProcfileContents = baseApi<
@@ -911,9 +916,11 @@ const getProcfileContents = baseApi<
     branch: string;
   }
 >("GET", (pathParams) => {
-  return `/api/projects/${pathParams.project_id}/gitrepos/${pathParams.git_repo_id
-    }/repos/${pathParams.kind}/${pathParams.owner}/${pathParams.name
-    }/${encodeURIComponent(pathParams.branch)}/procfile`;
+  return `/api/projects/${pathParams.project_id}/gitrepos/${
+    pathParams.git_repo_id
+  }/repos/${pathParams.kind}/${pathParams.owner}/${
+    pathParams.name
+  }/${encodeURIComponent(pathParams.branch)}/procfile`;
 });
 
 const getPorterYamlContents = baseApi<
@@ -929,9 +936,11 @@ const getPorterYamlContents = baseApi<
     branch: string;
   }
 >("GET", (pathParams) => {
-  return `/api/projects/${pathParams.project_id}/gitrepos/${pathParams.git_repo_id
-    }/repos/${pathParams.kind}/${pathParams.owner}/${pathParams.name
-    }/${encodeURIComponent(pathParams.branch)}/porteryaml`;
+  return `/api/projects/${pathParams.project_id}/gitrepos/${
+    pathParams.git_repo_id
+  }/repos/${pathParams.kind}/${pathParams.owner}/${
+    pathParams.name
+  }/${encodeURIComponent(pathParams.branch)}/porteryaml`;
 });
 
 const parsePorterYaml = baseApi<
@@ -991,30 +1000,32 @@ const getBranchHead = baseApi<
     branch: string;
   }
 >("GET", (pathParams) => {
-  return `/api/projects/${pathParams.project_id}/gitrepos/${pathParams.git_repo_id
-    }/repos/${pathParams.kind}/${pathParams.owner}/${pathParams.name
-    }/${encodeURIComponent(pathParams.branch)}/head`;
+  return `/api/projects/${pathParams.project_id}/gitrepos/${
+    pathParams.git_repo_id
+  }/repos/${pathParams.kind}/${pathParams.owner}/${
+    pathParams.name
+  }/${encodeURIComponent(pathParams.branch)}/head`;
 });
 
 const createApp = baseApi<
   | {
-    name: string;
-    deployment_target_id: string;
-    type: "github";
-    git_repo_id: number;
-    git_branch: string;
-    git_repo_name: string;
-    porter_yaml_path: string;
-  }
+      name: string;
+      deployment_target_id: string;
+      type: "github";
+      git_repo_id: number;
+      git_branch: string;
+      git_repo_name: string;
+      porter_yaml_path: string;
+    }
   | {
-    name: string;
-    deployment_target_id: string;
-    type: "docker-registry";
-    image: {
-      repository: string;
-      tag: string;
-    };
-  },
+      name: string;
+      deployment_target_id: string;
+      type: "docker-registry";
+      image: {
+        repository: string;
+        tag: string;
+      };
+    },
   {
     project_id: number;
     cluster_id: number;
@@ -2229,9 +2240,11 @@ const getEnvGroup = baseApi<
     version?: number;
   }
 >("GET", (pathParams) => {
-  return `/api/projects/${pathParams.id}/clusters/${pathParams.cluster_id
-    }/namespaces/${pathParams.namespace}/envgroup?name=${pathParams.name}${pathParams.version ? "&version=" + pathParams.version : ""
-    }`;
+  return `/api/projects/${pathParams.id}/clusters/${
+    pathParams.cluster_id
+  }/namespaces/${pathParams.namespace}/envgroup?name=${pathParams.name}${
+    pathParams.version ? "&version=" + pathParams.version : ""
+  }`;
 });
 
 const getConfigMap = baseApi<
@@ -2561,17 +2574,6 @@ const getUsage = baseApi<{}, { project_id: number }>(
   ({ project_id }) => `/api/projects/${project_id}/usage`
 );
 
-// Used for billing purposes
-const getCustomerToken = baseApi<{}, { project_id: number }>(
-  "GET",
-  ({ project_id }) => `/api/projects/${project_id}/billing/token`
-);
-
-const getHasBilling = baseApi<{}, { project_id: number }>(
-  "GET",
-  ({ project_id }) => `/api/projects/${project_id}/billing`
-);
-
 const getOnboardingState = baseApi<{}, { project_id: number }>(
   "GET",
   ({ project_id }) => `/api/projects/${project_id}/onboarding`
@@ -3427,16 +3429,18 @@ const removeStackEnvGroup = baseApi<
 // Billing
 const checkBillingCustomerExists = baseApi<
   {
-    user_email?: string,
+    user_email?: string;
   },
   {
     project_id?: number;
   }
->(
-  "POST",
-  ({ project_id }) =>
-    `/api/projects/${project_id}/billing/customer`
+>("POST", ({ project_id }) => `/api/projects/${project_id}/billing/customer`);
+
+const getHasBilling = baseApi<{}, { project_id: number }>(
+  "GET",
+  ({ project_id }) => `/api/projects/${project_id}/billing`
 );
+
 const listPaymentMethod = baseApi<
   {},
   {
@@ -3444,9 +3448,9 @@ const listPaymentMethod = baseApi<
   }
 >(
   "GET",
-  ({ project_id }) =>
-    `/api/projects/${project_id}/billing/payment_method`
+  ({ project_id }) => `/api/projects/${project_id}/billing/payment_method`
 );
+
 const addPaymentMethod = baseApi<
   {},
   {
@@ -3454,9 +3458,9 @@ const addPaymentMethod = baseApi<
   }
 >(
   "POST",
-  ({ project_id }) =>
-    `/api/projects/${project_id}/billing/payment_method`
+  ({ project_id }) => `/api/projects/${project_id}/billing/payment_method`
 );
+
 const updatePaymentMethod = baseApi<
   {},
   {
@@ -3468,6 +3472,7 @@ const updatePaymentMethod = baseApi<
   ({ project_id, payment_method_id }) =>
     `/api/projects/${project_id}/billing/payment_method/${payment_method_id}`
 );
+
 const deletePaymentMethod = baseApi<
   {},
   {
@@ -3480,7 +3485,7 @@ const deletePaymentMethod = baseApi<
     `/api/projects/${project_id}/billing/payment_method/${payment_method_id}`
 );
 
-const getGithubStatus = baseApi<{}, {}>("GET", ({ }) => `/api/status/github`);
+const getGithubStatus = baseApi<{}, {}>("GET", ({}) => `/api/status/github`);
 
 const createSecretAndOpenGitHubPullRequest = baseApi<
   {
@@ -3744,7 +3749,6 @@ export default {
   getPolicyDocument,
   createWebhookToken,
   getUsage,
-  getCustomerToken,
   getHasBilling,
   getOnboardingState,
   saveOnboardingState,

+ 7 - 0
internal/billing/billing.go

@@ -14,6 +14,9 @@ type BillingManager interface {
 	// DeleteCustomer will delete the customer from the billing provider
 	DeleteCustomer(proj *models.Project) (err error)
 
+	// CheckPaymentEnabled will check if the project has a payment method configured
+	CheckPaymentEnabled(proj *models.Project) (paymentEnabled bool, err error)
+
 	// ListPaymentMethod will return all payment methods for the project
 	ListPaymentMethod(proj *models.Project) (paymentMethods []types.PaymentMethod, err error)
 
@@ -36,6 +39,10 @@ func (s *NoopBillingManager) DeleteCustomer(proj *models.Project) (err error) {
 	return nil
 }
 
+func (s *NoopBillingManager) CheckPaymentEnabled(proj *models.Project) (paymentEnabled bool, err error) {
+	return false, nil
+}
+
 // ListPaymentMethod is a no-op
 func (s *NoopBillingManager) ListPaymentMethod(proj *models.Project) (paymentMethods []types.PaymentMethod, err error) {
 	return []types.PaymentMethod{}, nil

+ 13 - 0
internal/billing/stripe.go

@@ -56,6 +56,19 @@ func (s *StripeBillingManager) DeleteCustomer(proj *models.Project) (err error)
 	return nil
 }
 
+// CheckPaymentEnabled will return true if the project has a payment method added, false otherwise
+func (s *StripeBillingManager) CheckPaymentEnabled(proj *models.Project) (paymentEnabled bool, err error) {
+	stripe.Key = s.StripeSecretKey
+
+	params := &stripe.PaymentMethodListParams{
+		Customer: stripe.String(proj.BillingID),
+		Type:     stripe.String(string(stripe.PaymentMethodTypeCard)),
+	}
+	result := paymentmethod.List(params)
+
+	return result.Next(), nil
+}
+
 // ListPaymentMethod will return all payment methods for the project
 func (s *StripeBillingManager) ListPaymentMethod(proj *models.Project) (paymentMethods []types.PaymentMethod, err error) {
 	stripe.Key = s.StripeSecretKey