Prechádzať zdrojové kódy

docker cred helper poc

Alexander Belanger 5 rokov pred
rodič
commit
76324a2d33

+ 7 - 7
cli/cmd/api/registry.go

@@ -174,16 +174,16 @@ type GetECRTokenResponse struct {
 func (c *Client) GetECRAuthorizationToken(
 	ctx context.Context,
 	projectID uint,
-	registryID uint,
-) error {
+	region string,
+) (*GetECRTokenResponse, error) {
 	req, err := http.NewRequest(
 		"GET",
-		fmt.Sprintf("%s/projects/%d/registries/%d/ecr/token", c.BaseURL, projectID, registryID),
+		fmt.Sprintf("%s/projects/%d/registries/ecr/%s/token", c.BaseURL, projectID, region),
 		nil,
 	)
 
 	if err != nil {
-		return err
+		return nil, err
 	}
 
 	bodyResp := &GetECRTokenResponse{}
@@ -191,13 +191,13 @@ func (c *Client) GetECRAuthorizationToken(
 
 	if httpErr, err := c.sendRequest(req, bodyResp, true); httpErr != nil || err != nil {
 		if httpErr != nil {
-			return fmt.Errorf("code %d, errors %v", httpErr.Code, httpErr.Errors)
+			return nil, fmt.Errorf("code %d, errors %v", httpErr.Code, httpErr.Errors)
 		}
 
-		return err
+		return nil, err
 	}
 
-	return nil
+	return bodyResp, nil
 }
 
 // ListRegistryRepositoryResponse is the list of repositories in a registry

+ 2 - 0
cli/cmd/docker/porter.go

@@ -132,6 +132,8 @@ func StartPorter(opts *PorterStartOpts) (agent *Agent, id string, err error) {
 		}...)
 	}
 
