2
0

auth.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. package middleware
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io/ioutil"
  8. "net/http"
  9. "net/url"
  10. "strconv"
  11. "github.com/go-chi/chi"
  12. "github.com/gorilla/sessions"
  13. "github.com/porter-dev/porter/internal/models"
  14. "github.com/porter-dev/porter/internal/repository"
  15. )
  16. // Auth implements the authorization functions
  17. type Auth struct {
  18. store sessions.Store
  19. cookieName string
  20. repo *repository.Repository
  21. }
  22. // NewAuth returns a new Auth instance
  23. func NewAuth(
  24. store sessions.Store,
  25. cookieName string,
  26. repo *repository.Repository,
  27. ) *Auth {
  28. return &Auth{store, cookieName, repo}
  29. }
  30. // BasicAuthenticate just checks that a user is logged in
  31. func (auth *Auth) BasicAuthenticate(next http.Handler) http.Handler {
  32. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  33. if auth.isLoggedIn(w, r) {
  34. next.ServeHTTP(w, r)
  35. } else {
  36. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  37. return
  38. }
  39. return
  40. })
  41. }
  42. // IDLocation represents the location of the ID to use for authentication
  43. type IDLocation uint
  44. const (
  45. // URLParam location looks for a parameter in the URL endpoint
  46. URLParam IDLocation = iota
  47. // BodyParam location looks for a parameter in the body
  48. BodyParam
  49. // QueryParam location looks for a parameter in the query string
  50. QueryParam
  51. )
  52. type bodyUserID struct {
  53. UserID uint64 `json:"user_id"`
  54. }
  55. type bodyProjectID struct {
  56. ProjectID uint64 `json:"project_id"`
  57. }
  58. type bodyClusterID struct {
  59. ClusterID uint64 `json:"cluster_id"`
  60. }
  61. type bodyRegistryID struct {
  62. RegistryID uint64 `json:"registry_id"`
  63. }
  64. type bodyGitRepoID struct {
  65. GitRepoID uint64 `json:"git_repo_id"`
  66. }
  67. // DoesUserIDMatch checks the id URL parameter and verifies that it matches
  68. // the one stored in the session
  69. func (auth *Auth) DoesUserIDMatch(next http.Handler, loc IDLocation) http.Handler {
  70. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  71. var err error
  72. id, err := findUserIDInRequest(r, loc)
  73. if err == nil && auth.doesSessionMatchID(r, uint(id)) {
  74. next.ServeHTTP(w, r)
  75. } else {
  76. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  77. return
  78. }
  79. return
  80. })
  81. }
  82. // AccessType represents the various access types for a project
  83. type AccessType string
  84. // The various access types
  85. const (
  86. ReadAccess AccessType = "read"
  87. WriteAccess AccessType = "write"
  88. )
  89. // DoesUserHaveProjectAccess looks for a project_id parameter and checks that the
  90. // user has access via the specified accessType
  91. func (auth *Auth) DoesUserHaveProjectAccess(
  92. next http.Handler,
  93. projLoc IDLocation,
  94. accessType AccessType,
  95. ) http.Handler {
  96. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  97. var err error
  98. projID, err := findProjIDInRequest(r, projLoc)
  99. if err != nil {
  100. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  101. return
  102. }
  103. session, err := auth.store.Get(r, auth.cookieName)
  104. if err != nil {
  105. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  106. return
  107. }
  108. userID, ok := session.Values["user_id"].(uint)
  109. if !ok {
  110. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  111. return
  112. }
  113. // get the project
  114. proj, err := auth.repo.Project.ReadProject(uint(projID))
  115. if err != nil {
  116. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  117. return
  118. }
  119. // look for the user role in the project
  120. for _, role := range proj.Roles {
  121. if role.UserID == userID {
  122. if accessType == ReadAccess {
  123. next.ServeHTTP(w, r)
  124. return
  125. } else if accessType == WriteAccess {
  126. if role.Kind == models.RoleAdmin {
  127. next.ServeHTTP(w, r)
  128. return
  129. }
  130. }
  131. }
  132. }
  133. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  134. return
  135. })
  136. }
  137. // DoesUserHaveClusterAccess looks for a project_id parameter and a
  138. // cluster_id parameter, and verifies that the cluster belongs
  139. // to the project
  140. func (auth *Auth) DoesUserHaveClusterAccess(
  141. next http.Handler,
  142. projLoc IDLocation,
  143. clusterLoc IDLocation,
  144. ) http.Handler {
  145. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  146. clusterID, err := findClusterIDInRequest(r, clusterLoc)
  147. if err != nil {
  148. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  149. return
  150. }
  151. projID, err := findProjIDInRequest(r, projLoc)
  152. if err != nil {
  153. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  154. return
  155. }
  156. // get the service accounts belonging to the project
  157. clusters, err := auth.repo.Cluster.ListClustersByProjectID(uint(projID))
  158. if err != nil {
  159. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  160. return
  161. }
  162. doesExist := false
  163. for _, cluster := range clusters {
  164. if cluster.ID == uint(clusterID) {
  165. doesExist = true
  166. break
  167. }
  168. }
  169. if doesExist {
  170. next.ServeHTTP(w, r)
  171. return
  172. }
  173. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  174. return
  175. })
  176. }
  177. // DoesUserHaveRegistryAccess looks for a project_id parameter and a
  178. // registry_id parameter, and verifies that the registry belongs
  179. // to the project
  180. func (auth *Auth) DoesUserHaveRegistryAccess(
  181. next http.Handler,
  182. projLoc IDLocation,
  183. registryLoc IDLocation,
  184. ) http.Handler {
  185. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  186. regID, err := findRegistryIDInRequest(r, registryLoc)
  187. if err != nil {
  188. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  189. return
  190. }
  191. projID, err := findProjIDInRequest(r, projLoc)
  192. if err != nil {
  193. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  194. return
  195. }
  196. // get the service accounts belonging to the project
  197. regs, err := auth.repo.Registry.ListRegistriesByProjectID(uint(projID))
  198. if err != nil {
  199. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  200. return
  201. }
  202. doesExist := false
  203. for _, reg := range regs {
  204. if reg.ID == uint(regID) {
  205. doesExist = true
  206. break
  207. }
  208. }
  209. if doesExist {
  210. next.ServeHTTP(w, r)
  211. return
  212. }
  213. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  214. return
  215. })
  216. }
  217. // DoesUserHaveGitRepoAccess looks for a project_id parameter and a
  218. // git_repo_id parameter, and verifies that the git repo belongs
  219. // to the project
  220. func (auth *Auth) DoesUserHaveGitRepoAccess(
  221. next http.Handler,
  222. projLoc IDLocation,
  223. gitRepoLoc IDLocation,
  224. ) http.Handler {
  225. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  226. grID, err := findGitRepoIDInRequest(r, gitRepoLoc)
  227. if err != nil {
  228. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  229. return
  230. }
  231. projID, err := findProjIDInRequest(r, projLoc)
  232. if err != nil {
  233. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  234. return
  235. }
  236. // get the service accounts belonging to the project
  237. grs, err := auth.repo.GitRepo.ListGitReposByProjectID(uint(projID))
  238. if err != nil {
  239. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  240. return
  241. }
  242. doesExist := false
  243. for _, gr := range grs {
  244. if gr.ID == uint(grID) {
  245. doesExist = true
  246. break
  247. }
  248. }
  249. if doesExist {
  250. next.ServeHTTP(w, r)
  251. return
  252. }
  253. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  254. return
  255. })
  256. }
  257. // Helpers
  258. func (auth *Auth) doesSessionMatchID(r *http.Request, id uint) bool {
  259. session, _ := auth.store.Get(r, auth.cookieName)
  260. if sessID, ok := session.Values["user_id"].(uint); !ok || sessID != id {
  261. return false
  262. }
  263. return true
  264. }
  265. func (auth *Auth) isLoggedIn(w http.ResponseWriter, r *http.Request) bool {
  266. session, err := auth.store.Get(r, auth.cookieName)
  267. if err != nil {
  268. session.Values["authenticated"] = false
  269. if err := session.Save(r, w); err != nil {
  270. fmt.Println("error while saving session in isLoggedIn", err)
  271. }
  272. return false
  273. }
  274. if auth, ok := session.Values["authenticated"].(bool); !auth || !ok {
  275. return false
  276. }
  277. return true
  278. }
  279. func findUserIDInRequest(r *http.Request, userLoc IDLocation) (uint64, error) {
  280. var userID uint64
  281. var err error
  282. if userLoc == URLParam {
  283. userID, err = strconv.ParseUint(chi.URLParam(r, "user_id"), 0, 64)
  284. if err != nil {
  285. return 0, err
  286. }
  287. } else if userLoc == BodyParam {
  288. form := &bodyUserID{}
  289. body, err := ioutil.ReadAll(r.Body)
  290. if err != nil {
  291. return 0, err
  292. }
  293. err = json.Unmarshal(body, form)
  294. if err != nil {
  295. return 0, err
  296. }
  297. userID = form.UserID
  298. // need to create a new stream for the body
  299. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  300. } else {
  301. vals, err := url.ParseQuery(r.URL.RawQuery)
  302. if err != nil {
  303. return 0, err
  304. }
  305. if userStrArr, ok := vals["user_id"]; ok && len(userStrArr) == 1 {
  306. userID, err = strconv.ParseUint(userStrArr[0], 10, 64)
  307. } else {
  308. return 0, errors.New("user id not found")
  309. }
  310. }
  311. return userID, nil
  312. }
  313. func findProjIDInRequest(r *http.Request, projLoc IDLocation) (uint64, error) {
  314. var projID uint64
  315. var err error
  316. if projLoc == URLParam {
  317. projID, err = strconv.ParseUint(chi.URLParam(r, "project_id"), 0, 64)
  318. if err != nil {
  319. return 0, err
  320. }
  321. } else if projLoc == BodyParam {
  322. form := &bodyProjectID{}
  323. body, err := ioutil.ReadAll(r.Body)
  324. if err != nil {
  325. return 0, err
  326. }
  327. err = json.Unmarshal(body, form)
  328. if err != nil {
  329. return 0, err
  330. }
  331. projID = form.ProjectID
  332. // need to create a new stream for the body
  333. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  334. } else {
  335. vals, err := url.ParseQuery(r.URL.RawQuery)
  336. if err != nil {
  337. return 0, err
  338. }
  339. if projStrArr, ok := vals["project_id"]; ok && len(projStrArr) == 1 {
  340. projID, err = strconv.ParseUint(projStrArr[0], 10, 64)
  341. } else {
  342. return 0, errors.New("project id not found")
  343. }
  344. }
  345. return projID, nil
  346. }
  347. func findClusterIDInRequest(r *http.Request, clusterLoc IDLocation) (uint64, error) {
  348. var clusterID uint64
  349. var err error
  350. if clusterLoc == URLParam {
  351. clusterID, err = strconv.ParseUint(chi.URLParam(r, "cluster_id"), 0, 64)
  352. if err != nil {
  353. return 0, err
  354. }
  355. } else if clusterLoc == BodyParam {
  356. form := &bodyClusterID{}
  357. body, err := ioutil.ReadAll(r.Body)
  358. if err != nil {
  359. return 0, err
  360. }
  361. err = json.Unmarshal(body, form)
  362. if err != nil {
  363. return 0, err
  364. }
  365. clusterID = form.ClusterID
  366. // need to create a new stream for the body
  367. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  368. } else {
  369. vals, err := url.ParseQuery(r.URL.RawQuery)
  370. if err != nil {
  371. return 0, err
  372. }
  373. if clStrArr, ok := vals["cluster_id"]; ok && len(clStrArr) == 1 {
  374. clusterID, err = strconv.ParseUint(clStrArr[0], 10, 64)
  375. } else {
  376. return 0, errors.New("cluster id not found")
  377. }
  378. }
  379. return clusterID, nil
  380. }
  381. func findRegistryIDInRequest(r *http.Request, registryLoc IDLocation) (uint64, error) {
  382. var regID uint64
  383. var err error
  384. if registryLoc == URLParam {
  385. regID, err = strconv.ParseUint(chi.URLParam(r, "registry_id"), 0, 64)
  386. if err != nil {
  387. return 0, err
  388. }
  389. } else if registryLoc == BodyParam {
  390. form := &bodyRegistryID{}
  391. body, err := ioutil.ReadAll(r.Body)
  392. if err != nil {
  393. return 0, err
  394. }
  395. err = json.Unmarshal(body, form)
  396. if err != nil {
  397. return 0, err
  398. }
  399. regID = form.RegistryID
  400. // need to create a new stream for the body
  401. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  402. } else {
  403. vals, err := url.ParseQuery(r.URL.RawQuery)
  404. if err != nil {
  405. return 0, err
  406. }
  407. if regStrArr, ok := vals["registry_id"]; ok && len(regStrArr) == 1 {
  408. regID, err = strconv.ParseUint(regStrArr[0], 10, 64)
  409. } else {
  410. return 0, errors.New("registry id not found")
  411. }
  412. }
  413. return regID, nil
  414. }
  415. func findGitRepoIDInRequest(r *http.Request, gitRepoLoc IDLocation) (uint64, error) {
  416. var grID uint64
  417. var err error
  418. if gitRepoLoc == URLParam {
  419. grID, err = strconv.ParseUint(chi.URLParam(r, "git_repo_id"), 0, 64)
  420. if err != nil {
  421. return 0, err
  422. }
  423. } else if gitRepoLoc == BodyParam {
  424. form := &bodyGitRepoID{}
  425. body, err := ioutil.ReadAll(r.Body)
  426. if err != nil {
  427. return 0, err
  428. }
  429. err = json.Unmarshal(body, form)
  430. if err != nil {
  431. return 0, err
  432. }
  433. grID = form.GitRepoID
  434. // need to create a new stream for the body
  435. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  436. } else {
  437. vals, err := url.ParseQuery(r.URL.RawQuery)
  438. if err != nil {
  439. return 0, err
  440. }
  441. if regStrArr, ok := vals["git_repo_id"]; ok && len(regStrArr) == 1 {
  442. grID, err = strconv.ParseUint(regStrArr[0], 10, 64)
  443. } else {
  444. return 0, errors.New("git repo id not found")
  445. }
  446. }
  447. return grID, nil
  448. }