Ver código fonte

simplified fixes for azure (#3176)

* simplified fixes for azure

* fixes

* uncomment feature flag

* refactor method left

* small edits

---------

Co-authored-by: David Townley <davidtownley@Davids-MacBook-Air.local>
d-g-town 2 anos atrás
pai
commit
be83d7603b

+ 2 - 1
api/client/registry.go

@@ -201,6 +201,7 @@ func (c *Client) GetGARAuthorizationToken(
 func (c *Client) GetACRAuthorizationToken(
 	ctx context.Context,
 	projectID uint,
+	req *types.GetRegistryACRTokenRequest,
 ) (*types.GetRegistryTokenResponse, error) {
 	resp := &types.GetRegistryTokenResponse{}
 
@@ -209,7 +210,7 @@ func (c *Client) GetACRAuthorizationToken(
 			"/projects/%d/registries/acr/token",
 			projectID,
 		),
-		nil,
+		req,
 		resp,
 	)
 

+ 35 - 6
api/server/handlers/project_integration/create_azure.go

@@ -3,6 +3,8 @@ package project_integration
 import (
 	"net/http"
 
+	"github.com/porter-dev/porter/internal/telemetry"
+
 	"github.com/bufbuild/connect-go"
 
 	porterv1 "github.com/porter-dev/api-contracts/generated/go/porter/v1"
@@ -31,12 +33,19 @@ func NewCreateAzureHandler(
 }
 
 func (p *CreateAzureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	user, _ := r.Context().Value(types.UserScope).(*models.User)
-	project, _ := r.Context().Value(types.ProjectScope).(*models.Project)
+	ctx, span := telemetry.NewSpan(r.Context(), "serve-create-azure-connection")
+	defer span.End()
+
+	user, _ := ctx.Value(types.UserScope).(*models.User)
+	project, _ := ctx.Value(types.ProjectScope).(*models.Project)
+
+	telemetry.WithAttributes(span, telemetry.AttributeKV{Key: "project-id", Value: project.ID})
 
 	request := &types.CreateAzureRequest{}
 
 	if ok := p.DecodeAndValidate(w, r, request); !ok {
+		err := telemetry.Error(ctx, span, nil, "error decoding and validating request")
+		p.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusBadRequest))
 		return
 	}
 
@@ -44,7 +53,8 @@ func (p *CreateAzureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 
 	az, err := p.Repo().AzureIntegration().CreateAzureIntegration(az)
 	if err != nil {
-		p.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+		err = telemetry.Error(ctx, span, err, "error creating azure integration")
+		p.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
 		return
 	}
 
@@ -52,6 +62,12 @@ func (p *CreateAzureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		AzureIntegration: az.ToAzureIntegrationType(),
 	}
 
+	if p.Config().ClusterControlPlaneClient == nil {
+		err := telemetry.Error(ctx, span, nil, "cluster control plane client cannot be nil")
+		p.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
+		return
+	}
+
 	req := connect.NewRequest(&porterv1.SaveAzureCredentialsRequest{
 		ProjectId:              int64(project.ID),
 		ClientId:               request.AzureClientID,
@@ -59,13 +75,26 @@ func (p *CreateAzureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		TenantId:               request.AzureTenantID,
 		ServicePrincipalSecret: []byte(request.ServicePrincipalKey),
 	})
-	_, err = p.Config().ClusterControlPlaneClient.SaveAzureCredentials(r.Context(), req)
-
+	resp, err := p.Config().ClusterControlPlaneClient.SaveAzureCredentials(ctx, req)
 	if err != nil {
-		p.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+		err = telemetry.Error(ctx, span, err, "error saving azure credentials")
+		p.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
 		return
 	}
 
+	if resp.Msg == nil {
+		err = telemetry.Error(ctx, span, nil, "SaveAzureCredentials response message is nil")
+		p.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
+		return
+	}
+	if resp.Msg.CredentialsIdentifier == "" {
+		err = telemetry.Error(ctx, span, nil, "SaveAzureCredentials response credentials identifier is empty")
+		p.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
+		return
+	}
+
+	res.CloudProviderCredentialIdentifier = resp.Msg.CredentialsIdentifier
+
 	p.WriteResult(w, r, res)
 }
 

+ 93 - 16
api/server/handlers/registry/get_token.go

@@ -7,6 +7,9 @@ import (
 	"strings"
 	"time"
 
+	"github.com/porter-dev/porter/internal/telemetry"
+
+	"github.com/aws/aws-sdk-go/aws/arn"
 	"github.com/aws/aws-sdk-go/service/ecr"
 	"github.com/bufbuild/connect-go"
 	porterv1 "github.com/porter-dev/api-contracts/generated/go/porter/v1"
@@ -18,9 +21,6 @@ import (
 	"github.com/porter-dev/porter/internal/models"
 	"github.com/porter-dev/porter/internal/oauth"
 	"github.com/porter-dev/porter/internal/registry"
-	"github.com/porter-dev/porter/internal/telemetry"
-
-	"github.com/aws/aws-sdk-go/aws/arn"
 )
 
 type RegistryGetECRTokenHandler struct {
@@ -402,33 +402,110 @@ func NewRegistryGetACRTokenHandler(
 }
 
 func (c *RegistryGetACRTokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	proj, _ := r.Context().Value(types.ProjectScope).(*models.Project)
+	ctx, span := telemetry.NewSpan(r.Context(), "serve-acr-token")
+	defer span.End()
+
+	proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
+
+	telemetry.WithAttributes(span, telemetry.AttributeKV{Key: "project-id", Value: proj.ID})
+
+	request := &types.GetRegistryACRTokenRequest{}
+
+	if ok := c.DecodeAndValidate(w, r, request); !ok {
+		err := telemetry.Error(ctx, span, nil, "error decoding request")
+		c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusBadRequest))
+		return
+	}
+
+	if request.ServerURL == "" {
+		err := telemetry.Error(ctx, span, nil, "missing server url")
+		c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusBadRequest))
+		return
+	}
+
+	serverUrl := strings.TrimSuffix(request.ServerURL, "/")
+	telemetry.WithAttributes(span, telemetry.AttributeKV{Key: "server-url", Value: serverUrl})
 
 	// list registries and find one that matches the region
 	regs, err := c.Repo().Registry().ListRegistriesByProjectID(proj.ID)
 	if err != nil {
-		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
+		err = telemetry.Error(ctx, span, err, "error getting registries by project id")
+		c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
 		return
 	}
 
 	var token string
 	var expiresAt *time.Time
 
+	var matchingReg *models.Registry
 	for _, reg := range regs {
-		if reg.AzureIntegrationID != 0 && strings.Contains(reg.URL, "azurecr.io") {
-			_reg := registry.Registry(*reg)
-			username, pw, err := _reg.GetACRCredentials(c.Repo())
-			if err != nil {
-				c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-				continue
-			}
+		if strings.Contains(reg.URL, serverUrl) {
+			matchingReg = reg
+		}
+	}
 
-			token = base64.StdEncoding.EncodeToString([]byte(string(username) + ":" + string(pw)))
+	if matchingReg == nil {
+		err := telemetry.Error(ctx, span, err, "no matching registry")
+		c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
+		return
+	}
 
-			// we'll just set an arbitrary 30-day expiry time (this is not enforced)
-			timeExpires := time.Now().Add(30 * 24 * 3600 * time.Second)
-			expiresAt = &timeExpires
+	telemetry.WithAttributes(span, telemetry.AttributeKV{Key: "registry-name", Value: matchingReg.Name})
+
+	if proj.CapiProvisionerEnabled {
+		telemetry.WithAttributes(span, telemetry.AttributeKV{Key: "capi-provisioned", Value: true})
+
+		if c.Config().ClusterControlPlaneClient == nil {
+			err := telemetry.Error(ctx, span, nil, "cluster control plane client cannot be nil")
+			c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
+			return
+		}
+
+		tokenReq := connect.NewRequest(&porterv1.TokenForRegistryRequest{
+			ProjectId:   int64(proj.ID),
+			RegistryUri: matchingReg.URL,
+		})
+		tokenResp, err := c.Config().ClusterControlPlaneClient.TokenForRegistry(ctx, tokenReq)
+		if err != nil {
+			err = telemetry.Error(ctx, span, err, "error getting token response from ccp")
+			c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
+			return
+		}
+
+		if tokenResp.Msg == nil || tokenResp.Msg.Token == "" {
+			err := telemetry.Error(ctx, span, nil, "no token found in response")
+			c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
+			return
+		}
+
+		token = tokenResp.Msg.Token
+
+		// we'll just set an arbitrary 30-day expiry time (this is not enforced)
+		timeExpires := time.Now().UTC().Add(30 * 24 * time.Hour)
+		expiresAt = &timeExpires
+	}
+
+	if matchingReg.AzureIntegrationID != 0 {
+		telemetry.WithAttributes(span, telemetry.AttributeKV{Key: "capi-provisioned", Value: false})
+
+		_reg := registry.Registry(*matchingReg)
+		username, pw, err := _reg.GetACRCredentials(c.Repo())
+		if err != nil {
+			err = telemetry.Error(ctx, span, err, "error getting token response from ccp")
+			c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
+			return
 		}
+
+		token = base64.StdEncoding.EncodeToString([]byte(string(username) + ":" + string(pw)))
+		// we'll just set an arbitrary 30-day expiry time (this is not enforced)
+		timeExpires := time.Now().UTC().Add(30 * 24 * time.Hour)
+		expiresAt = &timeExpires
+	}
+
+	if token == "" {
+		err := telemetry.Error(ctx, span, nil, "missing token")
+		c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
+		return
 	}
 
 	resp := &types.GetRegistryTokenResponse{

+ 1 - 0
api/types/project_integration.go

@@ -170,6 +170,7 @@ type CreateAzureRequest struct {
 
 type CreateAzureResponse struct {
 	*AzureIntegration
+	CloudProviderCredentialIdentifier string `json:"cloud_provider_credentials_id"`
 }
 
 type ListAzureResponse []*AzureIntegration

+ 4 - 0
api/types/registry.go

@@ -176,6 +176,10 @@ type GetRegistryTokenResponse struct {
 	ExpiresAt *time.Time `json:"expires_at"`
 }
 
+type GetRegistryACRTokenRequest struct {
+	ServerURL string `schema:"server_url"`
+}
+
 type GetRegistryGCRTokenRequest struct {
 	ServerURL string `schema:"server_url"`
 }

+ 2 - 2
cli/cmd/docker/auth.go

@@ -257,8 +257,8 @@ func (a *AuthGetter) GetACRCredentials(serverURL string, projID uint) (user stri
 	if cachedEntry != nil && cachedEntry.IsValid(time.Now()) {
 		token = cachedEntry.AuthorizationToken
 	} else {
-		// get a token from the server
-		tokenResp, err := a.Client.GetACRAuthorizationToken(context.Background(), projID)
+		req := &types.GetRegistryACRTokenRequest{ServerURL: serverURL}
+		tokenResp, err := a.Client.GetACRAuthorizationToken(context.Background(), projID, req)
 		if err != nil {
 			return "", "", err
 		}

+ 1 - 1
dashboard/src/components/AzureCredentialForm.tsx

@@ -52,7 +52,7 @@ const AzureCredentialForm: React.FC<Props> = ({ goBack, proceed }) => {
       )
       .then(({ data }) => {
         setIsLoading(false);
-        proceed(data.id);
+        proceed(data.cloud_provider_credentials_id);
       })
       .catch((err) => {
         console.error(err);

+ 2 - 1
dashboard/src/components/AzureProvisionerSettings.tsx

@@ -32,6 +32,7 @@ const locationOptions = [
   { value: "eastus", label: "East US" },
   { value: "westus2", label: "West US 2" },
   { value: "westus3", label: "West US 3" },
+  { value: "canadacentral", label: "Central Canada" },
 ];
 
 const machineTypeOptions = [
@@ -113,7 +114,7 @@ const AzureProvisionerSettings: React.FC<Props> = (props) => {
         projectId: currentProject.id,
         kind: EnumKubernetesKind.AKS,
         cloudProvider: EnumCloudProvider.AZURE,
-        cloudProviderCredentialsId: "",
+        cloudProviderCredentialsId: props.credentialId,
         kindValues: {
           case: "aksKind",
           value: new AKS({

+ 9 - 11
dashboard/src/components/ProvisionerFlow.tsx

@@ -17,7 +17,7 @@ const providers = ["aws", "gcp", "azure"];
 
 type Props = {};
 
-const ProvisionerFlow: React.FC<Props> = ({ }) => {
+const ProvisionerFlow: React.FC<Props> = ({}) => {
   const {
     usage,
     hasBillingEnabled,
@@ -40,11 +40,7 @@ const ProvisionerFlow: React.FC<Props> = ({ }) => {
 
   const markStepCostConsent = async (step: string, provider: string) => {
     try {
-      await api.updateOnboardingStep(
-        "<token>",
-        { step, provider },
-        {}
-      );
+      await api.updateOnboardingStep("<token>", { step, provider }, {});
     } catch (err) {
       console.log(err);
     }
@@ -80,7 +76,7 @@ const ProvisionerFlow: React.FC<Props> = ({ }) => {
                         provider === "gcp"
                       )
                     ) {
-                      openCostConsentModal(provider)
+                      openCostConsentModal(provider);
                     }
                   }}
                 >
@@ -101,7 +97,7 @@ const ProvisionerFlow: React.FC<Props> = ({ }) => {
               setShowCostConfirmModal={setShowCostConfirmModal}
               markCostConsentComplete={() => {
                 try {
-                  markStepCostConsent("cost-consent-complete", "aws")
+                  markStepCostConsent("cost-consent-complete", "aws");
                 } catch (err) {
                   console.log(err);
                 }
@@ -126,7 +122,7 @@ const ProvisionerFlow: React.FC<Props> = ({ }) => {
                 setShowCostConfirmModal={setShowCostConfirmModal}
                 markCostConsentComplete={() => {
                   try {
-                    markStepCostConsent("cost-consent-complete", "azure")
+                    markStepCostConsent("cost-consent-complete", "azure");
                   } catch (err) {
                     console.log(err);
                   }
@@ -141,7 +137,8 @@ const ProvisionerFlow: React.FC<Props> = ({ }) => {
                       console.log(err);
                     }
                   }
-                }} />
+                }}
+              />
             )))}
       </>
     );
@@ -169,7 +166,8 @@ const ProvisionerFlow: React.FC<Props> = ({ }) => {
       (selectedProvider === "azure" && (
         <AzureCredentialForm
           goBack={() => setCurrentStep("cloud")}
-          proceed={() => {
+          proceed={(id) => {
+            setCredentialId(id);
             setCurrentStep("cluster");
           }}
         />

+ 21 - 10
internal/registry/registry.go

@@ -155,24 +155,35 @@ func (r *Registry) ListRepositories(
 	}
 
 	if project.CapiProvisionerEnabled {
+		// TODO: Remove this conditional when AWS list repos is supported in CCP
 		if strings.Contains(r.URL, ".azurecr.") {
 			telemetry.WithAttributes(span, telemetry.AttributeKV{Key: "auth-mechanism", Value: "capi-azure"})
-			creds, err := conf.Repo.AzureIntegration().ListAzureIntegrationsByProjectID(r.ProjectID)
+
+			req := connect.NewRequest(&porterv1.ListRepositoriesForRegistryRequest{
+				ProjectId:   int64(r.ProjectID),
+				RegistryUri: r.URL,
+			})
+
+			resp, err := conf.ClusterControlPlaneClient.ListRepositoriesForRegistry(ctx, req)
 			if err != nil {
-				return nil, telemetry.Error(ctx, span, err, "error getting azure credentials for capi cluster")
-			}
-			if len(creds) == 0 {
-				return nil, telemetry.Error(ctx, span, err, "no azure credentials for capi cluster")
+				return nil, telemetry.Error(ctx, span, err, "error listing ecr repositories")
 			}
-			r.AzureIntegrationID = creds[0].ID
-			telemetry.WithAttributes(span, telemetry.AttributeKV{Key: "azure-integration-id", Value: r.AzureIntegrationID})
 
-			repos, err := r.listACRRepositories(ctx, repo)
+			res := make([]*ptypes.RegistryRepository, 0)
+
+			parsedURL, err := url.Parse("https://" + r.URL)
 			if err != nil {
-				return nil, telemetry.Error(ctx, span, err, "error listing acr repositories")
+				return nil, telemetry.Error(ctx, span, err, "error parsing url")
 			}
 
-			return repos, nil
+			for _, repo := range resp.Msg.Repositories {
+				res = append(res, &ptypes.RegistryRepository{
+					Name: repo.Name,
+					URI:  parsedURL.Host + "/" + repo.Name,
+				})
+			}
+
+			return res, nil
 		} else {
 			telemetry.WithAttributes(span, telemetry.AttributeKV{Key: "auth-mechanism", Value: "capi-aws"})
 			uri := strings.TrimPrefix(r.URL, "https://")