referrals.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. package gorm
  2. import (
  3. "errors"
  4. "github.com/porter-dev/porter/internal/models"
  5. "github.com/porter-dev/porter/internal/repository"
  6. "gorm.io/gorm"
  7. )
  8. // ReferralRepository uses gorm.DB for querying the database
  9. type ReferralRepository struct {
  10. db *gorm.DB
  11. }
  12. // NewReferralRepository returns a ReferralRepository which uses
  13. // gorm.DB for querying the database
  14. func NewReferralRepository(db *gorm.DB) repository.ReferralRepository {
  15. return &ReferralRepository{db}
  16. }
  17. // CreateReferral creates a new referral in the database
  18. func (repo *ReferralRepository) CreateReferral(referral *models.Referral) (*models.Referral, error) {
  19. project := &models.Project{}
  20. if err := repo.db.Where("referral_code = ?", referral.Code).First(&project).Error; err != nil {
  21. return nil, err
  22. }
  23. assoc := repo.db.Model(&project).Association("Referrals")
  24. if assoc.Error != nil {
  25. return nil, assoc.Error
  26. }
  27. if err := assoc.Append(referral); err != nil {
  28. return nil, err
  29. }
  30. return referral, nil
  31. }
  32. // CountReferralsByProjectID returns the number of referrals a user has made
  33. func (repo *ReferralRepository) CountReferralsByProjectID(projectID uint, status string) (int64, error) {
  34. var count int64
  35. if err := repo.db.Model(&models.Referral{}).Where("project_id = ? AND status = ?", projectID, status).Count(&count).Error; err != nil {
  36. return 0, err
  37. }
  38. return count, nil
  39. }
  40. // GetReferralByReferredID returns a referral by the referred user's ID
  41. func (repo *ReferralRepository) GetReferralByReferredID(referredID uint) (*models.Referral, error) {
  42. referral := &models.Referral{}
  43. err := repo.db.Where("referred_user_id = ?", referredID).First(&referral).Error
  44. if errors.Is(err, gorm.ErrRecordNotFound) {
  45. return nil, nil
  46. }
  47. if err != nil {
  48. return &models.Referral{}, err
  49. }
  50. return referral, nil
  51. }
  52. // UpdateReferral updates a referral in the database
  53. func (repo *ReferralRepository) UpdateReferral(referral *models.Referral) (*models.Referral, error) {
  54. if err := repo.db.Save(referral).Error; err != nil {
  55. return nil, err
  56. }
  57. return referral, nil
  58. }