auth.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. package middleware
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io/ioutil"
  8. "net/http"
  9. "net/url"
  10. "strconv"
  11. "github.com/go-chi/chi"
  12. "github.com/gorilla/sessions"
  13. "github.com/porter-dev/porter/internal/models"
  14. "github.com/porter-dev/porter/internal/repository"
  15. )
  16. // Auth implements the authorization functions
  17. type Auth struct {
  18. store sessions.Store
  19. cookieName string
  20. repo *repository.Repository
  21. }
  22. // NewAuth returns a new Auth instance
  23. func NewAuth(
  24. store sessions.Store,
  25. cookieName string,
  26. repo *repository.Repository,
  27. ) *Auth {
  28. return &Auth{store, cookieName, repo}
  29. }
  30. // BasicAuthenticate just checks that a user is logged in
  31. func (auth *Auth) BasicAuthenticate(next http.Handler) http.Handler {
  32. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  33. if auth.isLoggedIn(w, r) {
  34. next.ServeHTTP(w, r)
  35. } else {
  36. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  37. return
  38. }
  39. return
  40. })
  41. }
  42. // IDLocation represents the location of the ID to use for authentication
  43. type IDLocation uint
  44. const (
  45. // URLParam location looks for a parameter in the URL endpoint
  46. URLParam IDLocation = iota
  47. // BodyParam location looks for a parameter in the body
  48. BodyParam
  49. // QueryParam location looks for a parameter in the query string
  50. QueryParam
  51. )
  52. type bodyUserID struct {
  53. UserID uint64 `json:"user_id"`
  54. }
  55. type bodyProjectID struct {
  56. ProjectID uint64 `json:"project_id"`
  57. }
  58. type bodyServiceAccountID struct {
  59. ServiceAccountID uint64 `json:"service_account_id"`
  60. }
  61. // DoesUserIDMatch checks the id URL parameter and verifies that it matches
  62. // the one stored in the session
  63. func (auth *Auth) DoesUserIDMatch(next http.Handler, loc IDLocation) http.Handler {
  64. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  65. var err error
  66. id, err := findUserIDInRequest(r, loc)
  67. if err == nil && auth.doesSessionMatchID(r, uint(id)) {
  68. next.ServeHTTP(w, r)
  69. } else {
  70. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  71. return
  72. }
  73. return
  74. })
  75. }
  76. // AccessType represents the various access types for a project
  77. type AccessType string
  78. // The various access types
  79. const (
  80. ReadAccess AccessType = "read"
  81. WriteAccess AccessType = "write"
  82. )
  83. // DoesUserHaveProjectAccess looks for a project_id parameter and checks that the
  84. // user has access via the specified accessType
  85. func (auth *Auth) DoesUserHaveProjectAccess(
  86. next http.Handler,
  87. projLoc IDLocation,
  88. accessType AccessType,
  89. ) http.Handler {
  90. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  91. var err error
  92. projID, err := findProjIDInRequest(r, projLoc)
  93. if err != nil {
  94. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  95. return
  96. }
  97. session, err := auth.store.Get(r, auth.cookieName)
  98. if err != nil {
  99. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  100. return
  101. }
  102. userID, ok := session.Values["user_id"].(uint)
  103. if !ok {
  104. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  105. return
  106. }
  107. // get the project
  108. proj, err := auth.repo.Project.ReadProject(uint(projID))
  109. if err != nil {
  110. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  111. return
  112. }
  113. // look for the user role in the project
  114. for _, role := range proj.Roles {
  115. if role.UserID == userID {
  116. if accessType == ReadAccess {
  117. next.ServeHTTP(w, r)
  118. return
  119. } else if accessType == WriteAccess {
  120. if role.Kind == models.RoleAdmin {
  121. next.ServeHTTP(w, r)
  122. return
  123. }
  124. }
  125. }
  126. }
  127. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  128. return
  129. })
  130. }
  131. // DoesUserHaveServiceAccountAccess looks for a project_id parameter and a
  132. // service_account_id parameter, and verifies that the service account belongs
  133. // to the project
  134. func (auth *Auth) DoesUserHaveServiceAccountAccess(
  135. next http.Handler,
  136. projLoc IDLocation,
  137. saLoc IDLocation,
  138. ) http.Handler {
  139. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  140. serviceAccountID, err := findServiceAccountIDInRequest(r, saLoc)
  141. if err != nil {
  142. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  143. return
  144. }
  145. projID, err := findProjIDInRequest(r, projLoc)
  146. if err != nil {
  147. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  148. return
  149. }
  150. // get the service accounts belonging to the project
  151. serviceAccounts, err := auth.repo.ServiceAccount.ListServiceAccountsByProjectID(uint(projID))
  152. if err != nil {
  153. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  154. return
  155. }
  156. doesExist := false
  157. for _, sa := range serviceAccounts {
  158. if sa.ID == uint(serviceAccountID) {
  159. doesExist = true
  160. break
  161. }
  162. }
  163. if doesExist {
  164. next.ServeHTTP(w, r)
  165. return
  166. }
  167. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  168. return
  169. })
  170. }
  171. // Helpers
  172. func (auth *Auth) doesSessionMatchID(r *http.Request, id uint) bool {
  173. session, _ := auth.store.Get(r, auth.cookieName)
  174. if sessID, ok := session.Values["user_id"].(uint); !ok || sessID != id {
  175. return false
  176. }
  177. return true
  178. }
  179. func (auth *Auth) isLoggedIn(w http.ResponseWriter, r *http.Request) bool {
  180. session, err := auth.store.Get(r, auth.cookieName)
  181. if err != nil {
  182. session.Values["authenticated"] = false
  183. if err := session.Save(r, w); err != nil {
  184. fmt.Println("error while saving session in isLoggedIn", err)
  185. }
  186. return false
  187. }
  188. if auth, ok := session.Values["authenticated"].(bool); !auth || !ok {
  189. return false
  190. }
  191. return true
  192. }
  193. func findUserIDInRequest(r *http.Request, userLoc IDLocation) (uint64, error) {
  194. var userID uint64
  195. var err error
  196. if userLoc == URLParam {
  197. userID, err = strconv.ParseUint(chi.URLParam(r, "user_id"), 0, 64)
  198. if err != nil {
  199. return 0, err
  200. }
  201. } else if userLoc == BodyParam {
  202. form := &bodyUserID{}
  203. body, err := ioutil.ReadAll(r.Body)
  204. if err != nil {
  205. return 0, err
  206. }
  207. err = json.Unmarshal(body, form)
  208. if err != nil {
  209. return 0, err
  210. }
  211. userID = form.UserID
  212. // need to create a new stream for the body
  213. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  214. } else {
  215. vals, err := url.ParseQuery(r.URL.RawQuery)
  216. if err != nil {
  217. return 0, err
  218. }
  219. if userStrArr, ok := vals["user_id"]; ok && len(userStrArr) == 1 {
  220. userID, err = strconv.ParseUint(userStrArr[0], 10, 64)
  221. } else {
  222. return 0, errors.New("user id not found")
  223. }
  224. }
  225. return userID, nil
  226. }
  227. func findProjIDInRequest(r *http.Request, projLoc IDLocation) (uint64, error) {
  228. var projID uint64
  229. var err error
  230. if projLoc == URLParam {
  231. projID, err = strconv.ParseUint(chi.URLParam(r, "project_id"), 0, 64)
  232. if err != nil {
  233. return 0, err
  234. }
  235. } else if projLoc == BodyParam {
  236. form := &bodyProjectID{}
  237. body, err := ioutil.ReadAll(r.Body)
  238. if err != nil {
  239. return 0, err
  240. }
  241. err = json.Unmarshal(body, form)
  242. if err != nil {
  243. return 0, err
  244. }
  245. projID = form.ProjectID
  246. // need to create a new stream for the body
  247. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  248. } else {
  249. vals, err := url.ParseQuery(r.URL.RawQuery)
  250. if err != nil {
  251. return 0, err
  252. }
  253. if projStrArr, ok := vals["project_id"]; ok && len(projStrArr) == 1 {
  254. projID, err = strconv.ParseUint(projStrArr[0], 10, 64)
  255. } else {
  256. return 0, errors.New("project id not found")
  257. }
  258. }
  259. return projID, nil
  260. }
  261. func findServiceAccountIDInRequest(r *http.Request, saLoc IDLocation) (uint64, error) {
  262. var saID uint64
  263. var err error
  264. if saLoc == URLParam {
  265. saID, err = strconv.ParseUint(chi.URLParam(r, "service_account_id"), 0, 64)
  266. if err != nil {
  267. return 0, err
  268. }
  269. } else if saLoc == BodyParam {
  270. form := &bodyServiceAccountID{}
  271. body, err := ioutil.ReadAll(r.Body)
  272. if err != nil {
  273. return 0, err
  274. }
  275. err = json.Unmarshal(body, form)
  276. if err != nil {
  277. return 0, err
  278. }
  279. saID = form.ServiceAccountID
  280. // need to create a new stream for the body
  281. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  282. } else {
  283. vals, err := url.ParseQuery(r.URL.RawQuery)
  284. if err != nil {
  285. return 0, err
  286. }
  287. if saStrArr, ok := vals["service_account_id"]; ok && len(saStrArr) == 1 {
  288. saID, err = strconv.ParseUint(saStrArr[0], 10, 64)
  289. } else {
  290. return 0, errors.New("service account id not found")
  291. }
  292. }
  293. return saID, nil
  294. }