auth.go 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227
  1. package middleware
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "github.com/google/go-github/github"
  8. "golang.org/x/oauth2"
  9. "io/ioutil"
  10. "net/http"
  11. "net/url"
  12. "strconv"
  13. "strings"
  14. "github.com/go-chi/chi"
  15. "github.com/gorilla/sessions"
  16. "github.com/porter-dev/porter/internal/auth/token"
  17. "github.com/porter-dev/porter/internal/models"
  18. "github.com/porter-dev/porter/internal/repository"
  19. )
  20. // Auth implements the authorization functions
  21. type Auth struct {
  22. store sessions.Store
  23. cookieName string
  24. tokenConf *token.TokenGeneratorConf
  25. repo *repository.Repository
  26. GithubProjectConf *oauth2.Config
  27. }
  28. // NewAuth returns a new Auth instance
  29. func NewAuth(
  30. store sessions.Store,
  31. cookieName string,
  32. tokenConf *token.TokenGeneratorConf,
  33. repo *repository.Repository,
  34. GithubProjectConf *oauth2.Config,
  35. ) *Auth {
  36. return &Auth{store, cookieName, tokenConf, repo, GithubProjectConf}
  37. }
  38. // BasicAuthenticate just checks that a user is logged in
  39. func (auth *Auth) BasicAuthenticate(next http.Handler) http.Handler {
  40. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  41. if auth.isLoggedIn(w, r) {
  42. next.ServeHTTP(w, r)
  43. } else {
  44. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  45. return
  46. }
  47. return
  48. })
  49. }
  50. // BasicAuthenticateWithRedirect checks that a user is logged in, and if they're not, the
  51. // user is redirected to the login page with the redirect path stored in the session
  52. func (auth *Auth) BasicAuthenticateWithRedirect(next http.Handler) http.Handler {
  53. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  54. if auth.isLoggedIn(w, r) {
  55. next.ServeHTTP(w, r)
  56. } else {
  57. session, err := auth.store.Get(r, auth.cookieName)
  58. if err != nil {
  59. http.Redirect(w, r, "/dashboard", 302)
  60. }
  61. // need state parameter to validate when redirected
  62. if r.URL.RawQuery == "" {
  63. session.Values["redirect"] = r.URL.Path
  64. } else {
  65. session.Values["redirect"] = r.URL.Path + "?" + r.URL.RawQuery
  66. }
  67. session.Save(r, w)
  68. http.Redirect(w, r, "/dashboard", 302)
  69. return
  70. }
  71. return
  72. })
  73. }
  74. // IDLocation represents the location of the ID to use for authentication
  75. type IDLocation uint
  76. const (
  77. // URLParam location looks for a parameter in the URL endpoint
  78. URLParam IDLocation = iota
  79. // BodyParam location looks for a parameter in the body
  80. BodyParam
  81. // QueryParam location looks for a parameter in the query string
  82. QueryParam
  83. )
  84. type bodyUserID struct {
  85. UserID uint64 `json:"user_id"`
  86. }
  87. type bodyProjectID struct {
  88. ProjectID uint64 `json:"project_id"`
  89. }
  90. type bodyClusterID struct {
  91. ClusterID uint64 `json:"cluster_id"`
  92. }
  93. type bodyRegistryID struct {
  94. RegistryID uint64 `json:"registry_id"`
  95. }
  96. type bodyGitRepoID struct {
  97. GitRepoID uint64 `json:"git_repo_id"`
  98. }
  99. type bodyInfraID struct {
  100. InfraID uint64 `json:"infra_id"`
  101. }
  102. type bodyInviteID struct {
  103. InviteID uint64 `json:"invite_id"`
  104. }
  105. type bodyAWSIntegrationID struct {
  106. AWSIntegrationID uint64 `json:"aws_integration_id"`
  107. }
  108. type bodyGCPIntegrationID struct {
  109. GCPIntegrationID uint64 `json:"gcp_integration_id"`
  110. }
  111. type bodyDOIntegrationID struct {
  112. DOIntegrationID uint64 `json:"do_integration_id"`
  113. }
  114. // DoesUserIDMatch checks the id URL parameter and verifies that it matches
  115. // the one stored in the session
  116. func (auth *Auth) DoesUserIDMatch(next http.Handler, loc IDLocation) http.Handler {
  117. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  118. var err error
  119. id, err := findUserIDInRequest(r, loc)
  120. // first check for token
  121. tok := auth.getTokenFromRequest(r)
  122. if err != nil {
  123. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  124. return
  125. } else if tok != nil && tok.IBy == uint(id) {
  126. next.ServeHTTP(w, r)
  127. return
  128. } else if auth.doesSessionMatchID(r, uint(id)) {
  129. next.ServeHTTP(w, r)
  130. } else {
  131. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  132. return
  133. }
  134. return
  135. })
  136. }
  137. // AccessType represents the various access types for a project
  138. type AccessType string
  139. // The various access types
  140. const (
  141. AdminAccess AccessType = "admin"
  142. ReadAccess AccessType = "read"
  143. WriteAccess AccessType = "write"
  144. )
  145. // DoesUserHaveProjectAccess looks for a project_id parameter and checks that the
  146. // user has access via the specified accessType
  147. func (auth *Auth) DoesUserHaveProjectAccess(
  148. next http.Handler,
  149. projLoc IDLocation,
  150. accessType AccessType,
  151. ) http.Handler {
  152. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  153. var err error
  154. projID, err := findProjIDInRequest(r, projLoc)
  155. if err != nil {
  156. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  157. return
  158. }
  159. // first check for token
  160. tok := auth.getTokenFromRequest(r)
  161. var userID uint
  162. if tok != nil && tok.ProjectID != 0 && tok.ProjectID == uint(projID) {
  163. next.ServeHTTP(w, r)
  164. return
  165. } else if tok != nil {
  166. userID = tok.IBy
  167. } else {
  168. session, err := auth.store.Get(r, auth.cookieName)
  169. if err != nil {
  170. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  171. return
  172. }
  173. sessionUserID, ok := session.Values["user_id"]
  174. userID = sessionUserID.(uint)
  175. if !ok {
  176. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  177. return
  178. }
  179. }
  180. // get the project
  181. proj, err := auth.repo.Project.ReadProject(uint(projID))
  182. if err != nil {
  183. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  184. return
  185. }
  186. // look for the user role in the project
  187. for _, role := range proj.Roles {
  188. if role.UserID == userID {
  189. if accessType == AdminAccess {
  190. if role.Kind == models.RoleAdmin {
  191. next.ServeHTTP(w, r)
  192. return
  193. }
  194. } else if accessType == WriteAccess {
  195. if role.Kind == models.RoleAdmin || role.Kind == models.RoleDeveloper {
  196. next.ServeHTTP(w, r)
  197. return
  198. }
  199. } else if accessType == ReadAccess {
  200. next.ServeHTTP(w, r)
  201. return
  202. }
  203. }
  204. }
  205. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  206. return
  207. })
  208. }
  209. // DoesUserHaveClusterAccess looks for a project_id parameter and a
  210. // cluster_id parameter, and verifies that the cluster belongs
  211. // to the project
  212. func (auth *Auth) DoesUserHaveClusterAccess(
  213. next http.Handler,
  214. projLoc IDLocation,
  215. clusterLoc IDLocation,
  216. ) http.Handler {
  217. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  218. clusterID, err := findClusterIDInRequest(r, clusterLoc)
  219. if err != nil {
  220. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  221. return
  222. }
  223. projID, err := findProjIDInRequest(r, projLoc)
  224. if err != nil {
  225. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  226. return
  227. }
  228. // get the service accounts belonging to the project
  229. clusters, err := auth.repo.Cluster.ListClustersByProjectID(uint(projID))
  230. if err != nil {
  231. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  232. return
  233. }
  234. doesExist := false
  235. for _, cluster := range clusters {
  236. if cluster.ID == uint(clusterID) {
  237. doesExist = true
  238. break
  239. }
  240. }
  241. if doesExist {
  242. next.ServeHTTP(w, r)
  243. return
  244. }
  245. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  246. return
  247. })
  248. }
  249. // DoesUserHaveInviteAccess looks for a project_id parameter and a
  250. // invite_id parameter, and verifies that the invite belongs
  251. // to the project
  252. func (auth *Auth) DoesUserHaveInviteAccess(
  253. next http.Handler,
  254. projLoc IDLocation,
  255. inviteLoc IDLocation,
  256. ) http.Handler {
  257. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  258. inviteID, err := findInviteIDInRequest(r, inviteLoc)
  259. if err != nil {
  260. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  261. return
  262. }
  263. projID, err := findProjIDInRequest(r, projLoc)
  264. if err != nil {
  265. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  266. return
  267. }
  268. // get the service accounts belonging to the project
  269. invites, err := auth.repo.Invite.ListInvitesByProjectID(uint(projID))
  270. if err != nil {
  271. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  272. return
  273. }
  274. doesExist := false
  275. for _, invite := range invites {
  276. if invite.ID == uint(inviteID) {
  277. doesExist = true
  278. break
  279. }
  280. }
  281. if doesExist {
  282. next.ServeHTTP(w, r)
  283. return
  284. }
  285. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  286. return
  287. })
  288. }
  289. // DoesUserHaveRegistryAccess looks for a project_id parameter and a
  290. // registry_id parameter, and verifies that the registry belongs
  291. // to the project
  292. func (auth *Auth) DoesUserHaveRegistryAccess(
  293. next http.Handler,
  294. projLoc IDLocation,
  295. registryLoc IDLocation,
  296. ) http.Handler {
  297. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  298. regID, err := findRegistryIDInRequest(r, registryLoc)
  299. if err != nil {
  300. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  301. return
  302. }
  303. projID, err := findProjIDInRequest(r, projLoc)
  304. if err != nil {
  305. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  306. return
  307. }
  308. // get the service accounts belonging to the project
  309. regs, err := auth.repo.Registry.ListRegistriesByProjectID(uint(projID))
  310. if err != nil {
  311. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  312. return
  313. }
  314. doesExist := false
  315. for _, reg := range regs {
  316. if reg.ID == uint(regID) {
  317. doesExist = true
  318. break
  319. }
  320. }
  321. if doesExist {
  322. next.ServeHTTP(w, r)
  323. return
  324. }
  325. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  326. return
  327. })
  328. }
  329. // DoesUserHaveGitInstallationAccess checks that a user has access to an installation id
  330. // by ensuring the installation id exists for one org or account they have access to
  331. // note that this makes a github API request, but the endpoint is fast so this doesn't add
  332. // much overhead
  333. func (auth *Auth) DoesUserHaveGitInstallationAccess(
  334. next http.Handler,
  335. gitRepoLoc IDLocation,
  336. ) http.Handler {
  337. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  338. // TODO: needs to use new github integration implementation
  339. grID, err := findGitInstallationIDInRequest(r, gitRepoLoc)
  340. if err != nil {
  341. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  342. return
  343. }
  344. tok := auth.getTokenFromRequest(r)
  345. var userID uint
  346. if tok != nil {
  347. userID = tok.IBy
  348. } else {
  349. session, err := auth.store.Get(r, auth.cookieName)
  350. if err != nil {
  351. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  352. return
  353. }
  354. sessionUserID, ok := session.Values["user_id"]
  355. userID = sessionUserID.(uint)
  356. if !ok {
  357. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  358. return
  359. }
  360. }
  361. user, err := auth.repo.User.ReadUser(userID)
  362. if err != nil {
  363. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  364. return
  365. }
  366. oauthInt, err := auth.repo.GithubAppOAuthIntegration.ReadGithubAppOauthIntegration(user.GithubAppIntegrationID)
  367. if err != nil {
  368. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  369. return
  370. }
  371. client := github.NewClient(auth.GithubProjectConf.Client(oauth2.NoContext, &oauth2.Token{
  372. AccessToken: string(oauthInt.AccessToken),
  373. RefreshToken: string(oauthInt.RefreshToken),
  374. TokenType: "Bearer",
  375. }))
  376. accountIDs := make([]int64, 0)
  377. AuthUser, _, err := client.Users.Get(context.Background(), "")
  378. if err != nil {
  379. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  380. return
  381. }
  382. accountIDs = append(accountIDs, *AuthUser.ID)
  383. opts := &github.ListOptions{
  384. PerPage: 100,
  385. Page: 1,
  386. }
  387. for {
  388. orgs, pages, err := client.Organizations.List(context.Background(), "", opts)
  389. if err != nil {
  390. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  391. return
  392. }
  393. for _, org := range orgs {
  394. accountIDs = append(accountIDs, *org.ID)
  395. }
  396. if pages.NextPage == 0 {
  397. break
  398. }
  399. }
  400. installations, err := auth.repo.GithubAppInstallation.ReadGithubAppInstallationByAccountIDs(accountIDs)
  401. for _, installation := range installations {
  402. if uint64(installation.InstallationID) == grID {
  403. next.ServeHTTP(w, r)
  404. return
  405. }
  406. }
  407. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  408. })
  409. }
  410. // DoesUserHaveInfraAccess looks for a project_id parameter and an
  411. // infra_id parameter, and verifies that the infra belongs
  412. // to the project
  413. func (auth *Auth) DoesUserHaveInfraAccess(
  414. next http.Handler,
  415. projLoc IDLocation,
  416. infraLoc IDLocation,
  417. ) http.Handler {
  418. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  419. infraID, err := findInfraIDInRequest(r, infraLoc)
  420. if err != nil {
  421. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  422. return
  423. }
  424. projID, err := findProjIDInRequest(r, projLoc)
  425. if err != nil {
  426. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  427. return
  428. }
  429. infras, err := auth.repo.Infra.ListInfrasByProjectID(uint(projID))
  430. if err != nil {
  431. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  432. return
  433. }
  434. doesExist := false
  435. for _, infra := range infras {
  436. if infra.ID == uint(infraID) {
  437. doesExist = true
  438. break
  439. }
  440. }
  441. if doesExist {
  442. next.ServeHTTP(w, r)
  443. return
  444. }
  445. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  446. return
  447. })
  448. }
  449. // DoesUserHaveAWSIntegrationAccess looks for a project_id parameter and an
  450. // aws_integration_id parameter, and verifies that the infra belongs
  451. // to the project
  452. func (auth *Auth) DoesUserHaveAWSIntegrationAccess(
  453. next http.Handler,
  454. projLoc IDLocation,
  455. awsLoc IDLocation,
  456. optional bool,
  457. ) http.Handler {
  458. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  459. awsID, err := findAWSIntegrationIDInRequest(r, awsLoc)
  460. if awsID == 0 && optional {
  461. next.ServeHTTP(w, r)
  462. return
  463. }
  464. if awsID == 0 || err != nil {
  465. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  466. return
  467. }
  468. projID, err := findProjIDInRequest(r, projLoc)
  469. if err != nil {
  470. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  471. return
  472. }
  473. awsInts, err := auth.repo.AWSIntegration.ListAWSIntegrationsByProjectID(uint(projID))
  474. if err != nil {
  475. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  476. return
  477. }
  478. doesExist := false
  479. for _, awsInt := range awsInts {
  480. if awsInt.ID == uint(awsID) {
  481. doesExist = true
  482. break
  483. }
  484. }
  485. if doesExist {
  486. next.ServeHTTP(w, r)
  487. return
  488. }
  489. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  490. return
  491. })
  492. }
  493. // DoesUserHaveGCPIntegrationAccess looks for a project_id parameter and an
  494. // gcp_integration_id parameter, and verifies that the infra belongs
  495. // to the project
  496. func (auth *Auth) DoesUserHaveGCPIntegrationAccess(
  497. next http.Handler,
  498. projLoc IDLocation,
  499. gcpLoc IDLocation,
  500. optional bool,
  501. ) http.Handler {
  502. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  503. gcpID, err := findGCPIntegrationIDInRequest(r, gcpLoc)
  504. if gcpID == 0 && optional {
  505. next.ServeHTTP(w, r)
  506. return
  507. }
  508. if gcpID == 0 || err != nil {
  509. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  510. return
  511. }
  512. projID, err := findProjIDInRequest(r, projLoc)
  513. if err != nil {
  514. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  515. return
  516. }
  517. gcpInts, err := auth.repo.GCPIntegration.ListGCPIntegrationsByProjectID(uint(projID))
  518. if err != nil {
  519. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  520. return
  521. }
  522. doesExist := false
  523. for _, awsInt := range gcpInts {
  524. if awsInt.ID == uint(gcpID) {
  525. doesExist = true
  526. break
  527. }
  528. }
  529. if doesExist {
  530. next.ServeHTTP(w, r)
  531. return
  532. }
  533. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  534. return
  535. })
  536. }
  537. // DoesUserHaveDOIntegrationAccess looks for a project_id parameter and an
  538. // do_integration_id parameter, and verifies that the infra belongs
  539. // to the project
  540. func (auth *Auth) DoesUserHaveDOIntegrationAccess(
  541. next http.Handler,
  542. projLoc IDLocation,
  543. doLoc IDLocation,
  544. optional bool,
  545. ) http.Handler {
  546. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  547. doID, err := findDOIntegrationIDInRequest(r, doLoc)
  548. if doID == 0 && optional {
  549. next.ServeHTTP(w, r)
  550. return
  551. }
  552. if doID == 0 || err != nil {
  553. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  554. return
  555. }
  556. projID, err := findProjIDInRequest(r, projLoc)
  557. if err != nil {
  558. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  559. return
  560. }
  561. oauthInts, err := auth.repo.OAuthIntegration.ListOAuthIntegrationsByProjectID(uint(projID))
  562. if err != nil {
  563. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  564. return
  565. }
  566. doesExist := false
  567. for _, oauthInt := range oauthInts {
  568. if oauthInt.ID == uint(doID) {
  569. doesExist = true
  570. break
  571. }
  572. }
  573. if doesExist {
  574. next.ServeHTTP(w, r)
  575. return
  576. }
  577. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  578. return
  579. })
  580. }
  581. // Helpers
  582. func (auth *Auth) doesSessionMatchID(r *http.Request, id uint) bool {
  583. session, _ := auth.store.Get(r, auth.cookieName)
  584. if sessID, ok := session.Values["user_id"].(uint); !ok || sessID != id {
  585. return false
  586. }
  587. return true
  588. }
  589. func (auth *Auth) isLoggedIn(w http.ResponseWriter, r *http.Request) bool {
  590. // first check for Bearer token
  591. tok := auth.getTokenFromRequest(r)
  592. if tok != nil {
  593. return true
  594. }
  595. session, err := auth.store.Get(r, auth.cookieName)
  596. if err != nil {
  597. session.Values["authenticated"] = false
  598. if err := session.Save(r, w); err != nil {
  599. return false
  600. }
  601. return false
  602. }
  603. if auth, ok := session.Values["authenticated"].(bool); !auth || !ok {
  604. return false
  605. }
  606. return true
  607. }
  608. func (auth *Auth) getTokenFromRequest(r *http.Request) *token.Token {
  609. reqToken := r.Header.Get("Authorization")
  610. splitToken := strings.Split(reqToken, "Bearer")
  611. if len(splitToken) != 2 {
  612. return nil
  613. }
  614. reqToken = strings.TrimSpace(splitToken[1])
  615. tok, err := token.GetTokenFromEncoded(reqToken, auth.tokenConf)
  616. if err != nil {
  617. return nil
  618. }
  619. return tok
  620. }
  621. func findUserIDInRequest(r *http.Request, userLoc IDLocation) (uint64, error) {
  622. var userID uint64
  623. var err error
  624. if userLoc == URLParam {
  625. userID, err = strconv.ParseUint(chi.URLParam(r, "user_id"), 0, 64)
  626. if err != nil {
  627. return 0, err
  628. }
  629. } else if userLoc == BodyParam {
  630. form := &bodyUserID{}
  631. body, err := ioutil.ReadAll(r.Body)
  632. if err != nil {
  633. return 0, err
  634. }
  635. err = json.Unmarshal(body, form)
  636. if err != nil {
  637. return 0, err
  638. }
  639. userID = form.UserID
  640. // need to create a new stream for the body
  641. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  642. } else {
  643. vals, err := url.ParseQuery(r.URL.RawQuery)
  644. if err != nil {
  645. return 0, err
  646. }
  647. if userStrArr, ok := vals["user_id"]; ok && len(userStrArr) == 1 {
  648. userID, err = strconv.ParseUint(userStrArr[0], 10, 64)
  649. } else {
  650. return 0, errors.New("user id not found")
  651. }
  652. }
  653. return userID, nil
  654. }
  655. func findProjIDInRequest(r *http.Request, projLoc IDLocation) (uint64, error) {
  656. var projID uint64
  657. var err error
  658. if projLoc == URLParam {
  659. projID, err = strconv.ParseUint(chi.URLParam(r, "project_id"), 0, 64)
  660. if err != nil {
  661. return 0, err
  662. }
  663. } else if projLoc == BodyParam {
  664. form := &bodyProjectID{}
  665. body, err := ioutil.ReadAll(r.Body)
  666. if err != nil {
  667. return 0, err
  668. }
  669. err = json.Unmarshal(body, form)
  670. if err != nil {
  671. return 0, err
  672. }
  673. projID = form.ProjectID
  674. // need to create a new stream for the body
  675. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  676. } else {
  677. vals, err := url.ParseQuery(r.URL.RawQuery)
  678. if err != nil {
  679. return 0, err
  680. }
  681. if projStrArr, ok := vals["project_id"]; ok && len(projStrArr) == 1 {
  682. projID, err = strconv.ParseUint(projStrArr[0], 10, 64)
  683. } else {
  684. return 0, errors.New("project id not found")
  685. }
  686. }
  687. return projID, nil
  688. }
  689. func findClusterIDInRequest(r *http.Request, clusterLoc IDLocation) (uint64, error) {
  690. var clusterID uint64
  691. var err error
  692. if clusterLoc == URLParam {
  693. clusterID, err = strconv.ParseUint(chi.URLParam(r, "cluster_id"), 0, 64)
  694. if err != nil {
  695. return 0, err
  696. }
  697. } else if clusterLoc == BodyParam {
  698. form := &bodyClusterID{}
  699. body, err := ioutil.ReadAll(r.Body)
  700. if err != nil {
  701. return 0, err
  702. }
  703. err = json.Unmarshal(body, form)
  704. if err != nil {
  705. return 0, err
  706. }
  707. clusterID = form.ClusterID
  708. // need to create a new stream for the body
  709. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  710. } else {
  711. vals, err := url.ParseQuery(r.URL.RawQuery)
  712. if err != nil {
  713. return 0, err
  714. }
  715. if clStrArr, ok := vals["cluster_id"]; ok && len(clStrArr) == 1 {
  716. clusterID, err = strconv.ParseUint(clStrArr[0], 10, 64)
  717. } else {
  718. return 0, errors.New("cluster id not found")
  719. }
  720. }
  721. return clusterID, nil
  722. }
  723. func findInviteIDInRequest(r *http.Request, inviteLoc IDLocation) (uint64, error) {
  724. var inviteID uint64
  725. var err error
  726. if inviteLoc == URLParam {
  727. inviteID, err = strconv.ParseUint(chi.URLParam(r, "invite_id"), 0, 64)
  728. if err != nil {
  729. return 0, err
  730. }
  731. } else if inviteLoc == BodyParam {
  732. form := &bodyInviteID{}
  733. body, err := ioutil.ReadAll(r.Body)
  734. if err != nil {
  735. return 0, err
  736. }
  737. err = json.Unmarshal(body, form)
  738. if err != nil {
  739. return 0, err
  740. }
  741. inviteID = form.InviteID
  742. // need to create a new stream for the body
  743. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  744. } else {
  745. vals, err := url.ParseQuery(r.URL.RawQuery)
  746. if err != nil {
  747. return 0, err
  748. }
  749. if invStrArr, ok := vals["invite_id"]; ok && len(invStrArr) == 1 {
  750. inviteID, err = strconv.ParseUint(invStrArr[0], 10, 64)
  751. } else {
  752. return 0, errors.New("invite id not found")
  753. }
  754. }
  755. return inviteID, nil
  756. }
  757. func findRegistryIDInRequest(r *http.Request, registryLoc IDLocation) (uint64, error) {
  758. var regID uint64
  759. var err error
  760. if registryLoc == URLParam {
  761. regID, err = strconv.ParseUint(chi.URLParam(r, "registry_id"), 0, 64)
  762. if err != nil {
  763. return 0, err
  764. }
  765. } else if registryLoc == BodyParam {
  766. form := &bodyRegistryID{}
  767. body, err := ioutil.ReadAll(r.Body)
  768. if err != nil {
  769. return 0, err
  770. }
  771. err = json.Unmarshal(body, form)
  772. if err != nil {
  773. return 0, err
  774. }
  775. regID = form.RegistryID
  776. // need to create a new stream for the body
  777. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  778. } else {
  779. vals, err := url.ParseQuery(r.URL.RawQuery)
  780. if err != nil {
  781. return 0, err
  782. }
  783. if regStrArr, ok := vals["registry_id"]; ok && len(regStrArr) == 1 {
  784. regID, err = strconv.ParseUint(regStrArr[0], 10, 64)
  785. } else {
  786. return 0, errors.New("registry id not found")
  787. }
  788. }
  789. return regID, nil
  790. }
  791. // findGitInstallationIDInRequest extracts and installation ID from a request
  792. func findGitInstallationIDInRequest(r *http.Request, gitRepoLoc IDLocation) (uint64, error) {
  793. var grID uint64
  794. var err error
  795. if gitRepoLoc == URLParam {
  796. grID, err = strconv.ParseUint(chi.URLParam(r, "installation_id"), 0, 64)
  797. if err != nil {
  798. return 0, err
  799. }
  800. } else if gitRepoLoc == BodyParam {
  801. form := &bodyGitRepoID{}
  802. body, err := ioutil.ReadAll(r.Body)
  803. if err != nil {
  804. return 0, err
  805. }
  806. err = json.Unmarshal(body, form)
  807. if err != nil {
  808. return 0, err
  809. }
  810. grID = form.GitRepoID
  811. // need to create a new stream for the body
  812. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  813. } else {
  814. vals, err := url.ParseQuery(r.URL.RawQuery)
  815. if err != nil {
  816. return 0, err
  817. }
  818. if regStrArr, ok := vals["installation_id"]; ok && len(regStrArr) == 1 {
  819. grID, err = strconv.ParseUint(regStrArr[0], 10, 64)
  820. } else {
  821. return 0, errors.New("git app installation id not found")
  822. }
  823. }
  824. return grID, nil
  825. }
  826. func findInfraIDInRequest(r *http.Request, infraLoc IDLocation) (uint64, error) {
  827. var infraID uint64
  828. var err error
  829. if infraLoc == URLParam {
  830. infraID, err = strconv.ParseUint(chi.URLParam(r, "infra_id"), 0, 64)
  831. if err != nil {
  832. return 0, err
  833. }
  834. } else if infraLoc == BodyParam {
  835. form := &bodyInfraID{}
  836. body, err := ioutil.ReadAll(r.Body)
  837. if err != nil {
  838. return 0, err
  839. }
  840. err = json.Unmarshal(body, form)
  841. if err != nil {
  842. return 0, err
  843. }
  844. infraID = form.InfraID
  845. // need to create a new stream for the body
  846. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  847. } else {
  848. vals, err := url.ParseQuery(r.URL.RawQuery)
  849. if err != nil {
  850. return 0, err
  851. }
  852. if regStrArr, ok := vals["infra_id"]; ok && len(regStrArr) == 1 {
  853. infraID, err = strconv.ParseUint(regStrArr[0], 10, 64)
  854. } else {
  855. return 0, errors.New("infra id not found")
  856. }
  857. }
  858. return infraID, nil
  859. }
  860. func findAWSIntegrationIDInRequest(r *http.Request, awsLoc IDLocation) (uint64, error) {
  861. var awsID uint64
  862. var err error
  863. if awsLoc == URLParam {
  864. awsID, err = strconv.ParseUint(chi.URLParam(r, "aws_integration_id"), 0, 64)
  865. if err != nil {
  866. return 0, err
  867. }
  868. } else if awsLoc == BodyParam {
  869. form := &bodyAWSIntegrationID{}
  870. body, err := ioutil.ReadAll(r.Body)
  871. if err != nil {
  872. return 0, err
  873. }
  874. err = json.Unmarshal(body, form)
  875. if err != nil {
  876. return 0, err
  877. }
  878. awsID = form.AWSIntegrationID
  879. // need to create a new stream for the body
  880. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  881. } else {
  882. vals, err := url.ParseQuery(r.URL.RawQuery)
  883. if err != nil {
  884. return 0, err
  885. }
  886. if regStrArr, ok := vals["aws_integration_id"]; ok && len(regStrArr) == 1 {
  887. awsID, err = strconv.ParseUint(regStrArr[0], 10, 64)
  888. } else {
  889. return 0, errors.New("aws integration id not found")
  890. }
  891. }
  892. return awsID, nil
  893. }
  894. func findGCPIntegrationIDInRequest(r *http.Request, gcpLoc IDLocation) (uint64, error) {
  895. var gcpID uint64
  896. var err error
  897. if gcpLoc == URLParam {
  898. gcpID, err = strconv.ParseUint(chi.URLParam(r, "gcp_integration_id"), 0, 64)
  899. if err != nil {
  900. return 0, err
  901. }
  902. } else if gcpLoc == BodyParam {
  903. form := &bodyGCPIntegrationID{}
  904. body, err := ioutil.ReadAll(r.Body)
  905. if err != nil {
  906. return 0, err
  907. }
  908. err = json.Unmarshal(body, form)
  909. if err != nil {
  910. return 0, err
  911. }
  912. gcpID = form.GCPIntegrationID
  913. // need to create a new stream for the body
  914. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  915. } else {
  916. vals, err := url.ParseQuery(r.URL.RawQuery)
  917. if err != nil {
  918. return 0, err
  919. }
  920. if regStrArr, ok := vals["gcp_integration_id"]; ok && len(regStrArr) == 1 {
  921. gcpID, err = strconv.ParseUint(regStrArr[0], 10, 64)
  922. } else {
  923. return 0, errors.New("gcp integration id not found")
  924. }
  925. }
  926. return gcpID, nil
  927. }
  928. func findDOIntegrationIDInRequest(r *http.Request, doLoc IDLocation) (uint64, error) {
  929. var doID uint64
  930. var err error
  931. if doLoc == URLParam {
  932. doID, err = strconv.ParseUint(chi.URLParam(r, "do_integration_id"), 0, 64)
  933. if err != nil {
  934. return 0, err
  935. }
  936. } else if doLoc == BodyParam {
  937. form := &bodyDOIntegrationID{}
  938. body, err := ioutil.ReadAll(r.Body)
  939. if err != nil {
  940. return 0, err
  941. }
  942. err = json.Unmarshal(body, form)
  943. if err != nil {
  944. return 0, err
  945. }
  946. doID = form.DOIntegrationID
  947. // need to create a new stream for the body
  948. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  949. } else {
  950. vals, err := url.ParseQuery(r.URL.RawQuery)
  951. if err != nil {
  952. return 0, err
  953. }
  954. if regStrArr, ok := vals["do_integration_id"]; ok && len(regStrArr) == 1 {
  955. doID, err = strconv.ParseUint(regStrArr[0], 10, 64)
  956. } else {
  957. return 0, errors.New("do integration id not found")
  958. }
  959. }
  960. return doID, nil
  961. }