Browse Source

Fix AWS config updates when spot bucket configured (#3575)

Signed-off-by: Eric Mrak <eric.mrak@fastly.com>
Eric Mrak 2 tháng trước cách đây
mục cha
commit
cdcff35b8d

+ 30 - 25
pkg/cloud/aws/provider.go

@@ -488,18 +488,18 @@ func (aws *AWS) GetAWSAccessKey() (*AWSAccessKey, error) {
 		return nil, fmt.Errorf("error configuring Cloud Provider %s", err)
 	}
 	// Look for service key values in env if not present in config
-	if config.ServiceKeyName == "" {
-		config.ServiceKeyName = env.GetAWSAccessKeyID()
+	if config.AwsServiceKeyName == "" {
+		config.AwsServiceKeyName = env.GetAWSAccessKeyID()
 	}
-	if config.ServiceKeySecret == "" {
-		config.ServiceKeySecret = env.GetAWSAccessKeySecret()
+	if config.AwsServiceKeySecret == "" {
+		config.AwsServiceKeySecret = env.GetAWSAccessKeySecret()
 	}
 
-	if config.ServiceKeyName == "" && config.ServiceKeySecret == "" {
+	if config.AwsServiceKeyName == "" && config.AwsServiceKeySecret == "" {
 		log.DedupedInfof(1, "missing service key values for AWS cloud integration attempting to use service account integration")
 	}
 
-	return &AWSAccessKey{AccessKeyID: config.ServiceKeyName, SecretAccessKey: config.ServiceKeySecret}, nil
+	return &AWSAccessKey{AccessKeyID: config.AwsServiceKeyName, SecretAccessKey: config.AwsServiceKeySecret}, nil
 }
 
 // GetAWSAthenaInfo generate an AWSAthenaInfo object from the config
@@ -533,27 +533,28 @@ func (aws *AWS) UpdateConfigFromConfigMap(cm map[string]string) (*models.CustomP
 	return aws.Config.UpdateFromMap(cm)
 }
 
-func (aws *AWS) UpdateConfig(r io.Reader, updateType string) (*models.CustomPricing, error) {
-	return aws.Config.Update(func(c *models.CustomPricing) error {
-		if updateType == SpotInfoUpdateType {
+func configUpdaterWithReaderAndType(r io.Reader, updateType string) func(c *models.CustomPricing) error {
+	return func(c *models.CustomPricing) error {
+		switch updateType {
+		case SpotInfoUpdateType:
 			asfi := AwsSpotFeedInfo{}
 			err := json.NewDecoder(r).Decode(&asfi)
 			if err != nil {
 				return err
 			}
 
-			c.ServiceKeyName = asfi.ServiceKeyName
+			c.AwsServiceKeyName = asfi.ServiceKeyName
 			if asfi.ServiceKeySecret != "" {
-				c.ServiceKeySecret = asfi.ServiceKeySecret
+				c.AwsServiceKeySecret = asfi.ServiceKeySecret
 			}
-			c.SpotDataPrefix = asfi.Prefix
-			c.SpotDataBucket = asfi.BucketName
+			c.AwsSpotDataPrefix = asfi.Prefix
+			c.AwsSpotDataBucket = asfi.BucketName
 			c.ProjectID = asfi.AccountID
-			c.SpotDataRegion = asfi.Region
+			c.AwsSpotDataRegion = asfi.Region
 			c.SpotLabel = asfi.SpotLabel
 			c.SpotLabelValue = asfi.SpotLabelValue
 
-		} else if updateType == AthenaInfoUpdateType {
+		case AthenaInfoUpdateType:
 			aai := AwsAthenaInfo{}
 			err := json.NewDecoder(r).Decode(&aai)
 			if err != nil {
@@ -566,9 +567,9 @@ func (aws *AWS) UpdateConfig(r io.Reader, updateType string) (*models.CustomPric
 			c.AthenaCatalog = aai.AthenaCatalog
 			c.AthenaTable = aai.AthenaTable
 			c.AthenaWorkgroup = aai.AthenaWorkgroup
-			c.ServiceKeyName = aai.ServiceKeyName
+			c.AwsServiceKeyName = aai.ServiceKeyName
 			if aai.ServiceKeySecret != "" {
-				c.ServiceKeySecret = aai.ServiceKeySecret
+				c.AwsServiceKeySecret = aai.ServiceKeySecret
 			}
 			if aai.MasterPayerARN != "" {
 				c.MasterPayerARN = aai.MasterPayerARN
@@ -577,8 +578,8 @@ func (aws *AWS) UpdateConfig(r io.Reader, updateType string) (*models.CustomPric
 			if aai.CURVersion != "" {
 				c.AthenaCURVersion = aai.CURVersion
 			}
-		} else {
-			a := make(map[string]interface{})
+		default:
+			a := make(map[string]any)
 			err := json.NewDecoder(r).Decode(&a)
 			if err != nil {
 				return err
@@ -604,7 +605,11 @@ func (aws *AWS) UpdateConfig(r io.Reader, updateType string) (*models.CustomPric
 			}
 		}
 		return nil
-	})
+	}
+}
+
+func (aws *AWS) UpdateConfig(r io.Reader, updateType string) (*models.CustomPricing, error) {
+	return aws.Config.Update(configUpdaterWithReaderAndType(r, updateType))
 }
 
 type awsKey struct {
@@ -870,10 +875,10 @@ func (aws *AWS) DownloadPricingData() error {
 	aws.BaseSpotGPUPrice = c.SpotGPU
 	aws.SpotLabelName = c.SpotLabel
 	aws.SpotLabelValue = c.SpotLabelValue
-	aws.SpotDataBucket = c.SpotDataBucket
-	aws.SpotDataPrefix = c.SpotDataPrefix
+	aws.SpotDataBucket = c.AwsSpotDataBucket
+	aws.SpotDataPrefix = c.AwsSpotDataPrefix
 	aws.ProjectID = c.ProjectID
-	aws.SpotDataRegion = c.SpotDataRegion
+	aws.SpotDataRegion = c.AwsSpotDataRegion
 
 	aws.ConfigureAuthWith(c) // load aws authentication from configuration or secret
 
@@ -1655,12 +1660,12 @@ func (aws *AWS) ConfigureAuthWith(config *models.CustomPricing) error {
 // Gets the aws key id and secret
 func (aws *AWS) getAWSAuth(forceReload bool, cp *models.CustomPricing) (string, string) {
 	// 1. Check config values first (set from frontend UI)
-	if cp.ServiceKeyName != "" && cp.ServiceKeySecret != "" {
+	if cp.AwsServiceKeyName != "" && cp.AwsServiceKeySecret != "" {
 		aws.ServiceAccountChecks.Set("hasKey", &models.ServiceAccountCheck{
 			Message: "AWS ServiceKey exists",
 			Status:  true,
 		})
-		return cp.ServiceKeyName, cp.ServiceKeySecret
+		return cp.AwsServiceKeyName, cp.AwsServiceKeySecret
 	}
 
 	// 2. Check for secret

+ 44 - 0
pkg/cloud/aws/provider_test.go

@@ -649,6 +649,50 @@ func TestGetPricingListURL(t *testing.T) {
 	}
 }
 
+func Test_configUpdaterWithReaderAndType_forSpotValues(t *testing.T) {
+	fixture, err := os.Open("testdata/aws-config.json")
+	if err != nil {
+		t.Fatalf("failed to load aws config fixture: %s", err)
+	}
+	defer fixture.Close()
+	c := &models.CustomPricing{}
+	callback := configUpdaterWithReaderAndType(fixture, "otherupdatetype")
+	err = callback(c)
+	if err != nil {
+		t.Fatalf("failed to load aws config: %s", err)
+	}
+	if c.AwsSpotDataBucket != "mybucket" {
+		t.Fatalf("Expected %s but got %s", "mybucket", c.AwsSpotDataBucket)
+	}
+	if c.AwsSpotDataPrefix != "myprefix" {
+		t.Fatalf("Expected %s but got %s", "myprefix", c.AwsSpotDataPrefix)
+	}
+	if c.AwsSpotDataRegion != "us-east-1" {
+		t.Fatalf("Expected %s but got %s", "us-east-1", c.AwsSpotDataRegion)
+	}
+
+	fixture2, err := os.Open("testdata/aws-config-empty.json")
+	if err != nil {
+		t.Fatalf("failed to load aws config fixture: %s", err)
+	}
+	defer fixture2.Close()
+	c = &models.CustomPricing{}
+	callback = configUpdaterWithReaderAndType(fixture2, "otherupdatetype")
+	err = callback(c)
+	if err != nil {
+		t.Fatalf("failed to load aws config: %s", err)
+	}
+	if c.AwsSpotDataBucket != "" {
+		t.Fatalf("Expected empty string but got %s", c.AwsSpotDataBucket)
+	}
+	if c.AwsSpotDataPrefix != "" {
+		t.Fatalf("Expected empty string but got %s", c.AwsSpotDataPrefix)
+	}
+	if c.AwsSpotDataRegion != "" {
+		t.Fatalf("Expected empty string but got %s", c.AwsSpotDataRegion)
+	}
+}
+
 // Mock cluster cache for testing
 type mockClusterCache struct {
 	pods []*clustercache.Pod

+ 2 - 0
pkg/cloud/aws/testdata/aws-config-empty.json

@@ -0,0 +1,2 @@
+{
+}

+ 5 - 0
pkg/cloud/aws/testdata/aws-config.json

@@ -0,0 +1,5 @@
+{
+  "awsSpotDataBucket": "mybucket",
+  "awsSpotDataPrefix": "myprefix",
+  "awsSpotDataRegion": "us-east-1"
+}

+ 2 - 2
pkg/cloud/config/watcher.go

@@ -194,8 +194,8 @@ func (cfw *ConfigFileWatcher) GetConfigs() []cloud.KeyedConfig {
 			AthenaDatabase:   customPricing.AthenaDatabase,
 			AthenaTable:      customPricing.AthenaTable,
 			AthenaWorkgroup:  customPricing.AthenaWorkgroup,
-			ServiceKeyName:   customPricing.ServiceKeyName,
-			ServiceKeySecret: customPricing.ServiceKeySecret,
+			ServiceKeyName:   customPricing.AwsServiceKeyName,
+			ServiceKeySecret: customPricing.AwsServiceKeySecret,
 			AccountID:        customPricing.AthenaProjectID,
 			MasterPayerARN:   customPricing.MasterPayerARN,
 		}

+ 2 - 2
pkg/cloud/gcp/provider.go

@@ -252,8 +252,8 @@ func (gcp *GCP) UpdateConfig(r io.Reader, updateType string) (*models.CustomPric
 			c.AthenaCatalog = a.AthenaCatalog
 			c.AthenaTable = a.AthenaTable
 			c.AthenaWorkgroup = a.AthenaWorkgroup
-			c.ServiceKeyName = a.ServiceKeyName
-			c.ServiceKeySecret = a.ServiceKeySecret
+			c.AwsServiceKeyName = a.ServiceKeyName
+			c.AwsServiceKeySecret = a.ServiceKeySecret
 			c.AthenaProjectID = a.AccountID
 		} else {
 			a := make(map[string]interface{})

+ 5 - 5
pkg/cloud/models/models.go

@@ -146,14 +146,14 @@ type CustomPricing struct {
 	SpotLabelValue               string `json:"spotLabelValue,omitempty"`
 	GpuLabel                     string `json:"gpuLabel,omitempty"`
 	GpuLabelValue                string `json:"gpuLabelValue,omitempty"`
-	ServiceKeyName               string `json:"awsServiceKeyName,omitempty"`
-	ServiceKeySecret             string `json:"awsServiceKeySecret,omitempty"`
+	AwsServiceKeyName            string `json:"awsServiceKeyName,omitempty"`
+	AwsServiceKeySecret          string `json:"awsServiceKeySecret,omitempty"`
 	AlibabaServiceKeyName        string `json:"alibabaServiceKeyName,omitempty"`
 	AlibabaServiceKeySecret      string `json:"alibabaServiceKeySecret,omitempty"`
 	AlibabaClusterRegion         string `json:"alibabaClusterRegion,omitempty"`
-	SpotDataRegion               string `json:"awsSpotDataRegion,omitempty"`
-	SpotDataBucket               string `json:"awsSpotDataBucket,omitempty"`
-	SpotDataPrefix               string `json:"awsSpotDataPrefix,omitempty"`
+	AwsSpotDataRegion            string `json:"awsSpotDataRegion,omitempty"`
+	AwsSpotDataBucket            string `json:"awsSpotDataBucket,omitempty"`
+	AwsSpotDataPrefix            string `json:"awsSpotDataPrefix,omitempty"`
 	ProjectID                    string `json:"projectID,omitempty"`
 	AthenaProjectID              string `json:"athenaProjectID,omitempty"`
 	AthenaBucketName             string `json:"athenaBucketName"`