Prechádzať zdrojové kódy

Finish logic for creating referrals and granting rewards

Mauricio Araujo 2 rokov pred
rodič
commit
18315fe5ef

+ 128 - 0
api/server/handlers/billing/credits.go

@@ -0,0 +1,128 @@
+package billing
+
+import (
+	"net/http"
+
+	"github.com/porter-dev/porter/api/server/handlers"
+	"github.com/porter-dev/porter/api/server/shared"
+	"github.com/porter-dev/porter/api/server/shared/apierrors"
+	"github.com/porter-dev/porter/api/server/shared/config"
+	"github.com/porter-dev/porter/api/types"
+	"github.com/porter-dev/porter/internal/models"
+	"github.com/porter-dev/porter/internal/telemetry"
+)
+
+const (
+	// referralRewardRequirement is the number of referred users required to
+	// be granted a credits reward
+	referralRewardRequirement = 5
+	// defaultRewardAmountUSD is the default amount in USD rewarded to users
+	// who reach the reward requirement
+	defaultRewardAmountUSD = 20
+	// defaultPaidAmountUSD is the amount paid by the user to get the credits
+	// grant, if set to 0 it means they were free
+	defaultPaidAmountUSD = 0
+)
+
+// ListCreditsHandler is a handler for getting available credits
+type ListCreditsHandler struct {
+	handlers.PorterHandlerWriter
+}
+
+// NewListCreditsHandler will create a new ListCreditsHandler
+func NewListCreditsHandler(
+	config *config.Config,
+	writer shared.ResultWriter,
+) *ListCreditsHandler {
+	return &ListCreditsHandler{
+		PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
+	}
+}
+
+func (c *ListCreditsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	ctx, span := telemetry.NewSpan(r.Context(), "serve-list-credits")
+	defer span.End()
+
+	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
+
+	if !c.Config().BillingManager.MetronomeConfigLoaded || !proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient) {
+		c.WriteResult(w, r, "")
+
+		telemetry.WithAttributes(span,
+			telemetry.AttributeKV{Key: "metronome-config-exists", Value: c.Config().BillingManager.MetronomeConfigLoaded},
+			telemetry.AttributeKV{Key: "metronome-enabled", Value: proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient)},
+		)
+		return
+	}
+
+	credits, err := c.Config().BillingManager.MetronomeClient.ListCustomerCredits(ctx, proj.UsageID)
+	if err != nil {
+		err := telemetry.Error(ctx, span, err, "error listing credits")
+		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+		return
+	}
+
+	telemetry.WithAttributes(span,
+		telemetry.AttributeKV{Key: "metronome-enabled", Value: true},
+		telemetry.AttributeKV{Key: "usage-id", Value: proj.UsageID},
+	)
+
+	c.WriteResult(w, r, credits)
+}
+
+// ClaimReferralRewardHandler is a handler for granting credits
+type ClaimReferralRewardHandler struct {
+	handlers.PorterHandlerWriter
+}
+
+// NewClaimReferralReward will create a new GrantCreditsHandler
+func NewClaimReferralReward(
+	config *config.Config,
+	writer shared.ResultWriter,
+) *ClaimReferralRewardHandler {
+	return &ClaimReferralRewardHandler{
+		PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
+	}
+}
+
+func (c *ClaimReferralRewardHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	ctx, span := telemetry.NewSpan(r.Context(), "serve-claim-credits-reward")
+	defer span.End()
+
+	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
+	user, _ := ctx.Value(types.UserScope).(*models.User)
+
+	if !c.Config().BillingManager.MetronomeConfigLoaded || !proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient) {
+		c.WriteResult(w, r, "")
+
+		telemetry.WithAttributes(span,
+			telemetry.AttributeKV{Key: "metronome-config-exists", Value: c.Config().BillingManager.MetronomeConfigLoaded},
+			telemetry.AttributeKV{Key: "metronome-enabled", Value: proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient)},
+		)
+		return
+	}
+
+	telemetry.WithAttributes(span,
+		telemetry.AttributeKV{Key: "metronome-enabled", Value: true},
+		telemetry.AttributeKV{Key: "usage-id", Value: proj.UsageID},
+		telemetry.AttributeKV{Key: "referral-code", Value: user.ReferralCode},
+		telemetry.AttributeKV{Key: "referral-reward-received", Value: user.ReferralRewardClaimed},
+	)
+
+	if !user.ReferralRewardClaimed {
+		err := c.Config().BillingManager.MetronomeClient.CreateCreditsGrant(ctx, proj.UsageID, defaultRewardAmountUSD, defaultPaidAmountUSD)
+		if err != nil {
+			c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
+			return
+		}
+
+		user.ReferralRewardClaimed = true
+		_, err = c.Repo().User().UpdateUser(user)
+		if err != nil {
+			c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
+			return
+		}
+	}
+
+	c.WriteResult(w, r, "")
+}

