Переглянути джерело

test repository for sessions

Alexander Belanger 5 роки тому
батько
коміт
3d59b98c16

+ 3 - 3
cmd/app/main.go

@@ -10,7 +10,7 @@ import (
 	"github.com/porter-dev/porter/server/api"
 
 	adapter "github.com/porter-dev/porter/internal/adapter"
-	sessionstore "github.com/porter-dev/porter/internal/auth/"
+	sessionstore "github.com/porter-dev/porter/internal/auth"
 	"github.com/porter-dev/porter/internal/config"
 	lr "github.com/porter-dev/porter/internal/logger"
 	vr "github.com/porter-dev/porter/internal/validator"
@@ -28,8 +28,8 @@ func main() {
 		return
 	}
 
-	key = []byte("secret") // TODO: change to os.Getenv("SESSION_KEY")
-	store, _ = sessionstore.NewStore(db, key)
+	key := []byte("secret") // TODO: change to os.Getenv("SESSION_KEY")
+	store, _ := sessionstore.NewStore(db, key)
 
 	validator := vr.New()
 	repo := gorm.NewRepository(db)

+ 30 - 33
internal/auth/sessionstore.go

@@ -9,14 +9,12 @@ import (
 	"strings"
 	"time"
 
-	"gorm.io/gorm"
-
 	"github.com/gorilla/securecookie"
 	"github.com/gorilla/sessions"
 	"github.com/pkg/errors"
 
 	"github.com/porter-dev/porter/internal/models"
-	rp "github.com/porter-dev/porter/internal/repository/gorm"
+	"github.com/porter-dev/porter/internal/repository"
 )
 
 // structs
@@ -26,7 +24,7 @@ type PGStore struct {
 	Codecs  []securecookie.Codec
 	Options *sessions.Options
 	Path    string
-	DbPool  *gorm.DB
+	Repo    *repository.Repository
 }
 
 // Helpers
@@ -35,8 +33,8 @@ type PGStore struct {
 // If l is 0 there is no limit to the size of a session, use with caution.
 // The default for a new PGStore is 4096. PostgreSQL allows for max
 // value sizes of up to 1GB (http://www.postgresql.org/docs/current/interactive/datatype-character.html)
-func (db *PGStore) MaxLength(l int) {
-	for _, c := range db.Codecs {
+func (store *PGStore) MaxLength(l int) {
+	for _, c := range store.Codecs {
 		if codec, ok := c.(*securecookie.SecureCookie); ok {
 			codec.MaxLength(l)
 		}
@@ -46,11 +44,11 @@ func (db *PGStore) MaxLength(l int) {
 // MaxAge sets the maximum age for the store and the underlying cookie
 // implementation. Individual sessions can be deleted by setting Options.MaxAge
 // = -1 for that session.
-func (db *PGStore) MaxAge(age int) {
-	db.Options.MaxAge = age
+func (store *PGStore) MaxAge(age int) {
+	store.Options.MaxAge = age
 
 	// Set the maxAge for each securecookie instance.
-	for _, codec := range db.Codecs {
+	for _, codec := range store.Codecs {
 		if sc, ok := codec.(*securecookie.SecureCookie); ok {
 			sc.MaxAge(age)
 		}
@@ -59,21 +57,20 @@ func (db *PGStore) MaxAge(age int) {
 
 // load fetches a session by ID from the database and decodes its content
 // into session.Values.
-func (db *PGStore) load(session *sessions.Session) error {
-	repo := rp.NewRepository(db.DbPool)
-	res, err := repo.Session.SelectSession(&models.Session{Key: session.ID})
+func (store *PGStore) load(session *sessions.Session) error {
+	res, err := store.Repo.Session.SelectSession(&models.Session{Key: session.ID})
 
 	if err != nil {
 		return err
 	}
 
-	return securecookie.DecodeMulti(session.Name(), string(res.Data), &session.Values, db.Codecs...)
+	return securecookie.DecodeMulti(session.Name(), string(res.Data), &session.Values, store.Codecs...)
 }
 
 // save writes encoded session.Values to a database record.
 // writes to http_sessions table by default.
-func (db *PGStore) save(session *sessions.Session) error {
-	encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, db.Codecs...)
+func (store *PGStore) save(session *sessions.Session) error {
+	encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, store.Codecs...)
 	if err != nil {
 		return err
 	}
@@ -91,34 +88,34 @@ func (db *PGStore) save(session *sessions.Session) error {
 		}
 	}
 
-	s := models.Session{
+	s := &models.Session{
 		Key:       session.ID,
 		Data:      []byte(encoded),
 		ExpiresAt: expiresOn,
 	}
 
-	repo := rp.NewRepository(db.DbPool)
+	repo := store.Repo
 
 	if session.IsNew {
-		_, createErr := repo.Session.CreateSession(&s)
+		_, createErr := repo.Session.CreateSession(s)
 		return createErr
 	}
 
-	_, updateErr := repo.Session.UpdateSession(&s)
+	_, updateErr := repo.Session.UpdateSession(s)
 	return updateErr
 }
 
 // Implementation of the interface (Get, New, Save)
 
 // NewStore takes an initialized db and session key pairs to create a session-store in postgres db.
-func NewStore(db *gorm.DB, keyPairs ...[]byte) (*PGStore, error) {
+func NewStore(repo *repository.Repository, keyPairs ...[]byte) (*PGStore, error) {
 	dbStore := &PGStore{
 		Codecs: securecookie.CodecsFromPairs(keyPairs...),
 		Options: &sessions.Options{
 			Path:   "/",
 			MaxAge: 86400 * 30,
 		},
-		DbPool: db,
+		Repo: repo,
 	}
 
 	return dbStore, nil
@@ -126,26 +123,26 @@ func NewStore(db *gorm.DB, keyPairs ...[]byte) (*PGStore, error) {
 
 // Get Fetches a session for a given name after it has been added to the
 // registry.
-func (db *PGStore) Get(r *http.Request, name string) (*sessions.Session, error) {
-	return sessions.GetRegistry(r).Get(db, name)
+func (store *PGStore) Get(r *http.Request, name string) (*sessions.Session, error) {
+	return sessions.GetRegistry(r).Get(store, name)
 }
 
 // New returns a new session for the given name without adding it to the registry.
-func (db *PGStore) New(r *http.Request, name string) (*sessions.Session, error) {
-	session := sessions.NewSession(db, name)
+func (store *PGStore) New(r *http.Request, name string) (*sessions.Session, error) {
+	session := sessions.NewSession(store, name)
 	if session == nil {
 		return nil, nil
 	}
 
-	opts := *db.Options
+	opts := *store.Options
 	session.Options = &(opts)
 	session.IsNew = true
 
 	var err error
 	if c, errCookie := r.Cookie(name); errCookie == nil {
-		err = securecookie.DecodeMulti(name, c.Value, &session.ID, db.Codecs...)
+		err = securecookie.DecodeMulti(name, c.Value, &session.ID, store.Codecs...)
 		if err == nil {
-			err = db.load(session)
+			err = store.load(session)
 			if err == nil {
 				session.IsNew = false
 			} else if errors.Cause(err) == sql.ErrNoRows {
@@ -154,14 +151,14 @@ func (db *PGStore) New(r *http.Request, name string) (*sessions.Session, error)
 		}
 	}
 
-	db.MaxAge(db.Options.MaxAge)
+	store.MaxAge(store.Options.MaxAge)
 
 	return session, err
 }
 
 // Save saves the given session into the database and deletes cookies if needed
-func (db *PGStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
-	repo := rp.NewRepository(db.DbPool)
+func (store *PGStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
+	repo := store.Repo
 
 	// Set delete if max-age is < 0
 	if session.Options.MaxAge < 0 {
@@ -180,12 +177,12 @@ func (db *PGStore) Save(r *http.Request, w http.ResponseWriter, session *session
 			), "=")
 	}
 
-	if err := db.save(session); err != nil {
+	if err := store.save(session); err != nil {
 		return err
 	}
 
 	// Keep the session ID key in a cookie so it can be looked up in DB later.
-	encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, db.Codecs...)
+	encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, store.Codecs...)
 	if err != nil {
 		return err
 	}

+ 5 - 6
internal/auth/sessionstore_test.go

@@ -7,8 +7,7 @@ import (
 
 	"github.com/gorilla/securecookie"
 	"github.com/gorilla/sessions"
-
-	dbConn "github.com/porter-dev/porter/internal/adapter"
+	"github.com/porter-dev/porter/internal/repository/test"
 )
 
 type headerOnlyResponseWriter http.Header
@@ -28,9 +27,9 @@ func (ho headerOnlyResponseWriter) WriteHeader(int) {
 var secret = "secret"
 
 func TestPGStore(t *testing.T) {
-	db, _ := dbConn.New()
+	repo := test.NewRepository(true)
 
-	ss, err := NewStore(db, []byte(secret))
+	ss, err := NewStore(repo, []byte(secret))
 	if err != nil {
 		t.Fatal("Failed to get store", err)
 	}
@@ -125,9 +124,9 @@ func TestPGStore(t *testing.T) {
 }
 
 func TestSessionOptionsAreUniquePerSession(t *testing.T) {
-	db, _ := dbConn.New()
+	repo := test.NewRepository(true)
 
-	ss, err := NewStore(db, []byte(secret))
+	ss, err := NewStore(repo, []byte(secret))
 	if err != nil {
 		t.Fatal("Failed to get store", err)
 	}

+ 1 - 1
internal/models/session.go

@@ -10,7 +10,7 @@ import (
 type Session struct {
 	gorm.Model
 	// Session ID
-	Key string
+	Key string `gorm:"unique"`
 	// encrypted cookie
 	Data []byte
 	// Time the session will expire

+ 7 - 9
internal/repository/gorm/session.go

@@ -6,20 +6,18 @@ import (
 	"gorm.io/gorm"
 )
 
-type sessionrepo struct {
+// SessionRepository uses gorm.DB for querying the database
+type SessionRepository struct {
 	db *gorm.DB
 }
 
 // NewSessionRepository returns pointer to repo along with the db
 func NewSessionRepository(db *gorm.DB) repository.SessionRepository {
-	return &sessionrepo{
-		db: db,
-	}
+	return &SessionRepository{db}
 }
 
 // CreateSession must take in Key, Data, and ExpiresAt as arguments.
-func (s *sessionrepo) CreateSession(session *models.Session) (*models.Session, error) {
-	// TODO: check for duplicate and return error
+func (s *SessionRepository) CreateSession(session *models.Session) (*models.Session, error) {
 	if err := s.db.Create(session).Error; err != nil {
 		return nil, err
 	}
@@ -27,7 +25,7 @@ func (s *sessionrepo) CreateSession(session *models.Session) (*models.Session, e
 }
 
 // UpdateSession updates only the Data field using Key as selector.
-func (s *sessionrepo) UpdateSession(session *models.Session) (*models.Session, error) {
+func (s *SessionRepository) UpdateSession(session *models.Session) (*models.Session, error) {
 	if err := s.db.Model(session).Where("Key = ?", session.Key).Updates(session).Error; err != nil {
 		return nil, err
 	}
@@ -35,7 +33,7 @@ func (s *sessionrepo) UpdateSession(session *models.Session) (*models.Session, e
 }
 
 // DeleteSession deletes a session by Key
-func (s *sessionrepo) DeleteSession(session *models.Session) (*models.Session, error) {
+func (s *SessionRepository) DeleteSession(session *models.Session) (*models.Session, error) {
 
 	if err := s.db.Where("Key = ?", session.Key).Delete(session).Error; err != nil {
 		return nil, err
@@ -45,7 +43,7 @@ func (s *sessionrepo) DeleteSession(session *models.Session) (*models.Session, e
 }
 
 // SelectSession returns a session with matching key
-func (s *sessionrepo) SelectSession(session *models.Session) (*models.Session, error) {
+func (s *SessionRepository) SelectSession(session *models.Session) (*models.Session, error) {
 
 	if err := s.db.Where("Key = ?", session.Key).First(session).Error; err != nil {
 		return nil, err

+ 2 - 1
internal/repository/test/repository.go

@@ -8,6 +8,7 @@ import (
 // and accepts a parameter that can trigger read/write errors
 func NewRepository(canQuery bool) *repository.Repository {
 	return &repository.Repository{
-		User: NewUserRepository(canQuery),
+		User:    NewUserRepository(canQuery),
+		Session: NewSessionRepository(canQuery),
 	}
 }

+ 95 - 0
internal/repository/test/session.go

@@ -0,0 +1,95 @@
+package test
+
+import (
+	"errors"
+
+	"github.com/porter-dev/porter/internal/models"
+	"github.com/porter-dev/porter/internal/repository"
+	"gorm.io/gorm"
+)
+
+// SessionRepository uses gorm.DB for querying the database
+type SessionRepository struct {
+	canQuery bool
+	sessions []*models.Session
+}
+
+// NewSessionRepository returns pointer to repo along with the db
+func NewSessionRepository(canQuery bool) repository.SessionRepository {
+	return &SessionRepository{canQuery, []*models.Session{}}
+}
+
+// CreateSession must take in Key, Data, and ExpiresAt as arguments.
+func (repo *SessionRepository) CreateSession(session *models.Session) (*models.Session, error) {
+	if !repo.canQuery {
+		return nil, errors.New("Cannot write database")
+	}
+
+	// make sure key doesn't exist
+	for _, s := range repo.sessions {
+		if s.Key == session.Key {
+			return nil, errors.New("Cannot write database")
+		}
+	}
+
+	sessions := repo.sessions
+	sessions = append(sessions, session)
+	repo.sessions = sessions
+	session.ID = uint(len(repo.sessions))
+
+	return session, nil
+}
+
+// UpdateSession updates only the Data field using Key as selector.
+func (repo *SessionRepository) UpdateSession(session *models.Session) (*models.Session, error) {
+	if !repo.canQuery {
+		return nil, errors.New("Cannot write database")
+	}
+
+	var oldSession *models.Session
+
+	for _, s := range repo.sessions {
+		if s.Key == session.Key {
+			oldSession = s
+		}
+	}
+
+	if oldSession != nil {
+		oldSession.Data = session.Data
+
+		return oldSession, nil
+	}
+
+	return nil, gorm.ErrRecordNotFound
+}
+
+// DeleteSession deletes a session by Key
+func (repo *SessionRepository) DeleteSession(session *models.Session) (*models.Session, error) {
+	if !repo.canQuery {
+		return nil, errors.New("Cannot write database")
+	}
+
+	if int(session.ID-1) >= len(repo.sessions) || repo.sessions[session.ID-1] == nil {
+		return nil, gorm.ErrRecordNotFound
+	}
+
+	index := int(session.ID - 1)
+	repo.sessions[index] = nil
+
+	return session, nil
+}
+
+// SelectSession returns a session with matching key
+func (repo *SessionRepository) SelectSession(session *models.Session) (*models.Session, error) {
+	if !repo.canQuery {
+		return nil, errors.New("Cannot write database")
+	}
+
+	for _, s := range repo.sessions {
+		if s.Key == session.Key {
+			return s, nil
+		}
+	}
+
+	return nil, gorm.ErrRecordNotFound
+}

+ 21 - 0
internal/repository/test/user.go

@@ -5,6 +5,7 @@ import (
 
 	"github.com/porter-dev/porter/internal/models"
 	"github.com/porter-dev/porter/internal/repository"
+	"golang.org/x/crypto/bcrypt"
 	"gorm.io/gorm"
 )
 
@@ -104,3 +105,23 @@ func (repo *UserRepository) DeleteUser(user *models.User) (*models.User, error)
 
 	return user, nil
 }
+
+// CheckPassword checks the input password is correct for the provided user id.
+func (repo *UserRepository) CheckPassword(id int, pwd string) (bool, error) {
+	if !repo.canQuery {
+		return false, errors.New("Cannot write database")
+	}
+
+	if int(id-1) >= len(repo.users) || repo.users[id-1] == nil {
+		return false, gorm.ErrRecordNotFound
+	}
+
+	index := int(id - 1)
+	user := *repo.users[index]
+
+	if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(pwd)); err != nil {
+		return false, err
+	}
+
+	return true, nil
+}

+ 5 - 1
server/api/user_handler_test.go

@@ -17,6 +17,7 @@ import (
 	"github.com/porter-dev/porter/server/api"
 	"github.com/porter-dev/porter/server/router"
 
+	sessionstore "github.com/porter-dev/porter/internal/auth"
 	lr "github.com/porter-dev/porter/internal/logger"
 	vr "github.com/porter-dev/porter/internal/validator"
 )
@@ -39,7 +40,10 @@ func initApi(canQuery bool) (*api.App, *repository.Repository) {
 
 	repo := test.NewRepository(canQuery)
 
-	return api.New(logger, repo, validator), repo
+	key := []byte("secret") // TODO: change to os.Getenv("SESSION_KEY")
+	store, _ := sessionstore.NewStore(db, key)
+
+	return api.New(logger, repo, validator, store), repo
 }
 
 func testUserRequest(t *testing.T, c userTest) {