migrate_vault.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. //go:build ee
  2. // +build ee
  3. package migrate
  4. import (
  5. "fmt"
  6. "github.com/porter-dev/porter/api/server/shared/config/env"
  7. "github.com/porter-dev/porter/ee/integrations/vault"
  8. ints "github.com/porter-dev/porter/internal/models/integrations"
  9. "github.com/porter-dev/porter/internal/repository/credentials"
  10. "gorm.io/gorm"
  11. )
  12. // process 100 records at a time
  13. const stepSize = 100
  14. func MigrateVault(db *gorm.DB, dbConf *env.DBConf, shouldFinalize bool) error {
  15. var vaultClient *vault.Client
  16. if dbConf.VaultAPIKey != "" && dbConf.VaultServerURL != "" && dbConf.VaultPrefix != "" {
  17. vaultClient = vault.NewClient(
  18. dbConf.VaultServerURL,
  19. dbConf.VaultAPIKey,
  20. dbConf.VaultPrefix,
  21. )
  22. } else {
  23. return fmt.Errorf("env variables not properly set for vault migration")
  24. }
  25. err := migrateOAuthIntegrationModel(db, vaultClient, shouldFinalize)
  26. if err != nil {
  27. fmt.Printf("failed on oauth migration: %v\n", err)
  28. return err
  29. }
  30. err = migrateGCPIntegrationModel(db, vaultClient, shouldFinalize)
  31. if err != nil {
  32. fmt.Printf("failed on gcp migration: %v\n", err)
  33. return err
  34. }
  35. err = migrateAWSIntegrationModel(db, vaultClient, shouldFinalize)
  36. if err != nil {
  37. fmt.Printf("failed on aws migration: %v\n", err)
  38. return err
  39. }
  40. err = migrateGitlabIntegrationModel(db, vaultClient, shouldFinalize)
  41. if err != nil {
  42. fmt.Printf("failed on gitlab migration: %v\n", err)
  43. return err
  44. }
  45. return nil
  46. }
  47. func migrateOAuthIntegrationModel(db *gorm.DB, client *vault.Client, shouldFinalize bool) error {
  48. // get count of model
  49. var count int64
  50. if err := db.Model(&ints.OAuthIntegration{}).Count(&count).Error; err != nil {
  51. return err
  52. }
  53. // make a map of ids to errors -- we don't clear the integrations with errors
  54. errors := make(map[uint]error)
  55. // iterate (count / stepSize) + 1 times using Limit and Offset
  56. for i := 0; i < (int(count)/stepSize)+1; i++ {
  57. oauths := []*ints.OAuthIntegration{}
  58. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&oauths).Error; err != nil {
  59. return err
  60. }
  61. // decrypt with the old key
  62. for _, oauth := range oauths {
  63. // Check if record already exists in vault client. If so, we don't write anything to vault,
  64. // since we don't want to overwrite any data that's been written.
  65. if resp, _ := client.GetOAuthCredential(oauth); resp != nil {
  66. continue
  67. }
  68. // write the data to the vault client
  69. if err := client.WriteOAuthCredential(oauth, &credentials.OAuthCredential{
  70. ClientID: oauth.ClientID,
  71. AccessToken: oauth.AccessToken,
  72. RefreshToken: oauth.RefreshToken,
  73. }); err != nil {
  74. errors[oauth.ID] = err
  75. fmt.Printf("oauth vault write error on ID %d: %v\n", oauth.ID, err)
  76. }
  77. }
  78. }
  79. fmt.Printf("migrated %d oauth integrations with %d errors\n", count, len(errors))
  80. if shouldFinalize {
  81. saveErrors := make(map[uint]error, 0)
  82. // iterate a second time, clearing the data
  83. // iterate (count / stepSize) + 1 times using Limit and Offset
  84. for i := 0; i < (int(count)/stepSize)+1; i++ {
  85. oauths := []*ints.OAuthIntegration{}
  86. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&oauths).Error; err != nil {
  87. return err
  88. }
  89. // decrypt with the old key
  90. for _, oauth := range oauths {
  91. if _, found := errors[oauth.ID]; !found {
  92. // clear the data from the db, and save
  93. oauth.ClientID = []byte{}
  94. oauth.AccessToken = []byte{}
  95. oauth.RefreshToken = []byte{}
  96. if err := db.Save(oauth).Error; err != nil {
  97. saveErrors[oauth.ID] = err
  98. }
  99. }
  100. }
  101. }
  102. fmt.Printf("cleared %d oauth integrations with %d errors\n", count, len(saveErrors))
  103. for saveErrorID, saveError := range saveErrors {
  104. fmt.Printf("oauth save error on ID %d: %v\n", saveErrorID, saveError)
  105. }
  106. }
  107. return nil
  108. }
  109. func migrateGCPIntegrationModel(db *gorm.DB, client *vault.Client, shouldFinalize bool) error {
  110. // get count of model
  111. var count int64
  112. if err := db.Model(&ints.GCPIntegration{}).Count(&count).Error; err != nil {
  113. return err
  114. }
  115. // make a map of ids to errors -- we don't clear the integrations with errors
  116. errors := make(map[uint]error)
  117. // iterate (count / stepSize) + 1 times using Limit and Offset
  118. for i := 0; i < (int(count)/stepSize)+1; i++ {
  119. gcps := []*ints.GCPIntegration{}
  120. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&gcps).Error; err != nil {
  121. return err
  122. }
  123. // decrypt with the old key
  124. for _, gcp := range gcps {
  125. // Check if record already exists in vault client. If so, we don't write anything to vault,
  126. // since we don't want to overwrite any data that's been written.
  127. if resp, _ := client.GetGCPCredential(gcp); resp != nil {
  128. continue
  129. }
  130. // write the data to the vault client
  131. if err := client.WriteGCPCredential(gcp, &credentials.GCPCredential{
  132. GCPKeyData: gcp.GCPKeyData,
  133. }); err != nil {
  134. errors[gcp.ID] = err
  135. fmt.Printf("gcp vault write error on ID %d: %v\n", gcp.ID, err)
  136. }
  137. }
  138. }
  139. fmt.Printf("migrated %d gcp integrations with %d errors\n", count, len(errors))
  140. if shouldFinalize {
  141. saveErrors := make(map[uint]error, 0)
  142. // iterate a second time, clearing the data
  143. // iterate (count / stepSize) + 1 times using Limit and Offset
  144. for i := 0; i < (int(count)/stepSize)+1; i++ {
  145. gcps := []*ints.GCPIntegration{}
  146. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&gcps).Error; err != nil {
  147. return err
  148. }
  149. // decrypt with the old key
  150. for _, gcp := range gcps {
  151. if _, found := errors[gcp.ID]; !found {
  152. // clear the data from the db, and save
  153. gcp.GCPKeyData = []byte{}
  154. if err := db.Save(gcp).Error; err != nil {
  155. saveErrors[gcp.ID] = err
  156. }
  157. }
  158. }
  159. }
  160. fmt.Printf("cleared %d gcp integrations with %d errors\n", count, len(saveErrors))
  161. for saveErrorID, saveError := range saveErrors {
  162. fmt.Printf("gcp save error on ID %d: %v\n", saveErrorID, saveError)
  163. }
  164. }
  165. return nil
  166. }
  167. func migrateAWSIntegrationModel(db *gorm.DB, client *vault.Client, shouldFinalize bool) error {
  168. // get count of model
  169. var count int64
  170. if err := db.Model(&ints.AWSIntegration{}).Count(&count).Error; err != nil {
  171. return err
  172. }
  173. // make a map of ids to errors -- we don't clear the integrations with errors
  174. errors := make(map[uint]error)
  175. // iterate (count / stepSize) + 1 times using Limit and Offset
  176. for i := 0; i < (int(count)/stepSize)+1; i++ {
  177. awss := []*ints.AWSIntegration{}
  178. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&awss).Error; err != nil {
  179. return err
  180. }
  181. // decrypt with the old key
  182. for _, aws := range awss {
  183. // Check if record already exists in vault client. If so, we don't write anything to vault,
  184. // since we don't want to overwrite any data that's been written.
  185. if resp, _ := client.GetAWSCredential(aws); resp != nil {
  186. continue
  187. }
  188. // write the data to the vault client
  189. if err := client.WriteAWSCredential(aws, &credentials.AWSCredential{
  190. AWSAccessKeyID: aws.AWSAccessKeyID,
  191. AWSClusterID: aws.AWSClusterID,
  192. AWSSecretAccessKey: aws.AWSSecretAccessKey,
  193. AWSSessionToken: aws.AWSSessionToken,
  194. }); err != nil {
  195. errors[aws.ID] = err
  196. fmt.Printf("aws vault write error on ID %d: %v\n", aws.ID, err)
  197. }
  198. }
  199. }
  200. fmt.Printf("migrated %d aws integrations with %d errors\n", count, len(errors))
  201. if shouldFinalize {
  202. saveErrors := make(map[uint]error, 0)
  203. // iterate a second time, clearing the data
  204. // iterate (count / stepSize) + 1 times using Limit and Offset
  205. for i := 0; i < (int(count)/stepSize)+1; i++ {
  206. awss := []*ints.AWSIntegration{}
  207. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&awss).Error; err != nil {
  208. return err
  209. }
  210. // decrypt with the old key
  211. for _, aws := range awss {
  212. if _, found := errors[aws.ID]; !found {
  213. // clear the data from the db, and save
  214. aws.AWSAccessKeyID = []byte{}
  215. aws.AWSClusterID = []byte{}
  216. aws.AWSSecretAccessKey = []byte{}
  217. aws.AWSSessionToken = []byte{}
  218. if err := db.Save(aws).Error; err != nil {
  219. saveErrors[aws.ID] = err
  220. }
  221. }
  222. }
  223. }
  224. fmt.Printf("cleared %d aws integrations with %d errors\n", count, len(saveErrors))
  225. for saveErrorID, saveError := range saveErrors {
  226. fmt.Printf("aws save error on ID %d: %v\n", saveErrorID, saveError)
  227. }
  228. }
  229. return nil
  230. }
  231. func migrateGitlabIntegrationModel(db *gorm.DB, client *vault.Client, shouldFinalize bool) error {
  232. // get count of model
  233. var count int64
  234. if err := db.Model(&ints.GitlabIntegration{}).Count(&count).Error; err != nil {
  235. return err
  236. }
  237. // make a map of ids to errors -- we don't clear the integrations with errors
  238. errors := make(map[uint]error)
  239. // iterate (count / stepSize) + 1 times using Limit and Offset
  240. for i := 0; i < (int(count)/stepSize)+1; i++ {
  241. giInts := []*ints.GitlabIntegration{}
  242. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&giInts).Error; err != nil {
  243. return err
  244. }
  245. // decrypt with the old key
  246. for _, gi := range giInts {
  247. // Check if record already exists in vault client. If so, we don't write anything to vault,
  248. // since we don't want to overwrite any data that's been written.
  249. if resp, _ := client.GetGitlabCredential(gi); resp != nil {
  250. continue
  251. }
  252. // write the data to the vault client
  253. if err := client.WriteGitlabCredential(gi, &credentials.GitlabCredential{
  254. AppClientID: gi.AppClientID,
  255. AppClientSecret: gi.AppClientSecret,
  256. }); err != nil {
  257. errors[gi.ID] = err
  258. fmt.Printf("gitlab vault write error on ID %d: %v\n", gi.ID, err)
  259. }
  260. }
  261. }
  262. fmt.Printf("migrated %d gitlab integrations with %d errors\n", count, len(errors))
  263. if shouldFinalize {
  264. saveErrors := make(map[uint]error, 0)
  265. // iterate a second time, clearing the data
  266. // iterate (count / stepSize) + 1 times using Limit and Offset
  267. for i := 0; i < (int(count)/stepSize)+1; i++ {
  268. giInts := []*ints.GitlabIntegration{}
  269. if err := db.Order("id asc").Offset(i * stepSize).Limit(stepSize).Find(&giInts).Error; err != nil {
  270. return err
  271. }
  272. // decrypt with the old key
  273. for _, gi := range giInts {
  274. if _, found := errors[gi.ID]; !found {
  275. // clear the data from the db, and save
  276. gi.AppClientID = []byte{}
  277. gi.AppClientSecret = []byte{}
  278. if err := db.Save(gi).Error; err != nil {
  279. saveErrors[gi.ID] = err
  280. }
  281. }
  282. }
  283. }
  284. fmt.Printf("cleared %d gitlab integrations with %d errors\n", count, len(saveErrors))
  285. for saveErrorID, saveError := range saveErrors {
  286. fmt.Printf("gitlab save error on ID %d: %v\n", saveErrorID, saveError)
  287. }
  288. }
  289. return nil
  290. }