+ 0 - 46
api/server/handlers/billing/plan.go

@@ -58,52 +58,6 @@ func (c *ListPlansHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	c.WriteResult(w, r, plan)
 }
 
-// ListCreditsHandler is a handler for getting available credits
-type ListCreditsHandler struct {
-	handlers.PorterHandlerWriter
-}
-
-// NewListCreditsHandler will create a new ListCreditsHandler
-func NewListCreditsHandler(
-	config *config.Config,
-	writer shared.ResultWriter,
-) *ListCreditsHandler {
-	return &ListCreditsHandler{
-		PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
-	}
-}
-
-func (c *ListCreditsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ctx, span := telemetry.NewSpan(r.Context(), "serve-list-credits")
-	defer span.End()
-
-	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
-
-	if !c.Config().BillingManager.MetronomeConfigLoaded || !proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient) {
-		c.WriteResult(w, r, "")
-
-		telemetry.WithAttributes(span,
-			telemetry.AttributeKV{Key: "metronome-config-exists", Value: c.Config().BillingManager.MetronomeConfigLoaded},
-			telemetry.AttributeKV{Key: "metronome-enabled", Value: proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient)},
-		)
-		return
-	}
-
-	credits, err := c.Config().BillingManager.MetronomeClient.ListCustomerCredits(ctx, proj.UsageID)
-	if err != nil {
-		err := telemetry.Error(ctx, span, err, "error listing credits")
-		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-		return
-	}
-
-	telemetry.WithAttributes(span,
-		telemetry.AttributeKV{Key: "metronome-enabled", Value: true},
-		telemetry.AttributeKV{Key: "usage-id", Value: proj.UsageID},
-	)
-
-	c.WriteResult(w, r, credits)
-}
-
 // ListCustomerUsageHandler returns customer usage aggregations like CPU and RAM hours.
 type ListCustomerUsageHandler struct {
 	handlers.PorterHandlerReadWriter

+ 1 - 1
api/server/handlers/user/create.go

@@ -108,7 +108,7 @@ func (u *UserCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 			ReferredUserID: user.ID,
 		}
 
-		_, err := u.Repo().Referral().CreateReferral(referral)
+		_, err = u.Repo().Referral().CreateReferral(referral)
 		if err != nil {
 			u.HandleAPIErrorNoWrite(w, r, apierrors.NewErrInternal(err))
 		}

+ 51 - 9
api/server/handlers/user/referrals.go

@@ -5,7 +5,11 @@ import (
 
 	"github.com/porter-dev/porter/api/server/handlers"
 	"github.com/porter-dev/porter/api/server/shared"
+	"github.com/porter-dev/porter/api/server/shared/apierrors"
 	"github.com/porter-dev/porter/api/server/shared/config"
+	"github.com/porter-dev/porter/api/types"
+	"github.com/porter-dev/porter/internal/models"
+	"github.com/porter-dev/porter/internal/telemetry"
 )
 
 // ListUserReferralsHandler is a handler for getting a list of user referrals
@@ -24,16 +28,54 @@ func NewListUserReferralsHandler(
 }
 
 func (u *ListUserReferralsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	// ctx, span := telemetry.NewSpan(r.Context(), "serve-list-user-referrals")
-	// defer span.End()
+	ctx, span := telemetry.NewSpan(r.Context(), "serve-list-user-referrals")
+	defer span.End()
 
-	// user, _ := ctx.Value(types.UserScope).(*models.User)
+	user, _ := ctx.Value(types.UserScope).(*models.User)
 
-	// referrals, err := u.Repo().Referral().ListReferralsByUserID(user.ID)
-	// if err != nil {
-	// 	u.HandleAPIError(w, r, err)
-	// 	return
-	// }
+	referralCount, err := u.Repo().Referral().GetReferralCountByUserID(user.ID)
+	if err != nil {
+		u.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
+		return
+	}
+
+	referralResponse := struct {
+		ReferralCount int `json:"count"`
+	}{
+		ReferralCount: referralCount,
+	}
+
+	u.WriteResult(w, r, referralResponse)
+}
+
+// GetUserReferralDetailsHandler is a handler for getting a user's referral code
+type GetUserReferralDetailsHandler struct {
+	handlers.PorterHandlerWriter
+}
+
+// NewGetUserReferralDetailsHandler returns an instance of GetUserReferralCodeHandler
+func NewGetUserReferralDetailsHandler(
+	config *config.Config,
+	writer shared.ResultWriter,
+) *GetUserReferralDetailsHandler {
+	return &GetUserReferralDetailsHandler{
+		PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
+	}
+}
+
+func (u *GetUserReferralDetailsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	ctx, span := telemetry.NewSpan(r.Context(), "serve-get-user-referral-details")
+	defer span.End()
+
+	user, _ := ctx.Value(types.UserScope).(*models.User)
+
+	referralCodeResponse := struct {
+		Code          string `json:"code"`
+		RewardClaimed bool   `json:"reward_claimed"`
+	}{
+		Code:          user.ReferralCode,
+		RewardClaimed: user.ReferralRewardClaimed,
+	}
 
-	u.WriteResult(w, r, "")
+	u.WriteResult(w, r, referralCodeResponse)
 }

+ 27 - 0
api/server/router/project.go

@@ -397,6 +397,33 @@ func getProjectRoutes(
 		Router:   r,
 	})
 
