Sfoglia il codice sorgente

Make service check thread safe

Sean Holcomb 4 anni fa
parent
commit
90250be1af
3 ha cambiato i file con 69 aggiunte e 56 eliminazioni
  1. 18 33
      pkg/cloud/awsprovider.go
  2. 7 15
      pkg/cloud/azureprovider.go
  3. 44 8
      pkg/cloud/provider.go

+ 18 - 33
pkg/cloud/awsprovider.go

@@ -153,7 +153,7 @@ type AWS struct {
 	ProjectID                   string
 	DownloadPricingDataLock     sync.RWMutex
 	Config                      *ProviderConfig
-	ServiceAccountChecks        map[string]*ServiceAccountCheck
+	serviceAccountChecks        *ServiceAccountChecks
 	clusterManagementPrice      float64
 	clusterAccountId            string
 	clusterRegion               string
@@ -743,9 +743,6 @@ func (aws *AWS) getRegionPricing(nodeList []*v1.Node) (*http.Response, string, e
 func (aws *AWS) DownloadPricingData() error {
 	aws.DownloadPricingDataLock.Lock()
 	defer aws.DownloadPricingDataLock.Unlock()
-	if aws.ServiceAccountChecks == nil {
-		aws.ServiceAccountChecks = make(map[string]*ServiceAccountCheck)
-	}
 	c, err := aws.Config.GetCustomPricingData()
 	if err != nil {
 		klog.V(1).Infof("Error downloading default pricing data: %s", err.Error())
@@ -1327,40 +1324,37 @@ func (aws *AWS) ConfigureAuthWith(config *CustomPricing) error {
 
 // Gets the aws key id and secret
 func (aws *AWS) getAWSAuth(forceReload bool, cp *CustomPricing) (string, string) {
-	if aws.ServiceAccountChecks == nil { // safety in case checks don't exist
-		aws.ServiceAccountChecks = make(map[string]*ServiceAccountCheck)
-	}
 
 	// 1. Check config values first (set from frontend UI)
 	if cp.ServiceKeyName != "" && cp.ServiceKeySecret != "" {
-		aws.ServiceAccountChecks["hasKey"] = &ServiceAccountCheck{
+		aws.serviceAccountChecks.set("hasKey", &ServiceAccountCheck{
 			Message: "AWS ServiceKey exists",
 			Status:  true,
-		}
+		})
 		return cp.ServiceKeyName, cp.ServiceKeySecret
 	}
 
 	// 2. Check for secret
 	s, _ := aws.loadAWSAuthSecret(forceReload)
 	if s != nil && s.AccessKeyID != "" && s.SecretAccessKey != "" {
-		aws.ServiceAccountChecks["hasKey"] = &ServiceAccountCheck{
+		aws.serviceAccountChecks.set("hasKey", &ServiceAccountCheck{
 			Message: "AWS ServiceKey exists",
 			Status:  true,
-		}
+		})
 		return s.AccessKeyID, s.SecretAccessKey
 	}
 
 	// 3. Fall back to env vars
 	if env.GetAWSAccessKeyID() == "" || env.GetAWSAccessKeyID() == "" {
-		aws.ServiceAccountChecks["hasKey"] = &ServiceAccountCheck{
+		aws.serviceAccountChecks.set("hasKey", &ServiceAccountCheck{
 			Message: "AWS ServiceKey exists",
 			Status:  false,
-		}
+		})
 	} else {
-		aws.ServiceAccountChecks["hasKey"] = &ServiceAccountCheck{
+		aws.serviceAccountChecks.set("hasKey", &ServiceAccountCheck{
 			Message: "AWS ServiceKey exists",
 			Status:  true,
-		}
+		})
 	}
 	return env.GetAWSAccessKeyID(), env.GetAWSAccessKeySecret()
 }
@@ -1873,9 +1867,6 @@ type spotInfo struct {
 }
 
 func (aws *AWS) parseSpotData(bucket string, prefix string, projectID string, region string) (map[string]*spotInfo, error) {
-	if aws.ServiceAccountChecks == nil { // Set up checks to store error/success states
-		aws.ServiceAccountChecks = make(map[string]*ServiceAccountCheck)
-	}
 
 	aws.ConfigureAuth() // configure aws api authentication by setting env vars
 
@@ -1908,17 +1899,17 @@ func (aws *AWS) parseSpotData(bucket string, prefix string, projectID string, re
 	}
 	lso, err := cli.ListObjects(context.TODO(), ls)
 	if err != nil {
-		aws.ServiceAccountChecks["bucketList"] = &ServiceAccountCheck{
+		aws.serviceAccountChecks.set("bucketList", &ServiceAccountCheck{
 			Message:        "Bucket List Permissions Available",
 			Status:         false,
 			AdditionalInfo: err.Error(),
-		}
+		})
 		return nil, err
 	} else {
-		aws.ServiceAccountChecks["bucketList"] = &ServiceAccountCheck{
+		aws.serviceAccountChecks.set("bucketList", &ServiceAccountCheck{
 			Message: "Bucket List Permissions Available",
 			Status:  true,
-		}
+		})
 	}
 	lsoLen := len(lso.Contents)
 	klog.V(2).Infof("Found %d spot data files from yesterday", lsoLen)
@@ -1961,17 +1952,17 @@ func (aws *AWS) parseSpotData(bucket string, prefix string, projectID string, re
 		buf := manager.NewWriteAtBuffer([]byte{})
 		_, err := downloader.Download(context.TODO(), buf, getObj)
 		if err != nil {
-			aws.ServiceAccountChecks["objectList"] = &ServiceAccountCheck{
+			aws.serviceAccountChecks.set("objectList", &ServiceAccountCheck{
 				Message:        "Object Get Permissions Available",
 				Status:         false,
 				AdditionalInfo: err.Error(),
-			}
+			})
 			return nil, err
 		} else {
-			aws.ServiceAccountChecks["objectList"] = &ServiceAccountCheck{
+			aws.serviceAccountChecks.set("objectList", &ServiceAccountCheck{
 				Message: "Object Get Permissions Available",
 				Status:  true,
-			}
+			})
 		}
 
 		r := bytes.NewReader(buf.Bytes())
@@ -2044,13 +2035,7 @@ func (aws *AWS) ApplyReservedInstancePricing(nodes map[string]*Node) {
 }
 
 func (aws *AWS) ServiceAccountStatus() *ServiceAccountStatus {
-	checks := []*ServiceAccountCheck{}
-	for _, v := range aws.ServiceAccountChecks {
-		checks = append(checks, v)
-	}
-	return &ServiceAccountStatus{
-		Checks: checks,
-	}
+	return aws.serviceAccountChecks.getStatus()
 }
 
 func (aws *AWS) CombinedDiscountForNode(instanceType string, isPreemptible bool, defaultDiscount, negotiatedDiscount float64) float64 {

+ 7 - 15
pkg/cloud/azureprovider.go

@@ -385,7 +385,7 @@ type Azure struct {
 	DownloadPricingDataLock        sync.RWMutex
 	Clientset                      clustercache.ClusterCache
 	Config                         *ProviderConfig
-	ServiceAccountChecks           map[string]*ServiceAccountCheck
+	serviceAccountChecks           *ServiceAccountChecks
 	RateCardPricingError           error
 	clusterAccountId               string
 	clusterRegion                  string
@@ -530,19 +530,17 @@ func (az *Azure) GetAzureStorageConfig(forceReload bool) (*AzureStorageConfig, e
 		defaultSubscriptionID = config.AzureSubscriptionID
 	}
 
-	if az.ServiceAccountChecks == nil {
-		az.ServiceAccountChecks = make(map[string]*ServiceAccountCheck)
-	}
+
 	// 1. Check for secret
 	s, err := az.loadAzureStorageConfig(forceReload)
 	if err != nil {
 		log.Errorf("Error, %s", err.Error())
 	}
 	if s != nil && s.AccessKey != "" && s.AccountName != "" && s.ContainerName != "" {
-		az.ServiceAccountChecks["hasStorage"] = &ServiceAccountCheck{
+		az.serviceAccountChecks.set("hasStorage", &ServiceAccountCheck{
 			Message: "Azure Storage Config exists",
 			Status:  true,
-		}
+		})
 
 		// To support already configured users, subscriptionID may not be set in secret in which case, the subscriptionID
 		// for the rate card API is used
@@ -551,10 +549,10 @@ func (az *Azure) GetAzureStorageConfig(forceReload bool) (*AzureStorageConfig, e
 		}
 		return s, nil
 	}
-	az.ServiceAccountChecks["hasStorage"] = &ServiceAccountCheck{
+	az.serviceAccountChecks.set("hasStorage", &ServiceAccountCheck{
 		Message: "Azure Storage Config exists",
 		Status:  false,
-	}
+	})
 	return nil, fmt.Errorf("azure storage config not found")
 
 }
@@ -1249,13 +1247,7 @@ func (az *Azure) GetLocalStorageQuery(window, offset time.Duration, rate bool, u
 }
 
 func (az *Azure) ServiceAccountStatus() *ServiceAccountStatus {
-	checks := []*ServiceAccountCheck{}
-	for _, v := range az.ServiceAccountChecks {
-		checks = append(checks, v)
-	}
-	return &ServiceAccountStatus{
-		Checks: checks,
-	}
+	return az.serviceAccountChecks.getStatus()
 }
 
 const rateCardPricingSource = "Rate Card API"

+ 44 - 8
pkg/cloud/provider.go

@@ -8,6 +8,7 @@ import (
 	"regexp"
 	"strconv"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/kubecost/cost-model/pkg/util"
@@ -212,6 +213,39 @@ type ServiceAccountStatus struct {
 	Checks []*ServiceAccountCheck `json:"checks"`
 }
 
+// ServiceAccountChecks is a thread safe map for holding ServiceAccountCheck objects
+type ServiceAccountChecks struct {
+	serviceAccountChecks map[string]*ServiceAccountCheck
+	lock                 *sync.RWMutex
+}
+
+// NewServiceAccountChecks initialize ServiceAccountChecks
+func NewServiceAccountChecks() *ServiceAccountChecks {
+	return &ServiceAccountChecks{
+		serviceAccountChecks: make(map[string]*ServiceAccountCheck),
+		lock: new(sync.RWMutex),
+	}
+}
+
+func (sac *ServiceAccountChecks) set(key string, check *ServiceAccountCheck) {
+	sac.lock.Lock()
+	defer sac.lock.Unlock()
+	sac.serviceAccountChecks[key] = check
+}
+
+// getStatus extracts ServiceAccountCheck objects into a slice and returns them in a ServiceAccountStatus
+func (sac *ServiceAccountChecks) getStatus() *ServiceAccountStatus {
+	sac.lock.Lock()
+	defer sac.lock.Unlock()
+	checks := []*ServiceAccountCheck{}
+	for _, v := range sac.serviceAccountChecks {
+		checks = append(checks, v)
+	}
+	return &ServiceAccountStatus{
+		Checks: checks,
+	}
+}
+
 type ServiceAccountCheck struct {
 	Message        string `json:"message"`
 	Status         bool   `json:"status"`
@@ -419,18 +453,20 @@ func NewProvider(cache clustercache.ClusterCache, apiKey string, config *config.
 	case "AWS":
 		klog.V(2).Info("Found ProviderID starting with \"aws\", using AWS Provider")
 		return &AWS{
-			Clientset:        cache,
-			Config:           NewProviderConfig(config, cp.configFileName),
-			clusterRegion:    cp.region,
-			clusterAccountId: cp.accountID,
+			Clientset:            cache,
+			Config:               NewProviderConfig(config, cp.configFileName),
+			clusterRegion:        cp.region,
+			clusterAccountId:     cp.accountID,
+			serviceAccountChecks: NewServiceAccountChecks(),
 		}, nil
 	case "AZURE":
 		klog.V(2).Info("Found ProviderID starting with \"azure\", using Azure Provider")
 		return &Azure{
-			Clientset:        cache,
-			Config:           NewProviderConfig(config, cp.configFileName),
-			clusterRegion:    cp.region,
-			clusterAccountId: cp.accountID,
+			Clientset:            cache,
+			Config:               NewProviderConfig(config, cp.configFileName),
+			clusterRegion:        cp.region,
+			clusterAccountId:     cp.accountID,
+			serviceAccountChecks: NewServiceAccountChecks(),
 		}, nil
 	default:
 		klog.V(2).Info("Unsupported provider, falling back to default")