Przeglądaj źródła

Per project referral

Mauricio Araujo 2 lat temu
rodzic
commit
75c2941cbd

+ 55 - 0
api/server/handlers/billing/create.go

@@ -1,8 +1,10 @@
 package billing
 
 import (
+	"context"
 	"fmt"
 	"net/http"
+	"time"
 
 	"github.com/porter-dev/porter/api/server/handlers"
 	"github.com/porter-dev/porter/api/server/shared"
@@ -15,6 +17,15 @@ import (
 	"github.com/porter-dev/porter/internal/telemetry"
 )
 
+const (
+	// defaultRewardAmountCents is the default amount in USD cents rewarded to users
+	// who successfully refer a new user
+	defaultRewardAmountCents = 1000
+	// defaultPaidAmountCents is the amount paid by the user to get the credits
+	// grant, if set to 0 it means they are free
+	defaultPaidAmountCents = 0
+)
+
 // CreateBillingHandler is a handler for creating payment methods
 type CreateBillingHandler struct {
 	handlers.PorterHandlerWriter
@@ -41,6 +52,7 @@ func (c *CreateBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
 	defer span.End()
 
 	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
+	user, _ := ctx.Value(types.UserScope).(*models.User)
 
 	clientSecret, err := c.Config().BillingManager.StripeClient.CreatePaymentMethod(ctx, proj.BillingID)
 	if err != nil {
@@ -54,6 +66,15 @@ func (c *CreateBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
 		telemetry.AttributeKV{Key: "customer-id", Value: proj.BillingID},
 	)
 
+	if proj.EnableSandbox {
+		// Grant a reward to the project that referred this user after linking a payment method
+		err = c.grantRewardIfReferral(ctx, user.ID)
+		if err != nil {
+			// Only log the error in case the reward grant fails, but don't return an error to the fe
+			telemetry.Error(ctx, span, err, "error granting credits reward")
+		}
+	}
+
 	c.WriteResult(w, r, clientSecret)
 }
 
@@ -104,3 +125,37 @@ func (c *SetDefaultBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ
 
 	c.WriteResult(w, r, "")
 }
+
+func (c *CreateBillingHandler) grantRewardIfReferral(ctx context.Context, referredUserID uint) (err error) {
+	ctx, span := telemetry.NewSpan(ctx, "grant-referral-reward")
+	defer span.End()
+
+	referral, err := c.Repo().Referral().GetReferralByReferredID(referredUserID)
+	if err != nil {
+		return telemetry.Error(ctx, span, err, "failed to find referral by referred id")
+	}
+
+	referrerProject, err := c.Repo().Project().ReadProject(referral.ProjectID)
+	if err != nil {
+		return telemetry.Error(ctx, span, err, "failed to find referrer project")
+	}
+
+	if referral != nil && referral.Status != models.ReferralStatusCompleted {
+		// Metronome requires an expiration to be passed in, so we set it to 5 years which in
+		// practice will mean the credits will most likely run out before expiring
+		expiresAt := time.Now().AddDate(5, 0, 0).Format(time.RFC3339)
+		reason := "Referral reward"
+		err := c.Config().BillingManager.MetronomeClient.CreateCreditsGrant(ctx, referrerProject.UsageID, reason, defaultRewardAmountCents, defaultPaidAmountCents, expiresAt)
+		if err != nil {
+			return telemetry.Error(ctx, span, err, "failed to grand credits reward")
+		}
+
+		referral.Status = models.ReferralStatusCompleted
+		_, err = c.Repo().Referral().UpdateReferral(referral)
+		if err != nil {
+			return telemetry.Error(ctx, span, err, "error while updating referral")
+		}
+	}
+
+	return nil
+}

+ 0 - 80
api/server/handlers/billing/credits.go

@@ -2,7 +2,6 @@ package billing
 
 import (
 	"net/http"
-	"time"
 
 	"github.com/porter-dev/porter/api/server/handlers"
 	"github.com/porter-dev/porter/api/server/shared"
@@ -13,18 +12,6 @@ import (
 	"github.com/porter-dev/porter/internal/telemetry"
 )
 
-const (
-	// referralRewardRequirement is the number of referred users required to
-	// be granted a credits reward
-	referralRewardRequirement = 5
-	// defaultRewardAmountUSD is the default amount in USD rewarded to users
-	// who reach the reward requirement
-	defaultRewardAmountCents = 2000
-	// defaultPaidAmountUSD is the amount paid by the user to get the credits
-	// grant, if set to 0 it means they were free
-	defaultPaidAmountCents = 0
-)
-
 // ListCreditsHandler is a handler for getting available credits
 type ListCreditsHandler struct {
 	handlers.PorterHandlerWriter
@@ -70,70 +57,3 @@ func (c *ListCreditsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 
 	c.WriteResult(w, r, credits)
 }
-
-// ClaimReferralRewardHandler is a handler for granting credits
-type ClaimReferralRewardHandler struct {
-	handlers.PorterHandlerWriter
-}
-
-// NewClaimReferralReward will create a new GrantCreditsHandler
-func NewClaimReferralReward(
-	config *config.Config,
-	writer shared.ResultWriter,
-) *ClaimReferralRewardHandler {
-	return &ClaimReferralRewardHandler{
-		PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
-	}
-}
-
-func (c *ClaimReferralRewardHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ctx, span := telemetry.NewSpan(r.Context(), "serve-claim-credits-reward")
-	defer span.End()
-
-	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
-	user, _ := ctx.Value(types.UserScope).(*models.User)
-
-	if !c.Config().BillingManager.MetronomeConfigLoaded || !proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient) {
-		c.WriteResult(w, r, "")
-
-		telemetry.WithAttributes(span,
-			telemetry.AttributeKV{Key: "metronome-config-exists", Value: c.Config().BillingManager.MetronomeConfigLoaded},
-			telemetry.AttributeKV{Key: "metronome-enabled", Value: proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient)},
-		)
-		return
-	}
-
-	telemetry.WithAttributes(span,
-		telemetry.AttributeKV{Key: "metronome-enabled", Value: true},
-		telemetry.AttributeKV{Key: "usage-id", Value: proj.UsageID},
-		telemetry.AttributeKV{Key: "referral-code", Value: user.ReferralCode},
-		telemetry.AttributeKV{Key: "referral-reward-received", Value: user.ReferralRewardClaimed},
-	)
-
-	// Check if the user is eligible for the referral reward
-	referralCount, err := c.Repo().Referral().GetReferralCountByUserID(user.ID)
-	if err != nil {
-		c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
-		return
-	}
-
-	if !user.ReferralRewardClaimed && referralCount >= referralRewardRequirement {
-		// Metronome requires an expiration to be passed in, so we set it to 5 years which in
-		// practice will mean the credits will run out before expiring
-		expiresAt := time.Now().AddDate(5, 0, 0).Format(time.RFC3339)
-		err := c.Config().BillingManager.MetronomeClient.CreateCreditsGrant(ctx, proj.UsageID, defaultRewardAmountCents, defaultPaidAmountCents, expiresAt)
-		if err != nil {
-			c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
-			return
-		}
-
-		user.ReferralRewardClaimed = true
-		_, err = c.Repo().User().UpdateUser(user)
-		if err != nil {
-			c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
-			return
-		}
-	}
-
-	c.WriteResult(w, r, "")
-}

+ 3 - 0
api/server/handlers/project/create.go

@@ -67,6 +67,9 @@ func (p *ProjectCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
 
 	if p.Config().ServerConf.EnableSandbox {
 		step = types.StepCleanUp
+
+		// Generate referral code for porter cloud projects
+		proj.ReferralCode = models.NewReferralCode()
 	}
 
 	// create onboarding flow set to the first step. Read in env var

+ 79 - 0
api/server/handlers/project/referrals.go

@@ -0,0 +1,79 @@
+package project
+
+import (
+	"net/http"
+
+	"github.com/google/uuid"
+	"github.com/porter-dev/porter/api/server/handlers"
+	"github.com/porter-dev/porter/api/server/shared"
+	"github.com/porter-dev/porter/api/server/shared/apierrors"
+	"github.com/porter-dev/porter/api/server/shared/config"
+	"github.com/porter-dev/porter/api/types"
+	"github.com/porter-dev/porter/internal/models"
+	"github.com/porter-dev/porter/internal/telemetry"
+)
+
+// GetProjectReferralDetailsHandler is a handler for getting a project's referral code
+type GetProjectReferralDetailsHandler struct {
+	handlers.PorterHandlerWriter
+}
+
+// NewGetProjectReferralDetailsHandler returns an instance of GetProjectReferralDetailsHandler
+func NewGetProjectReferralDetailsHandler(
+	config *config.Config,
+	writer shared.ResultWriter,
+) *GetProjectReferralDetailsHandler {
+	return &GetProjectReferralDetailsHandler{
+		PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
+	}
+}
+
+func (c *GetProjectReferralDetailsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	ctx, span := telemetry.NewSpan(r.Context(), "serve-get-project-referral-details")
+	defer span.End()
+
+	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
+
+	if !c.Config().BillingManager.MetronomeConfigLoaded || !proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient) ||
+		proj.UsageID == uuid.Nil || proj.EnableSandbox {
+		c.WriteResult(w, r, "")
+
+		telemetry.WithAttributes(span,
+			telemetry.AttributeKV{Key: "metronome-config-exists", Value: c.Config().BillingManager.MetronomeConfigLoaded},
+			telemetry.AttributeKV{Key: "metronome-enabled", Value: proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient)},
+		)
+		return
+	}
+
+	if proj.ReferralCode == "" {
+		telemetry.WithAttributes(span,
+			telemetry.AttributeKV{Key: "referral-code-exists", Value: false},
+		)
+
+		// Generate referral code for project if not present
+		proj.ReferralCode = models.NewReferralCode()
+		_, err := c.Repo().Project().UpdateProject(proj)
+		if err != nil {
+			err := telemetry.Error(ctx, span, err, "error updating project")
+			c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+			return
+		}
+	}
+
+	referralCount, err := c.Repo().Referral().CountReferralsByProjectID(proj.ID, models.ReferralStatusCompleted)
+	if err != nil {
+		err := telemetry.Error(ctx, span, err, "error listing referrals by project id")
+		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+		return
+	}
+
+	referralCodeResponse := struct {
+		Code          string `json:"code"`
+		ReferralCount int64  `json:"referral_count"`
+	}{
+		Code:          proj.ReferralCode,
+		ReferralCount: referralCount,
+	}
+
+	c.WriteResult(w, r, referralCodeResponse)
+}

+ 1 - 3
api/server/handlers/user/create.go

@@ -70,9 +70,6 @@ func (u *UserCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 
 	user.Password = string(hashedPw)
 
-	// Generate referral code for user
-	user.ReferralCode = models.NewReferralCode()
-
 	// write the user to the db
 	user, err = u.Repo().User().CreateUser(user)
 	if err != nil {
@@ -106,6 +103,7 @@ func (u *UserCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		referral := &models.Referral{
 			Code:           request.ReferredBy,
 			ReferredUserID: user.ID,
+			Status:         models.ReferralStatusSignedUp,
 		}
 
 		_, err = u.Repo().Referral().CreateReferral(referral)

+ 2 - 56
api/server/handlers/user/create_test.go

@@ -1,7 +1,6 @@
 package user_test
 
 import (
-	"encoding/json"
 	"net/http"
 	"testing"
 
@@ -10,7 +9,6 @@ import (
 	"github.com/porter-dev/porter/api/server/shared/apitest"
 	"github.com/porter-dev/porter/api/types"
 	"github.com/porter-dev/porter/internal/repository/test"
-	"github.com/stretchr/testify/assert"
 )
 
 func TestCreateUserSuccessful(t *testing.T) {
@@ -40,14 +38,7 @@ func TestCreateUserSuccessful(t *testing.T) {
 	// Use a struct that is the same as types.User but without the
 	// referral fields. This is because the referral code is randomly
 	// generated and is tested separately.
-	expUser := &struct {
-		ID            uint   `json:"id"`
-		Email         string `json:"email"`
-		EmailVerified bool   `json:"email_verified"`
-		FirstName     string `json:"first_name"`
-		LastName      string `json:"last_name"`
-		CompanyName   string `json:"company_name"`
-	}{
+	expUser :=  &types.CreateUserResponse{
 		ID:            1,
 		FirstName:     "Mister",
 		LastName:      "Porter",
@@ -56,14 +47,7 @@ func TestCreateUserSuccessful(t *testing.T) {
 		EmailVerified: false,
 	}
 
-	gotUser := &struct {
-		ID            uint   `json:"id"`
-		Email         string `json:"email"`
-		EmailVerified bool   `json:"email_verified"`
-		FirstName     string `json:"first_name"`
-		LastName      string `json:"last_name"`
-		CompanyName   string `json:"company_name"`
-	}{}
+	gotUser :=  &types.CreateUserResponse{}
 
 	apitest.AssertResponseExpected(t, rr, expUser, gotUser)
 }
@@ -210,41 +194,3 @@ func TestFailingCreateSessionMethod(t *testing.T) {
 
 	apitest.AssertResponseInternalServerError(t, rr)
 }
-
-func TestCreateUserReferralCode(t *testing.T) {
-	req, rr := apitest.GetRequestAndRecorder(
-		t,
-		string(types.HTTPVerbPost),
-		"/api/users",
-		&types.CreateUserRequest{
-			FirstName:   "Mister",
-			LastName:    "Porter",
-			CompanyName: "Porter Technologies, Inc.",
-			Email:       "mrp@porter.run",
-			Password:    "somepassword",
-		},
-	)
-
-	config := apitest.LoadConfig(t)
-
-	handler := user.NewUserCreateHandler(
-		config,
-		shared.NewDefaultRequestDecoderValidator(config.Logger, config.Alerter),
-		shared.NewDefaultResultWriter(config.Logger, config.Alerter),
-	)
-
-	handler.ServeHTTP(rr, req)
-	gotUser := &types.CreateUserResponse{}
-
-	// apitest.AssertResponseExpected(t, rr, expUser, gotUser)
-	err := json.NewDecoder(rr.Body).Decode(gotUser)
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	// This is the default lenth of a shortuuid
-	desiredLenth := 22
-	assert.NotEmpty(t, gotUser.ReferralCode, "referral code should not be empty")
-	assert.Len(t, gotUser.ReferralCode, desiredLenth, "referral code should be 22 characters long")
-	assert.Equal(t, gotUser.ReferralRewardClaimed, false, "referral reward claimed should be false for new user")
-}

+ 0 - 3
api/server/handlers/user/github_callback.go

@@ -146,9 +146,6 @@ func upsertUserFromToken(config *config.Config, tok *oauth2.Token) (*models.User
 				GithubUserID:  githubUser.GetID(),
 			}
 
-			// Generate referral code for user
-			user.ReferralCode = models.NewReferralCode()
-
 			user, err = config.Repo.User().CreateUser(user)
 			if err != nil {
 				return nil, err

+ 0 - 3
api/server/handlers/user/google_callback.go

@@ -132,9 +132,6 @@ func upsertGoogleUserFromToken(config *config.Config, tok *oauth2.Token) (*model
 				GoogleUserID:  gInfo.Sub,
 			}
 
-			// Generate referral code for user
-			user.ReferralCode = models.NewReferralCode()
-
 			user, err = config.Repo.User().CreateUser(user)
 			if err != nil {
 				return nil, err

+ 0 - 81
api/server/handlers/user/referrals.go

@@ -1,81 +0,0 @@
-package user
-
-import (
-	"net/http"
-
-	"github.com/porter-dev/porter/api/server/handlers"
-	"github.com/porter-dev/porter/api/server/shared"
-	"github.com/porter-dev/porter/api/server/shared/apierrors"
-	"github.com/porter-dev/porter/api/server/shared/config"
-	"github.com/porter-dev/porter/api/types"
-	"github.com/porter-dev/porter/internal/models"
-	"github.com/porter-dev/porter/internal/telemetry"
-)
-
-// ListUserReferralsHandler is a handler for getting a list of user referrals
-type ListUserReferralsHandler struct {
-	handlers.PorterHandlerWriter
-}
-
-// NewListUserReferralsHandler returns an instance of ListUserReferralsHandler
-func NewListUserReferralsHandler(
-	config *config.Config,
-	writer shared.ResultWriter,
-) *ListUserReferralsHandler {
-	return &ListUserReferralsHandler{
-		PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
-	}
-}
-
-func (u *ListUserReferralsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ctx, span := telemetry.NewSpan(r.Context(), "serve-list-user-referrals")
-	defer span.End()
-
-	user, _ := ctx.Value(types.UserScope).(*models.User)
-
-	referralCount, err := u.Repo().Referral().GetReferralCountByUserID(user.ID)
-	if err != nil {
-		u.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
-		return
-	}
-
-	referralResponse := struct {
-		ReferralCount int `json:"count"`
-	}{
-		ReferralCount: referralCount,
-	}
-
-	u.WriteResult(w, r, referralResponse)
-}
-
-// GetUserReferralDetailsHandler is a handler for getting a user's referral code
-type GetUserReferralDetailsHandler struct {
-	handlers.PorterHandlerWriter
-}
-
-// NewGetUserReferralDetailsHandler returns an instance of GetUserReferralCodeHandler
-func NewGetUserReferralDetailsHandler(
-	config *config.Config,
-	writer shared.ResultWriter,
-) *GetUserReferralDetailsHandler {
-	return &GetUserReferralDetailsHandler{
-		PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
-	}
-}
-
-func (u *GetUserReferralDetailsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	ctx, span := telemetry.NewSpan(r.Context(), "serve-get-user-referral-details")
-	defer span.End()
-
-	user, _ := ctx.Value(types.UserScope).(*models.User)
-
-	referralCodeResponse := struct {
-		Code          string `json:"code"`
-		RewardClaimed bool   `json:"reward_claimed"`
-	}{
-		Code:          user.ReferralCode,
-		RewardClaimed: user.ReferralRewardClaimed,
-	}
-
-	u.WriteResult(w, r, referralCodeResponse)
-}

+ 8 - 8
api/server/router/project.go

@@ -397,14 +397,14 @@ func getProjectRoutes(
 		Router:   r,
 	})
 
-	// GET /api/projects/{project_id}/billing/credits/claim_referral -> project.NewGetCreditsHandler
-	claimReferralRewardEndpoint := factory.NewAPIEndpoint(
+	// GET /api/projects/{project_id}/referrals/details -> user.NewGetUserReferralDetailsHandler
+	getReferralDetailsEndpoint := factory.NewAPIEndpoint(
 		&types.APIRequestMetadata{
-			Verb:   types.APIVerbCreate,
-			Method: types.HTTPVerbPost,
+			Verb:   types.APIVerbGet,
+			Method: types.HTTPVerbGet,
 			Path: &types.Path{
 				Parent:       basePath,
-				RelativePath: relPath + "/billing/credits/claim_referral",
+				RelativePath: relPath + "/referrals/details",
 			},
 			Scopes: []types.PermissionScope{
 				types.UserScope,
@@ -413,14 +413,14 @@ func getProjectRoutes(
 		},
 	)
 
-	claimReferralRewardHandler := billing.NewClaimReferralReward(
+	getReferralDetailsHandler := project.NewGetProjectReferralDetailsHandler(
 		config,
 		factory.GetResultWriter(),
 	)
 
 	routes = append(routes, &router.Route{
-		Endpoint: claimReferralRewardEndpoint,
-		Handler:  claimReferralRewardHandler,
+		Endpoint: getReferralDetailsEndpoint,
+		Handler:  getReferralDetailsHandler,
 		Router:   r,
 	})
 

+ 0 - 48
api/server/router/user.go

@@ -472,53 +472,5 @@ func getUserRoutes(
 		Router:   r,
 	})
 
-	// GET /api/referrals -> user.NewListUserReferralsHandler
-	listReferralsEndpoint := factory.NewAPIEndpoint(
-		&types.APIRequestMetadata{
-			Verb:   types.APIVerbGet,
-			Method: types.HTTPVerbGet,
-			Path: &types.Path{
-				Parent:       basePath,
-				RelativePath: "/referrals",
-			},
-			Scopes: []types.PermissionScope{types.UserScope},
-		},
-	)
-
-	listReferralsHandler := user.NewListUserReferralsHandler(
-		config,
-		factory.GetResultWriter(),
-	)
-
-	routes = append(routes, &router.Route{
-		Endpoint: listReferralsEndpoint,
-		Handler:  listReferralsHandler,
-		Router:   r,
-	})
-
-	// GET /api/referrals/details -> user.NewGetUserReferralDetailsHandler
-	getReferralDetailsEndpoint := factory.NewAPIEndpoint(
-		&types.APIRequestMetadata{
-			Verb:   types.APIVerbGet,
-			Method: types.HTTPVerbGet,
-			Path: &types.Path{
-				Parent:       basePath,
-				RelativePath: "/referrals/details",
-			},
-			Scopes: []types.PermissionScope{types.UserScope},
-		},
-	)
-
-	getReferralDetailsHandler := user.NewGetUserReferralDetailsHandler(
-		config,
-		factory.GetResultWriter(),
-	)
-
-	routes = append(routes, &router.Route{
-		Endpoint: getReferralDetailsEndpoint,
-		Handler:  getReferralDetailsHandler,
-		Router:   r,
-	})
-
 	return routes
 }

+ 1 - 0
api/types/billing_metronome.go

@@ -65,6 +65,7 @@ type CreateCreditsGrantRequest struct {
 	Name          string        `json:"name"`
 	ExpiresAt     string        `json:"expires_at"`
 	Priority      int           `json:"priority"`
+	Reason        string        `json:"reason"`
 }
 
 // ListCreditGrantsRequest is the request to list a user's credit grants. Note that only one of

+ 2 - 0
api/types/project.go

@@ -60,6 +60,8 @@ type Project struct {
 	AdvancedInfraEnabled            bool    `json:"advanced_infra_enabled"`
 	SandboxEnabled                  bool    `json:"sandbox_enabled"`
 	AdvancedRbacEnabled             bool    `json:"advanced_rbac_enabled"`
+	// ReferralCode is a unique code that can be shared to referr other users to Porter
+	ReferralCode string `json:"referral_code"`
 }
 
 // FeatureFlags is a struct that contains old feature flag representations

+ 1 - 7
api/types/user.go

@@ -7,12 +7,6 @@ type User struct {
 	FirstName     string `json:"first_name"`
 	LastName      string `json:"last_name"`
 	CompanyName   string `json:"company_name"`
-
-	// ReferralCode is a unique code that can be shared to referr other users to Porter
-	ReferralCode string `json:"referral_code"`
-	// ReferralRewardClaimed indicates if the user has already received a credits reward
-	// for referring users
-	ReferralRewardClaimed bool `json:"referral_reward_received"`
 }
 
 type CreateUserRequest struct {
@@ -22,7 +16,7 @@ type CreateUserRequest struct {
 	LastName       string `json:"last_name" form:"required,max=255"`
 	CompanyName    string `json:"company_name" form:"required,max=255"`
 	ReferralMethod string `json:"referral_method" form:"max=255"`
-	// ReferredBy is the referral code of the user who referred this user
+	// ReferredBy is the referral code of the project from which this user was referred
 	ReferredBy string `json:"referred_by_code" form:"max=255"`
 }
 

+ 1 - 5
dashboard/src/lib/billing/types.tsx

@@ -54,9 +54,5 @@ export const ClientSecretResponse = z.string();
 export type ReferralDetails = z.infer<typeof ReferralDetailsValidator>;
 export const ReferralDetailsValidator = z.object({
   code: z.string(),
-  reward_claimed: z.boolean(),
-}).nullable();
-
-export const ReferralsValidator = z.object({
-  count: z.number(),
+  referral_count: z.number(),
 }).nullable();

+ 2 - 67
dashboard/src/lib/hooks/useStripe.tsx

@@ -1,5 +1,5 @@
 import { useContext, useState } from "react";
-import { useQuery, useMutation, type UseQueryResult } from "@tanstack/react-query";
+import { useQuery, type UseQueryResult } from "@tanstack/react-query";
 import { z } from "zod";
 
 import {
@@ -13,7 +13,6 @@ import {
   type PaymentMethod,
   type PaymentMethodList,
   type UsageList,
-  ReferralsValidator,
   ReferralDetailsValidator,
   ReferralDetails
 } from "lib/billing/types";
@@ -67,10 +66,6 @@ type TGetReferralDetails = {
   referralDetails: ReferralDetails
 };
 
-type TGetReferrals = {
-  referralsCount: number | null;
-};
-
 export const usePaymentMethods = (): TUsePaymentMethod => {
   const { currentProject } = useContext(Context);
 
@@ -379,36 +374,6 @@ export const useCustomerUsage = (
   };
 };
 
-export const useReferrals = (): TGetReferrals => {
-  const { currentProject } = useContext(Context);
-
-  // Fetch referrals count
-  const referralsReq = useQuery(
-    ["getReferrals", currentProject?.id],
-    async (): Promise<number | null> => {
-      if (!currentProject?.metronome_enabled) {
-        return null;
-      }
-
-      try {
-        const res = await api.getReferrals(
-          "<token>",
-          {},
-          {}
-        );
-
-        const referrals = ReferralsValidator.parse(res.data);
-        return referrals?.count ?? null;
-      } catch (error) {
-        return null
-      }
-    });
-
-  return {
-    referralsCount: referralsReq.data ?? null,
-  };
-};
-
 export const useReferralDetails = (): TGetReferralDetails => {
   const { currentProject } = useContext(Context);
 
@@ -428,7 +393,7 @@ export const useReferralDetails = (): TGetReferralDetails => {
         const res = await api.getReferralDetails(
           "<token>",
           {},
-          {}
+          { project_id: currentProject?.id }
         );
 
         const referraldetails = ReferralDetailsValidator.parse(res.data);
@@ -442,33 +407,3 @@ export const useReferralDetails = (): TGetReferralDetails => {
     referralDetails: referralsReq.data ?? null,
   };
 };
-
-export const useClaimReferralReward = (): (() => void) => {
-  const { currentProject } = useContext(Context);
-
-  // Apply credits reward to this project
-  const referralsReq = useMutation(
-    ["claimReferralReward", currentProject?.id],
-    async (): Promise<void> => {
-      if (!currentProject?.metronome_enabled) {
-        return;
-      }
-
-      if (!currentProject?.id || currentProject.id === -1) {
-        return;
-      }
-
-      try {
-        await api.claimReferralReward(
-          "<token>",
-          {},
-          { project_id: currentProject?.id }
-        );
-      } catch (error) {
-        return;
-      }
-    });
-
-  // Return a function that can be called to execute the mutation
-  return () => referralsReq.mutate();
-};

+ 3 - 55
dashboard/src/main/home/project-settings/ReferralsPage.tsx

@@ -1,65 +1,13 @@
 import React from "react";
 import Spacer from "components/porter/Spacer";
 import Text from "components/porter/Text";
-import { useClaimReferralReward, useReferralDetails, useReferrals } from "lib/hooks/useStripe";
-import Button from "components/porter/Button";
+import { useReferralDetails } from "lib/hooks/useStripe";
 import Link from "components/porter/Link";
 
 function ReferralsPage(): JSX.Element {
-    const referralRewardRequirement = 5;
     const { referralDetails } = useReferralDetails();
-    const { referralsCount } = useReferrals();
-    const claimReferralReward = useClaimReferralReward();
     const baseUrl = window.location.origin;
 
-    const eligibleForReward = (): boolean => {
-        if (referralsCount === null) {
-            return false;
-        }
-
-        return referralsCount >= referralRewardRequirement;
-    }
-
-    const claimReward = (): void => {
-        claimReferralReward();
-    }
-
-    const displayReferral = (): JSX.Element => {
-        if (referralDetails === null || referralsCount === null) {
-            return <></>
-        }
-
-        if (!eligibleForReward()) {
-            return (
-                <>
-                    <Text>
-                        Refer {referralRewardRequirement - referralsCount} more people to earn a reward.
-                    </Text>
-                    <Spacer y={1} />
-                </>
-            )
-        }
-
-        if (referralDetails?.reward_claimed) {
-            return (
-                <>
-                    <Text>
-                        You have already claimed a reward for referring people to Porter.
-                    </Text>
-                    <Spacer y={1} />
-                </>
-            )
-        }
-
-        return (
-            <>
-                <Text>You are elegible for claiming a reward on this project.</Text>
-                <Spacer y={0.5} />
-                <Button onClick={claimReward}>Claim Reward</Button>
-            </>
-        )
-    }
-
     return (
         <>
             <Text size={16}>Referrals</Text>
@@ -74,12 +22,12 @@ function ReferralsPage(): JSX.Element {
                         Your referral link is {" "}
                     </Text>
                     <Link to={baseUrl + "/register?referral=" + referralDetails.code}>{baseUrl + "/register?referral=" + referralDetails.code}</Link>
+                    <Spacer y={1} />
+                    <Text>You have referred {referralDetails.referral_count} users</Text>
                 </>
 
             )}
             <Spacer y={1} />
-            {displayReferral()}
-            <Spacer y={1} />
         </>
     )
 }

+ 4 - 22
dashboard/src/shared/api.tsx

@@ -3577,33 +3577,17 @@ const deletePaymentMethod = baseApi<
     `/api/projects/${project_id}/billing/payment_method/${payment_method_id}`
 );
 
-const getReferrals = baseApi<
-  {},
-  {}
->(
-  "GET",
-  () =>
-    `/api/referrals`
-);
-
 const getReferralDetails = baseApi<
-  {},
-  {}
->(
-  "GET",
-  () =>
-    `/api/referrals/details`
-);
-
-const claimReferralReward = baseApi<
   {},
   {
     project_id?: number;
   }
 >(
-  "POST",
-  ({ project_id }) => `/api/projects/${project_id}/billing/credits/claim_referral`
+  "GET",
+  ({ project_id }) =>
+    `/api/projects/${project_id}/referrals/details`
 );
+
 const getGithubStatus = baseApi<{}, {}>("GET", ({ }) => `/api/status/github`);
 
 const createSecretAndOpenGitHubPullRequest = baseApi<
@@ -3997,9 +3981,7 @@ export default {
   addPaymentMethod,
   setDefaultPaymentMethod,
   deletePaymentMethod,
-  getReferrals,
   getReferralDetails,
-  claimReferralReward,
 
   // STATUS
   getGithubStatus,

+ 12 - 1
internal/billing/metronome.go

@@ -20,6 +20,16 @@ const (
 	defaultCollectionMethod = "charge_automatically"
 	defaultMaxRetries       = 10
 	porterStandardTrialDays = 15
+
+	// referralRewardRequirement is the number of referred users required to
+	// be granted a credits reward
+	referralRewardRequirement = 5
+	// defaultRewardAmountCents is the default amount in USD cents rewarded to users
+	// who reach the reward requirement
+	defaultRewardAmountCents = 1000
+	// defaultPaidAmountCents is the amount paid by the user to get the credits
+	// grant, if set to 0 it means they were free
+	defaultPaidAmountCents = 0
 )
 
 // MetronomeClient is the client used to call the Metronome API
@@ -243,7 +253,7 @@ func (m MetronomeClient) ListCustomerCredits(ctx context.Context, customerID uui
 }
 
 // CreateCreditsGrant will create a new credit grant for the customer with the specified amount
-func (m MetronomeClient) CreateCreditsGrant(ctx context.Context, customerID uuid.UUID, grantAmount float64, paidAmount float64, expiresAt string) (err error) {
+func (m MetronomeClient) CreateCreditsGrant(ctx context.Context, customerID uuid.UUID, reason string, grantAmount float64, paidAmount float64, expiresAt string) (err error) {
 	ctx, span := telemetry.NewSpan(ctx, "create-credits-grant")
 	defer span.End()
 
@@ -272,6 +282,7 @@ func (m MetronomeClient) CreateCreditsGrant(ctx context.Context, customerID uuid
 			CreditTypeID: creditTypeID,
 		},
 		Name:      "Porter Credits",
+		Reason:    reason,
 		ExpiresAt: expiresAt,
 		Priority:  1,
 	}

+ 7 - 0
internal/models/project.go

@@ -226,6 +226,12 @@ type Project struct {
 	EnableReprovision    bool `gorm:"default:false"`
 	AdvancedInfraEnabled bool `gorm:"default:false"`
 	AdvancedRbacEnabled  bool `gorm:"default:false"`
+
+	// ReferralCode is a unique code that can be shared to referr other users to Porter
+	ReferralCode string
+
+	// Referrals is a list of users that have been referred by this project's code
+	Referrals []Referral `json:"referrals"`
 }
 
 // GetFeatureFlag calls launchdarkly for the specified flag
@@ -332,6 +338,7 @@ func (p *Project) ToProjectType(launchDarklyClient *features.Client) types.Proje
 		AdvancedInfraEnabled:            p.GetFeatureFlag(AdvancedInfraEnabled, launchDarklyClient),
 		SandboxEnabled:                  p.EnableSandbox,
 		AdvancedRbacEnabled:             p.GetFeatureFlag(AdvancedRbacEnabled, launchDarklyClient),
+		ReferralCode:                    p.ReferralCode,
 	}
 }
 

+ 9 - 2
internal/models/referral.go

@@ -6,14 +6,21 @@ import (
 	"gorm.io/gorm"
 )
 
+const (
+	// ReferralStatusSignedUp is the status of a referral where the referred user has signed up
+	ReferralStatusSignedUp = "signed_up"
+	// ReferralStatusCompleted is the status of a referral where the referred user has linked a credit card
+	ReferralStatusCompleted = "completed"
+)
+
 // Referral type that extends gorm.Model
 type Referral struct {
 	gorm.Model
 
 	// Code is the referral code that is shared with the referred user
 	Code string
-	// UserID is the ID of the user who made the referral
-	UserID uint
+	// ProjectID is the ID of the project that was used to refer a new user
+	ProjectID uint
 	// ReferredUserID is the ID of the user who was referred
 	ReferredUserID uint
 	// Status is the status of the referral (pending, signed_up, etc.)

+ 6 - 16
internal/models/user.go

@@ -23,26 +23,16 @@ type User struct {
 	// The github user id used for login (optional)
 	GithubUserID int64
 	GoogleUserID string
-
-	// ReferralCode is a unique code that can be shared to referr other users to Porter
-	ReferralCode string
-	// ReferralRewardClaimed indicates if the user has already received a credits reward
-	// for referring users
-	ReferralRewardClaimed bool
-
-	Referrals []Referral `json:"referrals"`
 }
 
 // ToUserType generates an external types.User to be shared over REST
 func (u *User) ToUserType() *types.User {
 	return &types.User{
-		ID:                    u.ID,
-		Email:                 u.Email,
-		EmailVerified:         u.EmailVerified,
-		FirstName:             u.FirstName,
-		LastName:              u.LastName,
-		CompanyName:           u.CompanyName,
-		ReferralCode:          u.ReferralCode,
-		ReferralRewardClaimed: u.ReferralRewardClaimed,
+		ID:            u.ID,
+		Email:         u.Email,
+		EmailVerified: u.EmailVerified,
+		FirstName:     u.FirstName,
+		LastName:      u.LastName,
+		CompanyName:   u.CompanyName,
 	}
 }

+ 27 - 8
internal/repository/gorm/referrals.go

@@ -19,13 +19,13 @@ func NewReferralRepository(db *gorm.DB) repository.ReferralRepository {
 
 // CreateReferral creates a new referral in the database
 func (repo *ReferralRepository) CreateReferral(referral *models.Referral) (*models.Referral, error) {
-	user := &models.User{}
+	project := &models.Project{}
 
-	if err := repo.db.Where("referral_code = ?", referral.Code).First(&user).Error; err != nil {
+	if err := repo.db.Where("referral_code = ?", referral.Code).First(&project).Error; err != nil {
 		return nil, err
 	}
 
-	assoc := repo.db.Model(&user).Association("Referrals")
+	assoc := repo.db.Model(&project).Association("Referrals")
 
 	if assoc.Error != nil {
 		return nil, assoc.Error
@@ -38,11 +38,30 @@ func (repo *ReferralRepository) CreateReferral(referral *models.Referral) (*mode
 	return referral, nil
 }
 
-// GetReferralByCode returns the number of referrals a user has made
-func (repo *ReferralRepository) GetReferralCountByUserID(userID uint) (int, error) {
-	referrals := []models.Referral{}
-	if err := repo.db.Where("user_id = ?", userID).Find(&referrals).Error; err != nil {
+// CountReferralsByProjectID returns the number of referrals a user has made
+func (repo *ReferralRepository) CountReferralsByProjectID(projectID uint, status string) (int64, error) {
+	var count int64
+
+	if err := repo.db.Model(&models.Referral{}).Where("project_id = ? AND status = ?", projectID, status).Count(&count).Error; err != nil {
 		return 0, err
 	}
-	return len(referrals), nil
+
+	return count, nil
+}
+
+// GetReferralByCode returns the number of referrals a user has made
+func (repo *ReferralRepository) GetReferralByReferredID(referredID uint) (*models.Referral, error) {
+	referral := &models.Referral{}
+	if err := repo.db.Where("referred_user_id = ?", referredID).First(&referral).Error; err != nil {
+		return &models.Referral{}, err
+	}
+	return referral, nil
+}
+
+func (repo *ReferralRepository) UpdateReferral(referral *models.Referral) (*models.Referral, error) {
+	if err := repo.db.Save(referral).Error; err != nil {
+		return nil, err
+	}
+
+	return referral, nil
 }

+ 3 - 1
internal/repository/referral.go

@@ -7,5 +7,7 @@ import (
 // ReferralRepository represents the set of queries on the Referral model
 type ReferralRepository interface {
 	CreateReferral(referral *models.Referral) (*models.Referral, error)
-	GetReferralCountByUserID(userID uint) (int, error)
+	GetReferralByReferredID(referredID uint) (*models.Referral, error)
+	CountReferralsByProjectID(projectID uint, status string) (int64, error)
+	UpdateReferral(referral *models.Referral) (*models.Referral, error)
 }

+ 9 - 1
internal/repository/test/referrral.go

@@ -19,6 +19,14 @@ func (repo *ReferralRepository) CreateReferral(referral *models.Referral) (*mode
 	return referral, errors.New("cannot read database")
 }
 
-func (repo *ReferralRepository) GetReferralCountByUserID(userID uint) (int, error) {
+func (repo *ReferralRepository) CountReferralsByProjectID(projectID uint, status string) (int64, error) {
 	return 0, errors.New("cannot read database")
 }
+
+func (repo *ReferralRepository) GetReferralByReferredID(referredID uint) (*models.Referral, error) {
+	return &models.Referral{}, errors.New("cannot read database")
+}
+
+func (repo *ReferralRepository) UpdateReferral(referral *models.Referral) (*models.Referral, error) {
+	return referral, errors.New("cannot read database")
+}