| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- //go:build ee
- // +build ee
- package gorm
- import (
- "github.com/porter-dev/porter/ee/models"
- "github.com/porter-dev/porter/ee/repository"
- "github.com/porter-dev/porter/internal/encryption"
- "gorm.io/gorm"
- )
- // UserBillingRepository uses gorm.DB for querying the database
- type UserBillingRepository struct {
- db *gorm.DB
- key *[32]byte
- }
- func NewUserBillingRepository(db *gorm.DB, key *[32]byte) repository.UserBillingRepository {
- return &UserBillingRepository{db, key}
- }
- // CreateUserBilling adds a new User row to the Users table in the database
- func (repo *UserBillingRepository) CreateUserBilling(userBilling *models.UserBilling) (*models.UserBilling, error) {
- err := repo.EncryptUserBillingData(userBilling, repo.key)
- if err != nil {
- return nil, err
- }
- if err := repo.db.Create(userBilling).Error; err != nil {
- return nil, err
- }
- err = repo.DecryptUserBillingData(userBilling, repo.key)
- if err != nil {
- return nil, err
- }
- return userBilling, nil
- }
- func (repo *UserBillingRepository) ReadUserBilling(projectID, userID uint) (*models.UserBilling, error) {
- userBilling := &models.UserBilling{}
- if err := repo.db.Where("project_id = ? AND user_id = ?", projectID, userID).First(&userBilling).Error; err != nil {
- return nil, err
- }
- err := repo.DecryptUserBillingData(userBilling, repo.key)
- if err != nil {
- return nil, err
- }
- return userBilling, nil
- }
- // UpdateUserBilling updates user billing in the db
- func (repo *UserBillingRepository) UpdateUserBilling(userBilling *models.UserBilling) (*models.UserBilling, error) {
- err := repo.EncryptUserBillingData(userBilling, repo.key)
- if err != nil {
- return nil, err
- }
- if err := repo.db.Save(userBilling).Error; err != nil {
- return nil, err
- }
- err = repo.DecryptUserBillingData(userBilling, repo.key)
- if err != nil {
- return nil, err
- }
- return userBilling, nil
- }
- // EncryptUserBillingData will encrypt the user's billing data before writing
- // to the DB
- func (repo *UserBillingRepository) EncryptUserBillingData(
- userBilling *models.UserBilling,
- key *[32]byte,
- ) error {
- if tok := userBilling.Token; len(tok) > 0 {
- cipherData, err := encryption.Encrypt(tok, key)
- if err != nil {
- return err
- }
- userBilling.Token = cipherData
- }
- return nil
- }
- // DecryptUserBillingData will decrypt the user's billing data before returning it
- // from the DB
- func (repo *UserBillingRepository) DecryptUserBillingData(
- userBilling *models.UserBilling,
- key *[32]byte,
- ) error {
- if tok := userBilling.Token; len(tok) > 0 {
- plaintext, err := encryption.Decrypt(tok, key)
- if err != nil {
- return err
- }
- userBilling.Token = plaintext
- }
- return nil
- }
|