user_billing.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. //go:build ee
  2. // +build ee
  3. package gorm
  4. import (
  5. "github.com/porter-dev/porter/ee/models"
  6. "github.com/porter-dev/porter/ee/repository"
  7. "github.com/porter-dev/porter/internal/encryption"
  8. "gorm.io/gorm"
  9. )
  10. // UserBillingRepository uses gorm.DB for querying the database
  11. type UserBillingRepository struct {
  12. db *gorm.DB
  13. key *[32]byte
  14. }
  15. func NewUserBillingRepository(db *gorm.DB, key *[32]byte) repository.UserBillingRepository {
  16. return &UserBillingRepository{db, key}
  17. }
  18. // CreateUserBilling adds a new User row to the Users table in the database
  19. func (repo *UserBillingRepository) CreateUserBilling(userBilling *models.UserBilling) (*models.UserBilling, error) {
  20. err := repo.EncryptUserBillingData(userBilling, repo.key)
  21. if err != nil {
  22. return nil, err
  23. }
  24. if err := repo.db.Create(userBilling).Error; err != nil {
  25. return nil, err
  26. }
  27. err = repo.DecryptUserBillingData(userBilling, repo.key)
  28. if err != nil {
  29. return nil, err
  30. }
  31. return userBilling, nil
  32. }
  33. func (repo *UserBillingRepository) ReadUserBilling(projectID, userID uint) (*models.UserBilling, error) {
  34. userBilling := &models.UserBilling{}
  35. if err := repo.db.Where("project_id = ? AND user_id = ?", projectID, userID).First(&userBilling).Error; err != nil {
  36. return nil, err
  37. }
  38. err := repo.DecryptUserBillingData(userBilling, repo.key)
  39. if err != nil {
  40. return nil, err
  41. }
  42. return userBilling, nil
  43. }
  44. // UpdateUserBilling updates user billing in the db
  45. func (repo *UserBillingRepository) UpdateUserBilling(userBilling *models.UserBilling) (*models.UserBilling, error) {
  46. err := repo.EncryptUserBillingData(userBilling, repo.key)
  47. if err != nil {
  48. return nil, err
  49. }
  50. if err := repo.db.Save(userBilling).Error; err != nil {
  51. return nil, err
  52. }
  53. err = repo.DecryptUserBillingData(userBilling, repo.key)
  54. if err != nil {
  55. return nil, err
  56. }
  57. return userBilling, nil
  58. }
  59. // EncryptUserBillingData will encrypt the user's billing data before writing
  60. // to the DB
  61. func (repo *UserBillingRepository) EncryptUserBillingData(
  62. userBilling *models.UserBilling,
  63. key *[32]byte,
  64. ) error {
  65. if tok := userBilling.Token; len(tok) > 0 {
  66. cipherData, err := encryption.Encrypt(tok, key)
  67. if err != nil {
  68. return err
  69. }
  70. userBilling.Token = cipherData
  71. }
  72. return nil
  73. }
  74. // DecryptUserBillingData will decrypt the user's billing data before returning it
  75. // from the DB
  76. func (repo *UserBillingRepository) DecryptUserBillingData(
  77. userBilling *models.UserBilling,
  78. key *[32]byte,
  79. ) error {
  80. if tok := userBilling.Token; len(tok) > 0 {
  81. plaintext, err := encryption.Decrypt(tok, key)
  82. if err != nil {
  83. return err
  84. }
  85. userBilling.Token = plaintext
  86. }
  87. return nil
  88. }