auth.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package middleware
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "io/ioutil"
  6. "net/http"
  7. "strconv"
  8. "github.com/go-chi/chi"
  9. "github.com/gorilla/sessions"
  10. )
  11. // Auth implements the authorization functions
  12. type Auth struct {
  13. store sessions.Store
  14. cookieName string
  15. }
  16. // NewAuth returns a new Auth instance
  17. func NewAuth(
  18. store sessions.Store,
  19. cookieName string,
  20. ) *Auth {
  21. return &Auth{store, cookieName}
  22. }
  23. // BasicAuthenticate just checks that a user is logged in
  24. func (auth *Auth) BasicAuthenticate(next http.Handler) http.Handler {
  25. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  26. if auth.isLoggedIn(r) {
  27. next.ServeHTTP(w, r)
  28. } else {
  29. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  30. return
  31. }
  32. return
  33. })
  34. }
  35. // IDLocation represents the location of the ID to use for authentication
  36. type IDLocation uint
  37. const (
  38. // URLParam location looks for {id} in the URL
  39. URLParam IDLocation = iota
  40. // BodyParam location looks for user_id in the body
  41. BodyParam
  42. )
  43. type bodyID struct {
  44. UserID uint64 `json:"user_id"`
  45. }
  46. // DoesUserIDMatch checks the id URL parameter and verifies that it matches
  47. // the one stored in the session
  48. func (auth *Auth) DoesUserIDMatch(next http.Handler, loc IDLocation) http.Handler {
  49. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  50. var id uint64
  51. var err error
  52. if loc == URLParam {
  53. id, err = strconv.ParseUint(chi.URLParam(r, "id"), 0, 64)
  54. } else if loc == BodyParam {
  55. form := &bodyID{}
  56. body, _ := ioutil.ReadAll(r.Body)
  57. err = json.Unmarshal(body, form)
  58. id = form.UserID
  59. // need to create a new stream for the body
  60. r.Body = ioutil.NopCloser(bytes.NewReader(body))
  61. }
  62. if err == nil && auth.doesSessionMatchID(r, uint(id)) {
  63. next.ServeHTTP(w, r)
  64. } else {
  65. http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
  66. return
  67. }
  68. return
  69. })
  70. }
  71. // Helpers
  72. func (auth *Auth) doesSessionMatchID(r *http.Request, id uint) bool {
  73. session, _ := auth.store.Get(r, auth.cookieName)
  74. if sessID, ok := session.Values["user_id"].(uint); !ok || sessID != id {
  75. return false
  76. }
  77. return true
  78. }
  79. func (auth *Auth) isLoggedIn(r *http.Request) bool {
  80. session, _ := auth.store.Get(r, auth.cookieName)
  81. if auth, ok := session.Values["authenticated"].(bool); !auth || !ok {
  82. return false
  83. }
  84. return true
  85. }