|
|
@@ -10,20 +10,23 @@ import (
|
|
|
|
|
|
"github.com/go-chi/chi"
|
|
|
"github.com/gorilla/sessions"
|
|
|
+ "github.com/porter-dev/porter/internal/repository"
|
|
|
)
|
|
|
|
|
|
// Auth implements the authorization functions
|
|
|
type Auth struct {
|
|
|
store sessions.Store
|
|
|
cookieName string
|
|
|
+ repo *repository.Repository
|
|
|
}
|
|
|
|
|
|
// NewAuth returns a new Auth instance
|
|
|
func NewAuth(
|
|
|
store sessions.Store,
|
|
|
cookieName string,
|
|
|
+ repo *repository.Repository,
|
|
|
) *Auth {
|
|
|
- return &Auth{store, cookieName}
|
|
|
+ return &Auth{store, cookieName, repo}
|
|
|
}
|
|
|
|
|
|
// BasicAuthenticate just checks that a user is logged in
|
|
|
@@ -50,37 +53,74 @@ const (
|
|
|
BodyParam
|
|
|
)
|
|
|
|
|
|
-type bodyID struct {
|
|
|
+type bodyUserID struct {
|
|
|
UserID uint64 `json:"user_id"`
|
|
|
}
|
|
|
|
|
|
+type bodyProjectID struct {
|
|
|
+ ProjectID uint64 `json:"project_id"`
|
|
|
+}
|
|
|
+
|
|
|
// DoesUserIDMatch checks the id URL parameter and verifies that it matches
|
|
|
// the one stored in the session
|
|
|
func (auth *Auth) DoesUserIDMatch(next http.Handler, loc IDLocation) http.Handler {
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
- var id uint64
|
|
|
var err error
|
|
|
+ id := findUserIDInRequest(r, loc)
|
|
|
|
|
|
- if loc == URLParam {
|
|
|
- id, err = strconv.ParseUint(chi.URLParam(r, "id"), 0, 64)
|
|
|
- } else if loc == BodyParam {
|
|
|
- form := &bodyID{}
|
|
|
- body, _ := ioutil.ReadAll(r.Body)
|
|
|
- err = json.Unmarshal(body, form)
|
|
|
+ if err == nil && auth.doesSessionMatchID(r, uint(id)) {
|
|
|
+ next.ServeHTTP(w, r)
|
|
|
+ } else {
|
|
|
+ http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- id = form.UserID
|
|
|
+ return
|
|
|
+ })
|
|
|
+}
|
|
|
|
|
|
- // need to create a new stream for the body
|
|
|
- r.Body = ioutil.NopCloser(bytes.NewReader(body))
|
|
|
+// DoesUserHaveProjectReadAccess looks for a project_id parameter and checks that the
|
|
|
+// user has access to read that project
|
|
|
+func (auth *Auth) DoesUserHaveProjectReadAccess(
|
|
|
+ next http.Handler,
|
|
|
+ userLoc IDLocation,
|
|
|
+ projLoc IDLocation,
|
|
|
+) http.Handler {
|
|
|
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ var err error
|
|
|
+ projID := uint(findProjIDInRequest(r, projLoc))
|
|
|
+
|
|
|
+ session, err := auth.store.Get(r, auth.cookieName)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
- if err == nil && auth.doesSessionMatchID(r, uint(id)) {
|
|
|
- next.ServeHTTP(w, r)
|
|
|
- } else {
|
|
|
+ userID, ok := session.Values["user_id"].(uint)
|
|
|
+
|
|
|
+ if !ok {
|
|
|
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ // get the project
|
|
|
+ proj, err := auth.repo.Project.ReadProject(projID)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // look for the user role in the project
|
|
|
+ for _, role := range proj.Roles {
|
|
|
+ if role.UserID == userID {
|
|
|
+ next.ServeHTTP(w, r)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
|
|
return
|
|
|
})
|
|
|
}
|
|
|
@@ -111,3 +151,41 @@ func (auth *Auth) isLoggedIn(w http.ResponseWriter, r *http.Request) bool {
|
|
|
}
|
|
|
return true
|
|
|
}
|
|
|
+
|
|
|
+func findUserIDInRequest(r *http.Request, userLoc IDLocation) uint64 {
|
|
|
+ var userID uint64
|
|
|
+
|
|
|
+ if userLoc == URLParam {
|
|
|
+ userID, _ = strconv.ParseUint(chi.URLParam(r, "id"), 0, 64)
|
|
|
+ } else if userLoc == BodyParam {
|
|
|
+ form := &bodyUserID{}
|
|
|
+ body, _ := ioutil.ReadAll(r.Body)
|
|
|
+ _ = json.Unmarshal(body, form)
|
|
|
+
|
|
|
+ userID = form.UserID
|
|
|
+
|
|
|
+ // need to create a new stream for the body
|
|
|
+ r.Body = ioutil.NopCloser(bytes.NewReader(body))
|
|
|
+ }
|
|
|
+
|
|
|
+ return userID
|
|
|
+}
|
|
|
+
|
|
|
+func findProjIDInRequest(r *http.Request, projLoc IDLocation) uint64 {
|
|
|
+ var projID uint64
|
|
|
+
|
|
|
+ if projLoc == URLParam {
|
|
|
+ projID, _ = strconv.ParseUint(chi.URLParam(r, "id"), 0, 64)
|
|
|
+ } else if projLoc == BodyParam {
|
|
|
+ form := &bodyProjectID{}
|
|
|
+ body, _ := ioutil.ReadAll(r.Body)
|
|
|
+ _ = json.Unmarshal(body, form)
|
|
|
+
|
|
|
+ projID = form.ProjectID
|
|
|
+
|
|
|
+ // need to create a new stream for the body
|
|
|
+ r.Body = ioutil.NopCloser(bytes.NewReader(body))
|
|
|
+ }
|
|
|
+
|
|
|
+ return projID
|
|
|
+}
|