auth.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. package middleware
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io/ioutil"
  8. "net/http"
  9. "net/url"
  10. "strconv"
  11. "github.com/go-chi/chi"
  12. "github.com/gorilla/sessions"
  13. "github.com/porter-dev/porter/internal/models"
  14. "github.com/porter-dev/porter/internal/repository"
  15. )
  16. // Auth implements the authorization functions
  17. type Auth struct {
  18. store sessions.Store
  19. cookieName string
  20. repo *repository.Repository
  21. }
  22. // NewAuth returns a new Auth instance
  23. func NewAuth(
  24. store sessions.Store,
  25. cookieName string,
  26. repo *repository.Repository,
  27. ) *Auth {
  28. return &Auth{store, cookieName, repo}
  29. }
  30. // BasicAuthenticate just checks that a user is logged in
  31. func (auth *Auth) BasicAuthenticate(next http.Handler) http.Handler {
  32. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  33. if auth.isLoggedIn(w, r) {
  34. next.ServeHTTP(w, r)
  35. } else {
  36. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  37. return
  38. }
  39. return
  40. })
  41. }
  42. // IDLocation represents the location of the ID to use for authentication
  43. type IDLocation uint
  44. const (
  45. // URLParam location looks for a parameter in the URL endpoint
  46. URLParam IDLocation = iota
  47. // BodyParam location looks for a parameter in the body
  48. BodyParam
  49. // QueryParam location looks for a parameter in the query string
  50. QueryParam
  51. )
  52. type bodyUserID struct {
  53. UserID uint64 `json:"user_id"`
  54. }
  55. type bodyProjectID struct {
  56. ProjectID uint64 `json:"project_id"`
  57. }
  58. type bodyClusterID struct {
  59. ClusterID uint64 `json:"cluster_id"`
  60. }
  61. type bodyRegistryID struct {
  62. RegistryID uint64 `json:"registry_id"`
  63. }
  64. // DoesUserIDMatch checks the id URL parameter and verifies that it matches
  65. // the one stored in the session
  66. func (auth *Auth) DoesUserIDMatch(next http.Handler, loc IDLocation) http.Handler {
  67. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  68. var err error
  69. id, err := findUserIDInRequest(r, loc)
  70. if err == nil && auth.doesSessionMatchID(r, uint(id)) {
  71. next.ServeHTTP(w, r)
  72. } else {
  73. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  74. return
  75. }
  76. return
  77. })
  78. }
  79. // AccessType represents the various access types for a project
  80. type AccessType string
  81. // The various access types
  82. const (
  83. ReadAccess AccessType = "read"
  84. WriteAccess AccessType = "write"
  85. )
  86. // DoesUserHaveProjectAccess looks for a project_id parameter and checks that the
  87. // user has access via the specified accessType
  88. func (auth *Auth) DoesUserHaveProjectAccess(
  89. next http.Handler,
  90. projLoc IDLocation,
  91. accessType AccessType,
  92. ) http.Handler {
  93. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  94. var err error
  95. projID, err := findProjIDInRequest(r, projLoc)
  96. if err != nil {
  97. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  98. return
  99. }
  100. session, err := auth.store.Get(r, auth.cookieName)
  101. if err != nil {
  102. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  103. return
  104. }
  105. userID, ok := session.Values["user_id"].(uint)
  106. if !ok {
  107. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  108. return
  109. }
  110. // get the project
  111. proj, err := auth.repo.Project.ReadProject(uint(projID))
  112. if err != nil {
  113. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  114. return
  115. }
  116. // look for the user role in the project
  117. for _, role := range proj.Roles {
  118. if role.UserID == userID {
  119. if accessType == ReadAccess {
  120. next.ServeHTTP(w, r)
  121. return
  122. } else if accessType == WriteAccess {
  123. if role.Kind == models.RoleAdmin {
  124. next.ServeHTTP(w, r)
  125. return
  126. }
  127. }
  128. }
  129. }
  130. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  131. return
  132. })
  133. }
  134. // DoesUserHaveClusterAccess looks for a project_id parameter and a
  135. // cluster_id parameter, and verifies that the cluster belongs
  136. // to the project
  137. func (auth *Auth) DoesUserHaveClusterAccess(
  138. next http.Handler,
  139. projLoc IDLocation,
  140. clusterLoc IDLocation,
  141. ) http.Handler {
  142. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  143. clusterID, err := findClusterIDInRequest(r, clusterLoc)
  144. if err != nil {
  145. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  146. return
  147. }
  148. projID, err := findProjIDInRequest(r, projLoc)
  149. if err != nil {
  150. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  151. return
  152. }
  153. // get the service accounts belonging to the project
  154. clusters, err := auth.repo.Cluster.ListClustersByProjectID(uint(projID))
  155. if err != nil {
  156. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  157. return
  158. }
  159. doesExist := false
  160. for _, cluster := range clusters {
  161. if cluster.ID == uint(clusterID) {
  162. doesExist = true
  163. break
  164. }
  165. }
  166. if doesExist {
  167. next.ServeHTTP(w, r)
  168. return
  169. }
  170. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  171. return
  172. })
  173. }
  174. // DoesUserHaveRegistryAccess looks for a project_id parameter and a
  175. // registry_id parameter, and verifies that the registry belongs
  176. // to the project
  177. func (auth *Auth) DoesUserHaveRegistryAccess(
  178. next http.Handler,
  179. projLoc IDLocation,
  180. registryLoc IDLocation,
  181. ) http.Handler {
  182. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  183. regID, err := findRegistryIDInRequest(r, registryLoc)
  184. if err != nil {
  185. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  186. return
  187. }
  188. projID, err := findProjIDInRequest(r, projLoc)
  189. if err != nil {
  190. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  191. return
  192. }
  193. // get the service accounts belonging to the project
  194. regs, err := auth.repo.Registry.ListRegistriesByProjectID(uint(projID))
  195. if err != nil {
  196. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  197. return
  198. }
  199. doesExist := false
  200. for _, reg := range regs {
  201. if reg.ID == uint(regID) {
  202. doesExist = true
  203. break
  204. }
  205. }
  206. if doesExist {
  207. next.ServeHTTP(w, r)
  208. return
  209. }
  210. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  211. return
  212. })
  213. }
  214. // Helpers
  215. func (auth *Auth) doesSessionMatchID(r *http.Request, id uint) bool {
  216. session, _ := auth.store.Get(r, auth.cookieName)
  217. if sessID, ok := session.Values["user_id"].(uint); !ok || sessID != id {
  218. return false
  219. }
  220. return true
  221. }
  222. func (auth *Auth) isLoggedIn(w http.ResponseWriter, r *http.Request) bool {
  223. session, err := auth.store.Get(r, auth.cookieName)
  224. if err != nil {
  225. session.Values["authenticated"] = false
  226. if err := session.Save(r, w); err != nil {
  227. fmt.Println("error while saving session in isLoggedIn", err)
  228. }
  229. return false
  230. }
  231. if auth, ok := session.Values["authenticated"].(bool); !auth || !ok {
  232. return false
  233. }
  234. return true
  235. }
  236. func findUserIDInRequest(r *http.Request, userLoc IDLocation) (uint64, error) {
  237. var userID uint64
  238. var err error
  239. if userLoc == URLParam {
  240. userID, err = strconv.ParseUint(chi.URLParam(r, "user_id"), 0, 64)
  241. if err != nil {
  242. return 0, err
  243. }
  244. } else if userLoc == BodyParam {
  245. form := &bodyUserID{}
  246. body, err := ioutil.ReadAll(r.Body)
  247. if err != nil {
  248. return 0, err
  249. }
  250. err = json.Unmarshal(body, form)
  251. if err != nil {
  252. return 0, err
  253. }
  254. userID = form.UserID
  255. // need to create a new stream for the body
  256. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  257. } else {
  258. vals, err := url.ParseQuery(r.URL.RawQuery)
  259. if err != nil {
  260. return 0, err
  261. }
  262. if userStrArr, ok := vals["user_id"]; ok && len(userStrArr) == 1 {
  263. userID, err = strconv.ParseUint(userStrArr[0], 10, 64)
  264. } else {
  265. return 0, errors.New("user id not found")
  266. }
  267. }
  268. return userID, nil
  269. }
  270. func findProjIDInRequest(r *http.Request, projLoc IDLocation) (uint64, error) {
  271. var projID uint64
  272. var err error
  273. if projLoc == URLParam {
  274. projID, err = strconv.ParseUint(chi.URLParam(r, "project_id"), 0, 64)
  275. if err != nil {
  276. return 0, err
  277. }
  278. } else if projLoc == BodyParam {
  279. form := &bodyProjectID{}
  280. body, err := ioutil.ReadAll(r.Body)
  281. if err != nil {
  282. return 0, err
  283. }
  284. err = json.Unmarshal(body, form)
  285. if err != nil {
  286. return 0, err
  287. }
  288. projID = form.ProjectID
  289. // need to create a new stream for the body
  290. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  291. } else {
  292. vals, err := url.ParseQuery(r.URL.RawQuery)
  293. if err != nil {
  294. return 0, err
  295. }
  296. if projStrArr, ok := vals["project_id"]; ok && len(projStrArr) == 1 {
  297. projID, err = strconv.ParseUint(projStrArr[0], 10, 64)
  298. } else {
  299. return 0, errors.New("project id not found")
  300. }
  301. }
  302. return projID, nil
  303. }
  304. func findClusterIDInRequest(r *http.Request, clusterLoc IDLocation) (uint64, error) {
  305. var clusterID uint64
  306. var err error
  307. if clusterLoc == URLParam {
  308. clusterID, err = strconv.ParseUint(chi.URLParam(r, "cluster_id"), 0, 64)
  309. if err != nil {
  310. return 0, err
  311. }
  312. } else if clusterLoc == BodyParam {
  313. form := &bodyClusterID{}
  314. body, err := ioutil.ReadAll(r.Body)
  315. if err != nil {
  316. return 0, err
  317. }
  318. err = json.Unmarshal(body, form)
  319. if err != nil {
  320. return 0, err
  321. }
  322. clusterID = form.ClusterID
  323. // need to create a new stream for the body
  324. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  325. } else {
  326. vals, err := url.ParseQuery(r.URL.RawQuery)
  327. if err != nil {
  328. return 0, err
  329. }
  330. if clStrArr, ok := vals["cluster_id"]; ok && len(clStrArr) == 1 {
  331. clusterID, err = strconv.ParseUint(clStrArr[0], 10, 64)
  332. } else {
  333. return 0, errors.New("cluster id not found")
  334. }
  335. }
  336. return clusterID, nil
  337. }
  338. func findRegistryIDInRequest(r *http.Request, registryLoc IDLocation) (uint64, error) {
  339. var regID uint64
  340. var err error
  341. if registryLoc == URLParam {
  342. regID, err = strconv.ParseUint(chi.URLParam(r, "registry_id"), 0, 64)
  343. if err != nil {
  344. return 0, err
  345. }
  346. } else if registryLoc == BodyParam {
  347. form := &bodyRegistryID{}
  348. body, err := ioutil.ReadAll(r.Body)
  349. if err != nil {
  350. return 0, err
  351. }
  352. err = json.Unmarshal(body, form)
  353. if err != nil {
  354. return 0, err
  355. }
  356. regID = form.RegistryID
  357. // need to create a new stream for the body
  358. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  359. } else {
  360. vals, err := url.ParseQuery(r.URL.RawQuery)
  361. if err != nil {
  362. return 0, err
  363. }
  364. if regStrArr, ok := vals["registry_id"]; ok && len(regStrArr) == 1 {
  365. regID, err = strconv.ParseUint(regStrArr[0], 10, 64)
  366. } else {
  367. return 0, errors.New("registry id not found")
  368. }
  369. }
  370. return regID, nil
  371. }