+	// GET /api/projects/{project_id}/billing/credits/claim_referral -> project.NewGetCreditsHandler
+	claimReferralRewardEndpoint := factory.NewAPIEndpoint(
+		&types.APIRequestMetadata{
+			Verb:   types.APIVerbCreate,
+			Method: types.HTTPVerbPost,
+			Path: &types.Path{
+				Parent:       basePath,
+				RelativePath: relPath + "/billing/credits/claim_referral",
+			},
+			Scopes: []types.PermissionScope{
+				types.UserScope,
+				types.ProjectScope,
+			},
+		},
+	)
+
+	claimReferralRewardHandler := billing.NewClaimReferralReward(
+		config,
+		factory.GetResultWriter(),
+	)
+
+	routes = append(routes, &router.Route{
+		Endpoint: claimReferralRewardEndpoint,
+		Handler:  claimReferralRewardHandler,
+		Router:   r,
+	})
+
 	// POST /api/projects/{project_id}/billing/usage -> project.NewListCustomerUsageHandler
 	listCustomerUsageEndpoint := factory.NewAPIEndpoint(
 		&types.APIRequestMetadata{

+ 26 - 2
api/server/router/user.go

@@ -472,14 +472,14 @@ func getUserRoutes(
 		Router:   r,
 	})
 
-	// GET /api/users/referrals -> user.NewListUserReferralsHandler
+	// GET /api/referrals -> user.NewListUserReferralsHandler
 	listReferralsEndpoint := factory.NewAPIEndpoint(
 		&types.APIRequestMetadata{
 			Verb:   types.APIVerbGet,
 			Method: types.HTTPVerbGet,
 			Path: &types.Path{
 				Parent:       basePath,
-				RelativePath: "/users/referrals",
+				RelativePath: "/referrals",
 			},
 			Scopes: []types.PermissionScope{types.UserScope},
 		},
@@ -496,5 +496,29 @@ func getUserRoutes(
 		Router:   r,
 	})
 
+	// GET /api/referrals/details -> user.NewGetUserReferralDetailsHandler
+	getReferralDetailsEndpoint := factory.NewAPIEndpoint(
+		&types.APIRequestMetadata{
+			Verb:   types.APIVerbGet,
+			Method: types.HTTPVerbGet,
+			Path: &types.Path{
+				Parent:       basePath,
+				RelativePath: "/referrals/details",
+			},
+			Scopes: []types.PermissionScope{types.UserScope},
+		},
+	)
+
+	getReferralDetailsHandler := user.NewGetUserReferralDetailsHandler(
+		config,
+		factory.GetResultWriter(),
+	)
+
+	routes = append(routes, &router.Route{
+		Endpoint: getReferralDetailsEndpoint,
+		Handler:  getReferralDetailsHandler,
+		Router:   r,
+	})
+
 	return routes
 }

