Kaynağa Gözat

jwt token checked in auth

Alexander Belanger 5 yıl önce
ebeveyn
işleme
dabff26a1b

+ 9 - 9
internal/auth/token/token.go

@@ -27,7 +27,7 @@ type Token struct {
 	IAt       *time.Time
 }
 
-func (t *Token) GetTokenForUser(userID, projID uint) (*Token, error) {
+func GetTokenForUser(userID, projID uint) (*Token, error) {
 	if userID == 0 || projID == 0 {
 		return nil, fmt.Errorf("id cannot be 0")
 	}
@@ -43,7 +43,7 @@ func (t *Token) GetTokenForUser(userID, projID uint) (*Token, error) {
 	}, nil
 }
 
-func (t *Token) GetTokenForAPI(userID, projID uint) (*Token, error) {
+func GetTokenForAPI(userID, projID uint) (*Token, error) {
 	if userID == 0 || projID == 0 {
 		return nil, fmt.Errorf("id cannot be 0")
 	}
@@ -64,12 +64,12 @@ func (t *Token) EncodeToken(conf *TokenGeneratorConf) (string, error) {
 		"sub_kind":   t.SubKind,
 		"sub":        t.Sub,
 		"iby":        t.IBy,
-		"iat":        t.IAt.Unix(),
+		"iat":        fmt.Sprintf("%d", t.IAt.Unix()),
 		"project_id": t.ProjectID,
 	})
 
 	// Sign and get the complete encoded token as a string using the secret
-	return token.SignedString(conf.TokenSecret)
+	return token.SignedString([]byte(conf.TokenSecret))
 }
 
 func GetTokenFromEncoded(tokenString string, conf *TokenGeneratorConf) (*Token, error) {
@@ -78,30 +78,30 @@ func GetTokenFromEncoded(tokenString string, conf *TokenGeneratorConf) (*Token,
 			return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
 		}
 
-		return conf.TokenSecret, nil
+		return []byte(conf.TokenSecret), nil
 	})
 
 	if err != nil {
-		return nil, fmt.Errorf("could not parse token")
+		return nil, fmt.Errorf("could not parse token: %v", err)
 	}
 
 	if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
 		iby, err := strconv.ParseUint(fmt.Sprintf("%v", claims["iby"]), 10, 64)
 
 		if err != nil {
-			return nil, fmt.Errorf("invalid iby claim")
+			return nil, fmt.Errorf("invalid iby claim: %v", err)
 		}
 
 		projID, err := strconv.ParseUint(fmt.Sprintf("%v", claims["project_id"]), 10, 64)
 
 		if err != nil {
-			return nil, fmt.Errorf("invalid iby claim")
+			return nil, fmt.Errorf("invalid project_id claim: %v", err)
 		}
 
 		iatUnix, err := strconv.ParseInt(fmt.Sprintf("%v", claims["iat"]), 10, 64)
 
 		if err != nil {
-			return nil, fmt.Errorf("invalid iby claim")
+			return nil, fmt.Errorf("invalid iat claim: %v", err)
 		}
 
 		iat := time.Unix(iatUnix, 0)

+ 52 - 0
internal/auth/token/token_test.go

@@ -0,0 +1,52 @@
+package token_test
+
+import (
+	"testing"
+	"time"
+
+	"github.com/go-test/deep"
+	"github.com/porter-dev/porter/internal/auth/token"
+)
+
+func TestGetAndEncodeTokenForUser(t *testing.T) {
+	conf := &token.TokenGeneratorConf{
+		TokenSecret: "fakesecret",
+	}
+
+	tok, err := token.GetTokenForUser(1, 1)
+
+	if err != nil {
+		t.Fatalf("%v\n", err)
+	}
+
+	tokString, err := tok.EncodeToken(conf)
+
+	if err != nil {
+		t.Fatalf("%v\n", err)
+	}
+
+	// decode the token again and compare
+	expToken := &token.Token{
+		SubKind:   token.User,
+		Sub:       "1",
+		ProjectID: 1,
+		IBy:       1,
+	}
+
+	gotToken, err := token.GetTokenFromEncoded(tokString, conf)
+
+	if err != nil {
+		t.Fatalf("%v\n", err)
+	}
+
+	if now := time.Now(); now.Sub(*gotToken.IAt) < 5 && now.Sub(*gotToken.IAt) >= 0 {
+		t.Fatalf("time not within threshold: issued at %d, current time %d\n", gotToken.IAt.Unix(), now.Unix())
+	}
+
+	gotToken.IAt = nil
+
+	if diff := deep.Equal(expToken, gotToken); diff != nil {
+		t.Errorf("tokens not equal:")
+		t.Error(diff)
+	}
+}

