migrate_vault.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. //go:build ee
  2. package migrate
  3. import (
  4. "fmt"
  5. "github.com/porter-dev/porter/api/server/shared/config/env"
  6. "github.com/porter-dev/porter/ee/integrations/vault"
  7. ints "github.com/porter-dev/porter/internal/models/integrations"
  8. "github.com/porter-dev/porter/internal/repository/credentials"
  9. "gorm.io/gorm"
  10. )
  11. // process 100 records at a time
  12. const stepSize = 100
  13. func MigrateVault(db *gorm.DB, dbConf *env.DBConf, shouldFinalize bool) error {
  14. var vaultClient *vault.Client
  15. if dbConf.VaultAPIKey != "" && dbConf.VaultServerURL != "" && dbConf.VaultPrefix != "" {
  16. vaultClient = vault.NewClient(
  17. dbConf.VaultServerURL,
  18. dbConf.VaultAPIKey,
  19. dbConf.VaultPrefix,
  20. )
  21. } else {
  22. return fmt.Errorf("env variables not properly set for vault migration")
  23. }
  24. err := migrateOAuthIntegrationModel(db, vaultClient, shouldFinalize)
  25. if err != nil {
  26. fmt.Printf("failed on oauth migration: %v\n", err)
  27. return err
  28. }
  29. err = migrateGCPIntegrationModel(db, vaultClient, shouldFinalize)
  30. if err != nil {
  31. fmt.Printf("failed on gcp migration: %v\n", err)
  32. return err
  33. }
  34. err = migrateAWSIntegrationModel(db, vaultClient, shouldFinalize)
  35. if err != nil {
  36. fmt.Printf("failed on aws migration: %v\n", err)
  37. return err
  38. }
  39. err = migrateGitlabIntegrationModel(db, vaultClient, shouldFinalize)
  40. if err != nil {
  41. fmt.Printf("failed on gitlab migration: %v\n", err)
  42. return err
  43. }
  44. return nil
  45. }
  46. func migrateOAuthIntegrationModel(db *gorm.DB, client *vault.Client, shouldFinalize bool) error {
  47. // get count of model
  48. var count int64
  49. if err := db.Model(&ints.OAuthIntegration{}).Count(&count).Error; err != nil {
  50. return err
  51. }
  52. // make a map of ids to errors -- we don't clear the integrations with errors
  53. errors := make(map[uint]error)
  54. // iterate (count / stepSize) + 1 times using Limit and Offset
  55. for i := 0; i < (int(count)/stepSize)+1; i++ {
  56. oauths := []*ints.OAuthIntegration{}
  57. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&oauths).Error; err != nil {
  58. return err
  59. }
  60. // decrypt with the old key
  61. for _, oauth := range oauths {
  62. // Check if record already exists in vault client. If so, we don't write anything to vault,
  63. // since we don't want to overwrite any data that's been written.
  64. if resp, _ := client.GetOAuthCredential(oauth); resp != nil {
  65. continue
  66. }
  67. // write the data to the vault client
  68. if err := client.WriteOAuthCredential(oauth, &credentials.OAuthCredential{
  69. ClientID: oauth.ClientID,
  70. AccessToken: oauth.AccessToken,
  71. RefreshToken: oauth.RefreshToken,
  72. }); err != nil {
  73. errors[oauth.ID] = err
  74. fmt.Printf("oauth vault write error on ID %d: %v\n", oauth.ID, err)
  75. }
  76. }
  77. }
  78. fmt.Printf("migrated %d oauth integrations with %d errors\n", count, len(errors))
  79. if shouldFinalize {
  80. saveErrors := make(map[uint]error, 0)
  81. // iterate a second time, clearing the data
  82. // iterate (count / stepSize) + 1 times using Limit and Offset
  83. for i := 0; i < (int(count)/stepSize)+1; i++ {
  84. oauths := []*ints.OAuthIntegration{}
  85. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&oauths).Error; err != nil {
  86. return err
  87. }
  88. // decrypt with the old key
  89. for _, oauth := range oauths {
  90. if _, found := errors[oauth.ID]; !found {
  91. // clear the data from the db, and save
  92. oauth.ClientID = []byte{}
  93. oauth.AccessToken = []byte{}
  94. oauth.RefreshToken = []byte{}
  95. if err := db.Save(oauth).Error; err != nil {
  96. saveErrors[oauth.ID] = err
  97. }
  98. }
  99. }
  100. }
  101. fmt.Printf("cleared %d oauth integrations with %d errors\n", count, len(saveErrors))
  102. for saveErrorID, saveError := range saveErrors {
  103. fmt.Printf("oauth save error on ID %d: %v\n", saveErrorID, saveError)
  104. }
  105. }
  106. return nil
  107. }
  108. func migrateGCPIntegrationModel(db *gorm.DB, client *vault.Client, shouldFinalize bool) error {
  109. // get count of model
  110. var count int64
  111. if err := db.Model(&ints.GCPIntegration{}).Count(&count).Error; err != nil {
  112. return err
  113. }
  114. // make a map of ids to errors -- we don't clear the integrations with errors
  115. errors := make(map[uint]error)
  116. // iterate (count / stepSize) + 1 times using Limit and Offset
  117. for i := 0; i < (int(count)/stepSize)+1; i++ {
  118. gcps := []*ints.GCPIntegration{}
  119. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&gcps).Error; err != nil {
  120. return err
  121. }
  122. // decrypt with the old key
  123. for _, gcp := range gcps {
  124. // Check if record already exists in vault client. If so, we don't write anything to vault,
  125. // since we don't want to overwrite any data that's been written.
  126. if resp, _ := client.GetGCPCredential(gcp); resp != nil {
  127. continue
  128. }
  129. // write the data to the vault client
  130. if err := client.WriteGCPCredential(gcp, &credentials.GCPCredential{
  131. GCPKeyData: gcp.GCPKeyData,
  132. }); err != nil {
  133. errors[gcp.ID] = err
  134. fmt.Printf("gcp vault write error on ID %d: %v\n", gcp.ID, err)
  135. }
  136. }
  137. }
  138. fmt.Printf("migrated %d gcp integrations with %d errors\n", count, len(errors))
  139. if shouldFinalize {
  140. saveErrors := make(map[uint]error, 0)
  141. // iterate a second time, clearing the data
  142. // iterate (count / stepSize) + 1 times using Limit and Offset
  143. for i := 0; i < (int(count)/stepSize)+1; i++ {
  144. gcps := []*ints.GCPIntegration{}
  145. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&gcps).Error; err != nil {
  146. return err
  147. }
  148. // decrypt with the old key
  149. for _, gcp := range gcps {
  150. if _, found := errors[gcp.ID]; !found {
  151. // clear the data from the db, and save
  152. gcp.GCPKeyData = []byte{}
  153. if err := db.Save(gcp).Error; err != nil {
  154. saveErrors[gcp.ID] = err
  155. }
  156. }
  157. }
  158. }
  159. fmt.Printf("cleared %d gcp integrations with %d errors\n", count, len(saveErrors))
  160. for saveErrorID, saveError := range saveErrors {
  161. fmt.Printf("gcp save error on ID %d: %v\n", saveErrorID, saveError)
  162. }
  163. }
  164. return nil
  165. }
  166. func migrateAWSIntegrationModel(db *gorm.DB, client *vault.Client, shouldFinalize bool) error {
  167. // get count of model
  168. var count int64
  169. if err := db.Model(&ints.AWSIntegration{}).Count(&count).Error; err != nil {
  170. return err
  171. }
  172. // make a map of ids to errors -- we don't clear the integrations with errors
  173. errors := make(map[uint]error)
  174. // iterate (count / stepSize) + 1 times using Limit and Offset
  175. for i := 0; i < (int(count)/stepSize)+1; i++ {
  176. awss := []*ints.AWSIntegration{}
  177. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&awss).Error; err != nil {
  178. return err
  179. }
  180. // decrypt with the old key
  181. for _, aws := range awss {
  182. // Check if record already exists in vault client. If so, we don't write anything to vault,
  183. // since we don't want to overwrite any data that's been written.
  184. if resp, _ := client.GetAWSCredential(aws); resp != nil {
  185. continue
  186. }
  187. // write the data to the vault client
  188. if err := client.WriteAWSCredential(aws, &credentials.AWSCredential{
  189. AWSAccessKeyID: aws.AWSAccessKeyID,
  190. AWSClusterID: aws.AWSClusterID,
  191. AWSSecretAccessKey: aws.AWSSecretAccessKey,
  192. AWSSessionToken: aws.AWSSessionToken,
  193. }); err != nil {
  194. errors[aws.ID] = err
  195. fmt.Printf("aws vault write error on ID %d: %v\n", aws.ID, err)
  196. }
  197. }
  198. }
  199. fmt.Printf("migrated %d aws integrations with %d errors\n", count, len(errors))
  200. if shouldFinalize {
  201. saveErrors := make(map[uint]error, 0)
  202. // iterate a second time, clearing the data
  203. // iterate (count / stepSize) + 1 times using Limit and Offset
  204. for i := 0; i < (int(count)/stepSize)+1; i++ {
  205. awss := []*ints.AWSIntegration{}
  206. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&awss).Error; err != nil {
  207. return err
  208. }
  209. // decrypt with the old key
  210. for _, aws := range awss {
  211. if _, found := errors[aws.ID]; !found {
  212. // clear the data from the db, and save
  213. aws.AWSAccessKeyID = []byte{}
  214. aws.AWSClusterID = []byte{}
  215. aws.AWSSecretAccessKey = []byte{}
  216. aws.AWSSessionToken = []byte{}
  217. if err := db.Save(aws).Error; err != nil {
  218. saveErrors[aws.ID] = err
  219. }
  220. }
  221. }
  222. }
  223. fmt.Printf("cleared %d aws integrations with %d errors\n", count, len(saveErrors))
  224. for saveErrorID, saveError := range saveErrors {
  225. fmt.Printf("aws save error on ID %d: %v\n", saveErrorID, saveError)
  226. }
  227. }
  228. return nil
  229. }
  230. func migrateGitlabIntegrationModel(db *gorm.DB, client *vault.Client, shouldFinalize bool) error {
  231. // get count of model
  232. var count int64
  233. if err := db.Model(&ints.GitlabIntegration{}).Count(&count).Error; err != nil {
  234. return err
  235. }
  236. // make a map of ids to errors -- we don't clear the integrations with errors
  237. errors := make(map[uint]error)
  238. // iterate (count / stepSize) + 1 times using Limit and Offset
  239. for i := 0; i < (int(count)/stepSize)+1; i++ {
  240. giInts := []*ints.GitlabIntegration{}
  241. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&giInts).Error; err != nil {
  242. return err
  243. }
  244. // decrypt with the old key
  245. for _, gi := range giInts {
  246. // Check if record already exists in vault client. If so, we don't write anything to vault,
  247. // since we don't want to overwrite any data that's been written.
  248. if resp, _ := client.GetGitlabCredential(gi); resp != nil {
  249. continue
  250. }
  251. // write the data to the vault client
  252. if err := client.WriteGitlabCredential(gi, &credentials.GitlabCredential{
  253. AppClientID: gi.AppClientID,
  254. AppClientSecret: gi.AppClientSecret,
  255. }); err != nil {
  256. errors[gi.ID] = err
  257. fmt.Printf("gitlab vault write error on ID %d: %v\n", gi.ID, err)
  258. }
  259. }
  260. }
  261. fmt.Printf("migrated %d gitlab integrations with %d errors\n", count, len(errors))
  262. if shouldFinalize {
  263. saveErrors := make(map[uint]error, 0)
  264. // iterate a second time, clearing the data
  265. // iterate (count / stepSize) + 1 times using Limit and Offset
  266. for i := 0; i < (int(count)/stepSize)+1; i++ {
  267. giInts := []*ints.GitlabIntegration{}
  268. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&giInts).Error; err != nil {
  269. return err
  270. }
  271. // decrypt with the old key
  272. for _, gi := range giInts {
  273. if _, found := errors[gi.ID]; !found {
  274. // clear the data from the db, and save
  275. gi.AppClientID = []byte{}
  276. gi.AppClientSecret = []byte{}
  277. if err := db.Save(gi).Error; err != nil {
  278. saveErrors[gi.ID] = err
  279. }
  280. }
  281. }
  282. }
  283. fmt.Printf("cleared %d gitlab integrations with %d errors\n", count, len(saveErrors))
  284. for saveErrorID, saveError := range saveErrors {
  285. fmt.Printf("gitlab save error on ID %d: %v\n", saveErrorID, saveError)
  286. }
  287. }
  288. return nil
  289. }