auth.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. package middleware
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io/ioutil"
  8. "net/http"
  9. "net/url"
  10. "strconv"
  11. "github.com/go-chi/chi"
  12. "github.com/gorilla/sessions"
  13. "github.com/porter-dev/porter/internal/models"
  14. "github.com/porter-dev/porter/internal/repository"
  15. )
  16. // Auth implements the authorization functions
  17. type Auth struct {
  18. store sessions.Store
  19. cookieName string
  20. repo *repository.Repository
  21. }
  22. // NewAuth returns a new Auth instance
  23. func NewAuth(
  24. store sessions.Store,
  25. cookieName string,
  26. repo *repository.Repository,
  27. ) *Auth {
  28. return &Auth{store, cookieName, repo}
  29. }
  30. // BasicAuthenticate just checks that a user is logged in
  31. func (auth *Auth) BasicAuthenticate(next http.Handler) http.Handler {
  32. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  33. if auth.isLoggedIn(w, r) {
  34. next.ServeHTTP(w, r)
  35. } else {
  36. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  37. return
  38. }
  39. return
  40. })
  41. }
  42. // IDLocation represents the location of the ID to use for authentication
  43. type IDLocation uint
  44. const (
  45. // URLParam location looks for a parameter in the URL endpoint
  46. URLParam IDLocation = iota
  47. // BodyParam location looks for a parameter in the body
  48. BodyParam
  49. // QueryParam location looks for a parameter in the query string
  50. QueryParam
  51. )
  52. type bodyUserID struct {
  53. UserID uint64 `json:"user_id"`
  54. }
  55. type bodyProjectID struct {
  56. ProjectID uint64 `json:"project_id"`
  57. }
  58. type bodyClusterID struct {
  59. ClusterID uint64 `json:"cluster_id"`
  60. }
  61. type bodyRegistryID struct {
  62. RegistryID uint64 `json:"registry_id"`
  63. }
  64. type bodyGitRepoID struct {
  65. GitRepoID uint64 `json:"git_repo_id"`
  66. }
  67. type bodyInfraID struct {
  68. InfraID uint64 `json:"infra_id"`
  69. }
  70. // DoesUserIDMatch checks the id URL parameter and verifies that it matches
  71. // the one stored in the session
  72. func (auth *Auth) DoesUserIDMatch(next http.Handler, loc IDLocation) http.Handler {
  73. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  74. var err error
  75. id, err := findUserIDInRequest(r, loc)
  76. if err == nil && auth.doesSessionMatchID(r, uint(id)) {
  77. next.ServeHTTP(w, r)
  78. } else {
  79. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  80. return
  81. }
  82. return
  83. })
  84. }
  85. // AccessType represents the various access types for a project
  86. type AccessType string
  87. // The various access types
  88. const (
  89. ReadAccess AccessType = "read"
  90. WriteAccess AccessType = "write"
  91. )
  92. // DoesUserHaveProjectAccess looks for a project_id parameter and checks that the
  93. // user has access via the specified accessType
  94. func (auth *Auth) DoesUserHaveProjectAccess(
  95. next http.Handler,
  96. projLoc IDLocation,
  97. accessType AccessType,
  98. ) http.Handler {
  99. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  100. var err error
  101. projID, err := findProjIDInRequest(r, projLoc)
  102. if err != nil {
  103. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  104. return
  105. }
  106. session, err := auth.store.Get(r, auth.cookieName)
  107. if err != nil {
  108. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  109. return
  110. }
  111. userID, ok := session.Values["user_id"].(uint)
  112. if !ok {
  113. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  114. return
  115. }
  116. // get the project
  117. proj, err := auth.repo.Project.ReadProject(uint(projID))
  118. if err != nil {
  119. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  120. return
  121. }
  122. // look for the user role in the project
  123. for _, role := range proj.Roles {
  124. if role.UserID == userID {
  125. if accessType == ReadAccess {
  126. next.ServeHTTP(w, r)
  127. return
  128. } else if accessType == WriteAccess {
  129. if role.Kind == models.RoleAdmin {
  130. next.ServeHTTP(w, r)
  131. return
  132. }
  133. }
  134. }
  135. }
  136. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  137. return
  138. })
  139. }
  140. // DoesUserHaveClusterAccess looks for a project_id parameter and a
  141. // cluster_id parameter, and verifies that the cluster belongs
  142. // to the project
  143. func (auth *Auth) DoesUserHaveClusterAccess(
  144. next http.Handler,
  145. projLoc IDLocation,
  146. clusterLoc IDLocation,
  147. ) http.Handler {
  148. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  149. clusterID, err := findClusterIDInRequest(r, clusterLoc)
  150. if err != nil {
  151. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  152. return
  153. }
  154. projID, err := findProjIDInRequest(r, projLoc)
  155. if err != nil {
  156. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  157. return
  158. }
  159. // get the service accounts belonging to the project
  160. clusters, err := auth.repo.Cluster.ListClustersByProjectID(uint(projID))
  161. if err != nil {
  162. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  163. return
  164. }
  165. doesExist := false
  166. for _, cluster := range clusters {
  167. if cluster.ID == uint(clusterID) {
  168. doesExist = true
  169. break
  170. }
  171. }
  172. if doesExist {
  173. next.ServeHTTP(w, r)
  174. return
  175. }
  176. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  177. return
  178. })
  179. }
  180. // DoesUserHaveRegistryAccess looks for a project_id parameter and a
  181. // registry_id parameter, and verifies that the registry belongs
  182. // to the project
  183. func (auth *Auth) DoesUserHaveRegistryAccess(
  184. next http.Handler,
  185. projLoc IDLocation,
  186. registryLoc IDLocation,
  187. ) http.Handler {
  188. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  189. regID, err := findRegistryIDInRequest(r, registryLoc)
  190. if err != nil {
  191. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  192. return
  193. }
  194. projID, err := findProjIDInRequest(r, projLoc)
  195. if err != nil {
  196. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  197. return
  198. }
  199. // get the service accounts belonging to the project
  200. regs, err := auth.repo.Registry.ListRegistriesByProjectID(uint(projID))
  201. if err != nil {
  202. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  203. return
  204. }
  205. doesExist := false
  206. for _, reg := range regs {
  207. if reg.ID == uint(regID) {
  208. doesExist = true
  209. break
  210. }
  211. }
  212. if doesExist {
  213. next.ServeHTTP(w, r)
  214. return
  215. }
  216. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  217. return
  218. })
  219. }
  220. // DoesUserHaveGitRepoAccess looks for a project_id parameter and a
  221. // git_repo_id parameter, and verifies that the git repo belongs
  222. // to the project
  223. func (auth *Auth) DoesUserHaveGitRepoAccess(
  224. next http.Handler,
  225. projLoc IDLocation,
  226. gitRepoLoc IDLocation,
  227. ) http.Handler {
  228. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  229. grID, err := findGitRepoIDInRequest(r, gitRepoLoc)
  230. if err != nil {
  231. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  232. return
  233. }
  234. projID, err := findProjIDInRequest(r, projLoc)
  235. if err != nil {
  236. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  237. return
  238. }
  239. // get the service accounts belonging to the project
  240. grs, err := auth.repo.GitRepo.ListGitReposByProjectID(uint(projID))
  241. if err != nil {
  242. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  243. return
  244. }
  245. doesExist := false
  246. for _, gr := range grs {
  247. if gr.ID == uint(grID) {
  248. doesExist = true
  249. break
  250. }
  251. }
  252. if doesExist {
  253. next.ServeHTTP(w, r)
  254. return
  255. }
  256. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  257. return
  258. })
  259. }
  260. // DoesUserHaveInfraAccess looks for a project_id parameter and an
  261. // infra_id parameter, and verifies that the infra belongs
  262. // to the project
  263. func (auth *Auth) DoesUserHaveInfraAccess(
  264. next http.Handler,
  265. projLoc IDLocation,
  266. infraLoc IDLocation,
  267. ) http.Handler {
  268. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  269. infraID, err := findGitRepoIDInRequest(r, infraLoc)
  270. if err != nil {
  271. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  272. return
  273. }
  274. projID, err := findProjIDInRequest(r, projLoc)
  275. if err != nil {
  276. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  277. return
  278. }
  279. infras, err := auth.repo.AWSInfra.ListAWSInfrasByProjectID(uint(projID))
  280. if err != nil {
  281. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  282. return
  283. }
  284. doesExist := false
  285. for _, infra := range infras {
  286. if infra.ID == uint(infraID) {
  287. doesExist = true
  288. break
  289. }
  290. }
  291. if doesExist {
  292. next.ServeHTTP(w, r)
  293. return
  294. }
  295. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  296. return
  297. })
  298. }
  299. // Helpers
  300. func (auth *Auth) doesSessionMatchID(r *http.Request, id uint) bool {
  301. session, _ := auth.store.Get(r, auth.cookieName)
  302. if sessID, ok := session.Values["user_id"].(uint); !ok || sessID != id {
  303. return false
  304. }
  305. return true
  306. }
  307. func (auth *Auth) isLoggedIn(w http.ResponseWriter, r *http.Request) bool {
  308. session, err := auth.store.Get(r, auth.cookieName)
  309. if err != nil {
  310. session.Values["authenticated"] = false
  311. if err := session.Save(r, w); err != nil {
  312. fmt.Println("error while saving session in isLoggedIn", err)
  313. }
  314. return false
  315. }
  316. if auth, ok := session.Values["authenticated"].(bool); !auth || !ok {
  317. return false
  318. }
  319. return true
  320. }
  321. func findUserIDInRequest(r *http.Request, userLoc IDLocation) (uint64, error) {
  322. var userID uint64
  323. var err error
  324. if userLoc == URLParam {
  325. userID, err = strconv.ParseUint(chi.URLParam(r, "user_id"), 0, 64)
  326. if err != nil {
  327. return 0, err
  328. }
  329. } else if userLoc == BodyParam {
  330. form := &bodyUserID{}
  331. body, err := ioutil.ReadAll(r.Body)
  332. if err != nil {
  333. return 0, err
  334. }
  335. err = json.Unmarshal(body, form)
  336. if err != nil {
  337. return 0, err
  338. }
  339. userID = form.UserID
  340. // need to create a new stream for the body
  341. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  342. } else {
  343. vals, err := url.ParseQuery(r.URL.RawQuery)
  344. if err != nil {
  345. return 0, err
  346. }
  347. if userStrArr, ok := vals["user_id"]; ok && len(userStrArr) == 1 {
  348. userID, err = strconv.ParseUint(userStrArr[0], 10, 64)
  349. } else {
  350. return 0, errors.New("user id not found")
  351. }
  352. }
  353. return userID, nil
  354. }
  355. func findProjIDInRequest(r *http.Request, projLoc IDLocation) (uint64, error) {
  356. var projID uint64
  357. var err error
  358. if projLoc == URLParam {
  359. projID, err = strconv.ParseUint(chi.URLParam(r, "project_id"), 0, 64)
  360. if err != nil {
  361. return 0, err
  362. }
  363. } else if projLoc == BodyParam {
  364. form := &bodyProjectID{}
  365. body, err := ioutil.ReadAll(r.Body)
  366. if err != nil {
  367. return 0, err
  368. }
  369. err = json.Unmarshal(body, form)
  370. if err != nil {
  371. return 0, err
  372. }
  373. projID = form.ProjectID
  374. // need to create a new stream for the body
  375. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  376. } else {
  377. vals, err := url.ParseQuery(r.URL.RawQuery)
  378. if err != nil {
  379. return 0, err
  380. }
  381. if projStrArr, ok := vals["project_id"]; ok && len(projStrArr) == 1 {
  382. projID, err = strconv.ParseUint(projStrArr[0], 10, 64)
  383. } else {
  384. return 0, errors.New("project id not found")
  385. }
  386. }
  387. return projID, nil
  388. }
  389. func findClusterIDInRequest(r *http.Request, clusterLoc IDLocation) (uint64, error) {
  390. var clusterID uint64
  391. var err error
  392. if clusterLoc == URLParam {
  393. clusterID, err = strconv.ParseUint(chi.URLParam(r, "cluster_id"), 0, 64)
  394. if err != nil {
  395. return 0, err
  396. }
  397. } else if clusterLoc == BodyParam {
  398. form := &bodyClusterID{}
  399. body, err := ioutil.ReadAll(r.Body)
  400. if err != nil {
  401. return 0, err
  402. }
  403. err = json.Unmarshal(body, form)
  404. if err != nil {
  405. return 0, err
  406. }
  407. clusterID = form.ClusterID
  408. // need to create a new stream for the body
  409. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  410. } else {
  411. vals, err := url.ParseQuery(r.URL.RawQuery)
  412. if err != nil {
  413. return 0, err
  414. }
  415. if clStrArr, ok := vals["cluster_id"]; ok && len(clStrArr) == 1 {
  416. clusterID, err = strconv.ParseUint(clStrArr[0], 10, 64)
  417. } else {
  418. return 0, errors.New("cluster id not found")
  419. }
  420. }
  421. return clusterID, nil
  422. }
  423. func findRegistryIDInRequest(r *http.Request, registryLoc IDLocation) (uint64, error) {
  424. var regID uint64
  425. var err error
  426. if registryLoc == URLParam {
  427. regID, err = strconv.ParseUint(chi.URLParam(r, "registry_id"), 0, 64)
  428. if err != nil {
  429. return 0, err
  430. }
  431. } else if registryLoc == BodyParam {
  432. form := &bodyRegistryID{}
  433. body, err := ioutil.ReadAll(r.Body)
  434. if err != nil {
  435. return 0, err
  436. }
  437. err = json.Unmarshal(body, form)
  438. if err != nil {
  439. return 0, err
  440. }
  441. regID = form.RegistryID
  442. // need to create a new stream for the body
  443. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  444. } else {
  445. vals, err := url.ParseQuery(r.URL.RawQuery)
  446. if err != nil {
  447. return 0, err
  448. }
  449. if regStrArr, ok := vals["registry_id"]; ok && len(regStrArr) == 1 {
  450. regID, err = strconv.ParseUint(regStrArr[0], 10, 64)
  451. } else {
  452. return 0, errors.New("registry id not found")
  453. }
  454. }
  455. return regID, nil
  456. }
  457. func findGitRepoIDInRequest(r *http.Request, gitRepoLoc IDLocation) (uint64, error) {
  458. var grID uint64
  459. var err error
  460. if gitRepoLoc == URLParam {
  461. grID, err = strconv.ParseUint(chi.URLParam(r, "git_repo_id"), 0, 64)
  462. if err != nil {
  463. return 0, err
  464. }
  465. } else if gitRepoLoc == BodyParam {
  466. form := &bodyGitRepoID{}
  467. body, err := ioutil.ReadAll(r.Body)
  468. if err != nil {
  469. return 0, err
  470. }
  471. err = json.Unmarshal(body, form)
  472. if err != nil {
  473. return 0, err
  474. }
  475. grID = form.GitRepoID
  476. // need to create a new stream for the body
  477. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  478. } else {
  479. vals, err := url.ParseQuery(r.URL.RawQuery)
  480. if err != nil {
  481. return 0, err
  482. }
  483. if regStrArr, ok := vals["git_repo_id"]; ok && len(regStrArr) == 1 {
  484. grID, err = strconv.ParseUint(regStrArr[0], 10, 64)
  485. } else {
  486. return 0, errors.New("git repo id not found")
  487. }
  488. }
  489. return grID, nil
  490. }
  491. func findInfraIDInRequest(r *http.Request, infraLoc IDLocation) (uint64, error) {
  492. var infraID uint64
  493. var err error
  494. if infraLoc == URLParam {
  495. infraID, err = strconv.ParseUint(chi.URLParam(r, "infra_id"), 0, 64)
  496. if err != nil {
  497. return 0, err
  498. }
  499. } else if infraLoc == BodyParam {
  500. form := &bodyInfraID{}
  501. body, err := ioutil.ReadAll(r.Body)
  502. if err != nil {
  503. return 0, err
  504. }
  505. err = json.Unmarshal(body, form)
  506. if err != nil {
  507. return 0, err
  508. }
  509. infraID = form.InfraID
  510. // need to create a new stream for the body
  511. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  512. } else {
  513. vals, err := url.ParseQuery(r.URL.RawQuery)
  514. if err != nil {
  515. return 0, err
  516. }
  517. if regStrArr, ok := vals["infra_id"]; ok && len(regStrArr) == 1 {
  518. infraID, err = strconv.ParseUint(regStrArr[0], 10, 64)
  519. } else {
  520. return 0, errors.New("infra id not found")
  521. }
  522. }
  523. return infraID, nil
  524. }