Просмотр исходного кода

Web identity generalization + tests

Signed-off-by: Kaelan Patel <32113845+kaelanspatel@users.noreply.github.com>
Kaelan Patel 2 лет назад
Родитель
Сommit
4676fc326a
3 измененных файлов с 240 добавлено и 57 удалено
  1. 74 49
      pkg/cloud/aws/authorizer.go
  2. 96 8
      pkg/cloud/aws/authorizer_test.go
  3. 70 0
      pkg/cloud/aws/webidentity.go

+ 74 - 49
pkg/cloud/aws/authorizer.go

@@ -10,15 +10,12 @@ import (
 	"github.com/aws/aws-sdk-go-v2/service/sts"
 	"github.com/opencost/opencost/core/pkg/util/json"
 	"github.com/opencost/opencost/pkg/cloud"
-	"golang.org/x/oauth2/google"
-	"google.golang.org/api/idtoken"
-	"google.golang.org/api/option"
 )
 
 const AccessKeyAuthorizerType = "AWSAccessKey"
 const ServiceAccountAuthorizerType = "AWSServiceAccount"
 const AssumeRoleAuthorizerType = "AWSAssumeRole"
-const GoogleWebIdentityAuthorizerType = "GoogleWebIdentity"
+const WebIdentityAuthorizerType = "WebIdentity"
 
 // Authorizer implementations provide aws.Config for AWS SDK calls
 type Authorizer interface {
@@ -35,6 +32,8 @@ func SelectAuthorizerByType(typeStr string) (Authorizer, error) {
 		return &ServiceAccount{}, nil
 	case AssumeRoleAuthorizerType:
 		return &AssumeRole{}, nil
+	case WebIdentityAuthorizerType:
+		return &WebIdentity{}, nil
 	default:
 		return nil, fmt.Errorf("AWS: provider authorizer type '%s' is not valid", typeStr)
 	}
@@ -250,87 +249,113 @@ func (ara *AssumeRole) Sanitize() cloud.Config {
 	}
 }
 
