neon.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. package gorm
  2. import (
  3. "context"
  4. "github.com/porter-dev/porter/internal/encryption"
  5. ints "github.com/porter-dev/porter/internal/models/integrations"
  6. "github.com/porter-dev/porter/internal/repository"
  7. "github.com/porter-dev/porter/internal/telemetry"
  8. "gorm.io/gorm"
  9. )
  10. // NeonIntegrationRepository is a repository that manages neon integrations
  11. type NeonIntegrationRepository struct {
  12. db *gorm.DB
  13. key *[32]byte
  14. }
  15. // NewNeonIntegrationRepository returns a NeonIntegrationRepository
  16. func NewNeonIntegrationRepository(db *gorm.DB, key *[32]byte) repository.NeonIntegrationRepository {
  17. return &NeonIntegrationRepository{db, key}
  18. }
  19. // Insert creates a new neon integration
  20. func (repo *NeonIntegrationRepository) Insert(
  21. ctx context.Context, neonInt ints.NeonIntegration,
  22. ) (ints.NeonIntegration, error) {
  23. ctx, span := telemetry.NewSpan(ctx, "gorm-create-neon-integration")
  24. defer span.End()
  25. var created ints.NeonIntegration
  26. encrypted, err := repo.EncryptNeonIntegration(neonInt, repo.key)
  27. if err != nil {
  28. return created, telemetry.Error(ctx, span, err, "failed to encrypt")
  29. }
  30. if err := repo.db.Create(&encrypted).Error; err != nil {
  31. return created, telemetry.Error(ctx, span, err, "failed to create neon integration")
  32. }
  33. return created, nil
  34. }
  35. // Integrations returns all neon integrations for a given project
  36. func (repo *NeonIntegrationRepository) Integrations(
  37. ctx context.Context, projectID uint,
  38. ) ([]ints.NeonIntegration, error) {
  39. ctx, span := telemetry.NewSpan(ctx, "gorm-list-neon-integrations")
  40. defer span.End()
  41. var integrations []ints.NeonIntegration
  42. if err := repo.db.Where("project_id = ?", projectID).Find(&integrations).Error; err != nil {
  43. return integrations, telemetry.Error(ctx, span, err, "failed to list neon integrations")
  44. }
  45. for i, integration := range integrations {
  46. decrypted, err := repo.DecryptNeonIntegration(integration, repo.key)
  47. if err != nil {
  48. return integrations, telemetry.Error(ctx, span, err, "failed to decrypt")
  49. }
  50. integrations[i] = decrypted
  51. }
  52. return integrations, nil
  53. }
  54. // EncryptNeonIntegration will encrypt the neon integration data before
  55. // writing to the DB
  56. func (repo *NeonIntegrationRepository) EncryptNeonIntegration(
  57. neonInt ints.NeonIntegration,
  58. key *[32]byte,
  59. ) (ints.NeonIntegration, error) {
  60. encrypted := neonInt
  61. if len(encrypted.ClientID) > 0 {
  62. cipherData, err := encryption.Encrypt(encrypted.ClientID, key)
  63. if err != nil {
  64. return encrypted, err
  65. }
  66. encrypted.ClientID = cipherData
  67. }
  68. if len(encrypted.AccessToken) > 0 {
  69. cipherData, err := encryption.Encrypt(encrypted.AccessToken, key)
  70. if err != nil {
  71. return encrypted, err
  72. }
  73. encrypted.AccessToken = cipherData
  74. }
  75. if len(encrypted.RefreshToken) > 0 {
  76. cipherData, err := encryption.Encrypt(encrypted.RefreshToken, key)
  77. if err != nil {
  78. return encrypted, err
  79. }
  80. encrypted.RefreshToken = cipherData
  81. }
  82. return encrypted, nil
  83. }
  84. // DecryptNeonIntegration will decrypt the neon integration data before
  85. // returning it from the DB
  86. func (repo *NeonIntegrationRepository) DecryptNeonIntegration(
  87. neonInt ints.NeonIntegration,
  88. key *[32]byte,
  89. ) (ints.NeonIntegration, error) {
  90. decrypted := neonInt
  91. if len(decrypted.ClientID) > 0 {
  92. plaintext, err := encryption.Decrypt(decrypted.ClientID, key)
  93. if err != nil {
  94. return decrypted, err
  95. }
  96. decrypted.ClientID = plaintext
  97. }
  98. if len(decrypted.AccessToken) > 0 {
  99. plaintext, err := encryption.Decrypt(decrypted.AccessToken, key)
  100. if err != nil {
  101. return decrypted, err
  102. }
  103. decrypted.AccessToken = plaintext
  104. }
  105. if len(decrypted.RefreshToken) > 0 {
  106. plaintext, err := encryption.Decrypt(decrypted.RefreshToken, key)
  107. if err != nil {
  108. return decrypted, err
  109. }
  110. decrypted.RefreshToken = plaintext
  111. }
  112. return decrypted, nil
  113. }