+ 39 - 1
server/router/middleware/auth.go

@@ -9,9 +9,11 @@ import (
 	"net/http"
 	"net/url"
 	"strconv"
+	"strings"
 
 	"github.com/go-chi/chi"
 	"github.com/gorilla/sessions"
+	"github.com/porter-dev/porter/internal/auth/token"
 	"github.com/porter-dev/porter/internal/models"
 	"github.com/porter-dev/porter/internal/repository"
 )
@@ -20,6 +22,7 @@ import (
 type Auth struct {
 	store      sessions.Store
 	cookieName string
+	tokenConf  *token.TokenGeneratorConf
 	repo       *repository.Repository
 }
 
@@ -27,9 +30,10 @@ type Auth struct {
 func NewAuth(
 	store sessions.Store,
 	cookieName string,
+	tokenConf *token.TokenGeneratorConf,
 	repo *repository.Repository,
 ) *Auth {
-	return &Auth{store, cookieName, repo}
+	return &Auth{store, cookieName, tokenConf, repo}
 }
 
 // BasicAuthenticate just checks that a user is logged in
@@ -98,6 +102,14 @@ type bodyDOIntegrationID struct {
 // the one stored in the session
 func (auth *Auth) DoesUserIDMatch(next http.Handler, loc IDLocation) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		// first check for token
+		tok := auth.getTokenFromRequest(r)
+
+		if tok != nil && tok.SubKind == token.User && auth.doesSessionMatchID(r, tok.IBy) {
+			next.ServeHTTP(w, r)
+			return
+		}
+
 		var err error
 		id, err := findUserIDInRequest(r, loc)
 
@@ -137,6 +149,14 @@ func (auth *Auth) DoesUserHaveProjectAccess(
 			return
 		}
 
+		// first check for token
+		tok := auth.getTokenFromRequest(r)
+
+		if tok != nil && tok.ProjectID == uint(projID) {
+			next.ServeHTTP(w, r)
+			return
+		}
+
 		session, err := auth.store.Get(r, auth.cookieName)
 
 		if err != nil {
@@ -556,6 +576,8 @@ func (auth *Auth) doesSessionMatchID(r *http.Request, id uint) bool {
 }
 
 func (auth *Auth) isLoggedIn(w http.ResponseWriter, r *http.Request) bool {
+	// first check for Bearer token
+
 	session, err := auth.store.Get(r, auth.cookieName)
 	if err != nil {
 		session.Values["authenticated"] = false
@@ -571,6 +593,22 @@ func (auth *Auth) isLoggedIn(w http.ResponseWriter, r *http.Request) bool {
 	return true
 }
 
+func (auth *Auth) getTokenFromRequest(r *http.Request) *token.Token {
+	reqToken := r.Header.Get("Authorization")
+
+	splitToken := strings.Split(reqToken, "Bearer")
+
+	if len(splitToken) != 2 {
+		return nil
+	}
+
+	reqToken = strings.TrimSpace(splitToken[1])
+
+	tok, _ := token.GetTokenFromEncoded(reqToken, auth.tokenConf)
+
+	return tok
+}
+
 func findUserIDInRequest(r *http.Request, userLoc IDLocation) (uint64, error) {
 	var userID uint64
 	var err error