+ 18 - 24
api/types/billing_metronome.go

@@ -55,6 +55,18 @@ type EndCustomerPlanRequest struct {
 	VoidStripeInvoices bool `json:"void_stripe_invoices"`
 }
 
+// CreateCreditsGrantRequest is the request to create a credit grant for a customer
+type CreateCreditsGrantRequest struct {
+	// CustomerID is the id of the customer
+	CustomerID    uuid.UUID   `json:"customer_id"`
+	UniquenessKey string      `json:"uniqueness_key"`
+	GrantAmount   GrantAmount `json:"grant_amount"`
+	PaidAmount    PaidAmount  `json:"paid_amount"`
+	Name          string      `json:"name"`
+	ExpiresAt     string      `json:"expires_at"`
+	Priority      int         `json:"priority"`
+}
+
 // ListCreditGrantsRequest is the request to list a user's credit grants. Note that only one of
 // CreditTypeIDs, CustomerIDs, or CreditGrantIDs must be specified.
 type ListCreditGrantsRequest struct {
@@ -73,18 +85,6 @@ type ListCreditGrantsResponse struct {
 	GrantedCredits   float64 `json:"granted_credits"`
 }
 
-// EmbeddableDashboardRequest requests an embeddable customer dashboard to Metronome
-type EmbeddableDashboardRequest struct {
-	// CustomerID is the id of the customer
-	CustomerID uuid.UUID `json:"customer_id,omitempty"`
-	// DashboardType is the type of dashboard to retrieve
-	DashboardType string `json:"dashboard"`
-	// Options are optional dashboard specific options
-	Options []DashboardOption `json:"dashboard_options,omitempty"`
-	//  ColorOverrides is an optional list of colors to override
-	ColorOverrides []ColorOverride `json:"color_overrides,omitempty"`
-}
-
 // ListCustomerUsageRequest is the request to list usage for a customer
 type ListCustomerUsageRequest struct {
 	CustomerID       uuid.UUID `json:"customer_id"`
@@ -144,6 +144,12 @@ type GrantAmount struct {
 	CreditType CreditType `json:"credit_type"`
 }
 
+// PaidAmount represents the amount paid by the customer
+type PaidAmount struct {
+	Amount       float64   `json:"amount"`
+	CreditTypeID uuid.UUID `json:"credit_type_id"`
+}
+
 // Balance represents the effective balance of the grant as of the end of the customer's
 // current billing period.
 type Balance struct {
@@ -166,18 +172,6 @@ type CreditGrant struct {
 	ExpiresAt   string      `json:"expires_at"`
 }
 
-// DashboardOption are optional dashboard specific options
-type DashboardOption struct {
-	Key   string `json:"key"`
-	Value string `json:"value"`
-}
-
-// ColorOverride is an optional list of colors to override
-type ColorOverride struct {
-	Name  string `json:"name"`
-	Value string `json:"value"`
-}
-
 // BillingEvent represents a Metronome billing event.
 type BillingEvent struct {
 	CustomerID    string                 `json:"customer_id"`

+ 3 - 0
api/types/user.go

@@ -10,6 +10,9 @@ type User struct {
 
 	// ReferralCode is a unique code that can be shared to referr other users to Porter
 	ReferralCode string `json:"referral_code"`
+	// ReferralRewardClaimed indicates if the user has already received a credits reward
+	// for referring users
+	ReferralRewardClaimed bool `json:"referral_reward_received"`
 }
 
 type CreateUserRequest struct {

+ 10 - 0
dashboard/src/lib/billing/types.tsx

@@ -50,3 +50,13 @@ export const CreditGrantsValidator = z.object({
 });
 
 export const ClientSecretResponse = z.string();
+
+export type ReferralDetails = z.infer<typeof ReferralDetailsValidator>;
+export const ReferralDetailsValidator = z.object({
+  code: z.string(),
+  reward_claimed: z.boolean(),
+}).nullable();
+
+export const ReferralsValidator = z.object({
+  count: z.number(),
+}).nullable();

+ 85 - 14
dashboard/src/lib/hooks/useStripe.tsx

@@ -1,5 +1,5 @@
 import { useContext, useState } from "react";
-import { useQuery, type UseQueryResult } from "@tanstack/react-query";
+import { useQuery, useMutation, type UseQueryResult } from "@tanstack/react-query";
 import { z } from "zod";
 
 import {
@@ -13,6 +13,9 @@ import {
   type PaymentMethod,
   type PaymentMethodList,
   type UsageList,
+  ReferralsValidator,
+  ReferralDetailsValidator,
+  ReferralDetails
 } from "lib/billing/types";
 
 import api from "shared/api";
@@ -60,6 +63,14 @@ type TGetUsage = {
   usage: UsageList | null;
 };
 
+type TGetReferralDetails = {
+  referralDetails: ReferralDetails
+};
+
+type TGetReferrals = {
+  referralsCount: number | null;
+};
+
 export const usePaymentMethods = (): TUsePaymentMethod => {
   const { currentProject } = useContext(Context);
 
@@ -368,36 +379,96 @@ export const useCustomerUsage = (
   };
 };
 
-export const useReferrals = (): TGetPlan => {
-  const { currentProject, user } = useContext(Context);
+export const useReferrals = (): TGetReferrals => {
+  const { currentProject } = useContext(Context);
 
-  // Fetch current plan
-  const planReq = useQuery(
-    ["getReferrals", user?.id],
-    async (): Promise<Plan | null> => {
-      if (!currentProject?.billing_enabled) {
+  // Fetch referrals count
+  const referralsReq = useQuery(
+    ["getReferrals", currentProject?.id],
+    async (): Promise<number | null> => {
+      if (!currentProject?.metronome_enabled) {
         return null;
       }
 
-      if (!user?.id) {
+      try {
+        const res = await api.getReferrals(
+          "<token>",
+          {},
+          {}
+        );
+
+        const referrals = ReferralsValidator.parse(res.data);
+        return referrals?.count ?? null;
+      } catch (error) {
+        return null
+      }
+    });
+
+  return {
+    referralsCount: referralsReq.data ?? null,
+  };
+};
+
+export const useReferralDetails = (): TGetReferralDetails => {
+  const { currentProject } = useContext(Context);
+
+  // Fetch user's referral code
+  const referralsReq = useQuery(
+    ["getReferralDetails", currentProject?.id],
+    async (): Promise<ReferralDetails | null> => {
+      if (!currentProject?.metronome_enabled) {
+        return null;
+      }
+
+      if (!currentProject?.id || currentProject.id === -1) {
         return null;
       }
 
       try {
-        const res = await api.getReferrals(
+        const res = await api.getReferralDetails(
           "<token>",
           {},
-          { user_id: user.id }
+          {}
         );
 
-        const referrals = PlanValidator.parse(res.data);
-        return referrals;
+        const referraldetails = ReferralDetailsValidator.parse(res.data);
+        return referraldetails;
       } catch (error) {
         return null
       }
     });
 
   return {
-    plan: planReq.data ?? null,
+    referralDetails: referralsReq.data ?? null,
   };
+};
+
+export const useClaimReferralReward = (): (() => void) => {
+  const { currentProject } = useContext(Context);
+
+  // Apply credits reward to this project
+  const referralsReq = useMutation(
+    ["claimReferralReward", currentProject?.id],
+    async (): Promise<void> => {
+      if (!currentProject?.metronome_enabled) {
+        return;
+      }
+
+      if (!currentProject?.id || currentProject.id === -1) {
+        return;
+      }
+
+      try {
+        await api.claimReferralReward(
+          "<token>",
+          {},
+          { project_id: currentProject?.id }
+        );
+      } catch (error) {
+        return;
+      }
+    });
+
+  // Return a function that can be called to execute the mutation
+  return () => referralsReq.mutate();
 };

+ 55 - 6
dashboard/src/main/home/project-settings/ReferralsPage.tsx

@@ -1,11 +1,62 @@
-import React, { useContext } from "react";
+import React from "react";
 import Spacer from "components/porter/Spacer";
 import Text from "components/porter/Text";
-import { Context } from "shared/Context";
-import Fieldset from "components/porter/Fieldset";
+import { useClaimReferralReward, useReferralDetails, useReferrals } from "lib/hooks/useStripe";
+import Button from "components/porter/Button";
 
 function ReferralsPage(): JSX.Element {
+    const referralRewardRequirement = 5;
+    const { referralDetails } = useReferralDetails();
+    const { referralsCount } = useReferrals();
+    const claimReferralReward = useClaimReferralReward();
 
+    const eligibleForReward = (): boolean => {
+        if (referralsCount === null) {
+            return false;
+        }
+
+        return referralsCount >= referralRewardRequirement;
+    }
+
+    const claimReward = (): void => {
+        claimReferralReward();
+    }
+
+    const displayReferral = (): JSX.Element => {
+        if (referralDetails === null || referralsCount === null) {
+            return <></>
+        }
+
+        if (!eligibleForReward()) {
+            return (
+                <>
+                    <Text>
+                        Refer {referralRewardRequirement - referralsCount} more people to earn a reward.
+                    </Text>
+                    <Spacer y={1} />
+                </>
+            )
+        }
+
+        if (referralDetails?.reward_claimed) {
+            return (
+                <>
+                    <Text>
+                        You have already claimed a reward for referring people to Porter.
+                    </Text>
+                    <Spacer y={1} />
+                </>
+            )
+        }
+
+        return (
+            <>
+                <Text>You are elegible for claiming a reward on this project.</Text>
+                <Spacer y={0.5} />
+                <Button onClick={claimReward}>Claim Reward</Button>
+            </>
+        )
+    }
 
     return (
         <>
@@ -15,9 +66,7 @@ function ReferralsPage(): JSX.Element {
                 Refer people to Porter to earn credits.
             </Text>
             <Spacer y={1} />
-            <Text>
-                Your referral code is {user?.referralCode}
-            </Text>
+            {displayReferral()}
             <Spacer y={1} />
         </>
     )

+ 1 - 1
dashboard/src/shared/Context.tsx

@@ -139,7 +139,7 @@ class ContextProvider extends Component<PropsType, StateType> {
     user: null,
     setUser: (userId: number, email: string) => {
       this.setState({
-        user: { userId, email, isPorterUser: email?.endsWith("@porter.run"), referralCode: referralCode },
+        user: { userId, email, isPorterUser: email?.endsWith("@porter.run") },
       });
       if (window.intercomSettings) {
         window.intercomSettings["Porter User ID"] = userId;

+ 23 - 5
dashboard/src/shared/api.tsx

@@ -3579,15 +3579,31 @@ const deletePaymentMethod = baseApi<
 
 const getReferrals = baseApi<
   {},
-  {
-    user_id?: number;
-  }
+  {}
+>(
+  "GET",
+  () =>
+    `/api/referrals`
+);
+
+const getReferralDetails = baseApi<
+  {},
+  {}
 >(
   "GET",
-  ({ user_id }) =>
-    `/api/users/${user_id}/referrals`
+  () =>
+    `/api/referrals/details`
 );
 
+const claimReferralReward = baseApi<
+  {},
+  {
+    project_id?: number;
+  }
+>(
+  "POST",
+  ({ project_id }) => `/api/projects/${project_id}/billing/credits/claim_referral`
+);
 const getGithubStatus = baseApi<{}, {}>("GET", ({ }) => `/api/status/github`);
 
 const createSecretAndOpenGitHubPullRequest = baseApi<
@@ -3982,6 +3998,8 @@ export default {
   setDefaultPaymentMethod,
   deletePaymentMethod,
   getReferrals,
+  getReferralDetails,
+  claimReferralReward,
 
   // STATUS
   getGithubStatus,

+ 38 - 0
internal/billing/metronome.go

@@ -242,6 +242,44 @@ func (m MetronomeClient) ListCustomerCredits(ctx context.Context, customerID uui
 	return response, nil
 }
 
+func (m MetronomeClient) CreateCreditsGrant(ctx context.Context, customerID uuid.UUID, grantAmount float64, paidAmount float64) (err error) {
+	ctx, span := telemetry.NewSpan(ctx, "create-credits-grant")
+	defer span.End()
+
+	if customerID == uuid.Nil {
+		return telemetry.Error(ctx, span, err, "customer id empty")
+	}
+
+	path := "credits/createGrant"
+
+	req := types.CreateCreditsGrantRequest{
+		CustomerID:    customerID,
+		UniquenessKey: "porter-credits",
+		GrantAmount: types.GrantAmount{
+			Amount:     grantAmount,
+			CreditType: types.CreditType{},
+		},
+		PaidAmount: types.PaidAmount{
+			Amount:       paidAmount,
+			CreditTypeID: uuid.UUID{},
+		},
+		Name:      "Porter Credits",
+		ExpiresAt: "", // never expires
+		Priority:  1,
+	}
+
+	var result struct {
+		Data []types.CreditGrant `json:"data"`
+	}
+
+	_, err = m.do(http.MethodPost, path, req, &result)
+	if err != nil {
+		return telemetry.Error(ctx, span, err, "failed to create credits grant")
+	}
+
+	return nil
+}
+
 // ListCustomerUsage will return the aggregated usage for a customer
 func (m MetronomeClient) ListCustomerUsage(ctx context.Context, customerID uuid.UUID, startingOn string, endingBefore string, windowsSize string, currentPeriod bool) (usage []types.Usage, err error) {
 	ctx, span := telemetry.NewSpan(ctx, "list-customer-usage")

+ 11 - 7
internal/models/user.go

@@ -26,6 +26,9 @@ type User struct {
 
 	// ReferralCode is a unique code that can be shared to referr other users to Porter
 	ReferralCode string
+	// ReferralRewardClaimed indicates if the user has already received a credits reward
+	// for referring users
+	ReferralRewardClaimed bool
 
 	Referrals []Referral `json:"referrals"`
 }
@@ -33,12 +36,13 @@ type User struct {
 // ToUserType generates an external types.User to be shared over REST
 func (u *User) ToUserType() *types.User {
 	return &types.User{
-		ID:            u.ID,
-		Email:         u.Email,
-		EmailVerified: u.EmailVerified,
-		FirstName:     u.FirstName,
-		LastName:      u.LastName,
-		CompanyName:   u.CompanyName,
-		ReferralCode:  u.ReferralCode,
+		ID:                    u.ID,
+		Email:                 u.Email,
+		EmailVerified:         u.EmailVerified,
+		FirstName:             u.FirstName,
+		LastName:              u.LastName,
+		CompanyName:           u.CompanyName,
+		ReferralCode:          u.ReferralCode,
+		ReferralRewardClaimed: u.ReferralRewardClaimed,
 	}
 }

+ 8 - 0
internal/repository/gorm/referrals.go

@@ -37,3 +37,11 @@ func (repo *ReferralRepository) CreateReferral(referral *models.Referral) (*mode
 
 	return referral, nil
 }
+
+func (repo *ReferralRepository) GetReferralCountByUserID(userID uint) (int, error) {
+	referrals := []models.Referral{}
+	if err := repo.db.Where("user_id = ?", userID).Find(&referrals).Error; err != nil {
+		return 0, err
+	}
+	return len(referrals), nil
+}

+ 1 - 0
internal/repository/referral.go

@@ -7,6 +7,7 @@ import (
 // ReferralRepository represents the set of queries on the Referral model
 type ReferralRepository interface {
 	CreateReferral(referral *models.Referral) (*models.Referral, error)
+	GetReferralCountByUserID(userID uint) (int, error)
 	// ReadReferral(referralID uint) (*models.Referral, error)
 	// ReadReferralByUserID(userID, referralID string) (*models.Referral, error)
 	// ListReferralsByUserID(userID uint) ([]*models.Referral, error)