-type GoogleWebIdentity struct {
-	RoleARN        string                 `json:"roleARN"`
-	TokenRetriever GoogleIDTokenRetriever `json:"tokenRetriever"`
+type WebIdentity struct {
+	RoleARN          string           `json:"roleARN"`
+	IdentityProvider string           `json:"identityProvider"`
+	TokenRetriever   IDTokenRetriever `json:"tokenRetriever"`
 }
 
-type GoogleIDTokenRetriever struct {
-	Aud string `json:"aud"`
+func (wea *WebIdentity) CreateAWSConfig(region string) (aws.Config, error) {
+	cfg, err := awsconfig.LoadDefaultConfig(context.TODO(), awsconfig.WithRegion(region))
+	if err != nil {
+		return cfg, fmt.Errorf("failed to initialize AWS SDK config for region from annotation %s: %s", region, err)
+	}
+
+	stsSvc := sts.NewFromConfig(cfg)
+	creds := stscreds.NewWebIdentityRoleProvider(stsSvc, wea.RoleARN, wea.TokenRetriever)
+
+	cfg.Credentials = aws.NewCredentialsCache(creds)
+	return cfg, nil
 }
 
-func (gitr *GoogleIDTokenRetriever) GetIdentityToken() ([]byte, error) {
-	ctx := context.Background()
-	res := []byte{}
+func (wea *WebIdentity) MarshalJSON() ([]byte, error) {
+	fmap := make(map[string]any, 1)
+	fmap[cloud.AuthorizerTypeProperty] = WebIdentityAuthorizerType
+	fmap["roleARN"] = wea.RoleARN
+	fmap["identityProvider"] = wea.IdentityProvider
+	fmap["tokenRetriever"] = wea.TokenRetriever
+	return json.Marshal(fmap)
+}
 
-	credentials, err := google.FindDefaultCredentials(ctx)
+func (wea *WebIdentity) UnmarshalJSON(b []byte) error {
+	var f interface{}
+	err := json.Unmarshal(b, &f)
 	if err != nil {
-		return res, fmt.Errorf("failed to find default credentials: %v", err)
+		return err
 	}
 
-	ts, err := idtoken.NewTokenSource(ctx, gitr.Aud, option.WithCredentials(credentials))
+	fmap := f.(map[string]interface{})
+
+	roleARN, err := cloud.GetInterfaceValue[string](fmap, "roleARN")
 	if err != nil {
-		return res, fmt.Errorf("failed to create ID token source: %w", err)
+		return fmt.Errorf("WebIdentity: UnmarshalJSON: %s", err.Error())
 	}
+	wea.RoleARN = roleARN
 
-	t, err := ts.Token()
+	idp, err := cloud.GetInterfaceValue[string](fmap, "identityProvider")
 	if err != nil {
-		return res, fmt.Errorf("failed to receive ID token from metadata server: %w", err)
+		return fmt.Errorf("WebIdentity: UnmarshalJSON: %s", err.Error())
 	}
+	wea.IdentityProvider = idp
 
-	return []byte(t.AccessToken), nil
-}
+	var tr interface{}
 
-func (wea *GoogleWebIdentity) CreateAWSConfig(region string) (aws.Config, error) {
-	cfg, err := awsconfig.LoadDefaultConfig(context.TODO(), awsconfig.WithRegion(region))
+	tr, err = cloud.GetInterfaceValue[interface{}](fmap, "tokenRetriever")
 	if err != nil {
-		return cfg, fmt.Errorf("failed to initialize AWS SDK config for region from annotation %s: %s", region, err)
+		return fmt.Errorf("WebIdentity: UnmarshalJSON: %s", err.Error())
 	}
 
-	stsSvc := sts.NewFromConfig(cfg)
-	creds := stscreds.NewWebIdentityRoleProvider(stsSvc, wea.RoleARN, &wea.TokenRetriever)
+	trb, err := json.Marshal(tr)
+	if err != nil {
+		return fmt.Errorf("WebIdentity: UnmarshalJSON: %s", err.Error())
+	}
 
-	cfg.Credentials = aws.NewCredentialsCache(creds)
-	return cfg, nil
-}
+	var tokenRetriever IDTokenRetriever
+	switch idp {
+	case "Google":
+		tokenRetriever = &GoogleIDTokenRetriever{}
 
-func (wea *GoogleWebIdentity) MarshalJSON() ([]byte, error) {
-	fmap := make(map[string]any, 1)
-	fmap[cloud.AuthorizerTypeProperty] = GoogleWebIdentityAuthorizerType
-	fmap["roleARN"] = wea.RoleARN
-	fmap["tokenRetriever"] = wea.TokenRetriever
-	return json.Marshal(fmap)
-}
+	}
 
-func (wea *GoogleWebIdentity) Validate() error {
-	if wea.TokenRetriever.Aud == "" {
-		return fmt.Errorf("GoogleWebIdenity: missing token retriver audience")
+	err = json.Unmarshal(trb, &tokenRetriever)
+	if err != nil {
+		return fmt.Errorf("WebIdentity: UnmarshalJSON: %s", err.Error())
 	}
 
+	wea.TokenRetriever = tokenRetriever
+
+	return nil
+}
+
+func (wea *WebIdentity) Validate() error {
+
 	if wea.RoleARN == "" {
-		return fmt.Errorf("GoogleWebIdenity: missing RoleARN configuration")
+		return fmt.Errorf("WebIdentity: missing RoleARN configuration")
 	}
 
-	return nil
+	if wea.TokenRetriever == nil {
+		return fmt.Errorf("WebIdentity: missing TokenRetriever configuration")
+	}
+
+	return wea.TokenRetriever.Validate()
 }
 
-func (wea *GoogleWebIdentity) Equals(config cloud.Config) bool {
+func (wea *WebIdentity) Equals(config cloud.Config) bool {
 	if config == nil {
 		return false
 	}
-	thatConfig, ok := config.(*GoogleWebIdentity)
+	thatConfig, ok := config.(*WebIdentity)
 	if !ok {
 		return false
 	}
 
-	return wea.RoleARN == thatConfig.RoleARN && wea.TokenRetriever.Aud == thatConfig.TokenRetriever.Aud
+	return wea.RoleARN == thatConfig.RoleARN && wea.IdentityProvider == thatConfig.IdentityProvider && wea.TokenRetriever.Equals(thatConfig.TokenRetriever)
 }
 
-func (wea *GoogleWebIdentity) Sanitize() cloud.Config {
-	return &GoogleWebIdentity{
-		RoleARN: wea.RoleARN,
-		TokenRetriever: GoogleIDTokenRetriever{
-			Aud: wea.TokenRetriever.Aud,
-		},
+func (wea *WebIdentity) Sanitize() cloud.Config {
+	return &WebIdentity{
+		RoleARN:          wea.RoleARN,
+		IdentityProvider: wea.IdentityProvider,
+		TokenRetriever:   wea.TokenRetriever.Sanitize(),
 	}
 }

+ 96 - 8
pkg/cloud/aws/authorizer_test.go

@@ -3,6 +3,7 @@ package aws
 import (
 	"testing"
 
+	"github.com/opencost/opencost/core/pkg/util/json"
 	"github.com/opencost/opencost/pkg/cloud"
 )
 
@@ -53,15 +54,17 @@ func TestAuthorizerJSON_Sanitize(t *testing.T) {
 			},
 		},
 		"Google Web Identity": {
-			input: &GoogleWebIdentity{
-				RoleARN: "role arn",
-				TokenRetriever: GoogleIDTokenRetriever{
+			input: &WebIdentity{
+				RoleARN:          "role arn",
+				IdentityProvider: "Google",
+				TokenRetriever: &GoogleIDTokenRetriever{
 					Aud: "aud",
 				},
 			},
-			expected: &GoogleWebIdentity{
-				RoleARN: "role arn",
-				TokenRetriever: GoogleIDTokenRetriever{
+			expected: &WebIdentity{
+				RoleARN:          "role arn",
+				IdentityProvider: "Google",
+				TokenRetriever: &GoogleIDTokenRetriever{
 					Aud: "aud",
 				},
 			},
@@ -69,10 +72,95 @@ func TestAuthorizerJSON_Sanitize(t *testing.T) {
 	}
 	for name, tc := range testCases {
 		t.Run(name, func(t *testing.T) {
+
+			b, err := tc.input.MarshalJSON()
+			if err != nil {
+				t.Errorf("Failed to Marshal Authorizer: %s", err)
+			}
+
+			var f interface{}
+			err = json.Unmarshal(b, &f)
+			if err != nil {
+				t.Errorf("Failed to Unmarshal Authorizer: %s", err)
+			}
+
+			authorizer, err := cloud.AuthorizerFromInterface(f, SelectAuthorizerByType)
+			if err != nil {
+				t.Errorf("Failed to Unmarshal Authorizer: %s", err)
+			}
+
 			// Convert to AuthorizerJSON for sanitization
-			sanitizedAuthorizer := tc.input.Sanitize()
+			if authorizer != nil {
+				sanitizedAuthorizer := authorizer.Sanitize()
+
+				if !tc.expected.Equals(sanitizedAuthorizer) {
+					t.Error("Authorizer was not as expected after Sanitization")
+				}
+			}
+
+		})
+	}
+}
+
+func TestAuthorizerJSON_Encode(t *testing.T) {
+
+	testCases := map[string]struct {
+		authorizer Authorizer
+	}{
+		"Access Key": {
+			authorizer: &AccessKey{
+				ID:     "ID",
+				Secret: "Secret",
+			},
+		},
+		"Service Account": {
+			authorizer: &ServiceAccount{},
+		},
+		"Master Payer Access Key": {
+			authorizer: &AssumeRole{
+				Authorizer: &AccessKey{
+					ID:     "ID",
+					Secret: "Secret",
+				},
+				RoleARN: "role arn",
+			},
+		},
+		"Master Payer Service Account": {
+			authorizer: &AssumeRole{
+				Authorizer: &ServiceAccount{},
+				RoleARN:    "role arn",
+			},
+		},
+		"Google Web Identity": {
+			authorizer: &WebIdentity{
+				RoleARN:          "role arn",
+				IdentityProvider: "Google",
+				TokenRetriever: &GoogleIDTokenRetriever{
+					Aud: "aud",
+				},
+			},
+		},
+	}
+	for name, tc := range testCases {
+		t.Run(name, func(t *testing.T) {
+
+			b, err := tc.authorizer.MarshalJSON()
+			if err != nil {
+				t.Errorf("Failed to Marshal Authorizer: %s", err)
+			}
+
+			var f interface{}
+			err = json.Unmarshal(b, &f)
+			if err != nil {
+				t.Errorf("Failed to Unmarshal Authorizer: %s", err)
+			}
+
+			authorizer, err := cloud.AuthorizerFromInterface(f, SelectAuthorizerByType)
+			if err != nil {
+				t.Errorf("Failed to Unmarshal Authorizer: %s", err)
+			}
 
-			if !tc.expected.Equals(sanitizedAuthorizer) {
+			if !tc.authorizer.Equals(authorizer) {
 				t.Error("Authorizer was not as expected after Sanitization")
 			}
 

+ 70 - 0
pkg/cloud/aws/webidentity.go

@@ -0,0 +1,70 @@
+package aws
+
+import (
+	"context"
+	"fmt"
+
+	"golang.org/x/oauth2/google"
+	"google.golang.org/api/idtoken"
+	"google.golang.org/api/option"
+)
+
+type IDTokenRetriever interface {
+	GetIdentityToken() ([]byte, error)
+	Validate() error
+	Sanitize() IDTokenRetriever
+	Equals(IDTokenRetriever) bool
+}
+
+type GoogleIDTokenRetriever struct {
+	Aud string `json:"aud"`
+}
+
+func (gitr GoogleIDTokenRetriever) GetIdentityToken() ([]byte, error) {
+	ctx := context.Background()
+	res := []byte{}
+
+	credentials, err := google.FindDefaultCredentials(ctx)
+	if err != nil {
+		return res, fmt.Errorf("failed to find default credentials: %v", err)
+	}
+
+	ts, err := idtoken.NewTokenSource(ctx, gitr.Aud, option.WithCredentials(credentials))
+	if err != nil {
+		return res, fmt.Errorf("failed to create ID token source: %w", err)
+	}
+
+	t, err := ts.Token()
+	if err != nil {
+		return res, fmt.Errorf("failed to receive ID token from metadata server: %w", err)
+	}
+
+	return []byte(t.AccessToken), nil
+}
+
+func (gitr GoogleIDTokenRetriever) Validate() error {
+	if gitr.Aud == "" {
+		return fmt.Errorf("GoogleIDTokenRetriever: missing audience configuration")
+	}
+
+	return nil
+}
+
+func (gitr GoogleIDTokenRetriever) Equals(other IDTokenRetriever) bool {
+	that, ok := other.(*GoogleIDTokenRetriever)
+	if !ok {
+		return false
+	}
+
+	if gitr.Aud != that.Aud {
+		return false
+	}
+
+	return true
+}
+
+func (gitr GoogleIDTokenRetriever) Sanitize() IDTokenRetriever {
+	return &GoogleIDTokenRetriever{
+		Aud: gitr.Aud,
+	}
+}