Przeglądaj źródła

Merge pull request #2710 from kaelanspatel/kaelan-aws-web-identity

AWS IRSA authorizer for cloud integrations
Niko Kovacevic 2 lat temu
rodzic
commit
9f6f0cf1ae

+ 117 - 7
pkg/cloud/aws/authorizer.go

@@ -15,6 +15,7 @@ import (
 const AccessKeyAuthorizerType = "AWSAccessKey"
 const ServiceAccountAuthorizerType = "AWSServiceAccount"
 const AssumeRoleAuthorizerType = "AWSAssumeRole"
+const WebIdentityAuthorizerType = "WebIdentity"
 
 // Authorizer implementations provide aws.Config for AWS SDK calls
 type Authorizer interface {
@@ -31,6 +32,8 @@ func SelectAuthorizerByType(typeStr string) (Authorizer, error) {
 		return &ServiceAccount{}, nil
 	case AssumeRoleAuthorizerType:
 		return &AssumeRole{}, nil
+	case WebIdentityAuthorizerType:
+		return &WebIdentity{}, nil
 	default:
 		return nil, fmt.Errorf("AWS: provider authorizer type '%s' is not valid", typeStr)
 	}
@@ -129,11 +132,7 @@ func (sa *ServiceAccount) Equals(config cloud.Config) bool {
 		return false
 	}
 	_, ok := config.(*ServiceAccount)
-	if !ok {
-		return false
-	}
-
-	return true
+	return ok
 }
 
 func (sa *ServiceAccount) Sanitize() cloud.Config {
@@ -204,7 +203,7 @@ func (ara *AssumeRole) CreateAWSConfig(region string) (aws.Config, error) {
 
 func (ara *AssumeRole) Validate() error {
 	if ara.Authorizer == nil {
-		return fmt.Errorf("AssumeRole: misisng base Authorizer")
+		return fmt.Errorf("AssumeRole: missing base Authorizer")
 	}
 	err := ara.Authorizer.Validate()
 	if err != nil {
@@ -212,7 +211,7 @@ func (ara *AssumeRole) Validate() error {
 	}
 
 	if ara.RoleARN == "" {
-		return fmt.Errorf("AssumeRole: misisng RoleARN configuration")
+		return fmt.Errorf("AssumeRole: missing RoleARN configuration")
 	}
 
 	return nil
@@ -249,3 +248,114 @@ func (ara *AssumeRole) Sanitize() cloud.Config {
 		RoleARN:    ara.RoleARN,
 	}
 }
+
+type WebIdentity struct {
+	RoleARN          string           `json:"roleARN"`
+	IdentityProvider string           `json:"identityProvider"`
+	TokenRetriever   IDTokenRetriever `json:"tokenRetriever"`
+}
+
+func (wea *WebIdentity) CreateAWSConfig(region string) (aws.Config, error) {
+	cfg, err := awsconfig.LoadDefaultConfig(context.TODO(), awsconfig.WithRegion(region))
+	if err != nil {
+		return cfg, fmt.Errorf("failed to initialize AWS SDK config for region from annotation %s: %s", region, err)
+	}
+
+	stsSvc := sts.NewFromConfig(cfg)
+	creds := stscreds.NewWebIdentityRoleProvider(stsSvc, wea.RoleARN, wea.TokenRetriever)
+
+	cfg.Credentials = aws.NewCredentialsCache(creds)
+	return cfg, nil
+}
+
+func (wea *WebIdentity) MarshalJSON() ([]byte, error) {
+	fmap := make(map[string]any, 1)
+	fmap[cloud.AuthorizerTypeProperty] = WebIdentityAuthorizerType
+	fmap["roleARN"] = wea.RoleARN
+	fmap["identityProvider"] = wea.IdentityProvider
+	fmap["tokenRetriever"] = wea.TokenRetriever
+	return json.Marshal(fmap)
+}
+
+func (wea *WebIdentity) UnmarshalJSON(b []byte) error {
+	var f interface{}
+	err := json.Unmarshal(b, &f)
+	if err != nil {
+		return err
+	}
+
+	fmap := f.(map[string]interface{})
+
+	roleARN, err := cloud.GetInterfaceValue[string](fmap, "roleARN")
+	if err != nil {
+		return fmt.Errorf("WebIdentity: UnmarshalJSON: %s", err.Error())
+	}
+	wea.RoleARN = roleARN
+
+	idp, err := cloud.GetInterfaceValue[string](fmap, "identityProvider")
+	if err != nil {
+		return fmt.Errorf("WebIdentity: UnmarshalJSON: %s", err.Error())
+	}
+	wea.IdentityProvider = idp
+
+	var tr interface{}
+
+	tr, err = cloud.GetInterfaceValue[interface{}](fmap, "tokenRetriever")
+	if err != nil {
+		return fmt.Errorf("WebIdentity: UnmarshalJSON: %s", err.Error())
+	}
+
+	trb, err := json.Marshal(tr)
+	if err != nil {
+		return fmt.Errorf("WebIdentity: UnmarshalJSON: %s", err.Error())
+	}
+
+	var tokenRetriever IDTokenRetriever
+	switch idp {
+	case "Google":
+		tokenRetriever = &GoogleIDTokenRetriever{}
+
+	}
+
+	err = json.Unmarshal(trb, &tokenRetriever)
+	if err != nil {
+		return fmt.Errorf("WebIdentity: UnmarshalJSON: %s", err.Error())
+	}
+
+	wea.TokenRetriever = tokenRetriever
+
+	return nil
+}
+
+func (wea *WebIdentity) Validate() error {
+
+	if wea.RoleARN == "" {
+		return fmt.Errorf("WebIdentity: missing RoleARN configuration")
+	}
+
+	if wea.TokenRetriever == nil {
+		return fmt.Errorf("WebIdentity: missing TokenRetriever configuration")
+	}
+
+	return wea.TokenRetriever.Validate()
+}
+
+func (wea *WebIdentity) Equals(config cloud.Config) bool {
+	if config == nil {
+		return false
+	}
+	thatConfig, ok := config.(*WebIdentity)
+	if !ok {
+		return false
+	}
+
+	return wea.RoleARN == thatConfig.RoleARN && wea.IdentityProvider == thatConfig.IdentityProvider && wea.TokenRetriever.Equals(thatConfig.TokenRetriever)
+}
+
+func (wea *WebIdentity) Sanitize() cloud.Config {
+	return &WebIdentity{
+		RoleARN:          wea.RoleARN,
+		IdentityProvider: wea.IdentityProvider,
+		TokenRetriever:   wea.TokenRetriever.Sanitize(),
+	}
+}

+ 83 - 0
pkg/cloud/aws/authorizer_test.go

@@ -3,6 +3,7 @@ package aws
 import (
 	"testing"
 
+	"github.com/opencost/opencost/core/pkg/util/json"
 	"github.com/opencost/opencost/pkg/cloud"
 )
 
@@ -52,6 +53,22 @@ func TestAuthorizerJSON_Sanitize(t *testing.T) {
 				RoleARN:    "role arn",
 			},
 		},
+		"Google Web Identity": {
+			input: &WebIdentity{
+				RoleARN:          "role arn",
+				IdentityProvider: "Google",
+				TokenRetriever: &GoogleIDTokenRetriever{
+					Aud: "aud",
+				},
+			},
+			expected: &WebIdentity{
+				RoleARN:          "role arn",
+				IdentityProvider: "Google",
+				TokenRetriever: &GoogleIDTokenRetriever{
+					Aud: "aud",
+				},
+			},
+		},
 	}
 	for name, tc := range testCases {
 		t.Run(name, func(t *testing.T) {
@@ -65,3 +82,69 @@ func TestAuthorizerJSON_Sanitize(t *testing.T) {
 		})
 	}
 }
+
+func TestAuthorizerJSON_Encode(t *testing.T) {
+
+	testCases := map[string]struct {
+		authorizer Authorizer
+	}{
+		"Access Key": {
+			authorizer: &AccessKey{
+				ID:     "ID",
+				Secret: "Secret",
+			},
+		},
+		"Service Account": {
+			authorizer: &ServiceAccount{},
+		},
+		"Master Payer Access Key": {
+			authorizer: &AssumeRole{
+				Authorizer: &AccessKey{
+					ID:     "ID",
+					Secret: "Secret",
+				},
+				RoleARN: "role arn",
+			},
+		},
+		"Master Payer Service Account": {
+			authorizer: &AssumeRole{
+				Authorizer: &ServiceAccount{},
+				RoleARN:    "role arn",
+			},
+		},
+		"Google Web Identity": {
+			authorizer: &WebIdentity{
+				RoleARN:          "role arn",
+				IdentityProvider: "Google",
+				TokenRetriever: &GoogleIDTokenRetriever{
+					Aud: "aud",
+				},
+			},
+		},
+	}
+	for name, tc := range testCases {
+		t.Run(name, func(t *testing.T) {
+
+			b, err := tc.authorizer.MarshalJSON()
+			if err != nil {
+				t.Errorf("Failed to Marshal Authorizer: %s", err)
+			}
+
+			var f interface{}
+			err = json.Unmarshal(b, &f)
+			if err != nil {
+				t.Errorf("Failed to Unmarshal Authorizer: %s", err)
+			}
+
+			authorizer, err := cloud.AuthorizerFromInterface(f, SelectAuthorizerByType)
+			if err != nil {
+				t.Errorf("Failed to Unmarshal Authorizer: %s", err)
+			}
+
+			if !tc.authorizer.Equals(authorizer) {
+				t.Error("Authorizer was not as expected after Sanitization")
+			}
+
+		})
+	}
+}

+ 70 - 0
pkg/cloud/aws/webidentity.go

@@ -0,0 +1,70 @@
+package aws
+
+import (
+	"context"
+	"fmt"
+
+	"golang.org/x/oauth2/google"
+	"google.golang.org/api/idtoken"
+	"google.golang.org/api/option"
+)
+
+type IDTokenRetriever interface {
+	GetIdentityToken() ([]byte, error)
+	Validate() error
+	Sanitize() IDTokenRetriever
+	Equals(IDTokenRetriever) bool
+}
+
+type GoogleIDTokenRetriever struct {
+	Aud string `json:"aud"`
+}
+
+func (gitr GoogleIDTokenRetriever) GetIdentityToken() ([]byte, error) {
+	ctx := context.Background()
+	res := []byte{}
+
+	credentials, err := google.FindDefaultCredentials(ctx)
+	if err != nil {
+		return res, fmt.Errorf("failed to find default credentials: %v", err)
+	}
+
+	ts, err := idtoken.NewTokenSource(ctx, gitr.Aud, option.WithCredentials(credentials))
+	if err != nil {
+		return res, fmt.Errorf("failed to create ID token source: %w", err)
+	}
+
+	t, err := ts.Token()
+	if err != nil {
+		return res, fmt.Errorf("failed to receive ID token from metadata server: %w", err)
+	}
+
+	return []byte(t.AccessToken), nil
+}
+
+func (gitr GoogleIDTokenRetriever) Validate() error {
+	if gitr.Aud == "" {
+		return fmt.Errorf("GoogleIDTokenRetriever: missing audience configuration")
+	}
+
+	return nil
+}
+
+func (gitr GoogleIDTokenRetriever) Equals(other IDTokenRetriever) bool {
+	that, ok := other.(*GoogleIDTokenRetriever)
+	if !ok {
+		return false
+	}
+
+	if gitr.Aud != that.Aud {
+		return false
+	}
+
+	return true
+}
+
+func (gitr GoogleIDTokenRetriever) Sanitize() IDTokenRetriever {
+	return &GoogleIDTokenRetriever{
+		Aud: gitr.Aud,
+	}
+}