Explorar el Código

gcr support w/out tokencache

Alexander Belanger hace 5 años
padre
commit
3f0254a062

+ 45 - 0
cli/cmd/api/integration.go

@@ -56,3 +56,48 @@ func (c *Client) CreateAWSIntegration(
 
 	return bodyResp, nil
 }
+
+// CreateGCPIntegrationRequest represents the accepted fields for creating
+// a gcp integration
+type CreateGCPIntegrationRequest struct {
+	GCPKeyData string `json:"gcp_key_data"`
+}
+
+// CreateGCPIntegrationResponse is the resulting integration after creation
+type CreateGCPIntegrationResponse ints.GCPIntegrationExternal
+
+// CreateGCPIntegration creates a GCP integration with the given request options
+func (c *Client) CreateGCPIntegration(
+	ctx context.Context,
+	projectID uint,
+	createGCP *CreateGCPIntegrationRequest,
+) (*CreateGCPIntegrationResponse, error) {
+	data, err := json.Marshal(createGCP)
+
+	if err != nil {
+		return nil, err
+	}
+
+	req, err := http.NewRequest(
+		"POST",
+		fmt.Sprintf("%s/projects/%d/integrations/gcp", c.BaseURL, projectID),
+		strings.NewReader(string(data)),
+	)
+
+	if err != nil {
+		return nil, err
+	}
+
+	req = req.WithContext(ctx)
+	bodyResp := &CreateGCPIntegrationResponse{}
+
+	if httpErr, err := c.sendRequest(req, bodyResp, true); httpErr != nil || err != nil {
+		if httpErr != nil {
+			return nil, fmt.Errorf("code %d, errors %v", httpErr.Code, httpErr.Errors)
+		}
+
+		return nil, err
+	}
+
+	return bodyResp, nil
+}

+ 46 - 0
cli/cmd/api/registry.go

@@ -58,6 +58,52 @@ func (c *Client) CreateECR(
 	return bodyResp, nil
 }
 
+// CreateGCRRequest represents the accepted fields for creating
+// a GCR registry
+type CreateGCRRequest struct {
+	Name             string `json:"name"`
+	GCPIntegrationID uint   `json:"gcp_integration_id"`
+}
+
+// CreateGCRResponse is the resulting registry after creation
+type CreateGCRResponse models.RegistryExternal
+
+// CreateGCR creates an Google Container Registry integration
+func (c *Client) CreateGCR(
+	ctx context.Context,
+	projectID uint,
+	createGCR *CreateGCRRequest,
+) (*CreateGCRResponse, error) {
+	data, err := json.Marshal(createGCR)
+
+	if err != nil {
+		return nil, err
+	}
+
+	req, err := http.NewRequest(
+		"POST",
+		fmt.Sprintf("%s/projects/%d/registries", c.BaseURL, projectID),
+		strings.NewReader(string(data)),
+	)
+
+	if err != nil {
+		return nil, err
+	}
+
+	req = req.WithContext(ctx)
+	bodyResp := &CreateGCRResponse{}
+
+	if httpErr, err := c.sendRequest(req, bodyResp, true); httpErr != nil || err != nil {
+		if httpErr != nil {
+			return nil, fmt.Errorf("code %d, errors %v", httpErr.Code, httpErr.Errors)
+		}
+
+		return nil, err
+	}
+
+	return bodyResp, nil
+}
+
 // ListRegistryRepositoryResponse is the list of repositories in a registry
 type ListRegistryRepositoryResponse []registry.Repository
 

+ 33 - 1
cli/cmd/connect.go

@@ -43,6 +43,18 @@ var connectECRCmd = &cobra.Command{
 	},
 }
 