+	opts.Env = append(opts.Env, "REDIS_ENABLED=false")
+
 	// create Porter container
 	startOpts := PorterServerStartOpts{
 		Name:          "porter_server_" + opts.ProcessID,

+ 15 - 12
cmd/app/main.go

@@ -30,14 +30,6 @@ func main() {
 		return
 	}
 
-	redis, err := adapter.NewRedisClient(&appConf.Redis)
-	prov.InitGlobalStream(redis)
-
-	if err != nil {
-		logger.Fatal().Err(err).Msg("")
-		return
-	}
-
 	err = db.AutoMigrate(
 		&models.Project{},
 		&models.Role{},
@@ -75,6 +67,21 @@ func main() {
 
 	repo := gorm.NewRepository(db, &key)
 
+	if appConf.Redis.Enabled {
+		redis, err := adapter.NewRedisClient(&appConf.Redis)
+
+		if err != nil {
+			logger.Fatal().Err(err).Msg("")
+			return
+		}
+
+		prov.InitGlobalStream(redis)
+
+		errorChan := make(chan error)
+
+		go prov.GlobalStreamListener(redis, *repo, errorChan)
+	}
+
 	a, _ := api.New(&api.AppConfig{
 		Logger:     logger,
 		Repository: repo,
@@ -96,10 +103,6 @@ func main() {
 		IdleTimeout:  appConf.Server.TimeoutIdle,
 	}
 
-	errorChan := make(chan error)
-
-	go prov.GlobalStreamListener(redis, *repo, errorChan)
-
 	if err := s.ListenAndServe(); err != nil && err != http.ErrServerClosed {
 		log.Fatal("Server startup failed", err)
 	}

+ 198 - 0
cmd/docker-credential-porter/helper/cache.go

@@ -0,0 +1,198 @@
+package helper
+
+import (
+	"crypto/md5"
+	"encoding/base64"
+	"encoding/json"
+	"fmt"
+	"io/ioutil"
+	"os"
+	"path/filepath"
+	"time"
+
+	"github.com/aws/aws-sdk-go/aws/credentials"
+	"github.com/sirupsen/logrus"
+	"k8s.io/client-go/util/homedir"
+)
+
+type CredentialsCache interface {
+	Get(registry string) *AuthEntry
+	Set(registry string, entry *AuthEntry)
+	List() []*AuthEntry
+	Clear()
+}
+
+type AuthEntry struct {
+	AuthorizationToken string
+	RequestedAt        time.Time
+	ExpiresAt          time.Time
+	ProxyEndpoint      string
+}
+
+// IsValid checks if AuthEntry is still valid at testTime. AuthEntries expire at 1/2 of their original
+// requested window.
+func (authEntry *AuthEntry) IsValid(testTime time.Time) bool {
+	validWindow := authEntry.ExpiresAt.Sub(authEntry.RequestedAt)
+	refreshTime := authEntry.ExpiresAt.Add(-1 * validWindow / time.Duration(2))
+	return testTime.Before(refreshTime)
+}
+
+func BuildCredentialsCache(region string) CredentialsCache {
+	home := homedir.HomeDir()
+	cacheDir := filepath.Join(home, ".porter")
+	cacheFilename := "cache.json"
+
+	return NewFileCredentialsCache(cacheDir, cacheFilename, region)
+}
+
+// Determine a key prefix for a credentials cache. Because auth tokens are scoped to an account and region, rely on provided
+// region, as well as hash of the access key.
+func credentialsCachePrefix(region string, credentials *credentials.Value) string {
+	return fmt.Sprintf("%s-%s-", region, checksum(credentials.AccessKeyID))
+}
+
+// Base64 encodes an MD5 checksum. Relied on for uniqueness, and not for cryptographic security.
+func checksum(text string) string {
+	hasher := md5.New()
+	data := hasher.Sum([]byte(text))
+	return base64.StdEncoding.EncodeToString(data)
+}
+
+const registryCacheVersion = "1.0"
+
+type RegistryCache struct {
+	Registries map[string]*AuthEntry
+	Version    string
+}
+
+type fileCredentialCache struct {
+	path           string
+	filename       string
+	cachePrefixKey string
+}
+
+func newRegistryCache() *RegistryCache {
+	return &RegistryCache{
+		Registries: make(map[string]*AuthEntry),
+		Version:    registryCacheVersion,
+	}
+}
+
+// NewFileCredentialsCache returns a new file credentials cache.
+//
+// path is used for temporary files during save, and filename should be a relative filename
+// in the same directory where the cache is serialized and deserialized.
+//
+// cachePrefixKey is used for scoping credentials for a given credential cache (i.e. region and
+// accessKey).
+func NewFileCredentialsCache(path string, filename string, cachePrefixKey string) CredentialsCache {
+	if _, err := os.Stat(path); err != nil {
+		os.MkdirAll(path, 0700)
+	}
+
+	return &fileCredentialCache{path: path, filename: filename, cachePrefixKey: cachePrefixKey}
+}
+
+func (f *fileCredentialCache) Get(registry string) *AuthEntry {
+	registryCache := f.init()
+
+	return registryCache.Registries[f.cachePrefixKey+registry]
+}
+
+func (f *fileCredentialCache) Set(registry string, entry *AuthEntry) {
+	registryCache := f.init()
+
+	registryCache.Registries[f.cachePrefixKey+registry] = entry
+
+	f.save(registryCache)
+}
+
+// List returns all of the available AuthEntries (regardless of prefix)
+func (f *fileCredentialCache) List() []*AuthEntry {
+	registryCache := f.init()
+
+	// optimize allocation for copy
+	entries := make([]*AuthEntry, 0, len(registryCache.Registries))
+
+	for _, entry := range registryCache.Registries {
+		entries = append(entries, entry)
+	}
+
+	return entries
+}
+
+func (f *fileCredentialCache) Clear() {
+	os.Remove(f.fullFilePath())
+}
+
+func (f *fileCredentialCache) fullFilePath() string {
+	return filepath.Join(f.path, f.filename)
+}
+
+// Saves credential cache to disk. This writes to a temporary file first, then moves the file to the config location.
+// This eliminates from reading partially written credential files, and reduces (but does not eliminate) concurrent
+// file access. There is not guarantee here for handling multiple writes at once since there is no out of process locking.
+func (f *fileCredentialCache) save(registryCache *RegistryCache) error {
+	file, err := ioutil.TempFile(f.path, ".config.json.tmp")
+	if err != nil {
+		return err
+	}
+
+	buff, err := json.MarshalIndent(registryCache, "", "  ")
+	if err != nil {
+		file.Close()
+		os.Remove(file.Name())
+		return err
+	}
+
+	_, err = file.Write(buff)
+
+	if err != nil {
+		file.Close()
+		os.Remove(file.Name())
+		return err
+	}
+
+	file.Close()
+	// note this is only atomic when relying on linux syscalls
+	os.Rename(file.Name(), f.fullFilePath())
+	return err
+}
+
+func (f *fileCredentialCache) init() *RegistryCache {
+	registryCache, err := f.load()
+	if err != nil {
+		logrus.WithError(err).Info("Could not load existing cache")
+		f.Clear()
+		registryCache = newRegistryCache()
+	}
+	return registryCache
+}
+
+// Loading a cache from disk will return errors for malformed or incompatible cache files.
+func (f *fileCredentialCache) load() (*RegistryCache, error) {
+	registryCache := newRegistryCache()
+
+	file, err := os.Open(f.fullFilePath())
+	if os.IsNotExist(err) {
+		return registryCache, nil
+	}
+
+	if err != nil {
+		return nil, err
+	}
+
+	defer file.Close()
+
+	if err = json.NewDecoder(file).Decode(&registryCache); err != nil {
+		return nil, err
+	}
+
+	if registryCache.Version != registryCacheVersion {
+		return nil, fmt.Errorf("ecr: Registry cache version %#v is not compatible with %#v, ignoring existing cache",
+			registryCache.Version,
+			registryCacheVersion)
+	}
+
+	return registryCache, nil
+}

+ 74 - 13
cmd/docker-credential-porter/helper/helper.go

@@ -2,10 +2,14 @@ package helper
 
 import (
 	"context"
-	"io/ioutil"
+	"encoding/base64"
+	"fmt"
 	"log"
 	"os"
 	"path/filepath"
+	"regexp"
+	"strings"
+	"time"
 
 	"github.com/docker/docker-credential-helpers/credentials"
 	"github.com/porter-dev/porter/cli/cmd"
@@ -30,6 +34,8 @@ func (p *PorterHelper) Delete(serverURL string) error {
 	return nil
 }
 
+var ecrPattern = regexp.MustCompile(`(^[a-zA-Z0-9][a-zA-Z0-9-_]*)\.dkr\.ecr(\-fips)?\.([a-zA-Z0-9][a-zA-Z0-9-_]*)\.amazonaws\.com(\.cn)?`)
+
 // Get retrieves credentials from the store.
 // It returns username and secret as strings.
 func (p *PorterHelper) Get(serverURL string) (user string, secret string, err error) {
@@ -38,30 +44,85 @@ func (p *PorterHelper) Get(serverURL string) (user string, secret string, err er
 	file, _ := os.OpenFile(filepath.Join(home, ".porter", "logs.txt"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666)
 	log.SetOutput(file)
 
-	host := viper.GetString("host")
-	projID := viper.GetUint("project")
+	// parse the server url for region
+	matches := ecrPattern.FindStringSubmatch(serverURL)
+
+	if len(matches) == 0 {
+		return "", "", fmt.Errorf("docker-credential-porter can only be used with Amazon Elastic Container Registry.")
+	} else if len(matches) < 3 {
+		return "", "", fmt.Errorf(serverURL + "is not a valid repository URI for Amazon Elastic Container Registry.")
+	}
+
+	region := matches[3]
+
+	credCache := BuildCredentialsCache(region)
+	cachedEntry := credCache.Get(serverURL)
+
+	var token string
+
+	if cachedEntry != nil && cachedEntry.IsValid(time.Now()) {
+		token = cachedEntry.AuthorizationToken
+	} else {
+		host := viper.GetString("host")
+		projID := viper.GetUint("project")
+
+		client := api.NewClient(host+"/api", "cookie.json")
+
+		// get a token from the server
+		tokenResp, err := client.GetECRAuthorizationToken(context.Background(), projID, matches[3])
+
+		if err != nil {
+			return "", "", err
+		}
 
-	client := api.NewClient(host+"/api", "cookie.json")
+		token = tokenResp.Token
 
-	// list registries
-	reg, err := client.ListRegistries(context.Background(), projID)
+		// set the token in cache
+		credCache.Set(serverURL, &AuthEntry{
+			AuthorizationToken: token,
+			RequestedAt:        time.Now(),
+			ExpiresAt:          time.Now().Add(12 * time.Hour),
+			ProxyEndpoint:      serverURL,
+		})
+	}
 
-	log.Println("called regs", reg, err)
+	decodedToken, err := base64.StdEncoding.DecodeString(token)
 
 	if err != nil {
-		return "", "", err
+		return "", "", fmt.Errorf("Invalid token: %v", err)
 	}
 
-	log.Println(reg)
+	parts := strings.SplitN(string(decodedToken), ":", 2)
+
+	if len(parts) < 2 {
+		return "", "", fmt.Errorf("Invalid token: expected two parts, got %d", len(parts))
+	}
 
-	return "", "", nil
+	return parts[0], parts[1], nil
 }
 
 // List returns the stored serverURLs and their associated usernames.
 func (p *PorterHelper) List() (map[string]string, error) {
-	var home = homedir.HomeDir()
+	credCache := BuildCredentialsCache("")
+	entries := credCache.List()
+
+	res := make(map[string]string)
 
-	ioutil.WriteFile(filepath.Join(home, ".porter", "log.txt"), []byte("called list\n"), 0644)
+	for _, entry := range entries {
+		decodedToken, err := base64.StdEncoding.DecodeString(entry.AuthorizationToken)
+
+		if err != nil {
+			continue
+		}
+
+		parts := strings.SplitN(string(decodedToken), ":", 2)
+
+		if len(parts) < 2 {
+			continue
+		}
+
+		res[entry.ProxyEndpoint] = parts[0]
+	}
 
-	return nil, nil
+	return res, nil
 }

+ 2 - 2
go.mod

@@ -9,7 +9,7 @@ require (
 	github.com/DATA-DOG/go-sqlmock v1.5.0
 	github.com/Masterminds/semver v1.5.0 // indirect
 	github.com/aws/aws-sdk-go v1.35.4
-	github.com/awslabs/amazon-ecr-credential-helper/ecr-login v0.0.0-20201113001948-d77edb6d2e47 // indirect
+	github.com/awslabs/amazon-ecr-credential-helper/ecr-login v0.0.0-20201113001948-d77edb6d2e47
 	github.com/containerd/containerd v1.4.1 // indirect
 	github.com/coreos/rkt v1.30.0
 	github.com/docker/docker v1.4.2-0.20200203170920-46ec8731fbce
@@ -48,7 +48,7 @@ require (
 	github.com/pelletier/go-toml v1.8.1 // indirect
 	github.com/pkg/errors v0.9.1
 	github.com/rs/zerolog v1.20.0
-	github.com/sirupsen/logrus v1.7.0 // indirect
+	github.com/sirupsen/logrus v1.7.0
 	github.com/spf13/cobra v1.0.0
 	github.com/spf13/viper v1.4.0
 	github.com/stretchr/testify v1.6.1

+ 3 - 0
internal/config/redis.go

@@ -2,6 +2,9 @@ package config
 
 // RedisConf is the redis config required for the provisioner container
 type RedisConf struct {
+	// if redis should be used
+	Enabled bool `env:"REDIS_ENABLED,default=true"`
+
 	Host     string `env:"REDIS_HOST,default=redis"`
 	Port     string `env:"REDIS_PORT,default=6379"`
 	Username string `env:"REDIS_USER"`

+ 29 - 25
server/api/registry_handler.go

@@ -111,47 +111,51 @@ func (app *App) HandleGetProjectRegistryECRToken(w http.ResponseWriter, r *http.
 		return
 	}
 
-	registryID, err := strconv.ParseUint(chi.URLParam(r, "registry_id"), 0, 64)
+	region := chi.URLParam(r, "region")
 
-	if err != nil || registryID == 0 {
+	if region == "" {
 		app.handleErrorFormDecoding(err, ErrProjectDecode, w)
 		return
 	}
 
-	// handle write to the database
-	reg, err := app.Repo.Registry.ReadRegistry(uint(registryID))
+	// list registries and find one that matches the region
+	regs, err := app.Repo.Registry.ListRegistriesByProjectID(uint(projID))
+	var token string
 
-	if err != nil {
-		app.handleErrorDataRead(err, w)
-		return
-	}
+	for _, reg := range regs {
+		if reg.AWSIntegrationID != 0 {
+			awsInt, err := app.Repo.AWSIntegration.ReadAWSIntegration(reg.AWSIntegrationID)
 
-	// get the aws integration and session
-	awsInt, err := app.Repo.AWSIntegration.ReadAWSIntegration(reg.AWSIntegrationID)
+			if err != nil {
+				app.handleErrorDataRead(err, w)
+				return
+			}
 
-	if err != nil {
-		app.handleErrorDataRead(err, w)
-		return
-	}
+			if awsInt.AWSRegion == region {
+				// get the aws integration and session
+				sess, err := awsInt.GetSession()
 
-	sess, err := awsInt.GetSession()
+				if err != nil {
+					app.handleErrorDataRead(err, w)
+					return
+				}
 
-	if err != nil {
-		app.handleErrorDataRead(err, w)
-		return
-	}
+				ecrSvc := ecr.New(sess)
 
-	ecrSvc := ecr.New(sess)
+				output, err := ecrSvc.GetAuthorizationToken(&ecr.GetAuthorizationTokenInput{})
 
-	output, err := ecrSvc.GetAuthorizationToken(&ecr.GetAuthorizationTokenInput{})
+				if err != nil {
+					app.handleErrorDataRead(err, w)
+					return
+				}
 
-	if err != nil {
-		app.handleErrorDataRead(err, w)
-		return
+				token = *output.AuthorizationData[0].AuthorizationToken
+			}
+		}
 	}
 
 	resp := &ECRTokenResponse{
-		Token: *output.AuthorizationData[0].AuthorizationToken,
+		Token: token,
 	}
 
 	w.WriteHeader(http.StatusOK)

+ 2 - 6
server/router/router.go

@@ -422,13 +422,9 @@ func New(a *api.App) *chi.Mux {
 
 		r.Method(
 			"GET",
-			"/projects/{project_id}/registries/{registry_id}/ecr/token",
+			"/projects/{project_id}/registries/ecr/{region}/token",
 			auth.DoesUserHaveProjectAccess(
-				auth.DoesUserHaveRegistryAccess(
-					requestlog.NewHandler(a.HandleGetProjectRegistryECRToken, l),
-					mw.URLParam,
-					mw.URLParam,
-				),
+				requestlog.NewHandler(a.HandleGetProjectRegistryECRToken, l),
 				mw.URLParam,
 				mw.WriteAccess,
 			),