migrate_vault.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. // +build ee
  2. package migrate
  3. import (
  4. "fmt"
  5. "github.com/porter-dev/porter/api/server/shared/config/env"
  6. "github.com/porter-dev/porter/ee/integrations/vault"
  7. ints "github.com/porter-dev/porter/internal/models/integrations"
  8. "github.com/porter-dev/porter/internal/repository/credentials"
  9. "gorm.io/gorm"
  10. )
  11. // process 100 records at a time
  12. const stepSize = 100
  13. func MigrateVault(db *gorm.DB, dbConf *env.DBConf, shouldFinalize bool) error {
  14. var vaultClient *vault.Client
  15. if dbConf.VaultAPIKey != "" && dbConf.VaultServerURL != "" && dbConf.VaultPrefix != "" {
  16. vaultClient = vault.NewClient(
  17. dbConf.VaultServerURL,
  18. dbConf.VaultAPIKey,
  19. dbConf.VaultPrefix,
  20. )
  21. } else {
  22. return fmt.Errorf("env variables not properly set for vault migration")
  23. }
  24. err := migrateOAuthIntegrationModel(db, vaultClient, shouldFinalize)
  25. if err != nil {
  26. fmt.Printf("failed on oauth migration: %v\n", err)
  27. return err
  28. }
  29. err = migrateGCPIntegrationModel(db, vaultClient, shouldFinalize)
  30. if err != nil {
  31. fmt.Printf("failed on gcp migration: %v\n", err)
  32. return err
  33. }
  34. err = migrateAWSIntegrationModel(db, vaultClient, shouldFinalize)
  35. if err != nil {
  36. fmt.Printf("failed on aws migration: %v\n", err)
  37. return err
  38. }
  39. return nil
  40. }
  41. func migrateOAuthIntegrationModel(db *gorm.DB, client *vault.Client, shouldFinalize bool) error {
  42. // get count of model
  43. var count int64
  44. if err := db.Model(&ints.OAuthIntegration{}).Count(&count).Error; err != nil {
  45. return err
  46. }
  47. // make a map of ids to errors -- we don't clear the integrations with errors
  48. errors := make(map[uint]error)
  49. // iterate (count / stepSize) + 1 times using Limit and Offset
  50. for i := 0; i < (int(count)/stepSize)+1; i++ {
  51. oauths := []*ints.OAuthIntegration{}
  52. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&oauths).Error; err != nil {
  53. return err
  54. }
  55. // decrypt with the old key
  56. for _, oauth := range oauths {
  57. // Check if record already exists in vault client. If so, we don't write anything to vault,
  58. // since we don't want to overwrite any data that's been written.
  59. if resp, _ := client.GetOAuthCredential(oauth); resp != nil {
  60. continue
  61. }
  62. // write the data to the vault client
  63. if err := client.WriteOAuthCredential(oauth, &credentials.OAuthCredential{
  64. ClientID: oauth.ClientID,
  65. AccessToken: oauth.AccessToken,
  66. RefreshToken: oauth.RefreshToken,
  67. }); err != nil {
  68. errors[oauth.ID] = err
  69. fmt.Printf("oauth vault write error on ID %d: %v\n", oauth.ID, err)
  70. }
  71. }
  72. }
  73. fmt.Printf("migrated %d oauth integrations with %d errors\n", count, len(errors))
  74. if shouldFinalize {
  75. saveErrors := make(map[uint]error, 0)
  76. // iterate a second time, clearing the data
  77. // iterate (count / stepSize) + 1 times using Limit and Offset
  78. for i := 0; i < (int(count)/stepSize)+1; i++ {
  79. oauths := []*ints.OAuthIntegration{}
  80. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&oauths).Error; err != nil {
  81. return err
  82. }
  83. // decrypt with the old key
  84. for _, oauth := range oauths {
  85. if _, found := errors[oauth.ID]; !found {
  86. // clear the data from the db, and save
  87. oauth.ClientID = []byte{}
  88. oauth.AccessToken = []byte{}
  89. oauth.RefreshToken = []byte{}
  90. if err := db.Save(oauth).Error; err != nil {
  91. saveErrors[oauth.ID] = err
  92. }
  93. }
  94. }
  95. }
  96. fmt.Printf("cleared %d oauth integrations with %d errors\n", count, len(saveErrors))
  97. for saveErrorID, saveError := range saveErrors {
  98. fmt.Printf("oauth save error on ID %d: %v\n", saveErrorID, saveError)
  99. }
  100. }
  101. return nil
  102. }
  103. func migrateGCPIntegrationModel(db *gorm.DB, client *vault.Client, shouldFinalize bool) error {
  104. // get count of model
  105. var count int64
  106. if err := db.Model(&ints.GCPIntegration{}).Count(&count).Error; err != nil {
  107. return err
  108. }
  109. // make a map of ids to errors -- we don't clear the integrations with errors
  110. errors := make(map[uint]error)
  111. // iterate (count / stepSize) + 1 times using Limit and Offset
  112. for i := 0; i < (int(count)/stepSize)+1; i++ {
  113. gcps := []*ints.GCPIntegration{}
  114. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&gcps).Error; err != nil {
  115. return err
  116. }
  117. // decrypt with the old key
  118. for _, gcp := range gcps {
  119. // Check if record already exists in vault client. If so, we don't write anything to vault,
  120. // since we don't want to overwrite any data that's been written.
  121. if resp, _ := client.GetGCPCredential(gcp); resp != nil {
  122. continue
  123. }
  124. // write the data to the vault client
  125. if err := client.WriteGCPCredential(gcp, &credentials.GCPCredential{
  126. GCPKeyData: gcp.GCPKeyData,
  127. }); err != nil {
  128. errors[gcp.ID] = err
  129. fmt.Printf("gcp vault write error on ID %d: %v\n", gcp.ID, err)
  130. }
  131. }
  132. }
  133. fmt.Printf("migrated %d gcp integrations with %d errors\n", count, len(errors))
  134. if shouldFinalize {
  135. saveErrors := make(map[uint]error, 0)
  136. // iterate a second time, clearing the data
  137. // iterate (count / stepSize) + 1 times using Limit and Offset
  138. for i := 0; i < (int(count)/stepSize)+1; i++ {
  139. gcps := []*ints.GCPIntegration{}
  140. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&gcps).Error; err != nil {
  141. return err
  142. }
  143. // decrypt with the old key
  144. for _, gcp := range gcps {
  145. if _, found := errors[gcp.ID]; !found {
  146. // clear the data from the db, and save
  147. gcp.GCPKeyData = []byte{}
  148. if err := db.Save(gcp).Error; err != nil {
  149. saveErrors[gcp.ID] = err
  150. }
  151. }
  152. }
  153. }
  154. fmt.Printf("cleared %d gcp integrations with %d errors\n", count, len(saveErrors))
  155. for saveErrorID, saveError := range saveErrors {
  156. fmt.Printf("gcp save error on ID %d: %v\n", saveErrorID, saveError)
  157. }
  158. }
  159. return nil
  160. }
  161. func migrateAWSIntegrationModel(db *gorm.DB, client *vault.Client, shouldFinalize bool) error {
  162. // get count of model
  163. var count int64
  164. if err := db.Model(&ints.AWSIntegration{}).Count(&count).Error; err != nil {
  165. return err
  166. }
  167. // make a map of ids to errors -- we don't clear the integrations with errors
  168. errors := make(map[uint]error)
  169. // iterate (count / stepSize) + 1 times using Limit and Offset
  170. for i := 0; i < (int(count)/stepSize)+1; i++ {
  171. awss := []*ints.AWSIntegration{}
  172. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&awss).Error; err != nil {
  173. return err
  174. }
  175. // decrypt with the old key
  176. for _, aws := range awss {
  177. // Check if record already exists in vault client. If so, we don't write anything to vault,
  178. // since we don't want to overwrite any data that's been written.
  179. if resp, _ := client.GetAWSCredential(aws); resp != nil {
  180. continue
  181. }
  182. // write the data to the vault client
  183. if err := client.WriteAWSCredential(aws, &credentials.AWSCredential{
  184. AWSAccessKeyID: aws.AWSAccessKeyID,
  185. AWSClusterID: aws.AWSClusterID,
  186. AWSSecretAccessKey: aws.AWSSecretAccessKey,
  187. AWSSessionToken: aws.AWSSessionToken,
  188. }); err != nil {
  189. errors[aws.ID] = err
  190. fmt.Printf("aws vault write error on ID %d: %v\n", aws.ID, err)
  191. }
  192. }
  193. }
  194. fmt.Printf("migrated %d aws integrations with %d errors\n", count, len(errors))
  195. if shouldFinalize {
  196. saveErrors := make(map[uint]error, 0)
  197. // iterate a second time, clearing the data
  198. // iterate (count / stepSize) + 1 times using Limit and Offset
  199. for i := 0; i < (int(count)/stepSize)+1; i++ {
  200. awss := []*ints.AWSIntegration{}
  201. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&awss).Error; err != nil {
  202. return err
  203. }
  204. // decrypt with the old key
  205. for _, aws := range awss {
  206. if _, found := errors[aws.ID]; !found {
  207. // clear the data from the db, and save
  208. aws.AWSAccessKeyID = []byte{}
  209. aws.AWSClusterID = []byte{}
  210. aws.AWSSecretAccessKey = []byte{}
  211. aws.AWSSessionToken = []byte{}
  212. if err := db.Save(aws).Error; err != nil {
  213. saveErrors[aws.ID] = err
  214. }
  215. }
  216. }
  217. }
  218. fmt.Printf("cleared %d aws integrations with %d errors\n", count, len(saveErrors))
  219. for saveErrorID, saveError := range saveErrors {
  220. fmt.Printf("aws save error on ID %d: %v\n", saveErrorID, saveError)
  221. }
  222. }
  223. return nil
  224. }