+var connectGCRCmd = &cobra.Command{
+	Use:   "gcr",
+	Short: "Connects a GCR instance to a project",
+	Run: func(cmd *cobra.Command, args []string) {
+		err := checkLoginAndRun(args, runConnectGCR)
+
+		if err != nil {
+			os.Exit(1)
+		}
+	},
+}
+
 func init() {
 	rootCmd.AddCommand(connectCmd)
 
@@ -77,6 +89,7 @@ func init() {
 	)
 
 	connectCmd.AddCommand(connectECRCmd)
+	connectCmd.AddCommand(connectGCRCmd)
 }
 
 func runConnectKubeconfig(_ *api.AuthCheckResponse, client *api.Client, _ []string) error {
@@ -96,8 +109,27 @@ func runConnectKubeconfig(_ *api.AuthCheckResponse, client *api.Client, _ []stri
 }
 
 func runConnectECR(_ *api.AuthCheckResponse, client *api.Client, _ []string) error {
-	return connect.ECR(
+	regID, err := connect.ECR(
 		client,
 		getProjectID(),
 	)
+
+	if err != nil {
+		return err
+	}
+
+	return setRegistry(regID)
+}
+
+func runConnectGCR(_ *api.AuthCheckResponse, client *api.Client, _ []string) error {
+	regID, err := connect.GCR(
+		client,
+		getProjectID(),
+	)
+
+	if err != nil {
+		return err
+	}
+
+	return setRegistry(regID)
 }

+ 9 - 9
cli/cmd/connect/ecr.go

@@ -13,31 +13,31 @@ import (
 func ECR(
 	client *api.Client,
 	projectID uint,
-) error {
+) (uint, error) {
 	// if project ID is 0, ask the user to set the project ID or create a project
 	if projectID == 0 {
-		return fmt.Errorf("no project set, please run porter project set [id]")
+		return 0, fmt.Errorf("no project set, please run porter project set [id]")
 	}
 
 	// query for the access key id
 	accessKeyID, err := utils.PromptPlaintext(fmt.Sprintf(`AWS Access Key ID: `))
 
 	if err != nil {
-		return err
+		return 0, err
 	}
 
 	// query for the secret access key
 	secretKey, err := utils.PromptPlaintext(fmt.Sprintf(`AWS Secret Access Key: `))
 
 	if err != nil {
-		return err
+		return 0, err
 	}
 
 	// query for the region
 	region, err := utils.PromptPlaintext(fmt.Sprintf(`AWS Region: `))
 
 	if err != nil {
-		return err
+		return 0, err
 	}
 
 	// create the aws integration
@@ -52,7 +52,7 @@ func ECR(
 	)
 
 	if err != nil {
-		return err
+		return 0, err
 	}
 
 	color.New(color.FgGreen).Printf("created aws integration with id %d\n", integration.ID)
@@ -62,7 +62,7 @@ func ECR(
 	regName, err := utils.PromptPlaintext(fmt.Sprintf(`Give this registry a name: `))
 
 	if err != nil {
-		return err
+		return 0, err
 	}
 
 	reg, err := client.CreateECR(
@@ -75,10 +75,10 @@ func ECR(
 	)
 
 	if err != nil {
-		return err
+		return 0, err
 	}
 
 	color.New(color.FgGreen).Printf("created registry with id %d and name %s\n", reg.ID, reg.Name)
 
-	return nil
+	return reg.ID, nil
 }

+ 82 - 0
cli/cmd/connect/gcr.go

@@ -0,0 +1,82 @@
+package connect
+
+import (
+	"context"
+	"fmt"
+	"io/ioutil"
+	"os"
+
+	"github.com/fatih/color"
+	"github.com/porter-dev/porter/cli/cmd/api"
+	"github.com/porter-dev/porter/cli/cmd/utils"
+)
+
+// GCR creates a GCR integration
+func GCR(
+	client *api.Client,
+	projectID uint,
+) (uint, error) {
+	// if project ID is 0, ask the user to set the project ID or create a project
+	if projectID == 0 {
+		return 0, fmt.Errorf("no project set, please run porter project set [id]")
+	}
+
+	keyFileLocation, err := utils.PromptPlaintext(fmt.Sprintf(`Please provide the full path to a service account key file.
+Key file location: `))
+
+	if err != nil {
+		return 0, err
+	}
+
+	// attempt to read the key file location
+	if info, err := os.Stat(keyFileLocation); !os.IsNotExist(err) && !info.IsDir() {
+		// read the file
+		bytes, err := ioutil.ReadFile(keyFileLocation)
+
+		if err != nil {
+			return 0, err
+		}
+
+		// create the aws integration
+		integration, err := client.CreateGCPIntegration(
+			context.Background(),
+			projectID,
+			&api.CreateGCPIntegrationRequest{
+				GCPKeyData: string(bytes),
+			},
+		)
+
+		if err != nil {
+			return 0, err
+		}
+
+		color.New(color.FgGreen).Printf("created gcp integration with id %d\n", integration.ID)
+
+		// create the registry
+		// query for registry name
+		regName, err := utils.PromptPlaintext(fmt.Sprintf(`Give this registry a name: `))
+
+		if err != nil {
+			return 0, err
+		}
+
+		reg, err := client.CreateGCR(
+			context.Background(),
+			projectID,
+			&api.CreateGCRRequest{
+				Name:             regName,
+				GCPIntegrationID: integration.ID,
+			},
+		)
+
+		if err != nil {
+			return 0, err
+		}
+
+		color.New(color.FgGreen).Printf("created registry with id %d and name %s\n", reg.ID, reg.Name)
+
+		return reg.ID, nil
+	}
+
+	return 0, fmt.Errorf("could not read service account key file")
+}

+ 3 - 2
internal/kubernetes/config.go

@@ -303,8 +303,9 @@ func (conf *OutOfClusterConfig) getTokenCache() (tok *ints.TokenCache, err error
 func (conf *OutOfClusterConfig) setTokenCache(token string, expiry time.Time) error {
 	_, err := conf.Repo.Cluster.UpdateClusterTokenCache(
 		&ints.TokenCache{
-			Token:  []byte(token),
-			Expiry: expiry,
+			ClusterID: conf.Cluster.ID,
+			Token:     []byte(token),
+			Expiry:    expiry,
 		},
 	)
 

+ 4 - 2
internal/models/integrations/gcp.go

@@ -82,8 +82,10 @@ func (g *GCPIntegration) GetBearerToken(
 	cache, err := getTokenCache()
 
 	// check the token cache for a non-expired token
-	if tok := cache.Token; err == nil && !cache.IsExpired() && len(tok) > 0 {
-		return string(tok), nil
+	if cache != nil {
+		if tok := cache.Token; err == nil && !cache.IsExpired() && len(tok) > 0 {
+			return string(tok), nil
+		}
 	}
 
 	creds, err := google.CredentialsFromJSON(

+ 4 - 2
internal/models/integrations/token_cache.go

@@ -20,8 +20,10 @@ type SetTokenCacheFunc func(token string, expiry time.Time) error
 type TokenCache struct {
 	gorm.Model
 
-	ClusterID uint      `json:"cluster_id"`
-	Expiry    time.Time `json:"expiry,omitempty"`
+	ClusterID  uint `json:"cluster_id"`
+	RegistryID uint `json:"registry_id"`
+
+	Expiry time.Time `json:"expiry,omitempty"`
 
 	// ------------------------------------------------------------------
 	// All fields below this line are encrypted before storage

+ 3 - 0
internal/models/registry.go

@@ -22,6 +22,9 @@ type Registry struct {
 
 	GCPIntegrationID uint
 	AWSIntegrationID uint
+
+	// A token cache that can be used by an auth mechanism, if desired
+	TokenCache integrations.TokenCache `json:"token_cache"`
 }
 
 // RegistryExternal is an external Registry to be shared over REST

+ 114 - 3
internal/registry/registry.go

@@ -1,12 +1,16 @@
 package registry
 
 import (
+	"encoding/json"
 	"fmt"
+	"net/http"
 	"time"
 
 	"github.com/aws/aws-sdk-go/service/ecr"
 	"github.com/porter-dev/porter/internal/models"
 	"github.com/porter-dev/porter/internal/repository"
+
+	ints "github.com/porter-dev/porter/internal/models/integrations"
 )
 
 // Registry wraps the gorm Registry model
@@ -40,7 +44,11 @@ type Image struct {
 func (r *Registry) ListRepositories(repo repository.Repository) ([]*Repository, error) {
 	// switch on the auth mechanism to get a token
 	if r.AWSIntegrationID != 0 {
-		return r.listECRRepositories(repo.AWSIntegration)
+		return r.listECRRepositories(repo)
+	}
+
+	if r.GCPIntegrationID != 0 {
+		return r.listGCPRepositories(repo)
 	}
 
 	return nil, fmt.Errorf("error listing repositories")
@@ -54,8 +62,91 @@ func (r *Registry) ListImages(
 	return nil, nil
 }
 
-func (r *Registry) listECRRepositories(repo repository.AWSIntegrationRepository) ([]*Repository, error) {
-	aws, err := repo.ReadAWSIntegration(
+type gcrJWT struct {
+	AccessToken  string `json:"token"`
+	ExpiresInSec int    `json:"expires_in"`
+}
+
+type gcrRepositoryResp struct {
+	Repositories []string `json:"repositories"`
+}
+
+// TODO -- use a token cache for the JWT token as well
+func (r *Registry) listGCPRepositories(
+	repo repository.Repository,
+) ([]*Repository, error) {
+	gcp, err := repo.GCPIntegration.ReadGCPIntegration(
+		r.GCPIntegrationID,
+	)
+
+	if err != nil {
+		return nil, err
+	}
+
+	// get oauth2 access token
+	oauthTok, err := gcp.GetBearerToken(r.getTokenCache, r.setTokenCacheFunc(repo))
+
+	if err != nil {
+		return nil, err
+	}
+
+	// get jwt token
+	client := &http.Client{}
+
+	req, err := http.NewRequest(
+		"GET",
+		"https://gcr.io/v2/token?service=gcr.io&scope=registry:catalog:*",
+		nil,
+	)
+
+	req.SetBasicAuth("_token", oauthTok)
+
+	resp, err := client.Do(req)
+
+	if err != nil {
+		return nil, err
+	}
+
+	jwtSource := gcrJWT{}
+
+	if err := json.NewDecoder(resp.Body).Decode(&jwtSource); err != nil {
+		return nil, fmt.Errorf("Invalid token JSON from metadata: %v", err)
+	}
+
+	// use JWT token to request catalog
+	req, err = http.NewRequest(
+		"GET",
+		"https://gcr.io/v2/_catalog",
+		nil,
+	)
+
+	req.Header.Add("Authorization", "Bearer "+jwtSource.AccessToken)
+
+	resp, err = client.Do(req)
+
+	if err != nil {
+		return nil, err
+	}
+
+	gcrResp := gcrRepositoryResp{}
+
+	if err := json.NewDecoder(resp.Body).Decode(&gcrResp); err != nil {
+		return nil, fmt.Errorf("Could not read GCR repositories: %v", err)
+	}
+
+	res := make([]*Repository, 0)
+
+	for _, repo := range gcrResp.Repositories {
+		res = append(res, &Repository{
+			Name: repo,
+		})
+	}
+
+	return res, nil
+}
+
+func (r *Registry) listECRRepositories(repo repository.Repository) ([]*Repository, error) {
+	aws, err := repo.AWSIntegration.ReadAWSIntegration(
 		r.AWSIntegrationID,
 	)
 
@@ -88,3 +179,23 @@ func (r *Registry) listECRRepositories(repo repository.AWSIntegrationRepository)
 
 	return res, nil
 }
+
+func (r *Registry) getTokenCache() (tok *ints.TokenCache, err error) {
+	return &r.TokenCache, nil
+}
+
+func (r *Registry) setTokenCacheFunc(
+	repo repository.Repository,
+) ints.SetTokenCacheFunc {
+	return func(token string, expiry time.Time) error {
+		_, err := repo.Registry.UpdateRegistryTokenCache(
+			&ints.TokenCache{
+				RegistryID: r.ID,
+				Token:      []byte(token),
+				Expiry:     expiry,
+			},
+		)
+
+		return err
+	}
+}

+ 104 - 5
internal/repository/gorm/registry.go

@@ -2,23 +2,31 @@ package gorm
 
 import (
 	"github.com/porter-dev/porter/internal/models"
+	ints "github.com/porter-dev/porter/internal/models/integrations"
 	"github.com/porter-dev/porter/internal/repository"
 	"gorm.io/gorm"
 )
 
 // RegistryRepository uses gorm.DB for querying the database
 type RegistryRepository struct {
-	db *gorm.DB
+	db  *gorm.DB
+	key *[32]byte
 }
 
 // NewRegistryRepository returns a RegistryRepository which uses
 // gorm.DB for querying the database
-func NewRegistryRepository(db *gorm.DB) repository.RegistryRepository {
-	return &RegistryRepository{db}
+func NewRegistryRepository(db *gorm.DB, key *[32]byte) repository.RegistryRepository {
+	return &RegistryRepository{db, key}
 }
 
 // CreateRegistry creates a new registry
 func (repo *RegistryRepository) CreateRegistry(reg *models.Registry) (*models.Registry, error) {
+	err := repo.EncryptRegistryData(reg, repo.key)
+
+	if err != nil {
+		return nil, err
+	}
+
 	project := &models.Project{}
 
 	if err := repo.db.Where("id = ?", reg.ProjectID).First(&project).Error; err != nil {
@@ -35,6 +43,23 @@ func (repo *RegistryRepository) CreateRegistry(reg *models.Registry) (*models.Re
 		return nil, err
 	}
 
+	// create a token cache by default
+	assoc = repo.db.Model(reg).Association("TokenCache")
+
+	if assoc.Error != nil {
+		return nil, assoc.Error
+	}
+
+	if err := assoc.Append(&reg.TokenCache); err != nil {
+		return nil, err
+	}
+
+	err = repo.DecryptRegistryData(reg, repo.key)
+
+	if err != nil {
+		return nil, err
+	}
+
 	return reg, nil
 }
 
@@ -42,10 +67,12 @@ func (repo *RegistryRepository) CreateRegistry(reg *models.Registry) (*models.Re
 func (repo *RegistryRepository) ReadRegistry(id uint) (*models.Registry, error) {
 	reg := &models.Registry{}
 
-	if err := repo.db.Where("id = ?", id).First(&reg).Error; err != nil {
+	if err := repo.db.Preload("TokenCache").Where("id = ?", id).First(&reg).Error; err != nil {
 		return nil, err
 	}
 
+	repo.DecryptRegistryData(reg, repo.key)
+
 	return reg, nil
 }
 
@@ -56,9 +83,81 @@ func (repo *RegistryRepository) ListRegistriesByProjectID(
 ) ([]*models.Registry, error) {
 	regs := []*models.Registry{}
 
-	if err := repo.db.Where("project_id = ?", projectID).Find(&regs).Error; err != nil {
+	if err := repo.db.Preload("TokenCache").Where("project_id = ?", projectID).Find(&regs).Error; err != nil {
 		return nil, err
 	}
 
+	for _, reg := range regs {
+		repo.DecryptRegistryData(reg, repo.key)
+	}
+
 	return regs, nil
 }
+
+// UpdateRegistryTokenCache updates the token cache for a registry
+func (repo *RegistryRepository) UpdateRegistryTokenCache(
+	tokenCache *ints.TokenCache,
+) (*models.Registry, error) {
+	if tok := tokenCache.Token; len(tok) > 0 {
+		cipherData, err := repository.Encrypt(tok, repo.key)
+
+		if err != nil {
+			return nil, err
+		}
+
+		tokenCache.Token = cipherData
+	}
+
+	registry := &models.Registry{}
+
+	if err := repo.db.Where("id = ?", tokenCache.RegistryID).First(&registry).Error; err != nil {
+		return nil, err
+	}
+
+	registry.TokenCache.Token = tokenCache.Token
+	registry.TokenCache.Expiry = tokenCache.Expiry
+
+	if err := repo.db.Save(registry).Error; err != nil {
+		return nil, err
+	}
+
+	return registry, nil
+}
+
+// EncryptRegistryData will encrypt the user's registry data before writing
+// to the DB
+func (repo *RegistryRepository) EncryptRegistryData(
+	registry *models.Registry,
+	key *[32]byte,
+) error {
+	if tok := registry.TokenCache.Token; len(tok) > 0 {
+		cipherData, err := repository.Encrypt(tok, key)
+
+		if err != nil {
+			return err
+		}
+
+		registry.TokenCache.Token = cipherData
+	}
+
+	return nil
+}
+
+// DecryptRegistryData will decrypt the user's registry data before returning it
+// from the DB
+func (repo *RegistryRepository) DecryptRegistryData(
+	registry *models.Registry,
+	key *[32]byte,
+) error {
+	if tok := registry.TokenCache.Token; len(tok) > 0 {
+		plaintext, err := repository.Decrypt(tok, key)
+
+		if err != nil {
+			return err
+		}
+
+		registry.TokenCache.Token = plaintext
+	}
+
+	return nil
+}

+ 77 - 0
internal/repository/gorm/registry_test.go

@@ -2,9 +2,11 @@ package gorm_test
 
 import (
 	"testing"
+	"time"
 
 	"github.com/go-test/deep"
 	"github.com/porter-dev/porter/internal/models"
+	ints "github.com/porter-dev/porter/internal/models/integrations"
 	"gorm.io/gorm"
 )
 
@@ -82,3 +84,78 @@ func TestListRegistriesByProjectID(t *testing.T) {
 		t.Error(diff)
 	}
 }
+
+func TestUpdateRegistryToken(t *testing.T) {
+	tester := &tester{
+		dbFileName: "./porter_test_update_registry_token.db",
+	}
+
+	setupTestEnv(tester, t)
+	initProject(tester, t)
+	defer cleanup(tester, t)
+
+	reg := &models.Registry{
+		Name:      "registry-test",
+		ProjectID: tester.initProjects[0].Model.ID,
+		TokenCache: ints.TokenCache{
+			Token:  []byte("token-1"),
+			Expiry: time.Now().Add(-1 * time.Hour),
+		},
+	}
+
+	reg, err := tester.repo.Registry.CreateRegistry(reg)
+
+	if err != nil {
+		t.Fatalf("%v\n", err)
+	}
+
+	reg, err = tester.repo.Registry.ReadRegistry(reg.Model.ID)
+
+	if err != nil {
+		t.Fatalf("%v\n", err)
+	}
+
+	// make sure registry id of token is 1
+	if reg.TokenCache.RegistryID != 1 {
+		t.Fatalf("incorrect registry id in token cache: expected %d, got %d\n", 1, reg.TokenCache.RegistryID)
+	}
+
+	// make sure old token is expired
+	if isExpired := reg.TokenCache.IsExpired(); !isExpired {
+		t.Fatalf("token was not expired\n")
+	}
+
+	if string(reg.TokenCache.Token) != "token-1" {
+		t.Errorf("incorrect token in cache: expected %s, got %s\n", "token-2", reg.TokenCache.Token)
+	}
+
+	reg.TokenCache.Token = []byte("token-2")
+	reg.TokenCache.Expiry = time.Now().Add(24 * time.Hour)
+
+	reg, err = tester.repo.Registry.UpdateRegistryTokenCache(&reg.TokenCache)
+	if err != nil {
+		t.Fatalf("%v\n", err)
+	}
+	reg, err = tester.repo.Registry.ReadRegistry(reg.Model.ID)
+	if err != nil {
+		t.Fatalf("%v\n", err)
+	}
+
+	// make sure id is 1
+	if reg.Model.ID != 1 {
+		t.Errorf("incorrect registry ID: expected %d, got %d\n", 1, reg.Model.ID)
+	}
+
+	// make sure new token is correct and not expired
+	if reg.TokenCache.RegistryID != 1 {
+		t.Fatalf("incorrect registry ID in token cache: expected %d, got %d\n", 1, reg.TokenCache.RegistryID)
+	}
+
+	if isExpired := reg.TokenCache.IsExpired(); isExpired {
+		t.Fatalf("token was expired\n")
+	}
+
+	if string(reg.TokenCache.Token) != "token-2" {
+		t.Errorf("incorrect token in cache: expected %s, got %s\n", "token-2", reg.TokenCache.Token)
+	}
+}

+ 1 - 1
internal/repository/gorm/repository.go

@@ -14,7 +14,7 @@ func NewRepository(db *gorm.DB, key *[32]byte) *repository.Repository {
 		Project:          NewProjectRepository(db),
 		GitRepo:          NewGitRepoRepository(db, key),
 		Cluster:          NewClusterRepository(db, key),
-		Registry:         NewRegistryRepository(db),
+		Registry:         NewRegistryRepository(db, key),
 		KubeIntegration:  NewKubeIntegrationRepository(db, key),
 		OIDCIntegration:  NewOIDCIntegrationRepository(db, key),
 		OAuthIntegration: NewOAuthIntegrationRepository(db, key),

+ 5 - 1
internal/repository/registry.go

@@ -1,10 +1,14 @@
 package repository
 
-import "github.com/porter-dev/porter/internal/models"
+import (
+	"github.com/porter-dev/porter/internal/models"
+	ints "github.com/porter-dev/porter/internal/models/integrations"
+)
 
 // RegistryRepository represents the set of queries on the Registry model
 type RegistryRepository interface {
 	CreateRegistry(reg *models.Registry) (*models.Registry, error)
 	ReadRegistry(id uint) (*models.Registry, error)
 	ListRegistriesByProjectID(projectID uint) ([]*models.Registry, error)
+	UpdateRegistryTokenCache(tokenCache *ints.TokenCache) (*models.Registry, error)
 }

+ 16 - 0
internal/repository/test/registry.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 
 	"github.com/porter-dev/porter/internal/models"
+	ints "github.com/porter-dev/porter/internal/models/integrations"
 	"github.com/porter-dev/porter/internal/repository"
 	"gorm.io/gorm"
 )
@@ -71,3 +72,18 @@ func (repo *RegistryRepository) ListRegistriesByProjectID(
 
 	return res, nil
 }
+
+// UpdateRegistryTokenCache updates the token cache for a registry
+func (repo *RegistryRepository) UpdateRegistryTokenCache(
+	tokenCache *ints.TokenCache,
+) (*models.Registry, error) {
+	if !repo.canQuery {
+		return nil, errors.New("Cannot write database")
+	}
+
+	index := int(tokenCache.RegistryID - 1)
+	repo.registries[index].TokenCache.Token = tokenCache.Token
+	repo.registries[index].TokenCache.Expiry = tokenCache.Expiry
+
+	return repo.registries[index], nil
+}