get_token.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. package registry
  2. import (
  3. "encoding/base64"
  4. "fmt"
  5. "net/http"
  6. "strings"
  7. "time"
  8. "github.com/aws/aws-sdk-go/service/ecr"
  9. "github.com/bufbuild/connect-go"
  10. porterv1 "github.com/porter-dev/api-contracts/generated/go/porter/v1"
  11. "github.com/porter-dev/porter/api/server/handlers"
  12. "github.com/porter-dev/porter/api/server/shared"
  13. "github.com/porter-dev/porter/api/server/shared/apierrors"
  14. "github.com/porter-dev/porter/api/server/shared/config"
  15. "github.com/porter-dev/porter/api/types"
  16. "github.com/porter-dev/porter/internal/models"
  17. "github.com/porter-dev/porter/internal/oauth"
  18. "github.com/porter-dev/porter/internal/registry"
  19. "github.com/porter-dev/porter/internal/telemetry"
  20. "github.com/aws/aws-sdk-go/aws/arn"
  21. )
  22. type RegistryGetECRTokenHandler struct {
  23. handlers.PorterHandlerReadWriter
  24. }
  25. func NewRegistryGetECRTokenHandler(
  26. config *config.Config,
  27. decoderValidator shared.RequestDecoderValidator,
  28. writer shared.ResultWriter,
  29. ) *RegistryGetECRTokenHandler {
  30. return &RegistryGetECRTokenHandler{
  31. PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, decoderValidator, writer),
  32. }
  33. }
  34. func (c *RegistryGetECRTokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  35. ctx := r.Context()
  36. proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
  37. request := &types.GetRegistryECRTokenRequest{}
  38. if ok := c.DecodeAndValidate(w, r, request); !ok {
  39. return
  40. }
  41. if proj.CapiProvisionerEnabled {
  42. ecrRequest := porterv1.ECRTokenForRegistryRequest{
  43. ProjectId: int64(proj.ID),
  44. Region: request.Region,
  45. AwsAccountId: request.AccountID,
  46. }
  47. ecrResponse, err := c.Config().ClusterControlPlaneClient.ECRTokenForRegistry(ctx, connect.NewRequest(&ecrRequest))
  48. if err != nil {
  49. e := fmt.Errorf("error getting ecr token for capi cluster: %v", err)
  50. c.HandleAPIError(w, r, apierrors.NewErrInternal(e))
  51. return
  52. }
  53. if ecrResponse.Msg == nil {
  54. c.HandleAPIError(w, r, apierrors.NewErrInternal(fmt.Errorf("nil message received for ecr token")))
  55. return
  56. }
  57. expiry := ecrResponse.Msg.Expiry.AsTime()
  58. resp := &types.GetRegistryTokenResponse{
  59. Token: ecrResponse.Msg.Token,
  60. ExpiresAt: &expiry,
  61. }
  62. c.WriteResult(w, r, resp)
  63. return
  64. }
  65. // list registries and find one that matches the region
  66. regs, err := c.Repo().Registry().ListRegistriesByProjectID(proj.ID)
  67. if err != nil {
  68. c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  69. return
  70. }
  71. var token string
  72. var expiresAt *time.Time
  73. for _, reg := range regs {
  74. if reg.AWSIntegrationID != 0 {
  75. awsInt, err := c.Repo().AWSIntegration().ReadAWSIntegration(reg.ProjectID, reg.AWSIntegrationID)
  76. if err != nil {
  77. c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  78. return
  79. }
  80. // if the aws integration doesn't have an ARN populated, populate it
  81. if awsInt.AWSArn == "" {
  82. err = awsInt.PopulateAWSArn()
  83. if err != nil {
  84. continue
  85. }
  86. }
  87. parsedARN, err := arn.Parse(awsInt.AWSArn)
  88. if err != nil {
  89. continue
  90. }
  91. // if the account id is passed as part of the request, verify the account id matches the account id in the ARN
  92. if awsInt.AWSRegion == request.Region && (request.AccountID == "" || request.AccountID == parsedARN.AccountID) {
  93. // get the aws integration and session
  94. sess, err := awsInt.GetSession()
  95. if err != nil {
  96. c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  97. return
  98. }
  99. ecrSvc := ecr.New(sess)
  100. output, err := ecrSvc.GetAuthorizationToken(&ecr.GetAuthorizationTokenInput{})
  101. if err != nil {
  102. c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  103. return
  104. }
  105. token = *output.AuthorizationData[0].AuthorizationToken
  106. expiresAt = output.AuthorizationData[0].ExpiresAt
  107. }
  108. }
  109. }
  110. resp := &types.GetRegistryTokenResponse{
  111. Token: token,
  112. ExpiresAt: expiresAt,
  113. }
  114. c.WriteResult(w, r, resp)
  115. }
  116. type RegistryGetGCRTokenHandler struct {
  117. handlers.PorterHandlerReadWriter
  118. }
  119. func NewRegistryGetGCRTokenHandler(
  120. config *config.Config,
  121. decoderValidator shared.RequestDecoderValidator,
  122. writer shared.ResultWriter,
  123. ) *RegistryGetGCRTokenHandler {
  124. return &RegistryGetGCRTokenHandler{
  125. PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, decoderValidator, writer),
  126. }
  127. }
  128. func (c *RegistryGetGCRTokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  129. ctx, span := telemetry.NewSpan(r.Context(), "serve-registry-get-gcr-token")
  130. defer span.End()
  131. proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
  132. request := &types.GetRegistryGCRTokenRequest{}
  133. if ok := c.DecodeAndValidate(w, r, request); !ok {
  134. return
  135. }
  136. // list registries and find one that matches the region
  137. regs, err := c.Repo().Registry().ListRegistriesByProjectID(proj.ID)
  138. if err != nil {
  139. e := telemetry.Error(ctx, span, err, "error listing registries by project id")
  140. c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(e, http.StatusInternalServerError))
  141. return
  142. }
  143. var token string
  144. var expiresAt *time.Time
  145. for _, reg := range regs {
  146. if reg.GCPIntegrationID != 0 && strings.Contains(reg.URL, request.ServerURL) {
  147. _reg := registry.Registry(*reg)
  148. oauthTok, err := _reg.GetGCRToken(ctx, c.Repo())
  149. if err != nil {
  150. // if the oauth token is not nil, we still return the token but log an error
  151. if oauthTok == nil {
  152. e := telemetry.Error(ctx, span, err, "error getting gcr token")
  153. c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(e, http.StatusInternalServerError))
  154. return
  155. }
  156. e := telemetry.Error(ctx, span, err, "error getting gcr token, but token was returned")
  157. c.HandleAPIErrorNoWrite(w, r, apierrors.NewErrInternal(e))
  158. }
  159. token = oauthTok.AccessToken
  160. expiresAt = &oauthTok.Expiry
  161. break
  162. }
  163. }
  164. resp := &types.GetRegistryTokenResponse{
  165. Token: token,
  166. ExpiresAt: expiresAt,
  167. }
  168. c.WriteResult(w, r, resp)
  169. }
  170. type RegistryGetGARTokenHandler struct {
  171. handlers.PorterHandlerReadWriter
  172. }
  173. func NewRegistryGetGARTokenHandler(
  174. config *config.Config,
  175. decoderValidator shared.RequestDecoderValidator,
  176. writer shared.ResultWriter,
  177. ) *RegistryGetGARTokenHandler {
  178. return &RegistryGetGARTokenHandler{
  179. PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, decoderValidator, writer),
  180. }
  181. }
  182. func (c *RegistryGetGARTokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  183. ctx, span := telemetry.NewSpan(r.Context(), "serve-registry-get-gar-token")
  184. defer span.End()
  185. proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
  186. request := &types.GetRegistryGCRTokenRequest{}
  187. if ok := c.DecodeAndValidate(w, r, request); !ok {
  188. return
  189. }
  190. // list registries and find one that matches the region
  191. regs, err := c.Repo().Registry().ListRegistriesByProjectID(proj.ID)
  192. if err != nil {
  193. e := telemetry.Error(ctx, span, err, "error listing registries by project id")
  194. c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(e, http.StatusInternalServerError))
  195. return
  196. }
  197. var token string
  198. var expiresAt *time.Time
  199. for _, reg := range regs {
  200. if reg.GCPIntegrationID != 0 && strings.Contains(reg.URL, request.ServerURL) {
  201. _reg := registry.Registry(*reg)
  202. oauthTok, err := _reg.GetGARToken(ctx, c.Repo())
  203. if err != nil {
  204. // if the oauth token is not nil, we still return the token but log an error
  205. if oauthTok == nil {
  206. e := telemetry.Error(ctx, span, err, "error getting gar token")
  207. c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(e, http.StatusInternalServerError))
  208. return
  209. }
  210. e := telemetry.Error(ctx, span, err, "error getting gar token, but token was returned")
  211. c.HandleAPIErrorNoWrite(w, r, apierrors.NewErrInternal(e))
  212. }
  213. token = oauthTok.AccessToken
  214. expiresAt = &oauthTok.Expiry
  215. break
  216. }
  217. }
  218. resp := &types.GetRegistryTokenResponse{
  219. Token: token,
  220. ExpiresAt: expiresAt,
  221. }
  222. c.WriteResult(w, r, resp)
  223. }
  224. type RegistryGetDOCRTokenHandler struct {
  225. handlers.PorterHandlerReadWriter
  226. }
  227. func NewRegistryGetDOCRTokenHandler(
  228. config *config.Config,
  229. decoderValidator shared.RequestDecoderValidator,
  230. writer shared.ResultWriter,
  231. ) *RegistryGetDOCRTokenHandler {
  232. return &RegistryGetDOCRTokenHandler{
  233. PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, decoderValidator, writer),
  234. }
  235. }
  236. func (c *RegistryGetDOCRTokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  237. proj, _ := r.Context().Value(types.ProjectScope).(*models.Project)
  238. request := &types.GetRegistryDOCRTokenRequest{}
  239. if ok := c.DecodeAndValidate(w, r, request); !ok {
  240. return
  241. }
  242. // list registries and find one that matches the region
  243. regs, err := c.Repo().Registry().ListRegistriesByProjectID(proj.ID)
  244. if err != nil {
  245. c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  246. return
  247. }
  248. var token string
  249. var expiresAt *time.Time
  250. for _, reg := range regs {
  251. if reg.DOIntegrationID != 0 && strings.Contains(reg.URL, request.ServerURL) {
  252. oauthInt, err := c.Repo().OAuthIntegration().ReadOAuthIntegration(reg.ProjectID, reg.DOIntegrationID)
  253. if err != nil {
  254. c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  255. return
  256. }
  257. tok, expiry, err := oauth.GetAccessToken(
  258. oauthInt.SharedOAuthModel,
  259. c.Config().DOConf,
  260. oauth.MakeUpdateOAuthIntegrationTokenFunction(oauthInt, c.Repo()),
  261. )
  262. if err != nil {
  263. c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  264. return
  265. }
  266. token = tok
  267. expiresAt = expiry
  268. break
  269. }
  270. }
  271. resp := &types.GetRegistryTokenResponse{
  272. Token: token,
  273. ExpiresAt: expiresAt,
  274. }
  275. c.WriteResult(w, r, resp)
  276. }
  277. type RegistryGetDockerhubTokenHandler struct {
  278. handlers.PorterHandlerReadWriter
  279. }
  280. func NewRegistryGetDockerhubTokenHandler(
  281. config *config.Config,
  282. decoderValidator shared.RequestDecoderValidator,
  283. writer shared.ResultWriter,
  284. ) *RegistryGetDockerhubTokenHandler {
  285. return &RegistryGetDockerhubTokenHandler{
  286. PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, decoderValidator, writer),
  287. }
  288. }
  289. func (c *RegistryGetDockerhubTokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  290. proj, _ := r.Context().Value(types.ProjectScope).(*models.Project)
  291. // list registries and find one that matches the region
  292. regs, err := c.Repo().Registry().ListRegistriesByProjectID(proj.ID)
  293. if err != nil {
  294. c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  295. return
  296. }
  297. var token string
  298. var expiresAt *time.Time
  299. for _, reg := range regs {
  300. if reg.BasicIntegrationID != 0 && strings.Contains(reg.URL, "index.docker.io") {
  301. basic, err := c.Repo().BasicIntegration().ReadBasicIntegration(reg.ProjectID, reg.BasicIntegrationID)
  302. if err != nil {
  303. c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  304. return
  305. }
  306. token = base64.StdEncoding.EncodeToString([]byte(string(basic.Username) + ":" + string(basic.Password)))
  307. // we'll just set an arbitrary 30-day expiry time (this is not enforced)
  308. timeExpires := time.Now().Add(30 * 24 * 3600 * time.Second)
  309. expiresAt = &timeExpires
  310. }
  311. }
  312. resp := &types.GetRegistryTokenResponse{
  313. Token: token,
  314. ExpiresAt: expiresAt,
  315. }
  316. c.WriteResult(w, r, resp)
  317. }
  318. type RegistryGetACRTokenHandler struct {
  319. handlers.PorterHandlerReadWriter
  320. }
  321. func NewRegistryGetACRTokenHandler(
  322. config *config.Config,
  323. decoderValidator shared.RequestDecoderValidator,
  324. writer shared.ResultWriter,
  325. ) *RegistryGetACRTokenHandler {
  326. return &RegistryGetACRTokenHandler{
  327. PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, decoderValidator, writer),
  328. }
  329. }
  330. func (c *RegistryGetACRTokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  331. proj, _ := r.Context().Value(types.ProjectScope).(*models.Project)
  332. // list registries and find one that matches the region
  333. regs, err := c.Repo().Registry().ListRegistriesByProjectID(proj.ID)
  334. if err != nil {
  335. c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  336. return
  337. }
  338. var token string
  339. var expiresAt *time.Time
  340. for _, reg := range regs {
  341. if reg.AzureIntegrationID != 0 && strings.Contains(reg.URL, "azurecr.io") {
  342. _reg := registry.Registry(*reg)
  343. username, pw, err := _reg.GetACRCredentials(c.Repo())
  344. if err != nil {
  345. c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
  346. continue
  347. }
  348. token = base64.StdEncoding.EncodeToString([]byte(string(username) + ":" + string(pw)))
  349. // we'll just set an arbitrary 30-day expiry time (this is not enforced)
  350. timeExpires := time.Now().Add(30 * 24 * 3600 * time.Second)
  351. expiresAt = &timeExpires
  352. }
  353. }
  354. resp := &types.GetRegistryTokenResponse{
  355. Token: token,
  356. ExpiresAt: expiresAt,
  357. }
  358. c.WriteResult(w, r, resp)
  359. }