infra.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package gorm
  2. import (
  3. "github.com/porter-dev/porter/internal/models"
  4. "github.com/porter-dev/porter/internal/repository"
  5. "gorm.io/gorm"
  6. )
  7. // InfraRepository uses gorm.DB for querying the database
  8. type InfraRepository struct {
  9. db *gorm.DB
  10. key *[32]byte
  11. }
  12. // NewInfraRepository returns a InfraRepository which uses
  13. // gorm.DB for querying the database
  14. func NewInfraRepository(db *gorm.DB, key *[32]byte) repository.InfraRepository {
  15. return &InfraRepository{db, key}
  16. }
  17. // CreateInfra creates a new aws infra
  18. func (repo *InfraRepository) CreateInfra(infra *models.Infra) (*models.Infra, error) {
  19. err := repo.EncryptInfraData(infra, repo.key)
  20. if err != nil {
  21. return nil, err
  22. }
  23. project := &models.Project{}
  24. if err := repo.db.Where("id = ?", infra.ProjectID).First(&project).Error; err != nil {
  25. return nil, err
  26. }
  27. assoc := repo.db.Model(&project).Association("Infras")
  28. if assoc.Error != nil {
  29. return nil, assoc.Error
  30. }
  31. if err := assoc.Append(infra); err != nil {
  32. return nil, err
  33. }
  34. err = repo.DecryptInfraData(infra, repo.key)
  35. if err != nil {
  36. return nil, err
  37. }
  38. return infra, nil
  39. }
  40. // ReadInfra gets a aws infra specified by a unique id
  41. func (repo *InfraRepository) ReadInfra(id uint) (*models.Infra, error) {
  42. infra := &models.Infra{}
  43. if err := repo.db.Where("id = ?", id).First(&infra).Error; err != nil {
  44. return nil, err
  45. }
  46. err := repo.DecryptInfraData(infra, repo.key)
  47. if err != nil {
  48. return nil, err
  49. }
  50. return infra, nil
  51. }
  52. // ListInfrasByProjectID finds all aws infras
  53. // for a given project id
  54. func (repo *InfraRepository) ListInfrasByProjectID(
  55. projectID uint,
  56. ) ([]*models.Infra, error) {
  57. infras := []*models.Infra{}
  58. if err := repo.db.Where("project_id = ?", projectID).Find(&infras).Error; err != nil {
  59. return nil, err
  60. }
  61. for _, infra := range infras {
  62. repo.DecryptInfraData(infra, repo.key)
  63. }
  64. return infras, nil
  65. }
  66. // UpdateInfra modifies an existing Infra in the database
  67. func (repo *InfraRepository) UpdateInfra(
  68. ai *models.Infra,
  69. ) (*models.Infra, error) {
  70. err := repo.EncryptInfraData(ai, repo.key)
  71. if err != nil {
  72. return nil, err
  73. }
  74. if err := repo.db.Save(ai).Error; err != nil {
  75. return nil, err
  76. }
  77. err = repo.DecryptInfraData(ai, repo.key)
  78. if err != nil {
  79. return nil, err
  80. }
  81. return ai, nil
  82. }
  83. // EncryptInfraData will encrypt the infra data before
  84. // writing to the DB
  85. func (repo *InfraRepository) EncryptInfraData(
  86. infra *models.Infra,
  87. key *[32]byte,
  88. ) error {
  89. if len(infra.LastApplied) > 0 {
  90. cipherData, err := repository.Encrypt(infra.LastApplied, key)
  91. if err != nil {
  92. return err
  93. }
  94. infra.LastApplied = cipherData
  95. }
  96. return nil
  97. }
  98. // DecryptInfraData will decrypt the user's infra data before
  99. // returning it from the DB
  100. func (repo *InfraRepository) DecryptInfraData(
  101. infra *models.Infra,
  102. key *[32]byte,
  103. ) error {
  104. if len(infra.LastApplied) > 0 {
  105. plaintext, err := repository.Decrypt(infra.LastApplied, key)
  106. if err != nil {
  107. return err
  108. }
  109. infra.LastApplied = plaintext
  110. }
  111. return nil
  112. }